Refactor batch deduplication logic to enhance node resolution and track duplicate pairs (#929) (#936)

* Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929)

* Simplify deduplication process in bulk_utils by reusing canonical nodes.
* Update dedup_helpers to store duplicate pairs during resolution.
* Modify node_operations to append duplicate pairs when resolving nodes.
* Add tests to verify deduplication behavior and ensure correct state updates.

* reveret to concurrent dedup with fanout and then reconcilation

* add performance note for deduplication loop in bulk_utils

* enhance deduplication logic in bulk_utils to handle missing canonical nodes gracefully

* Update graphiti_core/utils/bulk_utils.py

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>

* refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution

* implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils

* document directed union-find lookup in bulk_utils for clarity

---------

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Chalef 2025-09-26 08:40:18 -07:00 committed by GitHub
parent 1e56019027
commit 9aee3174bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 371 additions and 63 deletions

View file

@ -43,8 +43,14 @@ from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query, get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query, get_episode_node_save_bulk_query,
) )
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupResolutionState,
_build_candidate_indexes,
_normalize_string_exact,
_resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.edge_operations import (
extract_edges, extract_edges,
resolve_extracted_edge, resolve_extracted_edge,
@ -63,6 +69,38 @@ logger = logging.getLogger(__name__)
CHUNK_SIZE = 10 CHUNK_SIZE = 10
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
"""Collapse alias -> canonical chains while preserving direction.
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
"""
parent: dict[str, str] = {}
def find(uuid: str) -> str:
"""Directed union-find lookup using iterative path compression."""
parent.setdefault(uuid, uuid)
root = uuid
while parent[root] != root:
root = parent[root]
while parent[uuid] != root:
next_uuid = parent[uuid]
parent[uuid] = root
uuid = next_uuid
return root
for source_uuid, target_uuid in pairs:
parent.setdefault(source_uuid, source_uuid)
parent.setdefault(target_uuid, target_uuid)
parent[find(source_uuid)] = find(target_uuid)
return {uuid: find(uuid) for uuid in parent}
class RawEpisode(BaseModel): class RawEpisode(BaseModel):
name: str name: str
uuid: str | None = Field(default=None) uuid: str | None = Field(default=None)
@ -266,83 +304,111 @@ async def dedupe_nodes_bulk(
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, type[BaseModel]] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]: ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
embedder = clients.embedder """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
min_score = 0.8
# generate embeddings 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
await semaphore_gather( reconciled against the live graph just like the non-batch flow.
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes] 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
) duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
can apply to edges and persistence.
"""
# Find similar results first_pass_results = await semaphore_gather(
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
for i, nodes_i in enumerate(extracted_nodes):
existing_nodes: list[EntityNode] = []
for j, nodes_j in enumerate(extracted_nodes):
if i == j:
continue
existing_nodes += nodes_j
candidates_i: list[EntityNode] = []
for node in nodes_i:
for existing_node in existing_nodes:
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
# This approach will cast a wider net than BM25, which is ideal for this use case
node_words = set(node.name.lower().split())
existing_node_words = set(existing_node.name.lower().split())
has_overlap = not node_words.isdisjoint(existing_node_words)
if has_overlap:
candidates_i.append(existing_node)
continue
# Check for semantic similarity even if there is no overlap
similarity = np.dot(
normalize_l2(node.name_embedding or []),
normalize_l2(existing_node.name_embedding or []),
)
if similarity >= min_score:
candidates_i.append(existing_node)
dedupe_tuples.append((nodes_i, candidates_i))
# Determine Node Resolutions
bulk_node_resolutions: list[
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
] = await semaphore_gather(
*[ *[
resolve_extracted_nodes( resolve_extracted_nodes(
clients, clients,
dedupe_tuple[0], nodes,
episode_tuples[i][0], episode_tuples[i][0],
episode_tuples[i][1], episode_tuples[i][1],
entity_types, entity_types,
existing_nodes_override=dedupe_tuples[i][1],
) )
for i, dedupe_tuple in enumerate(dedupe_tuples) for i, nodes in enumerate(extracted_nodes)
] ]
) )
# Collect all duplicate pairs sorted by uuid episode_resolutions: list[tuple[str, list[EntityNode]]] = []
per_episode_uuid_maps: list[dict[str, str]] = []
duplicate_pairs: list[tuple[str, str]] = [] duplicate_pairs: list[tuple[str, str]] = []
for _, _, duplicates in bulk_node_resolutions:
for duplicate in duplicates:
n, m = duplicate
duplicate_pairs.append((n.uuid, m.uuid))
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid) for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs) first_pass_results, episode_tuples, strict=True
):
episode_resolutions.append((episode.uuid, resolved_nodes))
per_episode_uuid_maps.append(uuid_map)
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
node_uuid_map: dict[str, EntityNode] = { canonical_nodes: dict[str, EntityNode] = {}
node.uuid: node for nodes in extracted_nodes for node in nodes for _, resolved_nodes in episode_resolutions:
} for node in resolved_nodes:
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
# but if batches grow significantly we should switch to an incremental index or chunked
# processing.
if not canonical_nodes:
canonical_nodes[node.uuid] = node
continue
existing_candidates = list(canonical_nodes.values())
normalized = _normalize_string_exact(node.name)
exact_match = next(
(
candidate
for candidate in existing_candidates
if _normalize_string_exact(candidate.name) == normalized
),
None,
)
if exact_match is not None:
if exact_match.uuid != node.uuid:
duplicate_pairs.append((node.uuid, exact_match.uuid))
continue
indexes = _build_candidate_indexes(existing_candidates)
state = DedupResolutionState(
resolved_nodes=[None],
uuid_map={},
unresolved_indices=[],
)
_resolve_with_similarity([node], indexes, state)
resolved = state.resolved_nodes[0]
if resolved is None:
canonical_nodes[node.uuid] = node
continue
canonical_uuid = resolved.uuid
canonical_nodes.setdefault(canonical_uuid, resolved)
if canonical_uuid != node.uuid:
duplicate_pairs.append((node.uuid, canonical_uuid))
union_pairs: list[tuple[str, str]] = []
for uuid_map in per_episode_uuid_maps:
union_pairs.extend(uuid_map.items())
union_pairs.extend(duplicate_pairs)
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
nodes_by_episode: dict[str, list[EntityNode]] = {} nodes_by_episode: dict[str, list[EntityNode]] = {}
for i, nodes in enumerate(extracted_nodes): for episode_uuid, resolved_nodes in episode_resolutions:
episode = episode_tuples[i][0] deduped_nodes: list[EntityNode] = []
seen: set[str] = set()
for node in resolved_nodes:
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
if canonical_uuid in seen:
continue
seen.add(canonical_uuid)
canonical_node = canonical_nodes.get(canonical_uuid)
if canonical_node is None:
logger.error(
'Canonical node %s missing during batch dedupe; falling back to %s',
canonical_uuid,
node.uuid,
)
canonical_node = node
deduped_nodes.append(canonical_node)
nodes_by_episode[episode.uuid] = [ nodes_by_episode[episode_uuid] = deduped_nodes
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
]
return nodes_by_episode, compressed_map return nodes_by_episode, compressed_map

View file

@ -20,7 +20,7 @@ import math
import re import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from hashlib import blake2b from hashlib import blake2b
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -164,6 +164,7 @@ class DedupResolutionState:
resolved_nodes: list[EntityNode | None] resolved_nodes: list[EntityNode | None]
uuid_map: dict[str, str] uuid_map: dict[str, str]
unresolved_indices: list[int] unresolved_indices: list[int]
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes: def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
@ -213,6 +214,8 @@ def _resolve_with_similarity(
match = existing_matches[0] match = existing_matches[0]
state.resolved_nodes[idx] = match state.resolved_nodes[idx] = match
state.uuid_map[node.uuid] = match.uuid state.uuid_map[node.uuid] = match.uuid
if match.uuid != node.uuid:
state.duplicate_pairs.append((node, match))
continue continue
if len(existing_matches) > 1: if len(existing_matches) > 1:
state.unresolved_indices.append(idx) state.unresolved_indices.append(idx)
@ -236,6 +239,8 @@ def _resolve_with_similarity(
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD: if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
state.resolved_nodes[idx] = best_candidate state.resolved_nodes[idx] = best_candidate
state.uuid_map[node.uuid] = best_candidate.uuid state.uuid_map[node.uuid] = best_candidate.uuid
if best_candidate.uuid != node.uuid:
state.duplicate_pairs.append((node, best_candidate))
continue continue
state.unresolved_indices.append(idx) state.unresolved_indices.append(idx)

View file

@ -306,6 +306,8 @@ async def _resolve_with_llm(
state.resolved_nodes[original_index] = resolved_node state.resolved_nodes[original_index] = resolved_node
state.uuid_map[extracted_node.uuid] = resolved_node.uuid state.uuid_map[extracted_node.uuid] = resolved_node.uuid
if resolved_node.uuid != extracted_node.uuid:
state.duplicate_pairs.append((extracted_node, resolved_node))
async def resolve_extracted_nodes( async def resolve_extracted_nodes(
@ -332,7 +334,6 @@ async def resolve_extracted_nodes(
uuid_map={}, uuid_map={},
unresolved_indices=[], unresolved_indices=[],
) )
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
_resolve_with_similarity(extracted_nodes, indexes, state) _resolve_with_similarity(extracted_nodes, indexes, state)
@ -359,7 +360,7 @@ async def resolve_extracted_nodes(
new_node_duplicates: list[ new_node_duplicates: list[
tuple[EntityNode, EntityNode] tuple[EntityNode, EntityNode]
] = await filter_existing_duplicate_of_edges(driver, node_duplicates) ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
return ( return (
[node for node in state.resolved_nodes if node is not None], [node for node in state.resolved_nodes if node is not None],

View file

@ -0,0 +1,232 @@
from collections import deque
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
from graphiti_core.utils.datetime_utils import utc_now
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
return EpisodicNode(
name=f'episode-{uuid_suffix}',
group_id=group_id,
labels=[],
source=EpisodeType.message,
content='content',
source_description='test',
created_at=utc_now(),
valid_at=utc_now(),
)
def _make_clients() -> GraphitiClients:
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
return GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
ensure_ascii=False,
)
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('1')
episode_two = _make_episode('2')
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
canonical = extracted_one
call_queue = deque()
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
call_queue.append(existing_nodes_override)
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
assert existing_nodes_override is None
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert len(call_queue) == 2
assert call_queue[0] is None
assert call_queue[1] is None
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
clients = _make_clients()
resolve_mock = AsyncMock()
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[],
[],
)
assert nodes_by_episode == {}
assert compressed_map == {}
resolve_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
clients = _make_clients()
episode = _make_episode('solo')
extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted]],
[(episode, [])],
)
assert nodes_by_episode == {episode.uuid: [extracted]}
assert compressed_map == {extracted.uuid: extracted.uuid}
resolve_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('one')
episode_two = _make_episode('two')
extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
canonical = extracted_one
alias = extracted_two
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(alias.uuid) == canonical.uuid
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
clients = _make_clients()
episode = _make_episode('missing')
extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
with caplog.at_level('WARNING'):
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted]],
[(episode, [])],
)
assert nodes_by_episode[episode.uuid] == [extracted]
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
assert any('Canonical node missing' in rec.message for rec in caplog.records)
def test_build_directed_uuid_map_empty():
assert bulk_utils._build_directed_uuid_map([]) == {}
def test_build_directed_uuid_map_chain():
mapping = bulk_utils._build_directed_uuid_map(
[
('a', 'b'),
('b', 'c'),
]
)
assert mapping['a'] == 'c'
assert mapping['b'] == 'c'
assert mapping['c'] == 'c'
def test_build_directed_uuid_map_preserves_direction():
mapping = bulk_utils._build_directed_uuid_map(
[
('alias', 'canonical'),
]
)
assert mapping['alias'] == 'canonical'
assert mapping['canonical'] == 'canonical'
def test_resolve_edge_pointers_updates_sources():
created_at = utc_now()
edge = EntityEdge(
name='knows',
fact='fact',
group_id='group',
source_node_uuid='alias',
target_node_uuid='target',
created_at=created_at,
)
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
assert edge.source_node_uuid == 'canonical'
assert edge.target_node_uuid == 'target'

View file

@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state():
assert state.resolved_nodes[0].uuid == candidate.uuid assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.unresolved_indices == [] assert state.unresolved_indices == []
assert state.duplicate_pairs == [(extracted, candidate)]
def test_resolve_with_similarity_low_entropy_defers_resolution(): def test_resolve_with_similarity_low_entropy_defers_resolution():
@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution():
assert state.resolved_nodes[0] is None assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0] assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm(): def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
assert state.resolved_nodes[0] is None assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0] assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
@pytest.mark.asyncio @pytest.mark.asyncio
@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
assert state.uuid_map[extracted.uuid] == candidate.uuid assert state.uuid_map[extracted.uuid] == candidate.uuid
assert captured_context['existing_nodes'][0]['idx'] == 0 assert captured_context['existing_nodes'][0]['idx'] == 0
assert isinstance(captured_context['existing_nodes'], list) assert isinstance(captured_context['existing_nodes'], list)
assert state.duplicate_pairs == [(extracted, candidate)]