refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution
This commit is contained in:
parent
2d695cc7f0
commit
7a688ce924
2 changed files with 118 additions and 8 deletions
|
|
@ -355,14 +355,24 @@ async def dedupe_nodes_bulk(
|
||||||
union_pairs.extend(uuid_map.items())
|
union_pairs.extend(uuid_map.items())
|
||||||
union_pairs.extend(duplicate_pairs)
|
union_pairs.extend(duplicate_pairs)
|
||||||
|
|
||||||
compressed_map: dict[str, str] = compress_uuid_map(union_pairs)
|
parent: dict[str, str] = {}
|
||||||
# We pass directed edges (extracted -> canonical) to the compressor, but the utility treats
|
|
||||||
# them as undirected pairs and picks the lexicographically smaller UUID as the component root.
|
def find(uuid: str) -> str:
|
||||||
# Re-write the entries using the original direction so that each source maps to the canonical
|
"""Directed union-find lookup so aliases always point to the true canonical UUID."""
|
||||||
# target returned by the first/second pass even if its UUID sorts before the canonical one.
|
parent.setdefault(uuid, uuid)
|
||||||
|
if parent[uuid] != uuid:
|
||||||
|
parent[uuid] = find(parent[uuid])
|
||||||
|
return parent[uuid]
|
||||||
|
|
||||||
for source_uuid, target_uuid in union_pairs:
|
for source_uuid, target_uuid in union_pairs:
|
||||||
canonical_uuid = compressed_map.get(target_uuid, target_uuid)
|
parent.setdefault(source_uuid, source_uuid)
|
||||||
compressed_map[source_uuid] = canonical_uuid
|
parent.setdefault(target_uuid, target_uuid)
|
||||||
|
# Force the alias chain (source -> target) to collapse in the canonical direction.
|
||||||
|
root_target = find(target_uuid)
|
||||||
|
root_source = find(source_uuid)
|
||||||
|
parent[root_source] = root_target
|
||||||
|
|
||||||
|
compressed_map: dict[str, str] = {uuid: find(uuid) for uuid in parent}
|
||||||
|
|
||||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||||
for episode_uuid, resolved_nodes in episode_resolutions:
|
for episode_uuid, resolved_nodes in episode_resolutions:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -84,3 +84,103 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
|
||||||
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||||
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||||
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
|
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
|
||||||
|
clients = _make_clients()
|
||||||
|
|
||||||
|
resolve_mock = AsyncMock()
|
||||||
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||||
|
|
||||||
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||||
|
clients,
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert nodes_by_episode == {}
|
||||||
|
assert compressed_map == {}
|
||||||
|
resolve_mock.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
|
||||||
|
clients = _make_clients()
|
||||||
|
|
||||||
|
episode = _make_episode('solo')
|
||||||
|
extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
|
||||||
|
|
||||||
|
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
|
||||||
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||||
|
|
||||||
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||||
|
clients,
|
||||||
|
[[extracted]],
|
||||||
|
[(episode, [])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert nodes_by_episode == {episode.uuid: [extracted]}
|
||||||
|
assert compressed_map == {extracted.uuid: extracted.uuid}
|
||||||
|
resolve_mock.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
|
||||||
|
clients = _make_clients()
|
||||||
|
|
||||||
|
episode_one = _make_episode('one')
|
||||||
|
episode_two = _make_episode('two')
|
||||||
|
|
||||||
|
extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
|
||||||
|
extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
|
||||||
|
|
||||||
|
canonical = extracted_one
|
||||||
|
alias = extracted_two
|
||||||
|
|
||||||
|
async def fake_resolve(
|
||||||
|
clients_arg,
|
||||||
|
nodes_arg,
|
||||||
|
episode_arg,
|
||||||
|
previous_episodes_arg,
|
||||||
|
entity_types_arg,
|
||||||
|
existing_nodes_override=None,
|
||||||
|
):
|
||||||
|
if nodes_arg == [extracted_one]:
|
||||||
|
return [canonical], {canonical.uuid: canonical.uuid}, []
|
||||||
|
assert nodes_arg == [extracted_two]
|
||||||
|
return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
|
||||||
|
|
||||||
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
|
||||||
|
|
||||||
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||||
|
clients,
|
||||||
|
[[extracted_one], [extracted_two]],
|
||||||
|
[(episode_one, []), (episode_two, [])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||||
|
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||||
|
assert compressed_map.get(alias.uuid) == canonical.uuid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
|
||||||
|
clients = _make_clients()
|
||||||
|
|
||||||
|
episode = _make_episode('missing')
|
||||||
|
extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
|
||||||
|
|
||||||
|
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
|
||||||
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||||
|
|
||||||
|
with caplog.at_level('WARNING'):
|
||||||
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||||
|
clients,
|
||||||
|
[[extracted]],
|
||||||
|
[(episode, [])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert nodes_by_episode[episode.uuid] == [extracted]
|
||||||
|
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
|
||||||
|
assert any('Canonical node missing' in rec.message for rec in caplog.records)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue