From 9aee3174bd6f9fcab0ab252a6149e0359704c84e Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Fri, 26 Sep 2025 08:40:18 -0700 Subject: [PATCH] Refactor batch deduplication logic to enhance node resolution and track duplicate pairs (#929) (#936) * Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929) * Simplify deduplication process in bulk_utils by reusing canonical nodes. * Update dedup_helpers to store duplicate pairs during resolution. * Modify node_operations to append duplicate pairs when resolving nodes. * Add tests to verify deduplication behavior and ensure correct state updates. * reveret to concurrent dedup with fanout and then reconcilation * add performance note for deduplication loop in bulk_utils * enhance deduplication logic in bulk_utils to handle missing canonical nodes gracefully * Update graphiti_core/utils/bulk_utils.py Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> * refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution * implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils * document directed union-find lookup in bulk_utils for clarity --------- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- graphiti_core/utils/bulk_utils.py | 186 +++++++++----- .../utils/maintenance/dedup_helpers.py | 7 +- .../utils/maintenance/node_operations.py | 5 +- tests/utils/maintenance/test_bulk_utils.py | 232 ++++++++++++++++++ .../utils/maintenance/test_node_operations.py | 4 + 5 files changed, 371 insertions(+), 63 deletions(-) create mode 100644 tests/utils/maintenance/test_bulk_utils.py diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 181cda74..66a15bcb 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -43,8 +43,14 @@ from graphiti_core.models.nodes.node_db_queries import ( get_entity_node_save_bulk_query, get_episode_node_save_bulk_query, ) -from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings +from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings +from graphiti_core.utils.maintenance.dedup_helpers import ( + DedupResolutionState, + _build_candidate_indexes, + _normalize_string_exact, + _resolve_with_similarity, +) from graphiti_core.utils.maintenance.edge_operations import ( extract_edges, resolve_extracted_edge, @@ -63,6 +69,38 @@ logger = logging.getLogger(__name__) CHUNK_SIZE = 10 +def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]: + """Collapse alias -> canonical chains while preserving direction. + + The incoming pairs represent directed mappings discovered during node dedupe. We use a simple + union-find with iterative path compression to ensure every source UUID resolves to its ultimate + canonical target, even if aliases appear lexicographically smaller than the canonical UUID. + """ + + parent: dict[str, str] = {} + + def find(uuid: str) -> str: + """Directed union-find lookup using iterative path compression.""" + parent.setdefault(uuid, uuid) + root = uuid + while parent[root] != root: + root = parent[root] + + while parent[uuid] != root: + next_uuid = parent[uuid] + parent[uuid] = root + uuid = next_uuid + + return root + + for source_uuid, target_uuid in pairs: + parent.setdefault(source_uuid, source_uuid) + parent.setdefault(target_uuid, target_uuid) + parent[find(source_uuid)] = find(target_uuid) + + return {uuid: find(uuid) for uuid in parent} + + class RawEpisode(BaseModel): name: str uuid: str | None = Field(default=None) @@ -266,83 +304,111 @@ async def dedupe_nodes_bulk( episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], entity_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]: - embedder = clients.embedder - min_score = 0.8 + """Resolve entity duplicates across an in-memory batch using a two-pass strategy. - # generate embeddings - await semaphore_gather( - *[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes] - ) + 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is + reconciled against the live graph just like the non-batch flow. + 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch + duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers + can apply to edges and persistence. + """ - # Find similar results - dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = [] - for i, nodes_i in enumerate(extracted_nodes): - existing_nodes: list[EntityNode] = [] - for j, nodes_j in enumerate(extracted_nodes): - if i == j: - continue - existing_nodes += nodes_j - - candidates_i: list[EntityNode] = [] - for node in nodes_i: - for existing_node in existing_nodes: - # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices) - # This approach will cast a wider net than BM25, which is ideal for this use case - node_words = set(node.name.lower().split()) - existing_node_words = set(existing_node.name.lower().split()) - has_overlap = not node_words.isdisjoint(existing_node_words) - if has_overlap: - candidates_i.append(existing_node) - continue - - # Check for semantic similarity even if there is no overlap - similarity = np.dot( - normalize_l2(node.name_embedding or []), - normalize_l2(existing_node.name_embedding or []), - ) - if similarity >= min_score: - candidates_i.append(existing_node) - - dedupe_tuples.append((nodes_i, candidates_i)) - - # Determine Node Resolutions - bulk_node_resolutions: list[ - tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]] - ] = await semaphore_gather( + first_pass_results = await semaphore_gather( *[ resolve_extracted_nodes( clients, - dedupe_tuple[0], + nodes, episode_tuples[i][0], episode_tuples[i][1], entity_types, - existing_nodes_override=dedupe_tuples[i][1], ) - for i, dedupe_tuple in enumerate(dedupe_tuples) + for i, nodes in enumerate(extracted_nodes) ] ) - # Collect all duplicate pairs sorted by uuid + episode_resolutions: list[tuple[str, list[EntityNode]]] = [] + per_episode_uuid_maps: list[dict[str, str]] = [] duplicate_pairs: list[tuple[str, str]] = [] - for _, _, duplicates in bulk_node_resolutions: - for duplicate in duplicates: - n, m = duplicate - duplicate_pairs.append((n.uuid, m.uuid)) - # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid) - compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs) + for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip( + first_pass_results, episode_tuples, strict=True + ): + episode_resolutions.append((episode.uuid, resolved_nodes)) + per_episode_uuid_maps.append(uuid_map) + duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates) - node_uuid_map: dict[str, EntityNode] = { - node.uuid: node for nodes in extracted_nodes for node in nodes - } + canonical_nodes: dict[str, EntityNode] = {} + for _, resolved_nodes in episode_resolutions: + for node in resolved_nodes: + # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild + # the MinHash index for the accumulated canonical pool each time. The LRU-backed + # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE), + # but if batches grow significantly we should switch to an incremental index or chunked + # processing. + if not canonical_nodes: + canonical_nodes[node.uuid] = node + continue + + existing_candidates = list(canonical_nodes.values()) + normalized = _normalize_string_exact(node.name) + exact_match = next( + ( + candidate + for candidate in existing_candidates + if _normalize_string_exact(candidate.name) == normalized + ), + None, + ) + if exact_match is not None: + if exact_match.uuid != node.uuid: + duplicate_pairs.append((node.uuid, exact_match.uuid)) + continue + + indexes = _build_candidate_indexes(existing_candidates) + state = DedupResolutionState( + resolved_nodes=[None], + uuid_map={}, + unresolved_indices=[], + ) + _resolve_with_similarity([node], indexes, state) + + resolved = state.resolved_nodes[0] + if resolved is None: + canonical_nodes[node.uuid] = node + continue + + canonical_uuid = resolved.uuid + canonical_nodes.setdefault(canonical_uuid, resolved) + if canonical_uuid != node.uuid: + duplicate_pairs.append((node.uuid, canonical_uuid)) + + union_pairs: list[tuple[str, str]] = [] + for uuid_map in per_episode_uuid_maps: + union_pairs.extend(uuid_map.items()) + union_pairs.extend(duplicate_pairs) + + compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs) nodes_by_episode: dict[str, list[EntityNode]] = {} - for i, nodes in enumerate(extracted_nodes): - episode = episode_tuples[i][0] + for episode_uuid, resolved_nodes in episode_resolutions: + deduped_nodes: list[EntityNode] = [] + seen: set[str] = set() + for node in resolved_nodes: + canonical_uuid = compressed_map.get(node.uuid, node.uuid) + if canonical_uuid in seen: + continue + seen.add(canonical_uuid) + canonical_node = canonical_nodes.get(canonical_uuid) + if canonical_node is None: + logger.error( + 'Canonical node %s missing during batch dedupe; falling back to %s', + canonical_uuid, + node.uuid, + ) + canonical_node = node + deduped_nodes.append(canonical_node) - nodes_by_episode[episode.uuid] = [ - node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes - ] + nodes_by_episode[episode_uuid] = deduped_nodes return nodes_by_episode, compressed_map diff --git a/graphiti_core/utils/maintenance/dedup_helpers.py b/graphiti_core/utils/maintenance/dedup_helpers.py index 4916331e..b8ce68b8 100644 --- a/graphiti_core/utils/maintenance/dedup_helpers.py +++ b/graphiti_core/utils/maintenance/dedup_helpers.py @@ -20,7 +20,7 @@ import math import re from collections import defaultdict from collections.abc import Iterable -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache from hashlib import blake2b from typing import TYPE_CHECKING @@ -164,6 +164,7 @@ class DedupResolutionState: resolved_nodes: list[EntityNode | None] uuid_map: dict[str, str] unresolved_indices: list[int] + duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list) def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes: @@ -213,6 +214,8 @@ def _resolve_with_similarity( match = existing_matches[0] state.resolved_nodes[idx] = match state.uuid_map[node.uuid] = match.uuid + if match.uuid != node.uuid: + state.duplicate_pairs.append((node, match)) continue if len(existing_matches) > 1: state.unresolved_indices.append(idx) @@ -236,6 +239,8 @@ def _resolve_with_similarity( if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD: state.resolved_nodes[idx] = best_candidate state.uuid_map[node.uuid] = best_candidate.uuid + if best_candidate.uuid != node.uuid: + state.duplicate_pairs.append((node, best_candidate)) continue state.unresolved_indices.append(idx) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 693609d8..26bc23d9 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -306,6 +306,8 @@ async def _resolve_with_llm( state.resolved_nodes[original_index] = resolved_node state.uuid_map[extracted_node.uuid] = resolved_node.uuid + if resolved_node.uuid != extracted_node.uuid: + state.duplicate_pairs.append((extracted_node, resolved_node)) async def resolve_extracted_nodes( @@ -332,7 +334,6 @@ async def resolve_extracted_nodes( uuid_map={}, unresolved_indices=[], ) - node_duplicates: list[tuple[EntityNode, EntityNode]] = [] _resolve_with_similarity(extracted_nodes, indexes, state) @@ -359,7 +360,7 @@ async def resolve_extracted_nodes( new_node_duplicates: list[ tuple[EntityNode, EntityNode] - ] = await filter_existing_duplicate_of_edges(driver, node_duplicates) + ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs) return ( [node for node in state.resolved_nodes if node is not None], diff --git a/tests/utils/maintenance/test_bulk_utils.py b/tests/utils/maintenance/test_bulk_utils.py new file mode 100644 index 00000000..9ba2c25d --- /dev/null +++ b/tests/utils/maintenance/test_bulk_utils.py @@ -0,0 +1,232 @@ +from collections import deque +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from graphiti_core.edges import EntityEdge +from graphiti_core.graphiti_types import GraphitiClients +from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.utils import bulk_utils +from graphiti_core.utils.datetime_utils import utc_now + + +def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode: + return EpisodicNode( + name=f'episode-{uuid_suffix}', + group_id=group_id, + labels=[], + source=EpisodeType.message, + content='content', + source_description='test', + created_at=utc_now(), + valid_at=utc_now(), + ) + + +def _make_clients() -> GraphitiClients: + driver = MagicMock() + embedder = MagicMock() + cross_encoder = MagicMock() + llm_client = MagicMock() + + return GraphitiClients.model_construct( # bypass validation to allow test doubles + driver=driver, + embedder=embedder, + cross_encoder=cross_encoder, + llm_client=llm_client, + ensure_ascii=False, + ) + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch): + clients = _make_clients() + + episode_one = _make_episode('1') + episode_two = _make_episode('2') + + extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity']) + extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity']) + + canonical = extracted_one + + call_queue = deque() + + async def fake_resolve( + clients_arg, + nodes_arg, + episode_arg, + previous_episodes_arg, + entity_types_arg, + existing_nodes_override=None, + ): + call_queue.append(existing_nodes_override) + + if nodes_arg == [extracted_one]: + return [canonical], {canonical.uuid: canonical.uuid}, [] + + assert nodes_arg == [extracted_two] + assert existing_nodes_override is None + + return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)] + + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve) + + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [[extracted_one], [extracted_two]], + [(episode_one, []), (episode_two, [])], + ) + + assert len(call_queue) == 2 + assert call_queue[0] is None + assert call_queue[1] is None + + assert nodes_by_episode[episode_one.uuid] == [canonical] + assert nodes_by_episode[episode_two.uuid] == [canonical] + assert compressed_map.get(extracted_two.uuid) == canonical.uuid + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch): + clients = _make_clients() + + resolve_mock = AsyncMock() + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) + + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [], + [], + ) + + assert nodes_by_episode == {} + assert compressed_map == {} + resolve_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_single_episode(monkeypatch): + clients = _make_clients() + + episode = _make_episode('solo') + extracted = EntityNode(name='Solo', group_id='group', labels=['Entity']) + + resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, [])) + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) + + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [[extracted]], + [(episode, [])], + ) + + assert nodes_by_episode == {episode.uuid: [extracted]} + assert compressed_map == {extracted.uuid: extracted.uuid} + resolve_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch): + clients = _make_clients() + + episode_one = _make_episode('one') + episode_two = _make_episode('two') + + extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity']) + extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity']) + + canonical = extracted_one + alias = extracted_two + + async def fake_resolve( + clients_arg, + nodes_arg, + episode_arg, + previous_episodes_arg, + entity_types_arg, + existing_nodes_override=None, + ): + if nodes_arg == [extracted_one]: + return [canonical], {canonical.uuid: canonical.uuid}, [] + assert nodes_arg == [extracted_two] + return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)] + + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve) + + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [[extracted_one], [extracted_two]], + [(episode_one, []), (episode_two, [])], + ) + + assert nodes_by_episode[episode_one.uuid] == [canonical] + assert nodes_by_episode[episode_two.uuid] == [canonical] + assert compressed_map.get(alias.uuid) == canonical.uuid + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog): + clients = _make_clients() + + episode = _make_episode('missing') + extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity']) + + resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, [])) + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) + + with caplog.at_level('WARNING'): + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [[extracted]], + [(episode, [])], + ) + + assert nodes_by_episode[episode.uuid] == [extracted] + assert compressed_map.get(extracted.uuid) == 'missing-canonical' + assert any('Canonical node missing' in rec.message for rec in caplog.records) + + +def test_build_directed_uuid_map_empty(): + assert bulk_utils._build_directed_uuid_map([]) == {} + + +def test_build_directed_uuid_map_chain(): + mapping = bulk_utils._build_directed_uuid_map( + [ + ('a', 'b'), + ('b', 'c'), + ] + ) + + assert mapping['a'] == 'c' + assert mapping['b'] == 'c' + assert mapping['c'] == 'c' + + +def test_build_directed_uuid_map_preserves_direction(): + mapping = bulk_utils._build_directed_uuid_map( + [ + ('alias', 'canonical'), + ] + ) + + assert mapping['alias'] == 'canonical' + assert mapping['canonical'] == 'canonical' + + +def test_resolve_edge_pointers_updates_sources(): + created_at = utc_now() + edge = EntityEdge( + name='knows', + fact='fact', + group_id='group', + source_node_uuid='alias', + target_node_uuid='target', + created_at=created_at, + ) + + bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'}) + + assert edge.source_node_uuid == 'canonical' + assert edge.target_node_uuid == 'target' diff --git a/tests/utils/maintenance/test_node_operations.py b/tests/utils/maintenance/test_node_operations.py index a7250559..b2fc30b0 100644 --- a/tests/utils/maintenance/test_node_operations.py +++ b/tests/utils/maintenance/test_node_operations.py @@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state(): assert state.resolved_nodes[0].uuid == candidate.uuid assert state.uuid_map[extracted.uuid] == candidate.uuid assert state.unresolved_indices == [] + assert state.duplicate_pairs == [(extracted, candidate)] def test_resolve_with_similarity_low_entropy_defers_resolution(): @@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution(): assert state.resolved_nodes[0] is None assert state.unresolved_indices == [0] + assert state.duplicate_pairs == [] def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm(): @@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm(): assert state.resolved_nodes[0] is None assert state.unresolved_indices == [0] + assert state.duplicate_pairs == [] @pytest.mark.asyncio @@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch): assert state.uuid_map[extracted.uuid] == candidate.uuid assert captured_context['existing_nodes'][0]['idx'] == 0 assert isinstance(captured_context['existing_nodes'], list) + assert state.duplicate_pairs == [(extracted, candidate)]