test: Add tests for edge deduplication fixes
Added three tests to verify the edge deduplication fixes: 1. test_dedupe_edges_bulk_deduplicates_within_episode: Verifies that dedupe_edges_bulk now compares edges from the same episode after removing the `if i == j: continue` check. 2. test_resolve_extracted_edge_uses_integer_indices_for_duplicates: Validates that the LLM receives integer indices for duplicate detection and correctly processes returned duplicate_facts. 3. test_resolve_extracted_edges_fast_path_deduplication: Confirms that the fast-path exact string matching deduplicates identical edges before parallel processing, preventing race conditions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
160a8a1310
commit
3432602eb7
2 changed files with 317 additions and 0 deletions
|
|
@ -230,3 +230,101 @@ def test_resolve_edge_pointers_updates_sources():
|
|||
|
||||
assert edge.source_node_uuid == 'canonical'
|
||||
assert edge.target_node_uuid == 'target'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_edges_bulk_deduplicates_within_episode(monkeypatch):
|
||||
"""Test that dedupe_edges_bulk correctly compares edges within the same episode.
|
||||
|
||||
This test verifies the fix that removed the `if i == j: continue` check,
|
||||
which was preventing edges from the same episode from being compared against each other.
|
||||
"""
|
||||
clients = _make_clients()
|
||||
|
||||
# Track which edges are compared
|
||||
comparisons_made = []
|
||||
|
||||
# Create mock embedder that sets embedding values
|
||||
async def mock_create_embeddings(embedder, edges):
|
||||
for edge in edges:
|
||||
edge.fact_embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
monkeypatch.setattr(bulk_utils, 'create_entity_edge_embeddings', mock_create_embeddings)
|
||||
|
||||
# Mock resolve_extracted_edge to track comparisons and mark duplicates
|
||||
async def mock_resolve_extracted_edge(
|
||||
llm_client,
|
||||
extracted_edge,
|
||||
related_edges,
|
||||
existing_edges,
|
||||
episode,
|
||||
edge_type_candidates=None,
|
||||
custom_edge_type_names=None,
|
||||
ensure_ascii=False,
|
||||
):
|
||||
# Track that this edge was compared against the related_edges
|
||||
comparisons_made.append((extracted_edge.uuid, [r.uuid for r in related_edges]))
|
||||
|
||||
# If there are related edges with same source/target/fact, mark as duplicate
|
||||
for related in related_edges:
|
||||
if (
|
||||
related.uuid != extracted_edge.uuid # Can't be duplicate of self
|
||||
and related.source_node_uuid == extracted_edge.source_node_uuid
|
||||
and related.target_node_uuid == extracted_edge.target_node_uuid
|
||||
and related.fact.strip().lower() == extracted_edge.fact.strip().lower()
|
||||
):
|
||||
# Return the related edge and mark extracted_edge as duplicate
|
||||
return related, [], [related]
|
||||
# Otherwise return the extracted edge as-is
|
||||
return extracted_edge, [], []
|
||||
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_edge', mock_resolve_extracted_edge)
|
||||
|
||||
episode = _make_episode('1')
|
||||
source_uuid = 'source-uuid'
|
||||
target_uuid = 'target-uuid'
|
||||
|
||||
# Create 3 identical edges within the same episode
|
||||
edge1 = EntityEdge(
|
||||
name='recommends',
|
||||
fact='assistant recommends yoga poses',
|
||||
group_id='group',
|
||||
source_node_uuid=source_uuid,
|
||||
target_node_uuid=target_uuid,
|
||||
created_at=utc_now(),
|
||||
episodes=[episode.uuid],
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
name='recommends',
|
||||
fact='assistant recommends yoga poses',
|
||||
group_id='group',
|
||||
source_node_uuid=source_uuid,
|
||||
target_node_uuid=target_uuid,
|
||||
created_at=utc_now(),
|
||||
episodes=[episode.uuid],
|
||||
)
|
||||
edge3 = EntityEdge(
|
||||
name='recommends',
|
||||
fact='assistant recommends yoga poses',
|
||||
group_id='group',
|
||||
source_node_uuid=source_uuid,
|
||||
target_node_uuid=target_uuid,
|
||||
created_at=utc_now(),
|
||||
episodes=[episode.uuid],
|
||||
)
|
||||
|
||||
edges_by_episode = await bulk_utils.dedupe_edges_bulk(
|
||||
clients,
|
||||
[[edge1, edge2, edge3]],
|
||||
[(episode, [])],
|
||||
[],
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
# Verify that edges were compared against each other (within same episode)
|
||||
# Each edge should have been compared against all 3 edges (including itself, which gets filtered)
|
||||
assert len(comparisons_made) == 3
|
||||
for edge_uuid, compared_against in comparisons_made:
|
||||
# Each edge should have access to all 3 edges as candidates
|
||||
assert len(compared_against) >= 2 # At least 2 others (self is filtered out)
|
||||
|
|
|
|||
|
|
@ -434,3 +434,222 @@ async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client)
|
|||
assert resolved_edge.attributes == {}
|
||||
assert duplicates == []
|
||||
assert invalidated == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
|
||||
"""Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
|
||||
# Mock LLM to return duplicate_facts with integer indices
|
||||
mock_llm_client.generate_response.return_value = {
|
||||
'duplicate_facts': [0, 1], # LLM identifies first two related edges as duplicates
|
||||
'contradicted_facts': [],
|
||||
'fact_type': 'DEFAULT',
|
||||
}
|
||||
|
||||
extracted_edge = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='User likes yoga',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Create multiple related edges - LLM should receive these with integer indices
|
||||
related_edge_0 = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='User enjoys yoga',
|
||||
episodes=['episode_1'],
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
related_edge_1 = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='User practices yoga',
|
||||
episodes=['episode_2'],
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
related_edge_2 = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='User loves swimming',
|
||||
episodes=['episode_3'],
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=3),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
related_edges = [related_edge_0, related_edge_1, related_edge_2]
|
||||
|
||||
resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
|
||||
mock_llm_client,
|
||||
extracted_edge,
|
||||
related_edges,
|
||||
[],
|
||||
episode,
|
||||
edge_type_candidates=None,
|
||||
custom_edge_type_names=set(),
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
# Verify LLM was called
|
||||
mock_llm_client.generate_response.assert_called_once()
|
||||
|
||||
# Verify the system correctly identified duplicates using integer indices
|
||||
# The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
|
||||
assert len(duplicates) == 2
|
||||
assert related_edge_0 in duplicates
|
||||
assert related_edge_1 in duplicates
|
||||
assert invalidated == []
|
||||
|
||||
# Verify that the resolved edge is one of the duplicates (the first one found)
|
||||
# Check UUID since the episode list gets modified
|
||||
assert resolved_edge.uuid == related_edge_0.uuid
|
||||
assert episode.uuid in resolved_edge.episodes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
|
||||
"""Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
|
||||
from graphiti_core.utils.maintenance import edge_operations as edge_ops
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
|
||||
|
||||
# Track how many times resolve_extracted_edge is called
|
||||
resolve_call_count = 0
|
||||
|
||||
async def mock_resolve_extracted_edge(
|
||||
llm_client,
|
||||
extracted_edge,
|
||||
related_edges,
|
||||
existing_edges,
|
||||
episode,
|
||||
edge_type_candidates=None,
|
||||
custom_edge_type_names=None,
|
||||
ensure_ascii=False,
|
||||
):
|
||||
nonlocal resolve_call_count
|
||||
resolve_call_count += 1
|
||||
return extracted_edge, [], []
|
||||
|
||||
# Mock semaphore_gather to execute awaitable immediately
|
||||
async def immediate_gather(*aws, max_coroutines=None):
|
||||
results = []
|
||||
for aw in aws:
|
||||
results.append(await aw)
|
||||
return results
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
|
||||
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
|
||||
monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)
|
||||
|
||||
llm_client = MagicMock()
|
||||
clients = SimpleNamespace(
|
||||
driver=MagicMock(),
|
||||
llm_client=llm_client,
|
||||
embedder=MagicMock(),
|
||||
cross_encoder=MagicMock(),
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
source_node = EntityNode(
|
||||
uuid='source_uuid',
|
||||
name='Assistant',
|
||||
group_id='group_1',
|
||||
labels=['Entity'],
|
||||
)
|
||||
target_node = EntityNode(
|
||||
uuid='target_uuid',
|
||||
name='User',
|
||||
group_id='group_1',
|
||||
labels=['Entity'],
|
||||
)
|
||||
|
||||
# Create 3 identical edges
|
||||
edge1 = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='recommends',
|
||||
group_id='group_1',
|
||||
fact='assistant recommends yoga poses',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
edge2 = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='recommends',
|
||||
group_id='group_1',
|
||||
fact=' Assistant Recommends YOGA Poses ', # Different whitespace/case
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
edge3 = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='recommends',
|
||||
group_id='group_1',
|
||||
fact='assistant recommends yoga poses',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||
clients,
|
||||
[edge1, edge2, edge3],
|
||||
episode,
|
||||
[source_node, target_node],
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
# Fast path should have deduplicated the 3 identical edges to 1
|
||||
# So resolve_extracted_edge should only be called once
|
||||
assert resolve_call_count == 1
|
||||
assert len(resolved_edges) == 1
|
||||
assert invalidated_edges == []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue