implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils

This commit is contained in:
Daniel Chalef 2025-09-26 07:16:42 -07:00
parent 7a688ce924
commit 51855f91ae
2 changed files with 78 additions and 18 deletions

View file

@ -69,6 +69,37 @@ logger = logging.getLogger(__name__)
CHUNK_SIZE = 10
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
"""Collapse alias -> canonical chains while preserving direction.
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
"""
parent: dict[str, str] = {}
def find(uuid: str) -> str:
parent.setdefault(uuid, uuid)
root = uuid
while parent[root] != root:
root = parent[root]
while parent[uuid] != root:
next_uuid = parent[uuid]
parent[uuid] = root
uuid = next_uuid
return root
for source_uuid, target_uuid in pairs:
parent.setdefault(source_uuid, source_uuid)
parent.setdefault(target_uuid, target_uuid)
parent[find(source_uuid)] = find(target_uuid)
return {uuid: find(uuid) for uuid in parent}
class RawEpisode(BaseModel):
name: str
uuid: str | None = Field(default=None)
@ -355,24 +386,7 @@ async def dedupe_nodes_bulk(
union_pairs.extend(uuid_map.items())
union_pairs.extend(duplicate_pairs)
parent: dict[str, str] = {}
def find(uuid: str) -> str:
"""Directed union-find lookup so aliases always point to the true canonical UUID."""
parent.setdefault(uuid, uuid)
if parent[uuid] != uuid:
parent[uuid] = find(parent[uuid])
return parent[uuid]
for source_uuid, target_uuid in union_pairs:
parent.setdefault(source_uuid, source_uuid)
parent.setdefault(target_uuid, target_uuid)
# Force the alias chain (source -> target) to collapse in the canonical direction.
root_target = find(target_uuid)
root_source = find(source_uuid)
parent[root_source] = root_target
compressed_map: dict[str, str] = {uuid: find(uuid) for uuid in parent}
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
nodes_by_episode: dict[str, list[EntityNode]] = {}
for episode_uuid, resolved_nodes in episode_resolutions:

View file

@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
@ -184,3 +185,48 @@ async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplo
assert nodes_by_episode[episode.uuid] == [extracted]
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
assert any('Canonical node missing' in rec.message for rec in caplog.records)
def test_build_directed_uuid_map_empty():
assert bulk_utils._build_directed_uuid_map([]) == {}
def test_build_directed_uuid_map_chain():
mapping = bulk_utils._build_directed_uuid_map(
[
('a', 'b'),
('b', 'c'),
]
)
assert mapping['a'] == 'c'
assert mapping['b'] == 'c'
assert mapping['c'] == 'c'
def test_build_directed_uuid_map_preserves_direction():
mapping = bulk_utils._build_directed_uuid_map(
[
('alias', 'canonical'),
]
)
assert mapping['alias'] == 'canonical'
assert mapping['canonical'] == 'canonical'
def test_resolve_edge_pointers_updates_sources():
created_at = utc_now()
edge = EntityEdge(
name='knows',
fact='fact',
group_id='group',
source_node_uuid='alias',
target_node_uuid='target',
created_at=created_at,
)
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
assert edge.source_node_uuid == 'canonical'
assert edge.target_node_uuid == 'target'