fix: Prevent duplicate edge facts within same episode (#955)
* fix: Prevent duplicate edge facts within same episode This fixes three related bugs that allowed verbatim duplicate edge facts: 1. Fixed LLM deduplication: Changed related_edges_context to use integer indices instead of UUIDs, matching the EdgeDuplicate model expectations. 2. Fixed batch deduplication: Removed episode skip in dedupe_edges_bulk that prevented comparing edges from the same episode. Added self-comparison guard to prevent edge from comparing against itself. 3. Added fast-path deduplication: Added exact string matching before parallel processing in resolve_extracted_edges to catch within-episode duplicates early, preventing race conditions where concurrent edges can't see each other. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * test: Add tests for edge deduplication fixes Added three tests to verify the edge deduplication fixes: 1. test_dedupe_edges_bulk_deduplicates_within_episode: Verifies that dedupe_edges_bulk now compares edges from the same episode after removing the `if i == j: continue` check. 2. test_resolve_extracted_edge_uses_integer_indices_for_duplicates: Validates that the LLM receives integer indices for duplicate detection and correctly processes returned duplicate_facts. 3. test_resolve_extracted_edges_fast_path_deduplication: Confirms that the fast-path exact string matching deduplicates identical edges before parallel processing, preventing race conditions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fix: Remove unused variables flagged by ruff - Remove unused loop variable 'j' in bulk_utils.py - Remove unused return value 'edges_by_episode' in test - Replace unused 'edge_uuid' with '_' in test loop 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4d54493064
commit
420676faf2
5 changed files with 2412 additions and 2091 deletions
|
|
@ -433,14 +433,15 @@ async def dedupe_edges_bulk(
|
|||
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
||||
for i, edges_i in enumerate(extracted_edges):
|
||||
existing_edges: list[EntityEdge] = []
|
||||
for j, edges_j in enumerate(extracted_edges):
|
||||
if i == j:
|
||||
continue
|
||||
for edges_j in extracted_edges:
|
||||
existing_edges += edges_j
|
||||
|
||||
for edge in edges_i:
|
||||
candidates: list[EntityEdge] = []
|
||||
for existing_edge in existing_edges:
|
||||
# Skip self-comparison
|
||||
if edge.uuid == existing_edge.uuid:
|
||||
continue
|
||||
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
||||
# This approach will cast a wider net than BM25, which is ideal for this use case
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -232,6 +232,22 @@ async def resolve_extracted_edges(
|
|||
edge_types: dict[str, type[BaseModel]],
|
||||
edge_type_map: dict[tuple[str, str], list[str]],
|
||||
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
||||
# Fast path: deduplicate exact matches within the extracted edges before parallel processing
|
||||
seen: dict[tuple[str, str, str], EntityEdge] = {}
|
||||
deduplicated_edges: list[EntityEdge] = []
|
||||
|
||||
for edge in extracted_edges:
|
||||
key = (
|
||||
edge.source_node_uuid,
|
||||
edge.target_node_uuid,
|
||||
_normalize_string_exact(edge.fact),
|
||||
)
|
||||
if key not in seen:
|
||||
seen[key] = edge
|
||||
deduplicated_edges.append(edge)
|
||||
|
||||
extracted_edges = deduplicated_edges
|
||||
|
||||
driver = clients.driver
|
||||
llm_client = clients.llm_client
|
||||
embedder = clients.embedder
|
||||
|
|
@ -465,7 +481,7 @@ async def resolve_extracted_edge(
|
|||
|
||||
# Prepare context for LLM
|
||||
related_edges_context = [
|
||||
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
||||
{'id': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
||||
]
|
||||
|
||||
invalidation_edge_candidates_context = [
|
||||
|
|
|
|||
|
|
@ -230,3 +230,101 @@ def test_resolve_edge_pointers_updates_sources():
|
|||
|
||||
assert edge.source_node_uuid == 'canonical'
|
||||
assert edge.target_node_uuid == 'target'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_edges_bulk_deduplicates_within_episode(monkeypatch):
|
||||
"""Test that dedupe_edges_bulk correctly compares edges within the same episode.
|
||||
|
||||
This test verifies the fix that removed the `if i == j: continue` check,
|
||||
which was preventing edges from the same episode from being compared against each other.
|
||||
"""
|
||||
clients = _make_clients()
|
||||
|
||||
# Track which edges are compared
|
||||
comparisons_made = []
|
||||
|
||||
# Create mock embedder that sets embedding values
|
||||
async def mock_create_embeddings(embedder, edges):
|
||||
for edge in edges:
|
||||
edge.fact_embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
monkeypatch.setattr(bulk_utils, 'create_entity_edge_embeddings', mock_create_embeddings)
|
||||
|
||||
# Mock resolve_extracted_edge to track comparisons and mark duplicates
|
||||
async def mock_resolve_extracted_edge(
|
||||
llm_client,
|
||||
extracted_edge,
|
||||
related_edges,
|
||||
existing_edges,
|
||||
episode,
|
||||
edge_type_candidates=None,
|
||||
custom_edge_type_names=None,
|
||||
ensure_ascii=False,
|
||||
):
|
||||
# Track that this edge was compared against the related_edges
|
||||
comparisons_made.append((extracted_edge.uuid, [r.uuid for r in related_edges]))
|
||||
|
||||
# If there are related edges with same source/target/fact, mark as duplicate
|
||||
for related in related_edges:
|
||||
if (
|
||||
related.uuid != extracted_edge.uuid # Can't be duplicate of self
|
||||
and related.source_node_uuid == extracted_edge.source_node_uuid
|
||||
and related.target_node_uuid == extracted_edge.target_node_uuid
|
||||
and related.fact.strip().lower() == extracted_edge.fact.strip().lower()
|
||||
):
|
||||
# Return the related edge and mark extracted_edge as duplicate
|
||||
return related, [], [related]
|
||||
# Otherwise return the extracted edge as-is
|
||||
return extracted_edge, [], []
|
||||
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_edge', mock_resolve_extracted_edge)
|
||||
|
||||
episode = _make_episode('1')
|
||||
source_uuid = 'source-uuid'
|
||||
target_uuid = 'target-uuid'
|
||||
|
||||
# Create 3 identical edges within the same episode
|
||||
edge1 = EntityEdge(
|
||||
name='recommends',
|
||||
fact='assistant recommends yoga poses',
|
||||
group_id='group',
|
||||
source_node_uuid=source_uuid,
|
||||
target_node_uuid=target_uuid,
|
||||
created_at=utc_now(),
|
||||
episodes=[episode.uuid],
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
name='recommends',
|
||||
fact='assistant recommends yoga poses',
|
||||
group_id='group',
|
||||
source_node_uuid=source_uuid,
|
||||
target_node_uuid=target_uuid,
|
||||
created_at=utc_now(),
|
||||
episodes=[episode.uuid],
|
||||
)
|
||||
edge3 = EntityEdge(
|
||||
name='recommends',
|
||||
fact='assistant recommends yoga poses',
|
||||
group_id='group',
|
||||
source_node_uuid=source_uuid,
|
||||
target_node_uuid=target_uuid,
|
||||
created_at=utc_now(),
|
||||
episodes=[episode.uuid],
|
||||
)
|
||||
|
||||
await bulk_utils.dedupe_edges_bulk(
|
||||
clients,
|
||||
[[edge1, edge2, edge3]],
|
||||
[(episode, [])],
|
||||
[],
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
# Verify that edges were compared against each other (within same episode)
|
||||
# Each edge should have been compared against all 3 edges (including itself, which gets filtered)
|
||||
assert len(comparisons_made) == 3
|
||||
for _, compared_against in comparisons_made:
|
||||
# Each edge should have access to all 3 edges as candidates
|
||||
assert len(compared_against) >= 2 # At least 2 others (self is filtered out)
|
||||
|
|
|
|||
|
|
@ -434,3 +434,222 @@ async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client)
|
|||
assert resolved_edge.attributes == {}
|
||||
assert duplicates == []
|
||||
assert invalidated == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
|
||||
"""Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
|
||||
# Mock LLM to return duplicate_facts with integer indices
|
||||
mock_llm_client.generate_response.return_value = {
|
||||
'duplicate_facts': [0, 1], # LLM identifies first two related edges as duplicates
|
||||
'contradicted_facts': [],
|
||||
'fact_type': 'DEFAULT',
|
||||
}
|
||||
|
||||
extracted_edge = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='User likes yoga',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Create multiple related edges - LLM should receive these with integer indices
|
||||
related_edge_0 = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='User enjoys yoga',
|
||||
episodes=['episode_1'],
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
related_edge_1 = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='User practices yoga',
|
||||
episodes=['episode_2'],
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
related_edge_2 = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='test_edge',
|
||||
group_id='group_1',
|
||||
fact='User loves swimming',
|
||||
episodes=['episode_3'],
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=3),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
related_edges = [related_edge_0, related_edge_1, related_edge_2]
|
||||
|
||||
resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
|
||||
mock_llm_client,
|
||||
extracted_edge,
|
||||
related_edges,
|
||||
[],
|
||||
episode,
|
||||
edge_type_candidates=None,
|
||||
custom_edge_type_names=set(),
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
# Verify LLM was called
|
||||
mock_llm_client.generate_response.assert_called_once()
|
||||
|
||||
# Verify the system correctly identified duplicates using integer indices
|
||||
# The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
|
||||
assert len(duplicates) == 2
|
||||
assert related_edge_0 in duplicates
|
||||
assert related_edge_1 in duplicates
|
||||
assert invalidated == []
|
||||
|
||||
# Verify that the resolved edge is one of the duplicates (the first one found)
|
||||
# Check UUID since the episode list gets modified
|
||||
assert resolved_edge.uuid == related_edge_0.uuid
|
||||
assert episode.uuid in resolved_edge.episodes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
|
||||
"""Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
|
||||
from graphiti_core.utils.maintenance import edge_operations as edge_ops
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
|
||||
|
||||
# Track how many times resolve_extracted_edge is called
|
||||
resolve_call_count = 0
|
||||
|
||||
async def mock_resolve_extracted_edge(
|
||||
llm_client,
|
||||
extracted_edge,
|
||||
related_edges,
|
||||
existing_edges,
|
||||
episode,
|
||||
edge_type_candidates=None,
|
||||
custom_edge_type_names=None,
|
||||
ensure_ascii=False,
|
||||
):
|
||||
nonlocal resolve_call_count
|
||||
resolve_call_count += 1
|
||||
return extracted_edge, [], []
|
||||
|
||||
# Mock semaphore_gather to execute awaitable immediately
|
||||
async def immediate_gather(*aws, max_coroutines=None):
|
||||
results = []
|
||||
for aw in aws:
|
||||
results.append(await aw)
|
||||
return results
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
|
||||
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
|
||||
monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)
|
||||
|
||||
llm_client = MagicMock()
|
||||
clients = SimpleNamespace(
|
||||
driver=MagicMock(),
|
||||
llm_client=llm_client,
|
||||
embedder=MagicMock(),
|
||||
cross_encoder=MagicMock(),
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
source_node = EntityNode(
|
||||
uuid='source_uuid',
|
||||
name='Assistant',
|
||||
group_id='group_1',
|
||||
labels=['Entity'],
|
||||
)
|
||||
target_node = EntityNode(
|
||||
uuid='target_uuid',
|
||||
name='User',
|
||||
group_id='group_1',
|
||||
labels=['Entity'],
|
||||
)
|
||||
|
||||
# Create 3 identical edges
|
||||
edge1 = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='recommends',
|
||||
group_id='group_1',
|
||||
fact='assistant recommends yoga poses',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
edge2 = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='recommends',
|
||||
group_id='group_1',
|
||||
fact=' Assistant Recommends YOGA Poses ', # Different whitespace/case
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
edge3 = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='recommends',
|
||||
group_id='group_1',
|
||||
fact='assistant recommends yoga poses',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||
clients,
|
||||
[edge1, edge2, edge3],
|
||||
episode,
|
||||
[source_node, target_node],
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
# Fast path should have deduplicated the 3 identical edges to 1
|
||||
# So resolve_extracted_edge should only be called once
|
||||
assert resolve_call_count == 1
|
||||
assert len(resolved_edges) == 1
|
||||
assert invalidated_edges == []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue