Pass group_id to generate_response in extraction operations
Thread group_id parameter through all extraction-related generate_response() calls where it's naturally available (via episode.group_id or node.group_id). This enables consumers to override get_extraction_language_instruction() with group-specific language preferences. Changes: - edge_operations.py: Pass group_id in extract_edges() - node_operations.py: Pass episode.group_id in extract_nodes() and node.group_id in extract_attributes_from_node() - node_operations.py: Add group_id parameter to extract_nodes_reflexion() 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d430ab82b4
commit
a20643470f
2 changed files with 12 additions and 1 deletions
|
|
@ -139,6 +139,7 @@ async def extract_edges(
|
||||||
prompt_library.extract_edges.edge(context),
|
prompt_library.extract_edges.edge(context),
|
||||||
response_model=ExtractedEdges,
|
response_model=ExtractedEdges,
|
||||||
max_tokens=extract_edges_max_tokens,
|
max_tokens=extract_edges_max_tokens,
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
edges_data = ExtractedEdges(**llm_response).edges
|
edges_data = ExtractedEdges(**llm_response).edges
|
||||||
|
|
||||||
|
|
@ -150,6 +151,7 @@ async def extract_edges(
|
||||||
prompt_library.extract_edges.reflexion(context),
|
prompt_library.extract_edges.reflexion(context),
|
||||||
response_model=MissingFacts,
|
response_model=MissingFacts,
|
||||||
max_tokens=extract_edges_max_tokens,
|
max_tokens=extract_edges_max_tokens,
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
missing_facts = reflexion_response.get('missing_facts', [])
|
missing_facts = reflexion_response.get('missing_facts', [])
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,8 @@ async def extract_nodes_reflexion(
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
node_names: list[str],
|
node_names: list[str],
|
||||||
|
ensure_ascii: bool = False,
|
||||||
|
group_id: str = '',
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
|
|
@ -73,7 +75,7 @@ async def extract_nodes_reflexion(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.reflexion(context), MissedEntities
|
prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id
|
||||||
)
|
)
|
||||||
missed_entities = llm_response.get('missed_entities', [])
|
missed_entities = llm_response.get('missed_entities', [])
|
||||||
|
|
||||||
|
|
@ -129,16 +131,19 @@ async def extract_nodes(
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.extract_message(context),
|
prompt_library.extract_nodes.extract_message(context),
|
||||||
response_model=ExtractedEntities,
|
response_model=ExtractedEntities,
|
||||||
|
group_id=episode.group_id,
|
||||||
)
|
)
|
||||||
elif episode.source == EpisodeType.text:
|
elif episode.source == EpisodeType.text:
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.extract_text(context),
|
prompt_library.extract_nodes.extract_text(context),
|
||||||
response_model=ExtractedEntities,
|
response_model=ExtractedEntities,
|
||||||
|
group_id=episode.group_id,
|
||||||
)
|
)
|
||||||
elif episode.source == EpisodeType.json:
|
elif episode.source == EpisodeType.json:
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_nodes.extract_json(context),
|
prompt_library.extract_nodes.extract_json(context),
|
||||||
response_model=ExtractedEntities,
|
response_model=ExtractedEntities,
|
||||||
|
group_id=episode.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_object = ExtractedEntities(**llm_response)
|
response_object = ExtractedEntities(**llm_response)
|
||||||
|
|
@ -152,6 +157,8 @@ async def extract_nodes(
|
||||||
episode,
|
episode,
|
||||||
previous_episodes,
|
previous_episodes,
|
||||||
[entity.name for entity in extracted_entities],
|
[entity.name for entity in extracted_entities],
|
||||||
|
clients.ensure_ascii,
|
||||||
|
episode.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
entities_missed = len(missing_entities) != 0
|
entities_missed = len(missing_entities) != 0
|
||||||
|
|
@ -510,6 +517,7 @@ async def extract_attributes_from_node(
|
||||||
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
||||||
response_model=entity_type,
|
response_model=entity_type,
|
||||||
model_size=ModelSize.small,
|
model_size=ModelSize.small,
|
||||||
|
group_id=node.group_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if has_entity_attributes
|
if has_entity_attributes
|
||||||
|
|
@ -527,6 +535,7 @@ async def extract_attributes_from_node(
|
||||||
prompt_library.extract_nodes.extract_summary(summary_context),
|
prompt_library.extract_nodes.extract_summary(summary_context),
|
||||||
response_model=EntitySummary,
|
response_model=EntitySummary,
|
||||||
model_size=ModelSize.small,
|
model_size=ModelSize.small,
|
||||||
|
group_id=node.group_id,
|
||||||
)
|
)
|
||||||
node.summary = summary_response.get('summary', '')
|
node.summary = summary_response.get('summary', '')
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue