Refactor batch deduplication logic to enhance node resolution and track duplicate pairs (#929) (#936)

* Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929)

* Simplify deduplication process in bulk_utils by reusing canonical nodes.
* Update dedup_helpers to store duplicate pairs during resolution.
* Modify node_operations to append duplicate pairs when resolving nodes.
* Add tests to verify deduplication behavior and ensure correct state updates.

* reveret to concurrent dedup with fanout and then reconcilation

* add performance note for deduplication loop in bulk_utils

* enhance deduplication logic in bulk_utils to handle missing canonical nodes gracefully

* Update graphiti_core/utils/bulk_utils.py

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>

* refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution

* implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils

* document directed union-find lookup in bulk_utils for clarity

---------

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Chalef 2025-09-26 08:40:18 -07:00 committed by GitHub
parent 1e56019027
commit 9aee3174bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 371 additions and 63 deletions

View file

@ -43,8 +43,14 @@ from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupResolutionState,
_build_candidate_indexes,
_normalize_string_exact,
_resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
resolve_extracted_edge,
@ -63,6 +69,38 @@ logger = logging.getLogger(__name__)
CHUNK_SIZE = 10
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
"""Collapse alias -> canonical chains while preserving direction.
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
"""
parent: dict[str, str] = {}
def find(uuid: str) -> str:
"""Directed union-find lookup using iterative path compression."""
parent.setdefault(uuid, uuid)
root = uuid
while parent[root] != root:
root = parent[root]
while parent[uuid] != root:
next_uuid = parent[uuid]
parent[uuid] = root
uuid = next_uuid
return root
for source_uuid, target_uuid in pairs:
parent.setdefault(source_uuid, source_uuid)
parent.setdefault(target_uuid, target_uuid)
parent[find(source_uuid)] = find(target_uuid)
return {uuid: find(uuid) for uuid in parent}
class RawEpisode(BaseModel):
name: str
uuid: str | None = Field(default=None)
@ -266,83 +304,111 @@ async def dedupe_nodes_bulk(
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
embedder = clients.embedder
min_score = 0.8
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
# generate embeddings
await semaphore_gather(
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
)
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
reconciled against the live graph just like the non-batch flow.
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
can apply to edges and persistence.
"""
# Find similar results
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
for i, nodes_i in enumerate(extracted_nodes):
existing_nodes: list[EntityNode] = []
for j, nodes_j in enumerate(extracted_nodes):
if i == j:
continue
existing_nodes += nodes_j
candidates_i: list[EntityNode] = []
for node in nodes_i:
for existing_node in existing_nodes:
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
# This approach will cast a wider net than BM25, which is ideal for this use case
node_words = set(node.name.lower().split())
existing_node_words = set(existing_node.name.lower().split())
has_overlap = not node_words.isdisjoint(existing_node_words)
if has_overlap:
candidates_i.append(existing_node)
continue
# Check for semantic similarity even if there is no overlap
similarity = np.dot(
normalize_l2(node.name_embedding or []),
normalize_l2(existing_node.name_embedding or []),
)
if similarity >= min_score:
candidates_i.append(existing_node)
dedupe_tuples.append((nodes_i, candidates_i))
# Determine Node Resolutions
bulk_node_resolutions: list[
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
] = await semaphore_gather(
first_pass_results = await semaphore_gather(
*[
resolve_extracted_nodes(
clients,
dedupe_tuple[0],
nodes,
episode_tuples[i][0],
episode_tuples[i][1],
entity_types,
existing_nodes_override=dedupe_tuples[i][1],
)
for i, dedupe_tuple in enumerate(dedupe_tuples)
for i, nodes in enumerate(extracted_nodes)
]
)
# Collect all duplicate pairs sorted by uuid
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
per_episode_uuid_maps: list[dict[str, str]] = []
duplicate_pairs: list[tuple[str, str]] = []
for _, _, duplicates in bulk_node_resolutions:
for duplicate in duplicates:
n, m = duplicate
duplicate_pairs.append((n.uuid, m.uuid))
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
first_pass_results, episode_tuples, strict=True
):
episode_resolutions.append((episode.uuid, resolved_nodes))
per_episode_uuid_maps.append(uuid_map)
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
node_uuid_map: dict[str, EntityNode] = {
node.uuid: node for nodes in extracted_nodes for node in nodes
}
canonical_nodes: dict[str, EntityNode] = {}
for _, resolved_nodes in episode_resolutions:
for node in resolved_nodes:
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
# but if batches grow significantly we should switch to an incremental index or chunked
# processing.
if not canonical_nodes:
canonical_nodes[node.uuid] = node
continue
existing_candidates = list(canonical_nodes.values())
normalized = _normalize_string_exact(node.name)
exact_match = next(
(
candidate
for candidate in existing_candidates
if _normalize_string_exact(candidate.name) == normalized
),
None,
)
if exact_match is not None:
if exact_match.uuid != node.uuid:
duplicate_pairs.append((node.uuid, exact_match.uuid))
continue
indexes = _build_candidate_indexes(existing_candidates)
state = DedupResolutionState(
resolved_nodes=[None],
uuid_map={},
unresolved_indices=[],
)
_resolve_with_similarity([node], indexes, state)
resolved = state.resolved_nodes[0]
if resolved is None:
canonical_nodes[node.uuid] = node
continue
canonical_uuid = resolved.uuid
canonical_nodes.setdefault(canonical_uuid, resolved)
if canonical_uuid != node.uuid:
duplicate_pairs.append((node.uuid, canonical_uuid))
union_pairs: list[tuple[str, str]] = []
for uuid_map in per_episode_uuid_maps:
union_pairs.extend(uuid_map.items())
union_pairs.extend(duplicate_pairs)
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
nodes_by_episode: dict[str, list[EntityNode]] = {}
for i, nodes in enumerate(extracted_nodes):
episode = episode_tuples[i][0]
for episode_uuid, resolved_nodes in episode_resolutions:
deduped_nodes: list[EntityNode] = []
seen: set[str] = set()
for node in resolved_nodes:
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
if canonical_uuid in seen:
continue
seen.add(canonical_uuid)
canonical_node = canonical_nodes.get(canonical_uuid)
if canonical_node is None:
logger.error(
'Canonical node %s missing during batch dedupe; falling back to %s',
canonical_uuid,
node.uuid,
)
canonical_node = node
deduped_nodes.append(canonical_node)
nodes_by_episode[episode.uuid] = [
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
]
nodes_by_episode[episode_uuid] = deduped_nodes
return nodes_by_episode, compressed_map

View file

@ -20,7 +20,7 @@ import math
import re
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from hashlib import blake2b
from typing import TYPE_CHECKING
@ -164,6 +164,7 @@ class DedupResolutionState:
resolved_nodes: list[EntityNode | None]
uuid_map: dict[str, str]
unresolved_indices: list[int]
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
@ -213,6 +214,8 @@ def _resolve_with_similarity(
match = existing_matches[0]
state.resolved_nodes[idx] = match
state.uuid_map[node.uuid] = match.uuid
if match.uuid != node.uuid:
state.duplicate_pairs.append((node, match))
continue
if len(existing_matches) > 1:
state.unresolved_indices.append(idx)
@ -236,6 +239,8 @@ def _resolve_with_similarity(
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
state.resolved_nodes[idx] = best_candidate
state.uuid_map[node.uuid] = best_candidate.uuid
if best_candidate.uuid != node.uuid:
state.duplicate_pairs.append((node, best_candidate))
continue
state.unresolved_indices.append(idx)

View file

@ -306,6 +306,8 @@ async def _resolve_with_llm(
state.resolved_nodes[original_index] = resolved_node
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
if resolved_node.uuid != extracted_node.uuid:
state.duplicate_pairs.append((extracted_node, resolved_node))
async def resolve_extracted_nodes(
@ -332,7 +334,6 @@ async def resolve_extracted_nodes(
uuid_map={},
unresolved_indices=[],
)
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
_resolve_with_similarity(extracted_nodes, indexes, state)
@ -359,7 +360,7 @@ async def resolve_extracted_nodes(
new_node_duplicates: list[
tuple[EntityNode, EntityNode]
] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
return (
[node for node in state.resolved_nodes if node is not None],

View file

@ -0,0 +1,232 @@
from collections import deque
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
from graphiti_core.utils.datetime_utils import utc_now
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
return EpisodicNode(
name=f'episode-{uuid_suffix}',
group_id=group_id,
labels=[],
source=EpisodeType.message,
content='content',
source_description='test',
created_at=utc_now(),
valid_at=utc_now(),
)
def _make_clients() -> GraphitiClients:
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
return GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
ensure_ascii=False,
)
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('1')
episode_two = _make_episode('2')
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
canonical = extracted_one
call_queue = deque()
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
call_queue.append(existing_nodes_override)
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
assert existing_nodes_override is None
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert len(call_queue) == 2
assert call_queue[0] is None
assert call_queue[1] is None
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
clients = _make_clients()
resolve_mock = AsyncMock()
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[],
[],
)
assert nodes_by_episode == {}
assert compressed_map == {}
resolve_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
clients = _make_clients()
episode = _make_episode('solo')
extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted]],
[(episode, [])],
)
assert nodes_by_episode == {episode.uuid: [extracted]}
assert compressed_map == {extracted.uuid: extracted.uuid}
resolve_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('one')
episode_two = _make_episode('two')
extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
canonical = extracted_one
alias = extracted_two
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(alias.uuid) == canonical.uuid
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
clients = _make_clients()
episode = _make_episode('missing')
extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
with caplog.at_level('WARNING'):
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted]],
[(episode, [])],
)
assert nodes_by_episode[episode.uuid] == [extracted]
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
assert any('Canonical node missing' in rec.message for rec in caplog.records)
def test_build_directed_uuid_map_empty():
assert bulk_utils._build_directed_uuid_map([]) == {}
def test_build_directed_uuid_map_chain():
mapping = bulk_utils._build_directed_uuid_map(
[
('a', 'b'),
('b', 'c'),
]
)
assert mapping['a'] == 'c'
assert mapping['b'] == 'c'
assert mapping['c'] == 'c'
def test_build_directed_uuid_map_preserves_direction():
mapping = bulk_utils._build_directed_uuid_map(
[
('alias', 'canonical'),
]
)
assert mapping['alias'] == 'canonical'
assert mapping['canonical'] == 'canonical'
def test_resolve_edge_pointers_updates_sources():
created_at = utc_now()
edge = EntityEdge(
name='knows',
fact='fact',
group_id='group',
source_node_uuid='alias',
target_node_uuid='target',
created_at=created_at,
)
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
assert edge.source_node_uuid == 'canonical'
assert edge.target_node_uuid == 'target'

View file

@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state():
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.unresolved_indices == []
assert state.duplicate_pairs == [(extracted, candidate)]
def test_resolve_with_similarity_low_entropy_defers_resolution():
@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution():
assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
@pytest.mark.asyncio
@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert captured_context['existing_nodes'][0]['idx'] == 0
assert isinstance(captured_context['existing_nodes'], list)
assert state.duplicate_pairs == [(extracted, candidate)]