add IS_DUPLICATE_OF edges (#599)

* add IS_DUPLICATE_OF edges

* cypher query update

* robust handling
This commit is contained in:
Preston Rasmussen 2025-06-17 11:56:55 -04:00 committed by GitHub
parent 0d6a76d891
commit e8bf81fc6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 86 additions and 9 deletions

View file

@ -63,6 +63,7 @@ from graphiti_core.utils.maintenance.community_operations import (
update_community,
)
from graphiti_core.utils.maintenance.edge_operations import (
build_duplicate_of_edges,
build_episodic_edges,
extract_edges,
resolve_extracted_edge,
@ -375,7 +376,7 @@ class Graphiti:
)
# Extract edges and resolve nodes
(nodes, uuid_map), extracted_edges = await semaphore_gather(
(nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
resolve_extracted_nodes(
self.clients,
extracted_nodes,
@ -404,7 +405,9 @@ class Graphiti:
),
)
entity_edges = resolved_edges + invalidated_edges
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
episodic_edges = build_episodic_edges(nodes, episode, now)
@ -691,7 +694,7 @@ class Graphiti:
if edge.fact_embedding is None:
await edge.generate_embedding(self.embedder)
resolved_nodes, uuid_map = await resolve_extracted_nodes(
resolved_nodes, uuid_map, _ = await resolve_extracted_nodes(
self.clients,
[source_node, target_node],
)

View file

@ -26,12 +26,16 @@ class NodeDuplicate(BaseModel):
id: int = Field(..., description='integer id of the entity')
duplicate_idx: int = Field(
...,
description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
)
name: str = Field(
...,
description='Name of the entity. Should be the most complete and descriptive name possible.',
)
additional_duplicates: list[int] = Field(
...,
description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
)
class NodeResolutions(BaseModel):

View file

@ -19,7 +19,9 @@ from datetime import datetime
from time import time
from pydantic import BaseModel
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import (
CommunityEdge,
EntityEdge,
@ -27,7 +29,7 @@ from graphiti_core.edges import (
create_entity_edge_embeddings,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.helpers import DEFAULT_DATABASE, MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
@ -61,6 +63,28 @@ def build_episodic_edges(
return episodic_edges
def build_duplicate_of_edges(
episode: EpisodicNode,
created_at: datetime,
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
) -> list[EntityEdge]:
is_duplicate_of_edges: list[EntityEdge] = [
EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='IS_DUPLICATE_OF',
group_id=episode.group_id,
fact=f'{source_node.name} is a duplicate of {target_node.name}',
episodes=[episode.uuid],
created_at=created_at,
valid_at=created_at,
)
for source_node, target_node in duplicate_nodes
]
return is_duplicate_of_edges
def build_community_edges(
entity_nodes: list[EntityNode],
community_node: CommunityNode,
@ -570,3 +594,34 @@ async def dedupe_edge_list(
unique_edges.append(edge)
return unique_edges
async def filter_existing_duplicate_of_edges(
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
) -> list[tuple[EntityNode, EntityNode]]:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_nodes_map = {
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
}
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
database_=DEFAULT_DATABASE,
routing_='r',
)
# Remove duplicates that already have the IS_DUPLICATE_OF edge
for record in records:
duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
if duplicate_nodes_map.get(duplicate_tuple):
duplicate_nodes_map.pop(duplicate_tuple)
return list(duplicate_nodes_map.values())

View file

@ -40,6 +40,7 @@ from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
logger = logging.getLogger(__name__)
@ -225,8 +226,9 @@ async def resolve_extracted_nodes(
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> tuple[list[EntityNode], dict[str, str]]:
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
llm_client = clients.llm_client
driver = clients.driver
search_results: list[SearchResults] = await semaphore_gather(
*[
@ -295,9 +297,10 @@ async def resolve_extracted_nodes(
resolved_nodes: list[EntityNode] = []
uuid_map: dict[str, str] = {}
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
for resolution in node_resolutions:
resolution_id = resolution.get('id', -1)
duplicate_idx = resolution.get('duplicate_idx', -1)
resolution_id: int = resolution.get('id', -1)
duplicate_idx: int = resolution.get('duplicate_idx', -1)
extracted_node = extracted_nodes[resolution_id]
@ -312,9 +315,21 @@ async def resolve_extracted_nodes(
resolved_nodes.append(resolved_node)
uuid_map[extracted_node.uuid] = resolved_node.uuid
additional_duplicates: list[int] = resolution.get('additional_duplicates', [])
for idx in additional_duplicates:
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
if existing_node == resolved_node:
continue
node_duplicates.append((resolved_node, existing_nodes[idx]))
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
return resolved_nodes, uuid_map
new_node_duplicates: list[
tuple[EntityNode, EntityNode]
] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
return resolved_nodes, uuid_map, new_node_duplicates
async def extract_attributes_from_nodes(