diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 61176c12..691aa1a9 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -69,6 +69,37 @@ logger = logging.getLogger(__name__) CHUNK_SIZE = 10 +def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]: + """Collapse alias -> canonical chains while preserving direction. + + The incoming pairs represent directed mappings discovered during node dedupe. We use a simple + union-find with iterative path compression to ensure every source UUID resolves to its ultimate + canonical target, even if aliases appear lexicographically smaller than the canonical UUID. + """ + + parent: dict[str, str] = {} + + def find(uuid: str) -> str: + parent.setdefault(uuid, uuid) + root = uuid + while parent[root] != root: + root = parent[root] + + while parent[uuid] != root: + next_uuid = parent[uuid] + parent[uuid] = root + uuid = next_uuid + + return root + + for source_uuid, target_uuid in pairs: + parent.setdefault(source_uuid, source_uuid) + parent.setdefault(target_uuid, target_uuid) + parent[find(source_uuid)] = find(target_uuid) + + return {uuid: find(uuid) for uuid in parent} + + class RawEpisode(BaseModel): name: str uuid: str | None = Field(default=None) @@ -355,24 +386,7 @@ async def dedupe_nodes_bulk( union_pairs.extend(uuid_map.items()) union_pairs.extend(duplicate_pairs) - parent: dict[str, str] = {} - - def find(uuid: str) -> str: - """Directed union-find lookup so aliases always point to the true canonical UUID.""" - parent.setdefault(uuid, uuid) - if parent[uuid] != uuid: - parent[uuid] = find(parent[uuid]) - return parent[uuid] - - for source_uuid, target_uuid in union_pairs: - parent.setdefault(source_uuid, source_uuid) - parent.setdefault(target_uuid, target_uuid) - # Force the alias chain (source -> target) to collapse in the canonical direction. - root_target = find(target_uuid) - root_source = find(source_uuid) - parent[root_source] = root_target - - compressed_map: dict[str, str] = {uuid: find(uuid) for uuid in parent} + compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs) nodes_by_episode: dict[str, list[EntityNode]] = {} for episode_uuid, resolved_nodes in episode_resolutions: diff --git a/tests/utils/maintenance/test_bulk_utils.py b/tests/utils/maintenance/test_bulk_utils.py index 7c49906a..9ba2c25d 100644 --- a/tests/utils/maintenance/test_bulk_utils.py +++ b/tests/utils/maintenance/test_bulk_utils.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from graphiti_core.edges import EntityEdge from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.utils import bulk_utils @@ -184,3 +185,48 @@ async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplo assert nodes_by_episode[episode.uuid] == [extracted] assert compressed_map.get(extracted.uuid) == 'missing-canonical' assert any('Canonical node missing' in rec.message for rec in caplog.records) + + +def test_build_directed_uuid_map_empty(): + assert bulk_utils._build_directed_uuid_map([]) == {} + + +def test_build_directed_uuid_map_chain(): + mapping = bulk_utils._build_directed_uuid_map( + [ + ('a', 'b'), + ('b', 'c'), + ] + ) + + assert mapping['a'] == 'c' + assert mapping['b'] == 'c' + assert mapping['c'] == 'c' + + +def test_build_directed_uuid_map_preserves_direction(): + mapping = bulk_utils._build_directed_uuid_map( + [ + ('alias', 'canonical'), + ] + ) + + assert mapping['alias'] == 'canonical' + assert mapping['canonical'] == 'canonical' + + +def test_resolve_edge_pointers_updates_sources(): + created_at = utc_now() + edge = EntityEdge( + name='knows', + fact='fact', + group_id='group', + source_node_uuid='alias', + target_node_uuid='target', + created_at=created_at, + ) + + bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'}) + + assert edge.source_node_uuid == 'canonical' + assert edge.target_node_uuid == 'target'