diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 371c520f..a5a12c09 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -63,6 +63,7 @@ from graphiti_core.utils.maintenance.community_operations import ( update_community, ) from graphiti_core.utils.maintenance.edge_operations import ( + build_duplicate_of_edges, build_episodic_edges, extract_edges, resolve_extracted_edge, @@ -375,7 +376,7 @@ class Graphiti: ) # Extract edges and resolve nodes - (nodes, uuid_map), extracted_edges = await semaphore_gather( + (nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather( resolve_extracted_nodes( self.clients, extracted_nodes, @@ -404,7 +405,9 @@ class Graphiti: ), ) - entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates) + + entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges episodic_edges = build_episodic_edges(nodes, episode, now) @@ -691,7 +694,7 @@ class Graphiti: if edge.fact_embedding is None: await edge.generate_embedding(self.embedder) - resolved_nodes, uuid_map = await resolve_extracted_nodes( + resolved_nodes, uuid_map, _ = await resolve_extracted_nodes( self.clients, [source_node, target_node], ) diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py index 16fee8d9..ed932355 100644 --- a/graphiti_core/prompts/dedupe_nodes.py +++ b/graphiti_core/prompts/dedupe_nodes.py @@ -26,12 +26,16 @@ class NodeDuplicate(BaseModel): id: int = Field(..., description='integer id of the entity') duplicate_idx: int = Field( ..., - description='idx of the duplicate node. If no duplicate nodes are found, default to -1.', + description='idx of the duplicate entity. If no duplicate entities are found, default to -1.', ) name: str = Field( ..., description='Name of the entity. Should be the most complete and descriptive name possible.', ) + additional_duplicates: list[int] = Field( + ..., + description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.', + ) class NodeResolutions(BaseModel): diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 171c757b..729ae240 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -19,7 +19,9 @@ from datetime import datetime from time import time from pydantic import BaseModel +from typing_extensions import LiteralString +from graphiti_core.driver.driver import GraphDriver from graphiti_core.edges import ( CommunityEdge, EntityEdge, @@ -27,7 +29,7 @@ from graphiti_core.edges import ( create_entity_edge_embeddings, ) from graphiti_core.graphiti_types import GraphitiClients -from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather +from graphiti_core.helpers import DEFAULT_DATABASE, MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import ModelSize from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode @@ -61,6 +63,28 @@ def build_episodic_edges( return episodic_edges +def build_duplicate_of_edges( + episode: EpisodicNode, + created_at: datetime, + duplicate_nodes: list[tuple[EntityNode, EntityNode]], +) -> list[EntityEdge]: + is_duplicate_of_edges: list[EntityEdge] = [ + EntityEdge( + source_node_uuid=source_node.uuid, + target_node_uuid=target_node.uuid, + name='IS_DUPLICATE_OF', + group_id=episode.group_id, + fact=f'{source_node.name} is a duplicate of {target_node.name}', + episodes=[episode.uuid], + created_at=created_at, + valid_at=created_at, + ) + for source_node, target_node in duplicate_nodes + ] + + return is_duplicate_of_edges + + def build_community_edges( entity_nodes: list[EntityNode], community_node: CommunityNode, @@ -570,3 +594,34 @@ async def dedupe_edge_list( unique_edges.append(edge) return unique_edges + + +async def filter_existing_duplicate_of_edges( + driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]] +) -> list[tuple[EntityNode, EntityNode]]: + query: LiteralString = """ + UNWIND $duplicate_node_uuids AS duplicate_tuple + MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]}) + RETURN DISTINCT + n.uuid AS source_uuid, + m.uuid AS target_uuid + """ + + duplicate_nodes_map = { + (source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples + } + + records, _, _ = await driver.execute_query( + query, + duplicate_node_uuids=list(duplicate_nodes_map.keys()), + database_=DEFAULT_DATABASE, + routing_='r', + ) + + # Remove duplicates that already have the IS_DUPLICATE_OF edge + for record in records: + duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid')) + if duplicate_nodes_map.get(duplicate_tuple): + duplicate_nodes_map.pop(duplicate_tuple) + + return list(duplicate_nodes_map.values()) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 55e5d515..df466a56 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -40,6 +40,7 @@ from graphiti_core.search.search_config import SearchResults from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF from graphiti_core.search.search_filters import SearchFilters from graphiti_core.utils.datetime_utils import utc_now +from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges logger = logging.getLogger(__name__) @@ -225,8 +226,9 @@ async def resolve_extracted_nodes( episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, entity_types: dict[str, BaseModel] | None = None, -) -> tuple[list[EntityNode], dict[str, str]]: +) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: llm_client = clients.llm_client + driver = clients.driver search_results: list[SearchResults] = await semaphore_gather( *[ @@ -295,9 +297,10 @@ async def resolve_extracted_nodes( resolved_nodes: list[EntityNode] = [] uuid_map: dict[str, str] = {} + node_duplicates: list[tuple[EntityNode, EntityNode]] = [] for resolution in node_resolutions: - resolution_id = resolution.get('id', -1) - duplicate_idx = resolution.get('duplicate_idx', -1) + resolution_id: int = resolution.get('id', -1) + duplicate_idx: int = resolution.get('duplicate_idx', -1) extracted_node = extracted_nodes[resolution_id] @@ -312,9 +315,21 @@ async def resolve_extracted_nodes( resolved_nodes.append(resolved_node) uuid_map[extracted_node.uuid] = resolved_node.uuid + additional_duplicates: list[int] = resolution.get('additional_duplicates', []) + for idx in additional_duplicates: + existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node + if existing_node == resolved_node: + continue + + node_duplicates.append((resolved_node, existing_nodes[idx])) + logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') - return resolved_nodes, uuid_map + new_node_duplicates: list[ + tuple[EntityNode, EntityNode] + ] = await filter_existing_duplicate_of_edges(driver, node_duplicates) + + return resolved_nodes, uuid_map, new_node_duplicates async def extract_attributes_from_nodes(