graphiti/tests/utils/maintenance/test_bulk_utils.py
Daniel Chalef 9aee3174bd
Refactor batch deduplication logic to enhance node resolution and track duplicate pairs (#929) (#936)
* Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929)

* Simplify deduplication process in bulk_utils by reusing canonical nodes.
* Update dedup_helpers to store duplicate pairs during resolution.
* Modify node_operations to append duplicate pairs when resolving nodes.
* Add tests to verify deduplication behavior and ensure correct state updates.

* reveret to concurrent dedup with fanout and then reconcilation

* add performance note for deduplication loop in bulk_utils

* enhance deduplication logic in bulk_utils to handle missing canonical nodes gracefully

* Update graphiti_core/utils/bulk_utils.py

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>

* refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution

* implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils

* document directed union-find lookup in bulk_utils for clarity

---------

Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
2025-09-26 08:40:18 -07:00

232 lines
6.9 KiB
Python

from collections import deque
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
from graphiti_core.utils.datetime_utils import utc_now
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
return EpisodicNode(
name=f'episode-{uuid_suffix}',
group_id=group_id,
labels=[],
source=EpisodeType.message,
content='content',
source_description='test',
created_at=utc_now(),
valid_at=utc_now(),
)
def _make_clients() -> GraphitiClients:
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
return GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
ensure_ascii=False,
)
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('1')
episode_two = _make_episode('2')
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
canonical = extracted_one
call_queue = deque()
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
call_queue.append(existing_nodes_override)
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
assert existing_nodes_override is None
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert len(call_queue) == 2
assert call_queue[0] is None
assert call_queue[1] is None
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
clients = _make_clients()
resolve_mock = AsyncMock()
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[],
[],
)
assert nodes_by_episode == {}
assert compressed_map == {}
resolve_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
clients = _make_clients()
episode = _make_episode('solo')
extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted]],
[(episode, [])],
)
assert nodes_by_episode == {episode.uuid: [extracted]}
assert compressed_map == {extracted.uuid: extracted.uuid}
resolve_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('one')
episode_two = _make_episode('two')
extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
canonical = extracted_one
alias = extracted_two
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(alias.uuid) == canonical.uuid
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
clients = _make_clients()
episode = _make_episode('missing')
extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
with caplog.at_level('WARNING'):
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted]],
[(episode, [])],
)
assert nodes_by_episode[episode.uuid] == [extracted]
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
assert any('Canonical node missing' in rec.message for rec in caplog.records)
def test_build_directed_uuid_map_empty():
assert bulk_utils._build_directed_uuid_map([]) == {}
def test_build_directed_uuid_map_chain():
mapping = bulk_utils._build_directed_uuid_map(
[
('a', 'b'),
('b', 'c'),
]
)
assert mapping['a'] == 'c'
assert mapping['b'] == 'c'
assert mapping['c'] == 'c'
def test_build_directed_uuid_map_preserves_direction():
mapping = bulk_utils._build_directed_uuid_map(
[
('alias', 'canonical'),
]
)
assert mapping['alias'] == 'canonical'
assert mapping['canonical'] == 'canonical'
def test_resolve_edge_pointers_updates_sources():
created_at = utc_now()
edge = EntityEdge(
name='knows',
fact='fact',
group_id='group',
source_node_uuid='alias',
target_node_uuid='target',
created_at=created_at,
)
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
assert edge.source_node_uuid == 'canonical'
assert edge.target_node_uuid == 'target'