add tests for llm dedupe guardrails
This commit is contained in:
parent
27b8dd34a5
commit
76802f418f
2 changed files with 162 additions and 5 deletions
|
|
@ -291,18 +291,41 @@ async def _resolve_with_llm(
|
|||
|
||||
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
|
||||
|
||||
valid_relative_range = range(len(state.unresolved_indices))
|
||||
processed_relative_ids: set[int] = set()
|
||||
|
||||
for resolution in node_resolutions:
|
||||
relative_id: int = resolution.id
|
||||
duplicate_idx: int = resolution.duplicate_idx
|
||||
|
||||
if relative_id not in valid_relative_range:
|
||||
logger.warning(
|
||||
'Skipping invalid LLM dedupe id %s (unresolved indices: %s)',
|
||||
relative_id,
|
||||
state.unresolved_indices,
|
||||
)
|
||||
continue
|
||||
|
||||
if relative_id in processed_relative_ids:
|
||||
logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
|
||||
continue
|
||||
processed_relative_ids.add(relative_id)
|
||||
|
||||
original_index = state.unresolved_indices[relative_id]
|
||||
extracted_node = extracted_nodes[original_index]
|
||||
|
||||
resolved_node = (
|
||||
indexes.existing_nodes[duplicate_idx]
|
||||
if 0 <= duplicate_idx < len(indexes.existing_nodes)
|
||||
else extracted_node
|
||||
)
|
||||
resolved_node: EntityNode
|
||||
if duplicate_idx == -1:
|
||||
resolved_node = extracted_node
|
||||
elif 0 <= duplicate_idx < len(indexes.existing_nodes):
|
||||
resolved_node = indexes.existing_nodes[duplicate_idx]
|
||||
else:
|
||||
logger.warning(
|
||||
'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
|
||||
duplicate_idx,
|
||||
extracted_node.uuid,
|
||||
)
|
||||
resolved_node = extracted_node
|
||||
|
||||
state.resolved_nodes[original_index] = resolved_node
|
||||
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
|
@ -343,3 +344,136 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
|
|||
assert captured_context['existing_nodes'][0]['idx'] == 0
|
||||
assert isinstance(captured_context['existing_nodes'], list)
|
||||
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
|
||||
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
|
||||
|
||||
indexes = _build_candidate_indexes([])
|
||||
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
|
||||
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
|
||||
lambda context: ['prompt'],
|
||||
)
|
||||
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={
|
||||
'entity_resolutions': [
|
||||
{
|
||||
'id': 5,
|
||||
'duplicate_idx': -1,
|
||||
'name': 'Dexter',
|
||||
'duplicates': [],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await _resolve_with_llm(
|
||||
llm_client,
|
||||
[extracted],
|
||||
indexes,
|
||||
state,
|
||||
ensure_ascii=False,
|
||||
episode=_make_episode(),
|
||||
previous_episodes=[],
|
||||
entity_types=None,
|
||||
)
|
||||
|
||||
assert state.resolved_nodes[0] is None
|
||||
assert 'Skipping invalid LLM dedupe id 5' in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
|
||||
extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
|
||||
candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
|
||||
|
||||
indexes = _build_candidate_indexes([candidate])
|
||||
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
|
||||
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
|
||||
lambda context: ['prompt'],
|
||||
)
|
||||
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={
|
||||
'entity_resolutions': [
|
||||
{
|
||||
'id': 0,
|
||||
'duplicate_idx': 0,
|
||||
'name': 'Dizzy Gillespie',
|
||||
'duplicates': [0],
|
||||
},
|
||||
{
|
||||
'id': 0,
|
||||
'duplicate_idx': -1,
|
||||
'name': 'Dizzy',
|
||||
'duplicates': [],
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
await _resolve_with_llm(
|
||||
llm_client,
|
||||
[extracted],
|
||||
indexes,
|
||||
state,
|
||||
ensure_ascii=False,
|
||||
episode=_make_episode(),
|
||||
previous_episodes=[],
|
||||
entity_types=None,
|
||||
)
|
||||
|
||||
assert state.resolved_nodes[0].uuid == candidate.uuid
|
||||
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
||||
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
|
||||
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
|
||||
|
||||
indexes = _build_candidate_indexes([])
|
||||
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
|
||||
|
||||
monkeypatch.setattr(
|
||||
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
|
||||
lambda context: ['prompt'],
|
||||
)
|
||||
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={
|
||||
'entity_resolutions': [
|
||||
{
|
||||
'id': 0,
|
||||
'duplicate_idx': 10,
|
||||
'name': 'Dexter',
|
||||
'duplicates': [],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
await _resolve_with_llm(
|
||||
llm_client,
|
||||
[extracted],
|
||||
indexes,
|
||||
state,
|
||||
ensure_ascii=False,
|
||||
episode=_make_episode(),
|
||||
previous_episodes=[],
|
||||
entity_types=None,
|
||||
)
|
||||
|
||||
assert state.resolved_nodes[0] == extracted
|
||||
assert state.uuid_map[extracted.uuid] == extracted.uuid
|
||||
assert state.duplicate_pairs == []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue