Pass group_id to generate_response in extraction operations
Thread group_id parameter through all extraction-related generate_response() calls where it's naturally available (via episode.group_id or node.group_id). This enables consumers to override get_extraction_language_instruction() with group-specific language preferences. Changes: - edge_operations.py: Pass group_id in extract_edges() - node_operations.py: Pass episode.group_id in extract_nodes() and node.group_id in extract_attributes_from_node() - node_operations.py: Add group_id parameter to extract_nodes_reflexion() 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d430ab82b4
commit
a20643470f
2 changed files with 12 additions and 1 deletions
|
|
@ -139,6 +139,7 @@ async def extract_edges(
|
|||
prompt_library.extract_edges.edge(context),
|
||||
response_model=ExtractedEdges,
|
||||
max_tokens=extract_edges_max_tokens,
|
||||
group_id=group_id,
|
||||
)
|
||||
edges_data = ExtractedEdges(**llm_response).edges
|
||||
|
||||
|
|
@ -150,6 +151,7 @@ async def extract_edges(
|
|||
prompt_library.extract_edges.reflexion(context),
|
||||
response_model=MissingFacts,
|
||||
max_tokens=extract_edges_max_tokens,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
missing_facts = reflexion_response.get('missing_facts', [])
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ async def extract_nodes_reflexion(
|
|||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
node_names: list[str],
|
||||
ensure_ascii: bool = False,
|
||||
group_id: str = '',
|
||||
) -> list[str]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
|
|
@ -73,7 +75,7 @@ async def extract_nodes_reflexion(
|
|||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.reflexion(context), MissedEntities
|
||||
prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id
|
||||
)
|
||||
missed_entities = llm_response.get('missed_entities', [])
|
||||
|
||||
|
|
@ -129,16 +131,19 @@ async def extract_nodes(
|
|||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_message(context),
|
||||
response_model=ExtractedEntities,
|
||||
group_id=episode.group_id,
|
||||
)
|
||||
elif episode.source == EpisodeType.text:
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_text(context),
|
||||
response_model=ExtractedEntities,
|
||||
group_id=episode.group_id,
|
||||
)
|
||||
elif episode.source == EpisodeType.json:
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_json(context),
|
||||
response_model=ExtractedEntities,
|
||||
group_id=episode.group_id,
|
||||
)
|
||||
|
||||
response_object = ExtractedEntities(**llm_response)
|
||||
|
|
@ -152,6 +157,8 @@ async def extract_nodes(
|
|||
episode,
|
||||
previous_episodes,
|
||||
[entity.name for entity in extracted_entities],
|
||||
clients.ensure_ascii,
|
||||
episode.group_id,
|
||||
)
|
||||
|
||||
entities_missed = len(missing_entities) != 0
|
||||
|
|
@ -510,6 +517,7 @@ async def extract_attributes_from_node(
|
|||
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
||||
response_model=entity_type,
|
||||
model_size=ModelSize.small,
|
||||
group_id=node.group_id,
|
||||
)
|
||||
)
|
||||
if has_entity_attributes
|
||||
|
|
@ -527,6 +535,7 @@ async def extract_attributes_from_node(
|
|||
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||
response_model=EntitySummary,
|
||||
model_size=ModelSize.small,
|
||||
group_id=node.group_id,
|
||||
)
|
||||
node.summary = summary_response.get('summary', '')
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue