* Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929) * Simplify deduplication process in bulk_utils by reusing canonical nodes. * Update dedup_helpers to store duplicate pairs during resolution. * Modify node_operations to append duplicate pairs when resolving nodes. * Add tests to verify deduplication behavior and ensure correct state updates. * reveret to concurrent dedup with fanout and then reconcilation * add performance note for deduplication loop in bulk_utils * enhance deduplication logic in bulk_utils to handle missing canonical nodes gracefully * Update graphiti_core/utils/bulk_utils.py Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> * refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution * implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils * document directed union-find lookup in bulk_utils for clarity --------- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
232 lines
6.9 KiB
Python
232 lines
6.9 KiB
Python
from collections import deque
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
from graphiti_core.utils import bulk_utils
|
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
|
|
|
|
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
|
|
return EpisodicNode(
|
|
name=f'episode-{uuid_suffix}',
|
|
group_id=group_id,
|
|
labels=[],
|
|
source=EpisodeType.message,
|
|
content='content',
|
|
source_description='test',
|
|
created_at=utc_now(),
|
|
valid_at=utc_now(),
|
|
)
|
|
|
|
|
|
def _make_clients() -> GraphitiClients:
|
|
driver = MagicMock()
|
|
embedder = MagicMock()
|
|
cross_encoder = MagicMock()
|
|
llm_client = MagicMock()
|
|
|
|
return GraphitiClients.model_construct( # bypass validation to allow test doubles
|
|
driver=driver,
|
|
embedder=embedder,
|
|
cross_encoder=cross_encoder,
|
|
llm_client=llm_client,
|
|
ensure_ascii=False,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
|
|
clients = _make_clients()
|
|
|
|
episode_one = _make_episode('1')
|
|
episode_two = _make_episode('2')
|
|
|
|
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
|
|
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
|
|
|
|
canonical = extracted_one
|
|
|
|
call_queue = deque()
|
|
|
|
async def fake_resolve(
|
|
clients_arg,
|
|
nodes_arg,
|
|
episode_arg,
|
|
previous_episodes_arg,
|
|
entity_types_arg,
|
|
existing_nodes_override=None,
|
|
):
|
|
call_queue.append(existing_nodes_override)
|
|
|
|
if nodes_arg == [extracted_one]:
|
|
return [canonical], {canonical.uuid: canonical.uuid}, []
|
|
|
|
assert nodes_arg == [extracted_two]
|
|
assert existing_nodes_override is None
|
|
|
|
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
|
|
|
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
|
|
|
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
|
clients,
|
|
[[extracted_one], [extracted_two]],
|
|
[(episode_one, []), (episode_two, [])],
|
|
)
|
|
|
|
assert len(call_queue) == 2
|
|
assert call_queue[0] is None
|
|
assert call_queue[1] is None
|
|
|
|
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
|
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
|
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
|
|
clients = _make_clients()
|
|
|
|
resolve_mock = AsyncMock()
|
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
|
|
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
|
clients,
|
|
[],
|
|
[],
|
|
)
|
|
|
|
assert nodes_by_episode == {}
|
|
assert compressed_map == {}
|
|
resolve_mock.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
|
|
clients = _make_clients()
|
|
|
|
episode = _make_episode('solo')
|
|
extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
|
|
|
|
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
|
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
|
|
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
|
clients,
|
|
[[extracted]],
|
|
[(episode, [])],
|
|
)
|
|
|
|
assert nodes_by_episode == {episode.uuid: [extracted]}
|
|
assert compressed_map == {extracted.uuid: extracted.uuid}
|
|
resolve_mock.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
|
|
clients = _make_clients()
|
|
|
|
episode_one = _make_episode('one')
|
|
episode_two = _make_episode('two')
|
|
|
|
extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
|
|
extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
|
|
|
|
canonical = extracted_one
|
|
alias = extracted_two
|
|
|
|
async def fake_resolve(
|
|
clients_arg,
|
|
nodes_arg,
|
|
episode_arg,
|
|
previous_episodes_arg,
|
|
entity_types_arg,
|
|
existing_nodes_override=None,
|
|
):
|
|
if nodes_arg == [extracted_one]:
|
|
return [canonical], {canonical.uuid: canonical.uuid}, []
|
|
assert nodes_arg == [extracted_two]
|
|
return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
|
|
|
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
|
|
|
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
|
clients,
|
|
[[extracted_one], [extracted_two]],
|
|
[(episode_one, []), (episode_two, [])],
|
|
)
|
|
|
|
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
|
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
|
assert compressed_map.get(alias.uuid) == canonical.uuid
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
|
|
clients = _make_clients()
|
|
|
|
episode = _make_episode('missing')
|
|
extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
|
|
|
|
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
|
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
|
|
|
with caplog.at_level('WARNING'):
|
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
|
clients,
|
|
[[extracted]],
|
|
[(episode, [])],
|
|
)
|
|
|
|
assert nodes_by_episode[episode.uuid] == [extracted]
|
|
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
|
|
assert any('Canonical node missing' in rec.message for rec in caplog.records)
|
|
|
|
|
|
def test_build_directed_uuid_map_empty():
|
|
assert bulk_utils._build_directed_uuid_map([]) == {}
|
|
|
|
|
|
def test_build_directed_uuid_map_chain():
|
|
mapping = bulk_utils._build_directed_uuid_map(
|
|
[
|
|
('a', 'b'),
|
|
('b', 'c'),
|
|
]
|
|
)
|
|
|
|
assert mapping['a'] == 'c'
|
|
assert mapping['b'] == 'c'
|
|
assert mapping['c'] == 'c'
|
|
|
|
|
|
def test_build_directed_uuid_map_preserves_direction():
|
|
mapping = bulk_utils._build_directed_uuid_map(
|
|
[
|
|
('alias', 'canonical'),
|
|
]
|
|
)
|
|
|
|
assert mapping['alias'] == 'canonical'
|
|
assert mapping['canonical'] == 'canonical'
|
|
|
|
|
|
def test_resolve_edge_pointers_updates_sources():
|
|
created_at = utc_now()
|
|
edge = EntityEdge(
|
|
name='knows',
|
|
fact='fact',
|
|
group_id='group',
|
|
source_node_uuid='alias',
|
|
target_node_uuid='target',
|
|
created_at=created_at,
|
|
)
|
|
|
|
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
|
|
|
|
assert edge.source_node_uuid == 'canonical'
|
|
assert edge.target_node_uuid == 'target'
|