implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils
This commit is contained in:
parent
7a688ce924
commit
51855f91ae
2 changed files with 78 additions and 18 deletions
|
|
@ -69,6 +69,37 @@ logger = logging.getLogger(__name__)
|
||||||
CHUNK_SIZE = 10
|
CHUNK_SIZE = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
|
||||||
|
"""Collapse alias -> canonical chains while preserving direction.
|
||||||
|
|
||||||
|
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
|
||||||
|
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
|
||||||
|
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent: dict[str, str] = {}
|
||||||
|
|
||||||
|
def find(uuid: str) -> str:
|
||||||
|
parent.setdefault(uuid, uuid)
|
||||||
|
root = uuid
|
||||||
|
while parent[root] != root:
|
||||||
|
root = parent[root]
|
||||||
|
|
||||||
|
while parent[uuid] != root:
|
||||||
|
next_uuid = parent[uuid]
|
||||||
|
parent[uuid] = root
|
||||||
|
uuid = next_uuid
|
||||||
|
|
||||||
|
return root
|
||||||
|
|
||||||
|
for source_uuid, target_uuid in pairs:
|
||||||
|
parent.setdefault(source_uuid, source_uuid)
|
||||||
|
parent.setdefault(target_uuid, target_uuid)
|
||||||
|
parent[find(source_uuid)] = find(target_uuid)
|
||||||
|
|
||||||
|
return {uuid: find(uuid) for uuid in parent}
|
||||||
|
|
||||||
|
|
||||||
class RawEpisode(BaseModel):
|
class RawEpisode(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
uuid: str | None = Field(default=None)
|
uuid: str | None = Field(default=None)
|
||||||
|
|
@ -355,24 +386,7 @@ async def dedupe_nodes_bulk(
|
||||||
union_pairs.extend(uuid_map.items())
|
union_pairs.extend(uuid_map.items())
|
||||||
union_pairs.extend(duplicate_pairs)
|
union_pairs.extend(duplicate_pairs)
|
||||||
|
|
||||||
parent: dict[str, str] = {}
|
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
|
||||||
|
|
||||||
def find(uuid: str) -> str:
|
|
||||||
"""Directed union-find lookup so aliases always point to the true canonical UUID."""
|
|
||||||
parent.setdefault(uuid, uuid)
|
|
||||||
if parent[uuid] != uuid:
|
|
||||||
parent[uuid] = find(parent[uuid])
|
|
||||||
return parent[uuid]
|
|
||||||
|
|
||||||
for source_uuid, target_uuid in union_pairs:
|
|
||||||
parent.setdefault(source_uuid, source_uuid)
|
|
||||||
parent.setdefault(target_uuid, target_uuid)
|
|
||||||
# Force the alias chain (source -> target) to collapse in the canonical direction.
|
|
||||||
root_target = find(target_uuid)
|
|
||||||
root_source = find(source_uuid)
|
|
||||||
parent[root_source] = root_target
|
|
||||||
|
|
||||||
compressed_map: dict[str, str] = {uuid: find(uuid) for uuid in parent}
|
|
||||||
|
|
||||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||||
for episode_uuid, resolved_nodes in episode_resolutions:
|
for episode_uuid, resolved_nodes in episode_resolutions:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.utils import bulk_utils
|
from graphiti_core.utils import bulk_utils
|
||||||
|
|
@ -184,3 +185,48 @@ async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplo
|
||||||
assert nodes_by_episode[episode.uuid] == [extracted]
|
assert nodes_by_episode[episode.uuid] == [extracted]
|
||||||
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
|
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
|
||||||
assert any('Canonical node missing' in rec.message for rec in caplog.records)
|
assert any('Canonical node missing' in rec.message for rec in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_directed_uuid_map_empty():
|
||||||
|
assert bulk_utils._build_directed_uuid_map([]) == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_directed_uuid_map_chain():
|
||||||
|
mapping = bulk_utils._build_directed_uuid_map(
|
||||||
|
[
|
||||||
|
('a', 'b'),
|
||||||
|
('b', 'c'),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mapping['a'] == 'c'
|
||||||
|
assert mapping['b'] == 'c'
|
||||||
|
assert mapping['c'] == 'c'
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_directed_uuid_map_preserves_direction():
|
||||||
|
mapping = bulk_utils._build_directed_uuid_map(
|
||||||
|
[
|
||||||
|
('alias', 'canonical'),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mapping['alias'] == 'canonical'
|
||||||
|
assert mapping['canonical'] == 'canonical'
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_edge_pointers_updates_sources():
|
||||||
|
created_at = utc_now()
|
||||||
|
edge = EntityEdge(
|
||||||
|
name='knows',
|
||||||
|
fact='fact',
|
||||||
|
group_id='group',
|
||||||
|
source_node_uuid='alias',
|
||||||
|
target_node_uuid='target',
|
||||||
|
created_at=created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
|
||||||
|
|
||||||
|
assert edge.source_node_uuid == 'canonical'
|
||||||
|
assert edge.target_node_uuid == 'target'
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue