Pass group_id to generate_response in extraction operations

Thread group_id parameter through all extraction-related generate_response()
calls where it's naturally available (via episode.group_id or node.group_id).
This enables consumers to override get_extraction_language_instruction() with
group-specific language preferences.

Changes:
- edge_operations.py: Pass group_id in extract_edges()
- node_operations.py: Pass episode.group_id in extract_nodes() and
  node.group_id in extract_attributes_from_node()
- node_operations.py: Add group_id parameter to extract_nodes_reflexion()

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Chalef 2025-09-30 11:26:30 -07:00
parent d430ab82b4
commit a20643470f
2 changed files with 12 additions and 1 deletions

View file

@ -139,6 +139,7 @@ async def extract_edges(
prompt_library.extract_edges.edge(context),
response_model=ExtractedEdges,
max_tokens=extract_edges_max_tokens,
group_id=group_id,
)
edges_data = ExtractedEdges(**llm_response).edges
@ -150,6 +151,7 @@ async def extract_edges(
prompt_library.extract_edges.reflexion(context),
response_model=MissingFacts,
max_tokens=extract_edges_max_tokens,
group_id=group_id,
)
missing_facts = reflexion_response.get('missing_facts', [])

View file

@ -64,6 +64,8 @@ async def extract_nodes_reflexion(
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
node_names: list[str],
ensure_ascii: bool = False,
group_id: str = '',
) -> list[str]:
# Prepare context for LLM
context = {
@ -73,7 +75,7 @@ async def extract_nodes_reflexion(
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.reflexion(context), MissedEntities
prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id
)
missed_entities = llm_response.get('missed_entities', [])
@ -129,16 +131,19 @@ async def extract_nodes(
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_message(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
)
elif episode.source == EpisodeType.text:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_text(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
)
elif episode.source == EpisodeType.json:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_json(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
)
response_object = ExtractedEntities(**llm_response)
@ -152,6 +157,8 @@ async def extract_nodes(
episode,
previous_episodes,
[entity.name for entity in extracted_entities],
clients.ensure_ascii,
episode.group_id,
)
entities_missed = len(missing_entities) != 0
@ -510,6 +517,7 @@ async def extract_attributes_from_node(
prompt_library.extract_nodes.extract_attributes(attributes_context),
response_model=entity_type,
model_size=ModelSize.small,
group_id=node.group_id,
)
)
if has_entity_attributes
@ -527,6 +535,7 @@ async def extract_attributes_from_node(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
group_id=node.group_id,
)
node.summary = summary_response.get('summary', '')