Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929)

* Simplify deduplication process in bulk_utils by reusing canonical nodes.
* Update dedup_helpers to store duplicate pairs during resolution.
* Modify node_operations to append duplicate pairs when resolving nodes.
* Add tests to verify deduplication behavior and ensure correct state updates.
This commit is contained in:
Daniel Chalef 2025-09-25 14:32:03 -07:00
parent 1e56019027
commit 18550c84ac
5 changed files with 155 additions and 75 deletions

View file

@ -43,7 +43,7 @@ from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
@ -266,83 +266,66 @@ async def dedupe_nodes_bulk(
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
embedder = clients.embedder
min_score = 0.8
# generate embeddings
await semaphore_gather(
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
)
# Find similar results
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
for i, nodes_i in enumerate(extracted_nodes):
existing_nodes: list[EntityNode] = []
for j, nodes_j in enumerate(extracted_nodes):
if i == j:
continue
existing_nodes += nodes_j
candidates_i: list[EntityNode] = []
for node in nodes_i:
for existing_node in existing_nodes:
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
# This approach will cast a wider net than BM25, which is ideal for this use case
node_words = set(node.name.lower().split())
existing_node_words = set(existing_node.name.lower().split())
has_overlap = not node_words.isdisjoint(existing_node_words)
if has_overlap:
candidates_i.append(existing_node)
continue
# Check for semantic similarity even if there is no overlap
similarity = np.dot(
normalize_l2(node.name_embedding or []),
normalize_l2(existing_node.name_embedding or []),
)
if similarity >= min_score:
candidates_i.append(existing_node)
dedupe_tuples.append((nodes_i, candidates_i))
# Determine Node Resolutions
bulk_node_resolutions: list[
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
] = await semaphore_gather(
*[
resolve_extracted_nodes(
clients,
dedupe_tuple[0],
episode_tuples[i][0],
episode_tuples[i][1],
entity_types,
existing_nodes_override=dedupe_tuples[i][1],
)
for i, dedupe_tuple in enumerate(dedupe_tuples)
]
)
# Collect all duplicate pairs sorted by uuid
canonical_nodes: dict[str, EntityNode] = {}
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
per_episode_uuid_maps: list[dict[str, str]] = []
duplicate_pairs: list[tuple[str, str]] = []
for _, _, duplicates in bulk_node_resolutions:
for duplicate in duplicates:
n, m = duplicate
duplicate_pairs.append((n.uuid, m.uuid))
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
for nodes, (episode, previous_episodes) in zip(extracted_nodes, episode_tuples, strict=True):
existing_override = list(canonical_nodes.values()) if canonical_nodes else None
node_uuid_map: dict[str, EntityNode] = {
node.uuid: node for nodes in extracted_nodes for node in nodes
}
resolved_nodes, uuid_map, duplicates = await resolve_extracted_nodes(
clients,
nodes,
episode,
previous_episodes,
entity_types,
existing_nodes_override=existing_override,
)
per_episode_uuid_maps.append(uuid_map)
episode_resolutions.append((episode.uuid, resolved_nodes))
for node in resolved_nodes:
canonical_nodes[node.uuid] = node
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
uuid_chains: dict[str, str] = {}
for uuid_map in per_episode_uuid_maps:
uuid_chains.update(uuid_map)
for duplicate_uuid, canonical_uuid in duplicate_pairs:
uuid_chains[duplicate_uuid] = canonical_uuid
def _resolve_uuid(uuid: str) -> str:
seen: set[str] = set()
current = uuid
while True:
target = uuid_chains.get(current)
if target is None or target == current:
return current if target is None else target
if current in seen:
return current
seen.add(current)
current = target
compressed_map: dict[str, str] = {uuid: _resolve_uuid(uuid) for uuid in uuid_chains}
for canonical_uuid in canonical_nodes:
compressed_map.setdefault(canonical_uuid, canonical_uuid)
nodes_by_episode: dict[str, list[EntityNode]] = {}
for i, nodes in enumerate(extracted_nodes):
episode = episode_tuples[i][0]
for episode_uuid, resolved_nodes in episode_resolutions:
deduped_nodes: list[EntityNode] = []
seen: set[str] = set()
for node in resolved_nodes:
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
if canonical_uuid in seen:
continue
seen.add(canonical_uuid)
canonical_node = canonical_nodes.get(canonical_uuid, node)
deduped_nodes.append(canonical_node)
nodes_by_episode[episode.uuid] = [
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
]
nodes_by_episode[episode_uuid] = deduped_nodes
return nodes_by_episode, compressed_map

View file

@ -20,7 +20,7 @@ import math
import re
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from hashlib import blake2b
from typing import TYPE_CHECKING
@ -164,6 +164,7 @@ class DedupResolutionState:
resolved_nodes: list[EntityNode | None]
uuid_map: dict[str, str]
unresolved_indices: list[int]
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
@ -213,6 +214,8 @@ def _resolve_with_similarity(
match = existing_matches[0]
state.resolved_nodes[idx] = match
state.uuid_map[node.uuid] = match.uuid
if match.uuid != node.uuid:
state.duplicate_pairs.append((node, match))
continue
if len(existing_matches) > 1:
state.unresolved_indices.append(idx)
@ -236,6 +239,8 @@ def _resolve_with_similarity(
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
state.resolved_nodes[idx] = best_candidate
state.uuid_map[node.uuid] = best_candidate.uuid
if best_candidate.uuid != node.uuid:
state.duplicate_pairs.append((node, best_candidate))
continue
state.unresolved_indices.append(idx)

View file

@ -306,6 +306,8 @@ async def _resolve_with_llm(
state.resolved_nodes[original_index] = resolved_node
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
if resolved_node.uuid != extracted_node.uuid:
state.duplicate_pairs.append((extracted_node, resolved_node))
async def resolve_extracted_nodes(
@ -332,7 +334,6 @@ async def resolve_extracted_nodes(
uuid_map={},
unresolved_indices=[],
)
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
_resolve_with_similarity(extracted_nodes, indexes, state)
@ -359,7 +360,7 @@ async def resolve_extracted_nodes(
new_node_duplicates: list[
tuple[EntityNode, EntityNode]
] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
return (
[node for node in state.resolved_nodes if node is not None],

View file

@ -0,0 +1,87 @@
from collections import deque
from unittest.mock import MagicMock
import pytest
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
from graphiti_core.utils.datetime_utils import utc_now
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
return EpisodicNode(
name=f'episode-{uuid_suffix}',
group_id=group_id,
labels=[],
source=EpisodeType.message,
content='content',
source_description='test',
created_at=utc_now(),
valid_at=utc_now(),
)
def _make_clients() -> GraphitiClients:
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
return GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
ensure_ascii=False,
)
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('1')
episode_two = _make_episode('2')
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
canonical = extracted_one
call_queue = deque()
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
call_queue.append(existing_nodes_override)
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
assert existing_nodes_override is not None
assert existing_nodes_override[0] is canonical
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert len(call_queue) == 2
assert call_queue[0] is None
assert list(call_queue[1]) == [canonical]
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(extracted_two.uuid) == canonical.uuid

View file

@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state():
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.unresolved_indices == []
assert state.duplicate_pairs == [(extracted, candidate)]
def test_resolve_with_similarity_low_entropy_defers_resolution():
@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution():
assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
@pytest.mark.asyncio
@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert captured_context['existing_nodes'][0]['idx'] == 0
assert isinstance(captured_context['existing_nodes'], list)
assert state.duplicate_pairs == [(extracted, candidate)]