diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index a6760a40..83ae8bfe 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -139,6 +139,7 @@ async def extract_edges( prompt_library.extract_edges.edge(context), response_model=ExtractedEdges, max_tokens=extract_edges_max_tokens, + group_id=group_id, ) edges_data = ExtractedEdges(**llm_response).edges @@ -150,6 +151,7 @@ async def extract_edges( prompt_library.extract_edges.reflexion(context), response_model=MissingFacts, max_tokens=extract_edges_max_tokens, + group_id=group_id, ) missing_facts = reflexion_response.get('missing_facts', []) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 7f85b52c..509ed012 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -64,6 +64,8 @@ async def extract_nodes_reflexion( episode: EpisodicNode, previous_episodes: list[EpisodicNode], node_names: list[str], + ensure_ascii: bool = False, + group_id: str = '', ) -> list[str]: # Prepare context for LLM context = { @@ -73,7 +75,7 @@ async def extract_nodes_reflexion( } llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.reflexion(context), MissedEntities + prompt_library.extract_nodes.reflexion(context), MissedEntities, group_id=group_id ) missed_entities = llm_response.get('missed_entities', []) @@ -129,16 +131,19 @@ async def extract_nodes( llm_response = await llm_client.generate_response( prompt_library.extract_nodes.extract_message(context), response_model=ExtractedEntities, + group_id=episode.group_id, ) elif episode.source == EpisodeType.text: llm_response = await llm_client.generate_response( prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities, + group_id=episode.group_id, ) elif episode.source == EpisodeType.json: llm_response = await llm_client.generate_response( prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities, + group_id=episode.group_id, ) response_object = ExtractedEntities(**llm_response) @@ -152,6 +157,8 @@ async def extract_nodes( episode, previous_episodes, [entity.name for entity in extracted_entities], + clients.ensure_ascii, + episode.group_id, ) entities_missed = len(missing_entities) != 0 @@ -510,6 +517,7 @@ async def extract_attributes_from_node( prompt_library.extract_nodes.extract_attributes(attributes_context), response_model=entity_type, model_size=ModelSize.small, + group_id=node.group_id, ) ) if has_entity_attributes @@ -527,6 +535,7 @@ async def extract_attributes_from_node( prompt_library.extract_nodes.extract_summary(summary_context), response_model=EntitySummary, model_size=ModelSize.small, + group_id=node.group_id, ) node.summary = summary_response.get('summary', '')