implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils
This commit is contained in:
parent
7a688ce924
commit
51855f91ae
2 changed files with 78 additions and 18 deletions
|
|
@ -69,6 +69,37 @@ logger = logging.getLogger(__name__)
|
|||
CHUNK_SIZE = 10
|
||||
|
||||
|
||||
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
|
||||
"""Collapse alias -> canonical chains while preserving direction.
|
||||
|
||||
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
|
||||
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
|
||||
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
|
||||
"""
|
||||
|
||||
parent: dict[str, str] = {}
|
||||
|
||||
def find(uuid: str) -> str:
|
||||
parent.setdefault(uuid, uuid)
|
||||
root = uuid
|
||||
while parent[root] != root:
|
||||
root = parent[root]
|
||||
|
||||
while parent[uuid] != root:
|
||||
next_uuid = parent[uuid]
|
||||
parent[uuid] = root
|
||||
uuid = next_uuid
|
||||
|
||||
return root
|
||||
|
||||
for source_uuid, target_uuid in pairs:
|
||||
parent.setdefault(source_uuid, source_uuid)
|
||||
parent.setdefault(target_uuid, target_uuid)
|
||||
parent[find(source_uuid)] = find(target_uuid)
|
||||
|
||||
return {uuid: find(uuid) for uuid in parent}
|
||||
|
||||
|
||||
class RawEpisode(BaseModel):
|
||||
name: str
|
||||
uuid: str | None = Field(default=None)
|
||||
|
|
@ -355,24 +386,7 @@ async def dedupe_nodes_bulk(
|
|||
union_pairs.extend(uuid_map.items())
|
||||
union_pairs.extend(duplicate_pairs)
|
||||
|
||||
parent: dict[str, str] = {}
|
||||
|
||||
def find(uuid: str) -> str:
|
||||
"""Directed union-find lookup so aliases always point to the true canonical UUID."""
|
||||
parent.setdefault(uuid, uuid)
|
||||
if parent[uuid] != uuid:
|
||||
parent[uuid] = find(parent[uuid])
|
||||
return parent[uuid]
|
||||
|
||||
for source_uuid, target_uuid in union_pairs:
|
||||
parent.setdefault(source_uuid, source_uuid)
|
||||
parent.setdefault(target_uuid, target_uuid)
|
||||
# Force the alias chain (source -> target) to collapse in the canonical direction.
|
||||
root_target = find(target_uuid)
|
||||
root_source = find(source_uuid)
|
||||
parent[root_source] = root_target
|
||||
|
||||
compressed_map: dict[str, str] = {uuid: find(uuid) for uuid in parent}
|
||||
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
|
||||
|
||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||
for episode_uuid, resolved_nodes in episode_resolutions:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils import bulk_utils
|
||||
|
|
@ -184,3 +185,48 @@ async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplo
|
|||
assert nodes_by_episode[episode.uuid] == [extracted]
|
||||
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
|
||||
assert any('Canonical node missing' in rec.message for rec in caplog.records)
|
||||
|
||||
|
||||
def test_build_directed_uuid_map_empty():
|
||||
assert bulk_utils._build_directed_uuid_map([]) == {}
|
||||
|
||||
|
||||
def test_build_directed_uuid_map_chain():
|
||||
mapping = bulk_utils._build_directed_uuid_map(
|
||||
[
|
||||
('a', 'b'),
|
||||
('b', 'c'),
|
||||
]
|
||||
)
|
||||
|
||||
assert mapping['a'] == 'c'
|
||||
assert mapping['b'] == 'c'
|
||||
assert mapping['c'] == 'c'
|
||||
|
||||
|
||||
def test_build_directed_uuid_map_preserves_direction():
|
||||
mapping = bulk_utils._build_directed_uuid_map(
|
||||
[
|
||||
('alias', 'canonical'),
|
||||
]
|
||||
)
|
||||
|
||||
assert mapping['alias'] == 'canonical'
|
||||
assert mapping['canonical'] == 'canonical'
|
||||
|
||||
|
||||
def test_resolve_edge_pointers_updates_sources():
|
||||
created_at = utc_now()
|
||||
edge = EntityEdge(
|
||||
name='knows',
|
||||
fact='fact',
|
||||
group_id='group',
|
||||
source_node_uuid='alias',
|
||||
target_node_uuid='target',
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
|
||||
|
||||
assert edge.source_node_uuid == 'canonical'
|
||||
assert edge.target_node_uuid == 'target'
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue