add IS_DUPLICATE_OF edges (#599)
* add IS_DUPLICATE_OF edges * cypher query update * robust handling
This commit is contained in:
parent
0d6a76d891
commit
e8bf81fc6b
4 changed files with 86 additions and 9 deletions
|
|
@ -63,6 +63,7 @@ from graphiti_core.utils.maintenance.community_operations import (
|
|||
update_community,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
build_duplicate_of_edges,
|
||||
build_episodic_edges,
|
||||
extract_edges,
|
||||
resolve_extracted_edge,
|
||||
|
|
@ -375,7 +376,7 @@ class Graphiti:
|
|||
)
|
||||
|
||||
# Extract edges and resolve nodes
|
||||
(nodes, uuid_map), extracted_edges = await semaphore_gather(
|
||||
(nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
|
||||
resolve_extracted_nodes(
|
||||
self.clients,
|
||||
extracted_nodes,
|
||||
|
|
@ -404,7 +405,9 @@ class Graphiti:
|
|||
),
|
||||
)
|
||||
|
||||
entity_edges = resolved_edges + invalidated_edges
|
||||
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
|
||||
|
||||
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
||||
|
||||
episodic_edges = build_episodic_edges(nodes, episode, now)
|
||||
|
||||
|
|
@ -691,7 +694,7 @@ class Graphiti:
|
|||
if edge.fact_embedding is None:
|
||||
await edge.generate_embedding(self.embedder)
|
||||
|
||||
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
||||
resolved_nodes, uuid_map, _ = await resolve_extracted_nodes(
|
||||
self.clients,
|
||||
[source_node, target_node],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,12 +26,16 @@ class NodeDuplicate(BaseModel):
|
|||
id: int = Field(..., description='integer id of the entity')
|
||||
duplicate_idx: int = Field(
|
||||
...,
|
||||
description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
|
||||
description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description='Name of the entity. Should be the most complete and descriptive name possible.',
|
||||
)
|
||||
additional_duplicates: list[int] = Field(
|
||||
...,
|
||||
description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
|
||||
)
|
||||
|
||||
|
||||
class NodeResolutions(BaseModel):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ from datetime import datetime
|
|||
from time import time
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver
|
||||
from graphiti_core.edges import (
|
||||
CommunityEdge,
|
||||
EntityEdge,
|
||||
|
|
@ -27,7 +29,7 @@ from graphiti_core.edges import (
|
|||
create_entity_edge_embeddings,
|
||||
)
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.llm_client.config import ModelSize
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
|
|
@ -61,6 +63,28 @@ def build_episodic_edges(
|
|||
return episodic_edges
|
||||
|
||||
|
||||
def build_duplicate_of_edges(
|
||||
episode: EpisodicNode,
|
||||
created_at: datetime,
|
||||
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
|
||||
) -> list[EntityEdge]:
|
||||
is_duplicate_of_edges: list[EntityEdge] = [
|
||||
EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='IS_DUPLICATE_OF',
|
||||
group_id=episode.group_id,
|
||||
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
||||
episodes=[episode.uuid],
|
||||
created_at=created_at,
|
||||
valid_at=created_at,
|
||||
)
|
||||
for source_node, target_node in duplicate_nodes
|
||||
]
|
||||
|
||||
return is_duplicate_of_edges
|
||||
|
||||
|
||||
def build_community_edges(
|
||||
entity_nodes: list[EntityNode],
|
||||
community_node: CommunityNode,
|
||||
|
|
@ -570,3 +594,34 @@ async def dedupe_edge_list(
|
|||
unique_edges.append(edge)
|
||||
|
||||
return unique_edges
|
||||
|
||||
|
||||
async def filter_existing_duplicate_of_edges(
|
||||
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
||||
) -> list[tuple[EntityNode, EntityNode]]:
|
||||
query: LiteralString = """
|
||||
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
||||
RETURN DISTINCT
|
||||
n.uuid AS source_uuid,
|
||||
m.uuid AS target_uuid
|
||||
"""
|
||||
|
||||
duplicate_nodes_map = {
|
||||
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
||||
}
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
query,
|
||||
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
|
||||
database_=DEFAULT_DATABASE,
|
||||
routing_='r',
|
||||
)
|
||||
|
||||
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
||||
for record in records:
|
||||
duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
|
||||
if duplicate_nodes_map.get(duplicate_tuple):
|
||||
duplicate_nodes_map.pop(duplicate_tuple)
|
||||
|
||||
return list(duplicate_nodes_map.values())
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from graphiti_core.search.search_config import SearchResults
|
|||
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -225,8 +226,9 @@ async def resolve_extracted_nodes(
|
|||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
||||
llm_client = clients.llm_client
|
||||
driver = clients.driver
|
||||
|
||||
search_results: list[SearchResults] = await semaphore_gather(
|
||||
*[
|
||||
|
|
@ -295,9 +297,10 @@ async def resolve_extracted_nodes(
|
|||
|
||||
resolved_nodes: list[EntityNode] = []
|
||||
uuid_map: dict[str, str] = {}
|
||||
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
||||
for resolution in node_resolutions:
|
||||
resolution_id = resolution.get('id', -1)
|
||||
duplicate_idx = resolution.get('duplicate_idx', -1)
|
||||
resolution_id: int = resolution.get('id', -1)
|
||||
duplicate_idx: int = resolution.get('duplicate_idx', -1)
|
||||
|
||||
extracted_node = extracted_nodes[resolution_id]
|
||||
|
||||
|
|
@ -312,9 +315,21 @@ async def resolve_extracted_nodes(
|
|||
resolved_nodes.append(resolved_node)
|
||||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||
|
||||
additional_duplicates: list[int] = resolution.get('additional_duplicates', [])
|
||||
for idx in additional_duplicates:
|
||||
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
||||
if existing_node == resolved_node:
|
||||
continue
|
||||
|
||||
node_duplicates.append((resolved_node, existing_nodes[idx]))
|
||||
|
||||
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||
|
||||
return resolved_nodes, uuid_map
|
||||
new_node_duplicates: list[
|
||||
tuple[EntityNode, EntityNode]
|
||||
] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
|
||||
|
||||
return resolved_nodes, uuid_map, new_node_duplicates
|
||||
|
||||
|
||||
async def extract_attributes_from_nodes(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue