Fix index out of range errors in LLM deduplication responses (#939)

* add tests for llm dedupe guardrails

* document llm dedupe guardrails
This commit is contained in:
Daniel Chalef 2025-09-26 14:57:48 -07:00 committed by GitHub
parent 27b8dd34a5
commit d7828d48d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 167 additions and 6 deletions

View file

@ -241,7 +241,11 @@ async def _resolve_with_llm(
previous_episodes: list[EpisodicNode] | None,
entity_types: dict[str, type[BaseModel]] | None,
) -> None:
"""Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates."""
"""Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
The guardrails below defensively ignore malformed or duplicate LLM responses so the
ingestion workflow remains deterministic even when the model misbehaves.
"""
if not state.unresolved_indices:
return
@ -291,18 +295,41 @@ async def _resolve_with_llm(
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
valid_relative_range = range(len(state.unresolved_indices))
processed_relative_ids: set[int] = set()
for resolution in node_resolutions:
relative_id: int = resolution.id
duplicate_idx: int = resolution.duplicate_idx
if relative_id not in valid_relative_range:
logger.warning(
'Skipping invalid LLM dedupe id %s (unresolved indices: %s)',
relative_id,
state.unresolved_indices,
)
continue
if relative_id in processed_relative_ids:
logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
continue
processed_relative_ids.add(relative_id)
original_index = state.unresolved_indices[relative_id]
extracted_node = extracted_nodes[original_index]
resolved_node = (
indexes.existing_nodes[duplicate_idx]
if 0 <= duplicate_idx < len(indexes.existing_nodes)
else extracted_node
)
resolved_node: EntityNode
if duplicate_idx == -1:
resolved_node = extracted_node
elif 0 <= duplicate_idx < len(indexes.existing_nodes):
resolved_node = indexes.existing_nodes[duplicate_idx]
else:
logger.warning(
'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
duplicate_idx,
extracted_node.uuid,
)
resolved_node = extracted_node
state.resolved_nodes[original_index] = resolved_node
state.uuid_map[extracted_node.uuid] = resolved_node.uuid

View file

@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock
@ -343,3 +344,136 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
assert captured_context['existing_nodes'][0]['idx'] == 0
assert isinstance(captured_context['existing_nodes'], list)
assert state.duplicate_pairs == [(extracted, candidate)]
@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 5,
'duplicate_idx': -1,
'name': 'Dexter',
'duplicates': [],
}
]
}
)
with caplog.at_level(logging.WARNING):
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
ensure_ascii=False,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0] is None
assert 'Skipping invalid LLM dedupe id 5' in caplog.text
@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': 0,
'name': 'Dizzy Gillespie',
'duplicates': [0],
},
{
'id': 0,
'duplicate_idx': -1,
'name': 'Dizzy',
'duplicates': [],
},
]
}
)
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
ensure_ascii=False,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.duplicate_pairs == [(extracted, candidate)]
@pytest.mark.asyncio
async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': 10,
'name': 'Dexter',
'duplicates': [],
}
]
}
)
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
ensure_ascii=False,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0] == extracted
assert state.uuid_map[extracted.uuid] == extracted.uuid
assert state.duplicate_pairs == []