Node dedupe efficiency (#490)

* update resolve extracted edge

* updated edge resolution

* dedupe nodes update

* single pass node resolution

* updates

* mypy updates

* Update graphiti_core/prompts/dedupe_nodes.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* remove unused imports

* mypy

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-05-15 13:56:33 -04:00 committed by GitHub
parent f096c8770c
commit 9422b6f5fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 236 additions and 141 deletions

View file

@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_edge_invalidation_candidates,
get_mentioned_nodes,
get_relevant_edges,
)
@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
)
from graphiti_core.utils.maintenance.edge_operations import (
build_episodic_edges,
dedupe_extracted_edge,
extract_edges,
resolve_edge_contradictions,
resolve_extracted_edge,
resolve_extracted_edges,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
logger = logging.getLogger(__name__)
@ -681,17 +680,15 @@ class Graphiti:
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8)
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
existing_edges = (
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
)[0]
resolved_edge = await dedupe_extracted_edge(
self.llm_client,
updated_edge,
related_edges[0],
resolved_edge, invalidated_edges = await resolve_extracted_edge(
self.llm_client, updated_edge, related_edges, existing_edges
)
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
await add_nodes_and_edges_bulk(
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
)

View file

@ -87,8 +87,8 @@ def normalize_l2(embedding: list[float]) -> NDArray:
# Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather(
*coroutines: Coroutine,
max_coroutines: int = SEMAPHORE_LIMIT,
*coroutines: Coroutine,
max_coroutines: int = SEMAPHORE_LIMIT,
):
semaphore = asyncio.Semaphore(max_coroutines)

View file

@ -27,6 +27,10 @@ class EdgeDuplicate(BaseModel):
...,
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
)
contradicted_facts: list[int] = Field(
...,
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
)
class UniqueFact(BaseModel):
@ -41,11 +45,13 @@ class UniqueFacts(BaseModel):
class Prompt(Protocol):
edge: PromptVersion
edge_list: PromptVersion
resolve_edge: PromptVersion
class Versions(TypedDict):
edge: PromptFunction
edge_list: PromptFunction
resolve_edge: PromptFunction
def edge(context: dict[str, Any]) -> list[Message]:
@ -106,4 +112,41 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
]
versions: Versions = {'edge': edge, 'edge_list': edge_list}
def resolve_edge(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
'facts are contradicted by the new fact.',
),
Message(
role='user',
content=f"""
<NEW FACT>
{context['new_edge']}
</NEW FACT>
<EXISTING FACTS>
{context['existing_edges']}
</EXISTING FACTS>
<FACT INVALIDATION CANDIDATES>
{context['edge_invalidation_candidates']}
</FACT INVALIDATION CANDIDATES>
Task:
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
If there are no contradicted facts, return an empty list.
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
""",
),
]
versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}

View file

@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
class NodeDuplicate(BaseModel):
duplicate_node_id: int = Field(
id: int = Field(..., description='integer id of the entity')
duplicate_idx: int = Field(
...,
description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
)
name: str = Field(..., description='Name of the entity.')
name: str = Field(
...,
description='Name of the entity. Should be the most complete and descriptive name possible.',
)
class NodeResolutions(BaseModel):
entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
class Prompt(Protocol):
node: PromptVersion
node_list: PromptVersion
nodes: PromptVersion
class Versions(TypedDict):
node: PromptFunction
node_list: PromptFunction
nodes: PromptFunction
def node(context: dict[str, Any]) -> list[Message]:
@ -89,6 +99,67 @@ def node(context: dict[str, Any]) -> list[Message]:
]
def nodes(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
'of existing entities.',
),
Message(
role='user',
content=f"""
<PREVIOUS MESSAGES>
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>
Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
Each entity in ENTITIES is represented as a JSON object with the following structure:
{{
id: integer id of the entity,
name: "name of the entity",
entity_type: "ontological classification of the entity",
entity_type_description: "Description of what the entity type represents",
duplication_candidates: [
{{
idx: integer index of the candidate entity,
name: "name of the candidate entity",
entity_type: "ontological classification of the candidate entity",
...<additional attributes>
}}
]
}}
<ENTITIES>
{json.dumps(context['extracted_nodes'], indent=2)}
</ENTITIES>
For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
Do NOT mark entities as duplicates if:
- They are related but distinct.
- They have similar names or purposes but refer to separate instances or concepts.
Task:
Your response will be a list called entity_resolutions which contains one entry for each entity.
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
as an integer.
- If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
duplicate of.
- If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
""",
),
]
def node_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
@ -126,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
]
versions: Versions = {'node': node, 'node_list': node_list}
versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}

View file

@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
class InvalidatedEdges(BaseModel):
contradicted_facts: list[int] = Field(
...,
description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
)

View file

@ -35,9 +35,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
from graphiti_core.utils.maintenance.temporal_operations import (
get_edge_contradictions,
)
logger = logging.getLogger(__name__)
@ -245,7 +242,7 @@ async def resolve_extracted_edges(
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
get_relevant_edges(driver, extracted_edges, SearchFilters()),
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
)
related_edges_lists, edge_invalidation_candidates = search_results
@ -325,11 +322,52 @@ async def resolve_extracted_edge(
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
episode: EpisodicNode,
episode: EpisodicNode | None = None,
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, invalidation_candidates = await semaphore_gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges, episode),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, []
start = time()
# Prepare context for LLM
related_edges_context = [
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
]
invalidation_edge_candidates_context = [
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
]
context = {
'existing_edges': related_edges_context,
'new_edge': extracted_edge.fact,
'edge_invalidation_candidates': invalidation_edge_candidates_context,
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.resolve_edge(context),
response_model=EdgeDuplicate,
model_size=ModelSize.small,
)
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
resolved_edge = (
related_edges[duplicate_fact_id]
if 0 <= duplicate_fact_id < len(related_edges)
else extracted_edge
)
if duplicate_fact_id >= 0 and episode is not None:
resolved_edge.episodes.append(episode.uuid)
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
end = time()
logger.debug(
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
)
now = utc_now()

View file

@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
ExtractedEntities,
ExtractedEntity,
@ -243,28 +243,65 @@ async def resolve_extracted_nodes(
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
resolved_nodes: list[EntityNode] = await semaphore_gather(
*[
resolve_extracted_node(
llm_client,
extracted_node,
existing_nodes,
episode,
previous_episodes,
entity_types.get(
next((item for item in extracted_node.labels if item != 'Entity'), '')
)
if entity_types is not None
else None,
)
for extracted_node, existing_nodes in zip(
extracted_nodes, existing_nodes_lists, strict=True
)
]
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
# Prepare context for LLM
extracted_nodes_context = [
{
'id': i,
'name': node.name,
'entity_type': node.labels,
'entity_type_description': entity_types_dict.get(
next((item for item in node.labels if item != 'Entity'), '')
).__doc__
or 'Default Entity Type',
'duplication_candidates': [
{
**{
'idx': j,
'name': candidate.name,
'entity_types': candidate.labels,
},
**candidate.attributes,
}
for j, candidate in enumerate(existing_nodes_lists[i])
],
}
for i, node in enumerate(extracted_nodes)
]
context = {
'extracted_nodes': extracted_nodes_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.nodes(context),
response_model=NodeResolutions,
)
node_resolutions: list = llm_response.get('entity_resolutions', [])
resolved_nodes: list[EntityNode] = []
uuid_map: dict[str, str] = {}
for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True):
for resolution in node_resolutions:
resolution_id = resolution.get('id', -1)
duplicate_idx = resolution.get('duplicate_idx', -1)
extracted_node = extracted_nodes[resolution_id]
resolved_node = (
existing_nodes_lists[resolution_id][duplicate_idx]
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
else extracted_node
)
resolved_node.name = resolution.get('name')
resolved_nodes.append(resolved_node)
uuid_map[extracted_node.uuid] = resolved_node.uuid
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
@ -410,6 +447,7 @@ async def extract_attributes_from_node(
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(summary_context),
response_model=entity_attributes_model,
model_size=ModelSize.small,
)
node.summary = llm_response.get('summary', node.summary)

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.11.6pre9"
version = "0.11.6"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },

View file

@ -1,12 +1,10 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodicNode
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
@pytest.fixture
@ -91,96 +89,6 @@ def mock_previous_episodes():
]
@pytest.mark.asyncio
async def test_resolve_extracted_edge_no_changes(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
get_contradictions_mock = AsyncMock(return_value=[])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
)
resolved_edge, invalidated_edges = await resolve_extracted_edge(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
)
assert resolved_edge.uuid == mock_extracted_edge.uuid
assert invalidated_edges == []
dedupe_mock.assert_called_once()
get_contradictions_mock.assert_called_once()
@pytest.mark.asyncio
async def test_resolve_extracted_edge_with_invalidation(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
mock_extracted_edge.valid_at = valid_at
invalidation_candidate = EntityEdge(
source_node_uuid='source_uuid_4',
target_node_uuid='target_uuid_4',
name='invalidation_candidate',
group_id='group_1',
fact='Invalidation candidate fact',
episodes=['episode_4'],
created_at=datetime.now(timezone.utc),
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
invalid_at=None,
)
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
)
resolved_edge, invalidated_edges = await resolve_extracted_edge(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
)
assert resolved_edge.uuid == mock_extracted_edge.uuid
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == invalidation_candidate.uuid
assert invalidated_edges[0].invalid_at == valid_at
assert invalidated_edges[0].expired_at is not None
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])