Node dedupe efficiency (#490)

* update resolve extracted edge

* updated edge resolution

* dedupe nodes update

* single pass node resolution

* updates

* mypy updates

* Update graphiti_core/prompts/dedupe_nodes.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* remove unused imports

* mypy

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-05-15 13:56:33 -04:00 committed by GitHub
parent f096c8770c
commit 9422b6f5fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 236 additions and 141 deletions

View file

@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import (
from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import ( from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT, RELEVANT_SCHEMA_LIMIT,
get_edge_invalidation_candidates,
get_mentioned_nodes, get_mentioned_nodes,
get_relevant_edges, get_relevant_edges,
) )
@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import (
) )
from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.edge_operations import (
build_episodic_edges, build_episodic_edges,
dedupe_extracted_edge,
extract_edges, extract_edges,
resolve_edge_contradictions, resolve_extracted_edge,
resolve_extracted_edges, resolve_extracted_edges,
) )
from graphiti_core.utils.maintenance.graph_data_operations import ( from graphiti_core.utils.maintenance.graph_data_operations import (
@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes, extract_nodes,
resolve_extracted_nodes, resolve_extracted_nodes,
) )
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -681,17 +680,15 @@ class Graphiti:
updated_edge = resolve_edge_pointers([edge], uuid_map)[0] updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8) related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
existing_edges = (
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
)[0]
resolved_edge = await dedupe_extracted_edge( resolved_edge, invalidated_edges = await resolve_extracted_edge(
self.llm_client, self.llm_client, updated_edge, related_edges, existing_edges
updated_edge,
related_edges[0],
) )
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0])
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
await add_nodes_and_edges_bulk( await add_nodes_and_edges_bulk(
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
) )

View file

@ -87,8 +87,8 @@ def normalize_l2(embedding: list[float]) -> NDArray:
# Use this instead of asyncio.gather() to bound coroutines # Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather( async def semaphore_gather(
*coroutines: Coroutine, *coroutines: Coroutine,
max_coroutines: int = SEMAPHORE_LIMIT, max_coroutines: int = SEMAPHORE_LIMIT,
): ):
semaphore = asyncio.Semaphore(max_coroutines) semaphore = asyncio.Semaphore(max_coroutines)

View file

@ -27,6 +27,10 @@ class EdgeDuplicate(BaseModel):
..., ...,
description='id of the duplicate fact. If no duplicate facts are found, default to -1.', description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
) )
contradicted_facts: list[int] = Field(
...,
description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
)
class UniqueFact(BaseModel): class UniqueFact(BaseModel):
@ -41,11 +45,13 @@ class UniqueFacts(BaseModel):
class Prompt(Protocol): class Prompt(Protocol):
edge: PromptVersion edge: PromptVersion
edge_list: PromptVersion edge_list: PromptVersion
resolve_edge: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
edge: PromptFunction edge: PromptFunction
edge_list: PromptFunction edge_list: PromptFunction
resolve_edge: PromptFunction
def edge(context: dict[str, Any]) -> list[Message]: def edge(context: dict[str, Any]) -> list[Message]:
@ -106,4 +112,41 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
] ]
versions: Versions = {'edge': edge, 'edge_list': edge_list} def resolve_edge(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
'facts are contradicted by the new fact.',
),
Message(
role='user',
content=f"""
<NEW FACT>
{context['new_edge']}
</NEW FACT>
<EXISTING FACTS>
{context['existing_edges']}
</EXISTING FACTS>
<FACT INVALIDATION CANDIDATES>
{context['edge_invalidation_candidates']}
</FACT INVALIDATION CANDIDATES>
Task:
If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact.
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1.
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
If there are no contradicted facts, return an empty list.
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
""",
),
]
versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}

View file

@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion
class NodeDuplicate(BaseModel): class NodeDuplicate(BaseModel):
duplicate_node_id: int = Field( id: int = Field(..., description='integer id of the entity')
duplicate_idx: int = Field(
..., ...,
description='id of the duplicate node. If no duplicate nodes are found, default to -1.', description='idx of the duplicate node. If no duplicate nodes are found, default to -1.',
) )
name: str = Field(..., description='Name of the entity.') name: str = Field(
...,
description='Name of the entity. Should be the most complete and descriptive name possible.',
)
class NodeResolutions(BaseModel):
entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
class Prompt(Protocol): class Prompt(Protocol):
node: PromptVersion node: PromptVersion
node_list: PromptVersion node_list: PromptVersion
nodes: PromptVersion
class Versions(TypedDict): class Versions(TypedDict):
node: PromptFunction node: PromptFunction
node_list: PromptFunction node_list: PromptFunction
nodes: PromptFunction
def node(context: dict[str, Any]) -> list[Message]: def node(context: dict[str, Any]) -> list[Message]:
@ -89,6 +99,67 @@ def node(context: dict[str, Any]) -> list[Message]:
] ]
def nodes(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
'of existing entities.',
),
Message(
role='user',
content=f"""
<PREVIOUS MESSAGES>
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>
Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
Each entity in ENTITIES is represented as a JSON object with the following structure:
{{
id: integer id of the entity,
name: "name of the entity",
entity_type: "ontological classification of the entity",
entity_type_description: "Description of what the entity type represents",
duplication_candidates: [
{{
idx: integer index of the candidate entity,
name: "name of the candidate entity",
entity_type: "ontological classification of the candidate entity",
...<additional attributes>
}}
]
}}
<ENTITIES>
{json.dumps(context['extracted_nodes'], indent=2)}
</ENTITIES>
For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
Do NOT mark entities as duplicates if:
- They are related but distinct.
- They have similar names or purposes but refer to separate instances or concepts.
Task:
Your response will be a list called entity_resolutions which contains one entry for each entity.
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
as an integer.
- If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
duplicate of.
- If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
""",
),
]
def node_list(context: dict[str, Any]) -> list[Message]: def node_list(context: dict[str, Any]) -> list[Message]:
return [ return [
Message( Message(
@ -126,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]:
] ]
versions: Versions = {'node': node, 'node_list': node_list} versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}

View file

@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion
class InvalidatedEdges(BaseModel): class InvalidatedEdges(BaseModel):
contradicted_facts: list[int] = Field( contradicted_facts: list[int] = Field(
..., ...,
description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.', description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
) )

View file

@ -35,9 +35,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
from graphiti_core.utils.maintenance.temporal_operations import (
get_edge_contradictions,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -245,7 +242,7 @@ async def resolve_extracted_edges(
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather( search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
get_relevant_edges(driver, extracted_edges, SearchFilters()), get_relevant_edges(driver, extracted_edges, SearchFilters()),
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()), get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
) )
related_edges_lists, edge_invalidation_candidates = search_results related_edges_lists, edge_invalidation_candidates = search_results
@ -325,11 +322,52 @@ async def resolve_extracted_edge(
extracted_edge: EntityEdge, extracted_edge: EntityEdge,
related_edges: list[EntityEdge], related_edges: list[EntityEdge],
existing_edges: list[EntityEdge], existing_edges: list[EntityEdge],
episode: EpisodicNode, episode: EpisodicNode | None = None,
) -> tuple[EntityEdge, list[EntityEdge]]: ) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, invalidation_candidates = await semaphore_gather( if len(related_edges) == 0 and len(existing_edges) == 0:
dedupe_extracted_edge(llm_client, extracted_edge, related_edges, episode), return extracted_edge, []
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
start = time()
# Prepare context for LLM
related_edges_context = [
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
]
invalidation_edge_candidates_context = [
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
]
context = {
'existing_edges': related_edges_context,
'new_edge': extracted_edge.fact,
'edge_invalidation_candidates': invalidation_edge_candidates_context,
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.resolve_edge(context),
response_model=EdgeDuplicate,
model_size=ModelSize.small,
)
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
resolved_edge = (
related_edges[duplicate_fact_id]
if 0 <= duplicate_fact_id < len(related_edges)
else extracted_edge
)
if duplicate_fact_id >= 0 and episode is not None:
resolved_edge.episodes.append(episode.uuid)
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
end = time()
logger.debug(
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
) )
now = utc_now() now = utc_now()

View file

@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import ( from graphiti_core.prompts.extract_nodes import (
ExtractedEntities, ExtractedEntities,
ExtractedEntity, ExtractedEntity,
@ -243,28 +243,65 @@ async def resolve_extracted_nodes(
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results] existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
resolved_nodes: list[EntityNode] = await semaphore_gather( entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
*[
resolve_extracted_node( # Prepare context for LLM
llm_client, extracted_nodes_context = [
extracted_node, {
existing_nodes, 'id': i,
episode, 'name': node.name,
previous_episodes, 'entity_type': node.labels,
entity_types.get( 'entity_type_description': entity_types_dict.get(
next((item for item in extracted_node.labels if item != 'Entity'), '') next((item for item in node.labels if item != 'Entity'), '')
) ).__doc__
if entity_types is not None or 'Default Entity Type',
else None, 'duplication_candidates': [
) {
for extracted_node, existing_nodes in zip( **{
extracted_nodes, existing_nodes_lists, strict=True 'idx': j,
) 'name': candidate.name,
] 'entity_types': candidate.labels,
},
**candidate.attributes,
}
for j, candidate in enumerate(existing_nodes_lists[i])
],
}
for i, node in enumerate(extracted_nodes)
]
context = {
'extracted_nodes': extracted_nodes_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.nodes(context),
response_model=NodeResolutions,
) )
node_resolutions: list = llm_response.get('entity_resolutions', [])
resolved_nodes: list[EntityNode] = []
uuid_map: dict[str, str] = {} uuid_map: dict[str, str] = {}
for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True): for resolution in node_resolutions:
resolution_id = resolution.get('id', -1)
duplicate_idx = resolution.get('duplicate_idx', -1)
extracted_node = extracted_nodes[resolution_id]
resolved_node = (
existing_nodes_lists[resolution_id][duplicate_idx]
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
else extracted_node
)
resolved_node.name = resolution.get('name')
resolved_nodes.append(resolved_node)
uuid_map[extracted_node.uuid] = resolved_node.uuid uuid_map[extracted_node.uuid] = resolved_node.uuid
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
@ -410,6 +447,7 @@ async def extract_attributes_from_node(
llm_response = await llm_client.generate_response( llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(summary_context), prompt_library.extract_nodes.extract_attributes(summary_context),
response_model=entity_attributes_model, response_model=entity_attributes_model,
model_size=ModelSize.small,
) )
node.summary = llm_response.get('summary', node.summary) node.summary = llm_response.get('summary', node.summary)

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.11.6pre9" version = "0.11.6"
authors = [ authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },

View file

@ -1,12 +1,10 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock from unittest.mock import MagicMock
import pytest import pytest
from pytest import MonkeyPatch
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodicNode from graphiti_core.nodes import EpisodicNode
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
@pytest.fixture @pytest.fixture
@ -91,96 +89,6 @@ def mock_previous_episodes():
] ]
@pytest.mark.asyncio
async def test_resolve_extracted_edge_no_changes(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
get_contradictions_mock = AsyncMock(return_value=[])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
)
resolved_edge, invalidated_edges = await resolve_extracted_edge(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
)
assert resolved_edge.uuid == mock_extracted_edge.uuid
assert invalidated_edges == []
dedupe_mock.assert_called_once()
get_contradictions_mock.assert_called_once()
@pytest.mark.asyncio
async def test_resolve_extracted_edge_with_invalidation(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
mock_extracted_edge.valid_at = valid_at
invalidation_candidate = EntityEdge(
source_node_uuid='source_uuid_4',
target_node_uuid='target_uuid_4',
name='invalidation_candidate',
group_id='group_1',
fact='Invalidation candidate fact',
episodes=['episode_4'],
created_at=datetime.now(timezone.utc),
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
invalid_at=None,
)
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
)
resolved_edge, invalidated_edges = await resolve_extracted_edge(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
)
assert resolved_edge.uuid == mock_extracted_edge.uuid
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == invalidation_candidate.uuid
assert invalidated_edges[0].invalid_at == valid_at
assert invalidated_edges[0].expired_at is not None
# Run the tests # Run the tests
if __name__ == '__main__': if __name__ == '__main__':
pytest.main([__file__]) pytest.main([__file__])