test: Add tests for edge deduplication fixes

Added three tests to verify the edge deduplication fixes:

1. test_dedupe_edges_bulk_deduplicates_within_episode: Verifies that
   dedupe_edges_bulk now compares edges from the same episode after
   removing the `if i == j: continue` check.

2. test_resolve_extracted_edge_uses_integer_indices_for_duplicates:
   Validates that the LLM receives integer indices for duplicate
   detection and correctly processes returned duplicate_facts.

3. test_resolve_extracted_edges_fast_path_deduplication: Confirms that
   the fast-path exact string matching deduplicates identical edges
   before parallel processing, preventing race conditions.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Chalef 2025-09-30 20:48:53 -07:00
parent 160a8a1310
commit 3432602eb7
2 changed files with 317 additions and 0 deletions

View file

@ -230,3 +230,101 @@ def test_resolve_edge_pointers_updates_sources():
assert edge.source_node_uuid == 'canonical'
assert edge.target_node_uuid == 'target'
@pytest.mark.asyncio
async def test_dedupe_edges_bulk_deduplicates_within_episode(monkeypatch):
"""Test that dedupe_edges_bulk correctly compares edges within the same episode.
This test verifies the fix that removed the `if i == j: continue` check,
which was preventing edges from the same episode from being compared against each other.
"""
clients = _make_clients()
# Track which edges are compared
comparisons_made = []
# Create mock embedder that sets embedding values
async def mock_create_embeddings(embedder, edges):
for edge in edges:
edge.fact_embedding = [0.1, 0.2, 0.3]
monkeypatch.setattr(bulk_utils, 'create_entity_edge_embeddings', mock_create_embeddings)
# Mock resolve_extracted_edge to track comparisons and mark duplicates
async def mock_resolve_extracted_edge(
llm_client,
extracted_edge,
related_edges,
existing_edges,
episode,
edge_type_candidates=None,
custom_edge_type_names=None,
ensure_ascii=False,
):
# Track that this edge was compared against the related_edges
comparisons_made.append((extracted_edge.uuid, [r.uuid for r in related_edges]))
# If there are related edges with same source/target/fact, mark as duplicate
for related in related_edges:
if (
related.uuid != extracted_edge.uuid # Can't be duplicate of self
and related.source_node_uuid == extracted_edge.source_node_uuid
and related.target_node_uuid == extracted_edge.target_node_uuid
and related.fact.strip().lower() == extracted_edge.fact.strip().lower()
):
# Return the related edge and mark extracted_edge as duplicate
return related, [], [related]
# Otherwise return the extracted edge as-is
return extracted_edge, [], []
monkeypatch.setattr(bulk_utils, 'resolve_extracted_edge', mock_resolve_extracted_edge)
episode = _make_episode('1')
source_uuid = 'source-uuid'
target_uuid = 'target-uuid'
# Create 3 identical edges within the same episode
edge1 = EntityEdge(
name='recommends',
fact='assistant recommends yoga poses',
group_id='group',
source_node_uuid=source_uuid,
target_node_uuid=target_uuid,
created_at=utc_now(),
episodes=[episode.uuid],
)
edge2 = EntityEdge(
name='recommends',
fact='assistant recommends yoga poses',
group_id='group',
source_node_uuid=source_uuid,
target_node_uuid=target_uuid,
created_at=utc_now(),
episodes=[episode.uuid],
)
edge3 = EntityEdge(
name='recommends',
fact='assistant recommends yoga poses',
group_id='group',
source_node_uuid=source_uuid,
target_node_uuid=target_uuid,
created_at=utc_now(),
episodes=[episode.uuid],
)
edges_by_episode = await bulk_utils.dedupe_edges_bulk(
clients,
[[edge1, edge2, edge3]],
[(episode, [])],
[],
{},
{},
)
# Verify that edges were compared against each other (within same episode)
# Each edge should have been compared against all 3 edges (including itself, which gets filtered)
assert len(comparisons_made) == 3
for edge_uuid, compared_against in comparisons_made:
# Each edge should have access to all 3 edges as candidates
assert len(compared_against) >= 2 # At least 2 others (self is filtered out)

View file

@ -434,3 +434,222 @@ async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client)
assert resolved_edge.attributes == {}
assert duplicates == []
assert invalidated == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
"""Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
# Mock LLM to return duplicate_facts with integer indices
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [0, 1], # LLM identifies first two related edges as duplicates
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User likes yoga',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
# Create multiple related edges - LLM should receive these with integer indices
related_edge_0 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User enjoys yoga',
episodes=['episode_1'],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=None,
invalid_at=None,
)
related_edge_1 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User practices yoga',
episodes=['episode_2'],
created_at=datetime.now(timezone.utc) - timedelta(days=2),
valid_at=None,
invalid_at=None,
)
related_edge_2 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User loves swimming',
episodes=['episode_3'],
created_at=datetime.now(timezone.utc) - timedelta(days=3),
valid_at=None,
invalid_at=None,
)
related_edges = [related_edge_0, related_edge_1, related_edge_2]
resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
related_edges,
[],
episode,
edge_type_candidates=None,
custom_edge_type_names=set(),
ensure_ascii=True,
)
# Verify LLM was called
mock_llm_client.generate_response.assert_called_once()
# Verify the system correctly identified duplicates using integer indices
# The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
assert len(duplicates) == 2
assert related_edge_0 in duplicates
assert related_edge_1 in duplicates
assert invalidated == []
# Verify that the resolved edge is one of the duplicates (the first one found)
# Check UUID since the episode list gets modified
assert resolved_edge.uuid == related_edge_0.uuid
assert episode.uuid in resolved_edge.episodes
@pytest.mark.asyncio
async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
"""Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
# Track how many times resolve_extracted_edge is called
resolve_call_count = 0
async def mock_resolve_extracted_edge(
llm_client,
extracted_edge,
related_edges,
existing_edges,
episode,
edge_type_candidates=None,
custom_edge_type_names=None,
ensure_ascii=False,
):
nonlocal resolve_call_count
resolve_call_count += 1
return extracted_edge, [], []
# Mock semaphore_gather to execute awaitable immediately
async def immediate_gather(*aws, max_coroutines=None):
results = []
for aw in aws:
results.append(await aw)
return results
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)
llm_client = MagicMock()
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
ensure_ascii=True,
)
source_node = EntityNode(
uuid='source_uuid',
name='Assistant',
group_id='group_1',
labels=['Entity'],
)
target_node = EntityNode(
uuid='target_uuid',
name='User',
group_id='group_1',
labels=['Entity'],
)
# Create 3 identical edges
edge1 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact='assistant recommends yoga poses',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
edge2 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact=' Assistant Recommends YOGA Poses ', # Different whitespace/case
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
edge3 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact='assistant recommends yoga poses',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[edge1, edge2, edge3],
episode,
[source_node, target_node],
{},
{},
)
# Fast path should have deduplicated the 3 identical edges to 1
# So resolve_extracted_edge should only be called once
assert resolve_call_count == 1
assert len(resolved_edges) == 1
assert invalidated_edges == []