refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution
This commit is contained in:
parent
2d695cc7f0
commit
7a688ce924
2 changed files with 118 additions and 8 deletions
|
|
@ -355,14 +355,24 @@ async def dedupe_nodes_bulk(
|
|||
union_pairs.extend(uuid_map.items())
|
||||
union_pairs.extend(duplicate_pairs)
|
||||
|
||||
compressed_map: dict[str, str] = compress_uuid_map(union_pairs)
|
||||
# We pass directed edges (extracted -> canonical) to the compressor, but the utility treats
|
||||
# them as undirected pairs and picks the lexicographically smaller UUID as the component root.
|
||||
# Re-write the entries using the original direction so that each source maps to the canonical
|
||||
# target returned by the first/second pass even if its UUID sorts before the canonical one.
|
||||
parent: dict[str, str] = {}
|
||||
|
||||
def find(uuid: str) -> str:
|
||||
"""Directed union-find lookup so aliases always point to the true canonical UUID."""
|
||||
parent.setdefault(uuid, uuid)
|
||||
if parent[uuid] != uuid:
|
||||
parent[uuid] = find(parent[uuid])
|
||||
return parent[uuid]
|
||||
|
||||
for source_uuid, target_uuid in union_pairs:
|
||||
canonical_uuid = compressed_map.get(target_uuid, target_uuid)
|
||||
compressed_map[source_uuid] = canonical_uuid
|
||||
parent.setdefault(source_uuid, source_uuid)
|
||||
parent.setdefault(target_uuid, target_uuid)
|
||||
# Force the alias chain (source -> target) to collapse in the canonical direction.
|
||||
root_target = find(target_uuid)
|
||||
root_source = find(source_uuid)
|
||||
parent[root_source] = root_target
|
||||
|
||||
compressed_map: dict[str, str] = {uuid: find(uuid) for uuid in parent}
|
||||
|
||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||
for episode_uuid, resolved_nodes in episode_resolutions:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections import deque
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -84,3 +84,103 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
|
|||
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
|
||||
clients = _make_clients()
|
||||
|
||||
resolve_mock = AsyncMock()
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
assert nodes_by_episode == {}
|
||||
assert compressed_map == {}
|
||||
resolve_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
|
||||
clients = _make_clients()
|
||||
|
||||
episode = _make_episode('solo')
|
||||
extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
|
||||
|
||||
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[[extracted]],
|
||||
[(episode, [])],
|
||||
)
|
||||
|
||||
assert nodes_by_episode == {episode.uuid: [extracted]}
|
||||
assert compressed_map == {extracted.uuid: extracted.uuid}
|
||||
resolve_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
|
||||
clients = _make_clients()
|
||||
|
||||
episode_one = _make_episode('one')
|
||||
episode_two = _make_episode('two')
|
||||
|
||||
extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
|
||||
extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
|
||||
|
||||
canonical = extracted_one
|
||||
alias = extracted_two
|
||||
|
||||
async def fake_resolve(
|
||||
clients_arg,
|
||||
nodes_arg,
|
||||
episode_arg,
|
||||
previous_episodes_arg,
|
||||
entity_types_arg,
|
||||
existing_nodes_override=None,
|
||||
):
|
||||
if nodes_arg == [extracted_one]:
|
||||
return [canonical], {canonical.uuid: canonical.uuid}, []
|
||||
assert nodes_arg == [extracted_two]
|
||||
return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
|
||||
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
|
||||
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[[extracted_one], [extracted_two]],
|
||||
[(episode_one, []), (episode_two, [])],
|
||||
)
|
||||
|
||||
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||
assert compressed_map.get(alias.uuid) == canonical.uuid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
|
||||
clients = _make_clients()
|
||||
|
||||
episode = _make_episode('missing')
|
||||
extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
|
||||
|
||||
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||
|
||||
with caplog.at_level('WARNING'):
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[[extracted]],
|
||||
[(episode, [])],
|
||||
)
|
||||
|
||||
assert nodes_by_episode[episode.uuid] == [extracted]
|
||||
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
|
||||
assert any('Canonical node missing' in rec.message for rec in caplog.records)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue