From 76802f418fcf19021b6a71cb0dcf26f70ab8b000 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:50:11 -0700 Subject: [PATCH] add tests for llm dedupe guardrails --- .../utils/maintenance/node_operations.py | 33 ++++- .../utils/maintenance/test_node_operations.py | 134 ++++++++++++++++++ 2 files changed, 162 insertions(+), 5 deletions(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 26bc23d9..122dbc00 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -291,18 +291,41 @@ async def _resolve_with_llm( node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions + valid_relative_range = range(len(state.unresolved_indices)) + processed_relative_ids: set[int] = set() + for resolution in node_resolutions: relative_id: int = resolution.id duplicate_idx: int = resolution.duplicate_idx + if relative_id not in valid_relative_range: + logger.warning( + 'Skipping invalid LLM dedupe id %s (unresolved indices: %s)', + relative_id, + state.unresolved_indices, + ) + continue + + if relative_id in processed_relative_ids: + logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id) + continue + processed_relative_ids.add(relative_id) + original_index = state.unresolved_indices[relative_id] extracted_node = extracted_nodes[original_index] - resolved_node = ( - indexes.existing_nodes[duplicate_idx] - if 0 <= duplicate_idx < len(indexes.existing_nodes) - else extracted_node - ) + resolved_node: EntityNode + if duplicate_idx == -1: + resolved_node = extracted_node + elif 0 <= duplicate_idx < len(indexes.existing_nodes): + resolved_node = indexes.existing_nodes[duplicate_idx] + else: + logger.warning( + 'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.', + duplicate_idx, + extracted_node.uuid, + ) + resolved_node = extracted_node state.resolved_nodes[original_index] = resolved_node state.uuid_map[extracted_node.uuid] = resolved_node.uuid diff --git a/tests/utils/maintenance/test_node_operations.py b/tests/utils/maintenance/test_node_operations.py index b2fc30b0..0bbae6b3 100644 --- a/tests/utils/maintenance/test_node_operations.py +++ b/tests/utils/maintenance/test_node_operations.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict from unittest.mock import AsyncMock, MagicMock @@ -343,3 +344,136 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch): assert captured_context['existing_nodes'][0]['idx'] == 0 assert isinstance(captured_context['existing_nodes'], list) assert state.duplicate_pairs == [(extracted, candidate)] + + +@pytest.mark.asyncio +async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog): + extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([]) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes', + lambda context: ['prompt'], + ) + + llm_client = MagicMock() + llm_client.generate_response = AsyncMock( + return_value={ + 'entity_resolutions': [ + { + 'id': 5, + 'duplicate_idx': -1, + 'name': 'Dexter', + 'duplicates': [], + } + ] + } + ) + + with caplog.at_level(logging.WARNING): + await _resolve_with_llm( + llm_client, + [extracted], + indexes, + state, + ensure_ascii=False, + episode=_make_episode(), + previous_episodes=[], + entity_types=None, + ) + + assert state.resolved_nodes[0] is None + assert 'Skipping invalid LLM dedupe id 5' in caplog.text + + +@pytest.mark.asyncio +async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch): + extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity']) + candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([candidate]) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes', + lambda context: ['prompt'], + ) + + llm_client = MagicMock() + llm_client.generate_response = AsyncMock( + return_value={ + 'entity_resolutions': [ + { + 'id': 0, + 'duplicate_idx': 0, + 'name': 'Dizzy Gillespie', + 'duplicates': [0], + }, + { + 'id': 0, + 'duplicate_idx': -1, + 'name': 'Dizzy', + 'duplicates': [], + }, + ] + } + ) + + await _resolve_with_llm( + llm_client, + [extracted], + indexes, + state, + ensure_ascii=False, + episode=_make_episode(), + previous_episodes=[], + entity_types=None, + ) + + assert state.resolved_nodes[0].uuid == candidate.uuid + assert state.uuid_map[extracted.uuid] == candidate.uuid + assert state.duplicate_pairs == [(extracted, candidate)] + + +@pytest.mark.asyncio +async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch): + extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([]) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes', + lambda context: ['prompt'], + ) + + llm_client = MagicMock() + llm_client.generate_response = AsyncMock( + return_value={ + 'entity_resolutions': [ + { + 'id': 0, + 'duplicate_idx': 10, + 'name': 'Dexter', + 'duplicates': [], + } + ] + } + ) + + await _resolve_with_llm( + llm_client, + [extracted], + indexes, + state, + ensure_ascii=False, + episode=_make_episode(), + previous_episodes=[], + entity_types=None, + ) + + assert state.resolved_nodes[0] == extracted + assert state.uuid_map[extracted.uuid] == extracted.uuid + assert state.duplicate_pairs == []