From 859d7aee5e59fc75304c7c8f492821653bebc0ed Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 24 Sep 2025 21:17:08 -0700 Subject: [PATCH] refactor string formatting to use single quotes in node operations --- .../utils/maintenance/node_operations.py | 160 ++++++++---------- 1 file changed, 70 insertions(+), 90 deletions(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 94773557..693609d8 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -65,16 +65,16 @@ async def extract_nodes_reflexion( ) -> list[str]: # Prepare context for LLM context = { - "episode_content": episode.content, - "previous_episodes": [ep.content for ep in previous_episodes], - "extracted_entities": node_names, - "ensure_ascii": ensure_ascii, + 'episode_content': episode.content, + 'previous_episodes': [ep.content for ep in previous_episodes], + 'extracted_entities': node_names, + 'ensure_ascii': ensure_ascii, } llm_response = await llm_client.generate_response( prompt_library.extract_nodes.reflexion(context), MissedEntities ) - missed_entities = llm_response.get("missed_entities", []) + missed_entities = llm_response.get('missed_entities', []) return missed_entities @@ -89,24 +89,24 @@ async def extract_nodes( start = time() llm_client = clients.llm_client llm_response = {} - custom_prompt = "" + custom_prompt = '' entities_missed = True reflexion_iterations = 0 entity_types_context = [ { - "entity_type_id": 0, - "entity_type_name": "Entity", - "entity_type_description": "Default entity classification. Use this entity type if the entity is not one of the other listed types.", + 'entity_type_id': 0, + 'entity_type_name': 'Entity', + 'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.', } ] entity_types_context += ( [ { - "entity_type_id": i + 1, - "entity_type_name": type_name, - "entity_type_description": type_model.__doc__, + 'entity_type_id': i + 1, + 'entity_type_name': type_name, + 'entity_type_description': type_model.__doc__, } for i, (type_name, type_model) in enumerate(entity_types.items()) ] @@ -115,13 +115,13 @@ async def extract_nodes( ) context = { - "episode_content": episode.content, - "episode_timestamp": episode.valid_at.isoformat(), - "previous_episodes": [ep.content for ep in previous_episodes], - "custom_prompt": custom_prompt, - "entity_types": entity_types_context, - "source_description": episode.source_description, - "ensure_ascii": clients.ensure_ascii, + 'episode_content': episode.content, + 'episode_timestamp': episode.valid_at.isoformat(), + 'previous_episodes': [ep.content for ep in previous_episodes], + 'custom_prompt': custom_prompt, + 'entity_types': entity_types_context, + 'source_description': episode.source_description, + 'ensure_ascii': clients.ensure_ascii, } while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS: @@ -157,48 +157,42 @@ async def extract_nodes( entities_missed = len(missing_entities) != 0 - custom_prompt = "Make sure that the following entities are extracted: " + custom_prompt = 'Make sure that the following entities are extracted: ' for entity in missing_entities: - custom_prompt += f"\n{entity}," + custom_prompt += f'\n{entity},' - filtered_extracted_entities = [ - entity for entity in extracted_entities if entity.name.strip() - ] + filtered_extracted_entities = [entity for entity in extracted_entities if entity.name.strip()] end = time() - logger.debug( - f"Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms" - ) + logger.debug(f'Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms') # Convert the extracted data into EntityNode objects extracted_nodes = [] for extracted_entity in filtered_extracted_entities: type_id = extracted_entity.entity_type_id if 0 <= type_id < len(entity_types_context): - entity_type_name = entity_types_context[ - extracted_entity.entity_type_id - ].get("entity_type_name") + entity_type_name = entity_types_context[extracted_entity.entity_type_id].get( + 'entity_type_name' + ) else: - entity_type_name = "Entity" + entity_type_name = 'Entity' # Check if this entity type should be excluded if excluded_entity_types and entity_type_name in excluded_entity_types: - logger.debug( - f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"' - ) + logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"') continue - labels: list[str] = list({"Entity", str(entity_type_name)}) + labels: list[str] = list({'Entity', str(entity_type_name)}) new_node = EntityNode( name=extracted_entity.name, group_id=episode.group_id, labels=labels, - summary="", + summary='', created_at=utc_now(), ) extracted_nodes.append(new_node) - logger.debug(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") + logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - logger.debug(f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}") + logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') return extracted_nodes @@ -221,9 +215,7 @@ async def _collect_candidate_nodes( ] ) - candidate_nodes: list[EntityNode] = [ - node for result in search_results for node in result.nodes - ] + candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes] if existing_nodes_override is not None: candidate_nodes.extend(existing_nodes_override) @@ -253,21 +245,19 @@ async def _resolve_with_llm( if not state.unresolved_indices: return - entity_types_dict: dict[str, type[BaseModel]] = ( - entity_types if entity_types is not None else {} - ) + entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {} llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices] extracted_nodes_context = [ { - "id": i, - "name": node.name, - "entity_type": node.labels, - "entity_type_description": entity_types_dict.get( - next((item for item in node.labels if item != "Entity"), "") + 'id': i, + 'name': node.name, + 'entity_type': node.labels, + 'entity_type_description': entity_types_dict.get( + next((item for item in node.labels if item != 'Entity'), '') ).__doc__ - or "Default Entity Type", + or 'Default Entity Type', } for i, node in enumerate(llm_extracted_nodes) ] @@ -275,9 +265,9 @@ async def _resolve_with_llm( existing_nodes_context = [ { **{ - "idx": i, - "name": candidate.name, - "entity_types": candidate.labels, + 'idx': i, + 'name': candidate.name, + 'entity_types': candidate.labels, }, **candidate.attributes, } @@ -285,15 +275,13 @@ async def _resolve_with_llm( ] context = { - "extracted_nodes": extracted_nodes_context, - "existing_nodes": existing_nodes_context, - "episode_content": episode.content if episode is not None else "", - "previous_episodes": ( - [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [] + 'extracted_nodes': extracted_nodes_context, + 'existing_nodes': existing_nodes_context, + 'episode_content': episode.content if episode is not None else '', + 'previous_episodes': ( + [ep.content for ep in previous_episodes] if previous_episodes is not None else [] ), - "ensure_ascii": ensure_ascii, + 'ensure_ascii': ensure_ascii, } llm_response = await llm_client.generate_response( @@ -301,9 +289,7 @@ async def _resolve_with_llm( response_model=NodeResolutions, ) - node_resolutions: list[NodeDuplicate] = NodeResolutions( - **llm_response - ).entity_resolutions + node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions for resolution in node_resolutions: relative_id: int = resolution.id @@ -367,13 +353,13 @@ async def resolve_extracted_nodes( state.uuid_map[node.uuid] = node.uuid logger.debug( - "Resolved nodes: %s", + 'Resolved nodes: %s', [(node.name, node.uuid) for node in state.resolved_nodes if node is not None], ) - new_node_duplicates: list[tuple[EntityNode, EntityNode]] = ( - await filter_existing_duplicate_of_edges(driver, node_duplicates) - ) + new_node_duplicates: list[ + tuple[EntityNode, EntityNode] + ] = await filter_existing_duplicate_of_edges(driver, node_duplicates) return ( [node for node in state.resolved_nodes if node is not None], @@ -399,9 +385,7 @@ async def extract_attributes_from_nodes( episode, previous_episodes, ( - entity_types.get( - next((item for item in node.labels if item != "Entity"), "") - ) + entity_types.get(next((item for item in node.labels if item != 'Entity'), '')) if entity_types is not None else None ), @@ -425,32 +409,28 @@ async def extract_attributes_from_node( ensure_ascii: bool = False, ) -> EntityNode: node_context: dict[str, Any] = { - "name": node.name, - "summary": node.summary, - "entity_types": node.labels, - "attributes": node.attributes, + 'name': node.name, + 'summary': node.summary, + 'entity_types': node.labels, + 'attributes': node.attributes, } attributes_context: dict[str, Any] = { - "node": node_context, - "episode_content": episode.content if episode is not None else "", - "previous_episodes": ( - [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [] + 'node': node_context, + 'episode_content': episode.content if episode is not None else '', + 'previous_episodes': ( + [ep.content for ep in previous_episodes] if previous_episodes is not None else [] ), - "ensure_ascii": ensure_ascii, + 'ensure_ascii': ensure_ascii, } summary_context: dict[str, Any] = { - "node": node_context, - "episode_content": episode.content if episode is not None else "", - "previous_episodes": ( - [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [] + 'node': node_context, + 'episode_content': episode.content if episode is not None else '', + 'previous_episodes': ( + [ep.content for ep in previous_episodes] if previous_episodes is not None else [] ), - "ensure_ascii": ensure_ascii, + 'ensure_ascii': ensure_ascii, } has_entity_attributes: bool = bool( @@ -478,7 +458,7 @@ async def extract_attributes_from_node( if has_entity_attributes and entity_type is not None: entity_type(**llm_response) - node.summary = summary_response.get("summary", "") + node.summary = summary_response.get('summary', '') node_attributes = {key: value for key, value in llm_response.items()} node.attributes.update(node_attributes)