from collections import deque from unittest.mock import AsyncMock, MagicMock import pytest from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.utils import bulk_utils from graphiti_core.utils.datetime_utils import utc_now def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode: return EpisodicNode( name=f'episode-{uuid_suffix}', group_id=group_id, labels=[], source=EpisodeType.message, content='content', source_description='test', created_at=utc_now(), valid_at=utc_now(), ) def _make_clients() -> GraphitiClients: driver = MagicMock() embedder = MagicMock() cross_encoder = MagicMock() llm_client = MagicMock() return GraphitiClients.model_construct( # bypass validation to allow test doubles driver=driver, embedder=embedder, cross_encoder=cross_encoder, llm_client=llm_client, ensure_ascii=False, ) @pytest.mark.asyncio async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch): clients = _make_clients() episode_one = _make_episode('1') episode_two = _make_episode('2') extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity']) extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity']) canonical = extracted_one call_queue = deque() async def fake_resolve( clients_arg, nodes_arg, episode_arg, previous_episodes_arg, entity_types_arg, existing_nodes_override=None, ): call_queue.append(existing_nodes_override) if nodes_arg == [extracted_one]: return [canonical], {canonical.uuid: canonical.uuid}, [] assert nodes_arg == [extracted_two] assert existing_nodes_override is None return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)] monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve) nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( clients, [[extracted_one], [extracted_two]], [(episode_one, []), (episode_two, [])], ) assert len(call_queue) == 2 assert call_queue[0] is None assert call_queue[1] is None assert nodes_by_episode[episode_one.uuid] == [canonical] assert nodes_by_episode[episode_two.uuid] == [canonical] assert compressed_map.get(extracted_two.uuid) == canonical.uuid @pytest.mark.asyncio async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch): clients = _make_clients() resolve_mock = AsyncMock() monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( clients, [], [], ) assert nodes_by_episode == {} assert compressed_map == {} resolve_mock.assert_not_awaited() @pytest.mark.asyncio async def test_dedupe_nodes_bulk_single_episode(monkeypatch): clients = _make_clients() episode = _make_episode('solo') extracted = EntityNode(name='Solo', group_id='group', labels=['Entity']) resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, [])) monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( clients, [[extracted]], [(episode, [])], ) assert nodes_by_episode == {episode.uuid: [extracted]} assert compressed_map == {extracted.uuid: extracted.uuid} resolve_mock.assert_awaited_once() @pytest.mark.asyncio async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch): clients = _make_clients() episode_one = _make_episode('one') episode_two = _make_episode('two') extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity']) extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity']) canonical = extracted_one alias = extracted_two async def fake_resolve( clients_arg, nodes_arg, episode_arg, previous_episodes_arg, entity_types_arg, existing_nodes_override=None, ): if nodes_arg == [extracted_one]: return [canonical], {canonical.uuid: canonical.uuid}, [] assert nodes_arg == [extracted_two] return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)] monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve) nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( clients, [[extracted_one], [extracted_two]], [(episode_one, []), (episode_two, [])], ) assert nodes_by_episode[episode_one.uuid] == [canonical] assert nodes_by_episode[episode_two.uuid] == [canonical] assert compressed_map.get(alias.uuid) == canonical.uuid @pytest.mark.asyncio async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog): clients = _make_clients() episode = _make_episode('missing') extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity']) resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, [])) monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock) with caplog.at_level('WARNING'): nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( clients, [[extracted]], [(episode, [])], ) assert nodes_by_episode[episode.uuid] == [extracted] assert compressed_map.get(extracted.uuid) == 'missing-canonical' assert any('Canonical node missing' in rec.message for rec in caplog.records)