add tests for llm dedupe guardrails
This commit is contained in:
parent
27b8dd34a5
commit
76802f418f
2 changed files with 162 additions and 5 deletions
|
|
@ -291,18 +291,41 @@ async def _resolve_with_llm(
|
||||||
|
|
||||||
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
|
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
|
||||||
|
|
||||||
|
valid_relative_range = range(len(state.unresolved_indices))
|
||||||
|
processed_relative_ids: set[int] = set()
|
||||||
|
|
||||||
for resolution in node_resolutions:
|
for resolution in node_resolutions:
|
||||||
relative_id: int = resolution.id
|
relative_id: int = resolution.id
|
||||||
duplicate_idx: int = resolution.duplicate_idx
|
duplicate_idx: int = resolution.duplicate_idx
|
||||||
|
|
||||||
|
if relative_id not in valid_relative_range:
|
||||||
|
logger.warning(
|
||||||
|
'Skipping invalid LLM dedupe id %s (unresolved indices: %s)',
|
||||||
|
relative_id,
|
||||||
|
state.unresolved_indices,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if relative_id in processed_relative_ids:
|
||||||
|
logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
|
||||||
|
continue
|
||||||
|
processed_relative_ids.add(relative_id)
|
||||||
|
|
||||||
original_index = state.unresolved_indices[relative_id]
|
original_index = state.unresolved_indices[relative_id]
|
||||||
extracted_node = extracted_nodes[original_index]
|
extracted_node = extracted_nodes[original_index]
|
||||||
|
|
||||||
resolved_node = (
|
resolved_node: EntityNode
|
||||||
indexes.existing_nodes[duplicate_idx]
|
if duplicate_idx == -1:
|
||||||
if 0 <= duplicate_idx < len(indexes.existing_nodes)
|
resolved_node = extracted_node
|
||||||
else extracted_node
|
elif 0 <= duplicate_idx < len(indexes.existing_nodes):
|
||||||
)
|
resolved_node = indexes.existing_nodes[duplicate_idx]
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
|
||||||
|
duplicate_idx,
|
||||||
|
extracted_node.uuid,
|
||||||
|
)
|
||||||
|
resolved_node = extracted_node
|
||||||
|
|
||||||
state.resolved_nodes[original_index] = resolved_node
|
state.resolved_nodes[original_index] = resolved_node
|
||||||
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
|
@ -343,3 +344,136 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
|
||||||
assert captured_context['existing_nodes'][0]['idx'] == 0
|
assert captured_context['existing_nodes'][0]['idx'] == 0
|
||||||
assert isinstance(captured_context['existing_nodes'], list)
|
assert isinstance(captured_context['existing_nodes'], list)
|
||||||
assert state.duplicate_pairs == [(extracted, candidate)]
|
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
|
||||||
|
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
|
||||||
|
|
||||||
|
indexes = _build_candidate_indexes([])
|
||||||
|
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
|
||||||
|
lambda context: ['prompt'],
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_client = MagicMock()
|
||||||
|
llm_client.generate_response = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
'entity_resolutions': [
|
||||||
|
{
|
||||||
|
'id': 5,
|
||||||
|
'duplicate_idx': -1,
|
||||||
|
'name': 'Dexter',
|
||||||
|
'duplicates': [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
await _resolve_with_llm(
|
||||||
|
llm_client,
|
||||||
|
[extracted],
|
||||||
|
indexes,
|
||||||
|
state,
|
||||||
|
ensure_ascii=False,
|
||||||
|
episode=_make_episode(),
|
||||||
|
previous_episodes=[],
|
||||||
|
entity_types=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.resolved_nodes[0] is None
|
||||||
|
assert 'Skipping invalid LLM dedupe id 5' in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
|
||||||
|
extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
|
||||||
|
candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
|
||||||
|
|
||||||
|
indexes = _build_candidate_indexes([candidate])
|
||||||
|
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
|
||||||
|
lambda context: ['prompt'],
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_client = MagicMock()
|
||||||
|
llm_client.generate_response = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
'entity_resolutions': [
|
||||||
|
{
|
||||||
|
'id': 0,
|
||||||
|
'duplicate_idx': 0,
|
||||||
|
'name': 'Dizzy Gillespie',
|
||||||
|
'duplicates': [0],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 0,
|
||||||
|
'duplicate_idx': -1,
|
||||||
|
'name': 'Dizzy',
|
||||||
|
'duplicates': [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await _resolve_with_llm(
|
||||||
|
llm_client,
|
||||||
|
[extracted],
|
||||||
|
indexes,
|
||||||
|
state,
|
||||||
|
ensure_ascii=False,
|
||||||
|
episode=_make_episode(),
|
||||||
|
previous_episodes=[],
|
||||||
|
entity_types=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.resolved_nodes[0].uuid == candidate.uuid
|
||||||
|
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
||||||
|
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
|
||||||
|
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
|
||||||
|
|
||||||
|
indexes = _build_candidate_indexes([])
|
||||||
|
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
|
||||||
|
lambda context: ['prompt'],
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_client = MagicMock()
|
||||||
|
llm_client.generate_response = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
'entity_resolutions': [
|
||||||
|
{
|
||||||
|
'id': 0,
|
||||||
|
'duplicate_idx': 10,
|
||||||
|
'name': 'Dexter',
|
||||||
|
'duplicates': [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await _resolve_with_llm(
|
||||||
|
llm_client,
|
||||||
|
[extracted],
|
||||||
|
indexes,
|
||||||
|
state,
|
||||||
|
ensure_ascii=False,
|
||||||
|
episode=_make_episode(),
|
||||||
|
previous_episodes=[],
|
||||||
|
entity_types=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.resolved_nodes[0] == extracted
|
||||||
|
assert state.uuid_map[extracted.uuid] == extracted.uuid
|
||||||
|
assert state.duplicate_pairs == []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue