add IS_DUPLICATE_OF edges (#599)
* add IS_DUPLICATE_OF edges * cypher query update * robust handling
This commit is contained in:
parent
0d6a76d891
commit
e8bf81fc6b
4 changed files with 86 additions and 9 deletions
|
|
@ -63,6 +63,7 @@ from graphiti_core.utils.maintenance.community_operations import (
|
||||||
update_community,
|
update_community,
|
||||||
)
|
)
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
|
build_duplicate_of_edges,
|
||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
extract_edges,
|
extract_edges,
|
||||||
resolve_extracted_edge,
|
resolve_extracted_edge,
|
||||||
|
|
@ -375,7 +376,7 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract edges and resolve nodes
|
# Extract edges and resolve nodes
|
||||||
(nodes, uuid_map), extracted_edges = await semaphore_gather(
|
(nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
|
||||||
resolve_extracted_nodes(
|
resolve_extracted_nodes(
|
||||||
self.clients,
|
self.clients,
|
||||||
extracted_nodes,
|
extracted_nodes,
|
||||||
|
|
@ -404,7 +405,9 @@ class Graphiti:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
entity_edges = resolved_edges + invalidated_edges
|
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
|
||||||
|
|
||||||
|
entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
|
||||||
|
|
||||||
episodic_edges = build_episodic_edges(nodes, episode, now)
|
episodic_edges = build_episodic_edges(nodes, episode, now)
|
||||||
|
|
||||||
|
|
@ -691,7 +694,7 @@ class Graphiti:
|
||||||
if edge.fact_embedding is None:
|
if edge.fact_embedding is None:
|
||||||
await edge.generate_embedding(self.embedder)
|
await edge.generate_embedding(self.embedder)
|
||||||
|
|
||||||
resolved_nodes, uuid_map = await resolve_extracted_nodes(
|
resolved_nodes, uuid_map, _ = await resolve_extracted_nodes(
|
||||||
self.clients,
|
self.clients,
|
||||||
[source_node, target_node],
|
[source_node, target_node],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -26,12 +26,16 @@ class NodeDuplicate(BaseModel):
|
||||||
id: int = Field(..., description='integer id of the entity')
|
id: int = Field(..., description='integer id of the entity')
|
||||||
duplicate_idx: int = Field(
|
duplicate_idx: int = Field(
|
||||||
...,
|
...,
|
||||||
description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
|
description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
|
||||||
)
|
)
|
||||||
name: str = Field(
|
name: str = Field(
|
||||||
...,
|
...,
|
||||||
description='Name of the entity. Should be the most complete and descriptive name possible.',
|
description='Name of the entity. Should be the most complete and descriptive name possible.',
|
||||||
)
|
)
|
||||||
|
additional_duplicates: list[int] = Field(
|
||||||
|
...,
|
||||||
|
description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NodeResolutions(BaseModel):
|
class NodeResolutions(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,9 @@ from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import LiteralString
|
||||||
|
|
||||||
|
from graphiti_core.driver.driver import GraphDriver
|
||||||
from graphiti_core.edges import (
|
from graphiti_core.edges import (
|
||||||
CommunityEdge,
|
CommunityEdge,
|
||||||
EntityEdge,
|
EntityEdge,
|
||||||
|
|
@ -27,7 +29,7 @@ from graphiti_core.edges import (
|
||||||
create_entity_edge_embeddings,
|
create_entity_edge_embeddings,
|
||||||
)
|
)
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.llm_client.config import ModelSize
|
from graphiti_core.llm_client.config import ModelSize
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||||
|
|
@ -61,6 +63,28 @@ def build_episodic_edges(
|
||||||
return episodic_edges
|
return episodic_edges
|
||||||
|
|
||||||
|
|
||||||
|
def build_duplicate_of_edges(
|
||||||
|
episode: EpisodicNode,
|
||||||
|
created_at: datetime,
|
||||||
|
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
|
||||||
|
) -> list[EntityEdge]:
|
||||||
|
is_duplicate_of_edges: list[EntityEdge] = [
|
||||||
|
EntityEdge(
|
||||||
|
source_node_uuid=source_node.uuid,
|
||||||
|
target_node_uuid=target_node.uuid,
|
||||||
|
name='IS_DUPLICATE_OF',
|
||||||
|
group_id=episode.group_id,
|
||||||
|
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
||||||
|
episodes=[episode.uuid],
|
||||||
|
created_at=created_at,
|
||||||
|
valid_at=created_at,
|
||||||
|
)
|
||||||
|
for source_node, target_node in duplicate_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
return is_duplicate_of_edges
|
||||||
|
|
||||||
|
|
||||||
def build_community_edges(
|
def build_community_edges(
|
||||||
entity_nodes: list[EntityNode],
|
entity_nodes: list[EntityNode],
|
||||||
community_node: CommunityNode,
|
community_node: CommunityNode,
|
||||||
|
|
@ -570,3 +594,34 @@ async def dedupe_edge_list(
|
||||||
unique_edges.append(edge)
|
unique_edges.append(edge)
|
||||||
|
|
||||||
return unique_edges
|
return unique_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def filter_existing_duplicate_of_edges(
|
||||||
|
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
||||||
|
) -> list[tuple[EntityNode, EntityNode]]:
|
||||||
|
query: LiteralString = """
|
||||||
|
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
||||||
|
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
||||||
|
RETURN DISTINCT
|
||||||
|
n.uuid AS source_uuid,
|
||||||
|
m.uuid AS target_uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
duplicate_nodes_map = {
|
||||||
|
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
||||||
|
}
|
||||||
|
|
||||||
|
records, _, _ = await driver.execute_query(
|
||||||
|
query,
|
||||||
|
duplicate_node_uuids=list(duplicate_nodes_map.keys()),
|
||||||
|
database_=DEFAULT_DATABASE,
|
||||||
|
routing_='r',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
||||||
|
for record in records:
|
||||||
|
duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
|
||||||
|
if duplicate_nodes_map.get(duplicate_tuple):
|
||||||
|
duplicate_nodes_map.pop(duplicate_tuple)
|
||||||
|
|
||||||
|
return list(duplicate_nodes_map.values())
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ from graphiti_core.search.search_config import SearchResults
|
||||||
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
from graphiti_core.utils.datetime_utils import utc_now
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -225,8 +226,9 @@ async def resolve_extracted_nodes(
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode | None = None,
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
previous_episodes: list[EpisodicNode] | None = None,
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
|
driver = clients.driver
|
||||||
|
|
||||||
search_results: list[SearchResults] = await semaphore_gather(
|
search_results: list[SearchResults] = await semaphore_gather(
|
||||||
*[
|
*[
|
||||||
|
|
@ -295,9 +297,10 @@ async def resolve_extracted_nodes(
|
||||||
|
|
||||||
resolved_nodes: list[EntityNode] = []
|
resolved_nodes: list[EntityNode] = []
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
|
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
||||||
for resolution in node_resolutions:
|
for resolution in node_resolutions:
|
||||||
resolution_id = resolution.get('id', -1)
|
resolution_id: int = resolution.get('id', -1)
|
||||||
duplicate_idx = resolution.get('duplicate_idx', -1)
|
duplicate_idx: int = resolution.get('duplicate_idx', -1)
|
||||||
|
|
||||||
extracted_node = extracted_nodes[resolution_id]
|
extracted_node = extracted_nodes[resolution_id]
|
||||||
|
|
||||||
|
|
@ -312,9 +315,21 @@ async def resolve_extracted_nodes(
|
||||||
resolved_nodes.append(resolved_node)
|
resolved_nodes.append(resolved_node)
|
||||||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||||
|
|
||||||
|
additional_duplicates: list[int] = resolution.get('additional_duplicates', [])
|
||||||
|
for idx in additional_duplicates:
|
||||||
|
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
||||||
|
if existing_node == resolved_node:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_duplicates.append((resolved_node, existing_nodes[idx]))
|
||||||
|
|
||||||
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||||
|
|
||||||
return resolved_nodes, uuid_map
|
new_node_duplicates: list[
|
||||||
|
tuple[EntityNode, EntityNode]
|
||||||
|
] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
|
||||||
|
|
||||||
|
return resolved_nodes, uuid_map, new_node_duplicates
|
||||||
|
|
||||||
|
|
||||||
async def extract_attributes_from_nodes(
|
async def extract_attributes_from_nodes(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue