From 9422b6f5fb38178bd41dce49fb41e0c752c746ef Mon Sep 17 00:00:00 2001
From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
Date: Thu, 15 May 2025 13:56:33 -0400
Subject: [PATCH] Node dedupe efficiency (#490)
* update resolve extracted edge
* updated edge resolution
* dedupe nodes update
* single pass node resolution
* updates
* mypy updates
* Update graphiti_core/prompts/dedupe_nodes.py
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
* remove unused imports
* mypy
---------
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
---
graphiti_core/graphiti.py | 19 ++--
graphiti_core/helpers.py | 4 +-
graphiti_core/prompts/dedupe_edges.py | 45 ++++++++-
graphiti_core/prompts/dedupe_nodes.py | 79 +++++++++++++++-
graphiti_core/prompts/invalidate_edges.py | 2 +-
.../utils/maintenance/edge_operations.py | 54 +++++++++--
.../utils/maintenance/node_operations.py | 78 +++++++++++----
pyproject.toml | 2 +-
.../utils/maintenance/test_edge_operations.py | 94 +------------------
9 files changed, 236 insertions(+), 141 deletions(-)
diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index 4153a684..5417728f 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
+ get_edge_invalidation_candidates,
get_mentioned_nodes,
get_relevant_edges,
)
@@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
)
from graphiti_core.utils.maintenance.edge_operations import (
build_episodic_edges,
- dedupe_extracted_edge,
extract_edges,
- resolve_edge_contradictions,
+ resolve_extracted_edge,
resolve_extracted_edges,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
@@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
-from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
logger = logging.getLogger(__name__)
@@ -681,17 +680,15 @@ class Graphiti:
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
- related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
+ related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
+ existing_edges = (
+ await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
+ )[0]
- resolved_edge = await dedupe_extracted_edge(
- self.llm_client,
- updated_edge,
- related_edges[0],
+ resolved_edge, invalidated_edges = await resolve_extracted_edge(
+ self.llm_client, updated_edge, related_edges, existing_edges
)
- contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
- invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
-
await add_nodes_and_edges_bulk(
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
)
diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py
index 7c6175d4..21c388c2 100644
--- a/graphiti_core/helpers.py
+++ b/graphiti_core/helpers.py
@@ -87,8 +87,8 @@ def normalize_l2(embedding: list[float]) -> NDArray:
# Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather(
- *coroutines: Coroutine,
- max_coroutines: int = SEMAPHORE_LIMIT,
+ *coroutines: Coroutine,
+ max_coroutines: int = SEMAPHORE_LIMIT,
):
semaphore = asyncio.Semaphore(max_coroutines)
diff --git a/graphiti_core/prompts/dedupe_edges.py b/graphiti_core/prompts/dedupe_edges.py
index 5354f3cc..f63011d4 100644
--- a/graphiti_core/prompts/dedupe_edges.py
+++ b/graphiti_core/prompts/dedupe_edges.py
@@ -27,6 +27,10 @@ class EdgeDuplicate(BaseModel):
...,
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
)
+ contradicted_facts: list[int] = Field(
+ ...,
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
+ )
class UniqueFact(BaseModel):
@@ -41,11 +45,13 @@ class UniqueFacts(BaseModel):
class Prompt(Protocol):
edge: PromptVersion
edge_list: PromptVersion
+ resolve_edge: PromptVersion
class Versions(TypedDict):
edge: PromptFunction
edge_list: PromptFunction
+ resolve_edge: PromptFunction
def edge(context: dict[str, Any]) -> list[Message]:
@@ -106,4 +112,41 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
]
-versions: Versions = {'edge': edge, 'edge_list': edge_list}
+def resolve_edge(context: dict[str, Any]) -> list[Message]:
+ return [
+ Message(
+ role='system',
+ content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
+ 'facts are contradicted by the new fact.',
+ ),
+ Message(
+ role='user',
+ content=f"""
+
+ {context['new_edge']}
+
+
+
+ {context['existing_edges']}
+
+
+ {context['edge_invalidation_candidates']}
+
+
+
+ Task:
+ If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
+ If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
+
+ Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
+ Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
+ If there are no contradicted facts, return an empty list.
+
+ Guidelines:
+ 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
+ """,
+ ),
+ ]
+
+
+versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py
index 1cac6b79..318d4c9f 100644
--- a/graphiti_core/prompts/dedupe_nodes.py
+++ b/graphiti_core/prompts/dedupe_nodes.py
@@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
class NodeDuplicate(BaseModel):
- duplicate_node_id: int = Field(
+ id: int = Field(..., description='integer id of the entity')
+ duplicate_idx: int = Field(
...,
- description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
+ description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
)
- name: str = Field(..., description='Name of the entity.')
+ name: str = Field(
+ ...,
+ description='Name of the entity. Should be the most complete and descriptive name possible.',
+ )
+
+
+class NodeResolutions(BaseModel):
+ entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
class Prompt(Protocol):
node: PromptVersion
node_list: PromptVersion
+ nodes: PromptVersion
class Versions(TypedDict):
node: PromptFunction
node_list: PromptFunction
+ nodes: PromptFunction
def node(context: dict[str, Any]) -> list[Message]:
@@ -89,6 +99,67 @@ def node(context: dict[str, Any]) -> list[Message]:
]
+def nodes(context: dict[str, Any]) -> list[Message]:
+ return [
+ Message(
+ role='system',
+ content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
+ 'of existing entities.',
+ ),
+ Message(
+ role='user',
+ content=f"""
+
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
+
+
+ {context['episode_content']}
+
+
+
+ Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
+ Each entity in ENTITIES is represented as a JSON object with the following structure:
+ {{
+ id: integer id of the entity,
+ name: "name of the entity",
+ entity_type: "ontological classification of the entity",
+ entity_type_description: "Description of what the entity type represents",
+ duplication_candidates: [
+ {{
+ idx: integer index of the candidate entity,
+ name: "name of the candidate entity",
+ entity_type: "ontological classification of the candidate entity",
+ ...
+ }}
+ ]
+ }}
+
+
+ {json.dumps(context['extracted_nodes'], indent=2)}
+
+
+ For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
+
+ Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
+
+ Do NOT mark entities as duplicates if:
+ - They are related but distinct.
+ - They have similar names or purposes but refer to separate instances or concepts.
+
+ Task:
+ Your response will be a list called entity_resolutions which contains one entry for each entity.
+
+ For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
+ as an integer.
+
+ - If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
+ duplicate of.
+ - If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
+ """,
+ ),
+ ]
+
+
def node_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
@@ -126,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
]
-versions: Versions = {'node': node, 'node_list': node_list}
+versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
diff --git a/graphiti_core/prompts/invalidate_edges.py b/graphiti_core/prompts/invalidate_edges.py
index f30048a5..f5342ed3 100644
--- a/graphiti_core/prompts/invalidate_edges.py
+++ b/graphiti_core/prompts/invalidate_edges.py
@@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
class InvalidatedEdges(BaseModel):
contradicted_facts: list[int] = Field(
...,
- description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
+ description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
)
diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py
index 64973015..d90fba52 100644
--- a/graphiti_core/utils/maintenance/edge_operations.py
+++ b/graphiti_core/utils/maintenance/edge_operations.py
@@ -35,9 +35,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
-from graphiti_core.utils.maintenance.temporal_operations import (
- get_edge_contradictions,
-)
logger = logging.getLogger(__name__)
@@ -245,7 +242,7 @@ async def resolve_extracted_edges(
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
get_relevant_edges(driver, extracted_edges, SearchFilters()),
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
+ get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
)
related_edges_lists, edge_invalidation_candidates = search_results
@@ -325,11 +322,52 @@ async def resolve_extracted_edge(
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
- episode: EpisodicNode,
+ episode: EpisodicNode | None = None,
) -> tuple[EntityEdge, list[EntityEdge]]:
- resolved_edge, invalidation_candidates = await semaphore_gather(
- dedupe_extracted_edge(llm_client, extracted_edge, related_edges, episode),
- get_edge_contradictions(llm_client, extracted_edge, existing_edges),
+ if len(related_edges) == 0 and len(existing_edges) == 0:
+ return extracted_edge, []
+
+ start = time()
+
+ # Prepare context for LLM
+ related_edges_context = [
+ {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
+ ]
+
+ invalidation_edge_candidates_context = [
+ {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
+ ]
+
+ context = {
+ 'existing_edges': related_edges_context,
+ 'new_edge': extracted_edge.fact,
+ 'edge_invalidation_candidates': invalidation_edge_candidates_context,
+ }
+
+ llm_response = await llm_client.generate_response(
+ prompt_library.dedupe_edges.resolve_edge(context),
+ response_model=EdgeDuplicate,
+ model_size=ModelSize.small,
+ )
+
+ duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
+
+ resolved_edge = (
+ related_edges[duplicate_fact_id]
+ if 0 <= duplicate_fact_id < len(related_edges)
+ else extracted_edge
+ )
+
+ if duplicate_fact_id >= 0 and episode is not None:
+ resolved_edge.episodes.append(episode.uuid)
+
+ contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
+
+ invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
+
+ end = time()
+ logger.debug(
+ f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
)
now = utc_now()
diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py
index f25746bb..2b3de99e 100644
--- a/graphiti_core/utils/maintenance/node_operations.py
+++ b/graphiti_core/utils/maintenance/node_operations.py
@@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library
-from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
+from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
ExtractedEntities,
ExtractedEntity,
@@ -243,28 +243,65 @@ async def resolve_extracted_nodes(
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
- resolved_nodes: list[EntityNode] = await semaphore_gather(
- *[
- resolve_extracted_node(
- llm_client,
- extracted_node,
- existing_nodes,
- episode,
- previous_episodes,
- entity_types.get(
- next((item for item in extracted_node.labels if item != 'Entity'), '')
- )
- if entity_types is not None
- else None,
- )
- for extracted_node, existing_nodes in zip(
- extracted_nodes, existing_nodes_lists, strict=True
- )
- ]
+ entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
+
+ # Prepare context for LLM
+ extracted_nodes_context = [
+ {
+ 'id': i,
+ 'name': node.name,
+ 'entity_type': node.labels,
+ 'entity_type_description': entity_types_dict.get(
+ next((item for item in node.labels if item != 'Entity'), '')
+ ).__doc__
+ or 'Default Entity Type',
+ 'duplication_candidates': [
+ {
+ **{
+ 'idx': j,
+ 'name': candidate.name,
+ 'entity_types': candidate.labels,
+ },
+ **candidate.attributes,
+ }
+ for j, candidate in enumerate(existing_nodes_lists[i])
+ ],
+ }
+ for i, node in enumerate(extracted_nodes)
+ ]
+
+ context = {
+ 'extracted_nodes': extracted_nodes_context,
+ 'episode_content': episode.content if episode is not None else '',
+ 'previous_episodes': [ep.content for ep in previous_episodes]
+ if previous_episodes is not None
+ else [],
+ }
+
+ llm_response = await llm_client.generate_response(
+ prompt_library.dedupe_nodes.nodes(context),
+ response_model=NodeResolutions,
)
+ node_resolutions: list = llm_response.get('entity_resolutions', [])
+
+ resolved_nodes: list[EntityNode] = []
uuid_map: dict[str, str] = {}
- for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True):
+ for resolution in node_resolutions:
+ resolution_id = resolution.get('id', -1)
+ duplicate_idx = resolution.get('duplicate_idx', -1)
+
+ extracted_node = extracted_nodes[resolution_id]
+
+ resolved_node = (
+ existing_nodes_lists[resolution_id][duplicate_idx]
+ if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
+ else extracted_node
+ )
+
+ resolved_node.name = resolution.get('name')
+
+ resolved_nodes.append(resolved_node)
uuid_map[extracted_node.uuid] = resolved_node.uuid
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
@@ -410,6 +447,7 @@ async def extract_attributes_from_node(
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(summary_context),
response_model=entity_attributes_model,
+ model_size=ModelSize.small,
)
node.summary = llm_response.get('summary', node.summary)
diff --git a/pyproject.toml b/pyproject.toml
index 2da24413..6623627a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
-version = "0.11.6pre9"
+version = "0.11.6"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
diff --git a/tests/utils/maintenance/test_edge_operations.py b/tests/utils/maintenance/test_edge_operations.py
index 3145b74d..cdb1de9f 100644
--- a/tests/utils/maintenance/test_edge_operations.py
+++ b/tests/utils/maintenance/test_edge_operations.py
@@ -1,12 +1,10 @@
from datetime import datetime, timedelta, timezone
-from unittest.mock import AsyncMock, MagicMock
+from unittest.mock import MagicMock
import pytest
-from pytest import MonkeyPatch
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodicNode
-from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
@pytest.fixture
@@ -91,96 +89,6 @@ def mock_previous_episodes():
]
-@pytest.mark.asyncio
-async def test_resolve_extracted_edge_no_changes(
- mock_llm_client,
- mock_extracted_edge,
- mock_related_edges,
- mock_existing_edges,
- mock_current_episode,
- mock_previous_episodes,
- monkeypatch: MonkeyPatch,
-):
- # Mock the function calls
- dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
- get_contradictions_mock = AsyncMock(return_value=[])
-
- # Patch the function calls
- monkeypatch.setattr(
- 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
- )
- monkeypatch.setattr(
- 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
- get_contradictions_mock,
- )
-
- resolved_edge, invalidated_edges = await resolve_extracted_edge(
- mock_llm_client,
- mock_extracted_edge,
- mock_related_edges,
- mock_existing_edges,
- mock_current_episode,
- )
-
- assert resolved_edge.uuid == mock_extracted_edge.uuid
- assert invalidated_edges == []
- dedupe_mock.assert_called_once()
- get_contradictions_mock.assert_called_once()
-
-
-@pytest.mark.asyncio
-async def test_resolve_extracted_edge_with_invalidation(
- mock_llm_client,
- mock_extracted_edge,
- mock_related_edges,
- mock_existing_edges,
- mock_current_episode,
- mock_previous_episodes,
- monkeypatch: MonkeyPatch,
-):
- valid_at = datetime.now(timezone.utc) - timedelta(days=1)
- mock_extracted_edge.valid_at = valid_at
-
- invalidation_candidate = EntityEdge(
- source_node_uuid='source_uuid_4',
- target_node_uuid='target_uuid_4',
- name='invalidation_candidate',
- group_id='group_1',
- fact='Invalidation candidate fact',
- episodes=['episode_4'],
- created_at=datetime.now(timezone.utc),
- valid_at=datetime.now(timezone.utc) - timedelta(days=2),
- invalid_at=None,
- )
-
- # Mock the function calls
- dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
- get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate])
-
- # Patch the function calls
- monkeypatch.setattr(
- 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
- )
- monkeypatch.setattr(
- 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
- get_contradictions_mock,
- )
-
- resolved_edge, invalidated_edges = await resolve_extracted_edge(
- mock_llm_client,
- mock_extracted_edge,
- mock_related_edges,
- mock_existing_edges,
- mock_current_episode,
- )
-
- assert resolved_edge.uuid == mock_extracted_edge.uuid
- assert len(invalidated_edges) == 1
- assert invalidated_edges[0].uuid == invalidation_candidate.uuid
- assert invalidated_edges[0].invalid_at == valid_at
- assert invalidated_edges[0].expired_at is not None
-
-
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])