From 7a688ce9240acac1bb8e621073d196467551c73a Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Thu, 25 Sep 2025 22:50:21 -0700 Subject: [PATCH] refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution --- graphiti_core/utils/bulk_utils.py | 24 +++-- tests/utils/maintenance/test_bulk_utils.py | 102 ++++++++++++++++++++- 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index a611808d..61176c12 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -355,14 +355,24 @@ async def dedupe_nodes_bulk( union_pairs.extend(uuid_map.items()) union_pairs.extend(duplicate_pairs) - compressed_map: dict[str, str] = compress_uuid_map(union_pairs) - # We pass directed edges (extracted -> canonical) to the compressor, but the utility treats - # them as undirected pairs and picks the lexicographically smaller UUID as the component root. - # Re-write the entries using the original direction so that each source maps to the canonical - # target returned by the first/second pass even if its UUID sorts before the canonical one. + parent: dict[str, str] = {} + + def find(uuid: str) -> str: + """Directed union-find lookup so aliases always point to the true canonical UUID.""" + parent.setdefault(uuid, uuid) + if parent[uuid] != uuid: + parent[uuid] = find(parent[uuid]) + return parent[uuid] + for source_uuid, target_uuid in union_pairs: - canonical_uuid = compressed_map.get(target_uuid, target_uuid) - compressed_map[source_uuid] = canonical_uuid + parent.setdefault(source_uuid, source_uuid) + parent.setdefault(target_uuid, target_uuid) + # Force the alias chain (source -> target) to collapse in the canonical direction. + root_target = find(target_uuid) + root_source = find(source_uuid) + parent[root_source] = root_target + + compressed_map: dict[str, str] = {uuid: find(uuid) for uuid in parent} nodes_by_episode: dict[str, list[EntityNode]] = {} for episode_uuid, resolved_nodes in episode_resolutions: diff --git a/tests/utils/maintenance/test_bulk_utils.py b/tests/utils/maintenance/test_bulk_utils.py index f803d471..7c49906a 100644 --- a/tests/utils/maintenance/test_bulk_utils.py +++ b/tests/utils/maintenance/test_bulk_utils.py @@ -1,5 +1,5 @@ from collections import deque -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -84,3 +84,103 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch): assert nodes_by_episode[episode_one.uuid] == [canonical] assert nodes_by_episode[episode_two.uuid] == [canonical] assert compressed_map.get(extracted_two.uuid) == canonical.uuid + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch): + clients = _make_clients() + + resolve_mock = AsyncMock() + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) + + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [], + [], + ) + + assert nodes_by_episode == {} + assert compressed_map == {} + resolve_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_single_episode(monkeypatch): + clients = _make_clients() + + episode = _make_episode('solo') + extracted = EntityNode(name='Solo', group_id='group', labels=['Entity']) + + resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, [])) + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) + + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [[extracted]], + [(episode, [])], + ) + + assert nodes_by_episode == {episode.uuid: [extracted]} + assert compressed_map == {extracted.uuid: extracted.uuid} + resolve_mock.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch): + clients = _make_clients() + + episode_one = _make_episode('one') + episode_two = _make_episode('two') + + extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity']) + extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity']) + + canonical = extracted_one + alias = extracted_two + + async def fake_resolve( + clients_arg, + nodes_arg, + episode_arg, + previous_episodes_arg, + entity_types_arg, + existing_nodes_override=None, + ): + if nodes_arg == [extracted_one]: + return [canonical], {canonical.uuid: canonical.uuid}, [] + assert nodes_arg == [extracted_two] + return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)] + + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve) + + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [[extracted_one], [extracted_two]], + [(episode_one, []), (episode_two, [])], + ) + + assert nodes_by_episode[episode_one.uuid] == [canonical] + assert nodes_by_episode[episode_two.uuid] == [canonical] + assert compressed_map.get(alias.uuid) == canonical.uuid + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog): + clients = _make_clients() + + episode = _make_episode('missing') + extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity']) + + resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, [])) + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) + + with caplog.at_level('WARNING'): + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [[extracted]], + [(episode, [])], + ) + + assert nodes_by_episode[episode.uuid] == [extracted] + assert compressed_map.get(extracted.uuid) == 'missing-canonical' + assert any('Canonical node missing' in rec.message for rec in caplog.records)