Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929)

* Simplify deduplication process in bulk_utils by reusing canonical nodes.
* Update dedup_helpers to store duplicate pairs during resolution.
* Modify node_operations to append duplicate pairs when resolving nodes.
* Add tests to verify deduplication behavior and ensure correct state updates.
This commit is contained in:
Daniel Chalef 2025-09-25 14:32:03 -07:00
parent 1e56019027
commit 18550c84ac
5 changed files with 155 additions and 75 deletions

View file

@ -43,7 +43,7 @@ from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query, get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query, get_episode_node_save_bulk_query,
) )
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.edge_operations import ( from graphiti_core.utils.maintenance.edge_operations import (
extract_edges, extract_edges,
@ -266,83 +266,66 @@ async def dedupe_nodes_bulk(
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, type[BaseModel]] | None = None, entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]: ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
embedder = clients.embedder canonical_nodes: dict[str, EntityNode] = {}
min_score = 0.8 episode_resolutions: list[tuple[str, list[EntityNode]]] = []
per_episode_uuid_maps: list[dict[str, str]] = []
# generate embeddings
await semaphore_gather(
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
)
# Find similar results
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
for i, nodes_i in enumerate(extracted_nodes):
existing_nodes: list[EntityNode] = []
for j, nodes_j in enumerate(extracted_nodes):
if i == j:
continue
existing_nodes += nodes_j
candidates_i: list[EntityNode] = []
for node in nodes_i:
for existing_node in existing_nodes:
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
# This approach will cast a wider net than BM25, which is ideal for this use case
node_words = set(node.name.lower().split())
existing_node_words = set(existing_node.name.lower().split())
has_overlap = not node_words.isdisjoint(existing_node_words)
if has_overlap:
candidates_i.append(existing_node)
continue
# Check for semantic similarity even if there is no overlap
similarity = np.dot(
normalize_l2(node.name_embedding or []),
normalize_l2(existing_node.name_embedding or []),
)
if similarity >= min_score:
candidates_i.append(existing_node)
dedupe_tuples.append((nodes_i, candidates_i))
# Determine Node Resolutions
bulk_node_resolutions: list[
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
] = await semaphore_gather(
*[
resolve_extracted_nodes(
clients,
dedupe_tuple[0],
episode_tuples[i][0],
episode_tuples[i][1],
entity_types,
existing_nodes_override=dedupe_tuples[i][1],
)
for i, dedupe_tuple in enumerate(dedupe_tuples)
]
)
# Collect all duplicate pairs sorted by uuid
duplicate_pairs: list[tuple[str, str]] = [] duplicate_pairs: list[tuple[str, str]] = []
for _, _, duplicates in bulk_node_resolutions:
for duplicate in duplicates:
n, m = duplicate
duplicate_pairs.append((n.uuid, m.uuid))
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid) for nodes, (episode, previous_episodes) in zip(extracted_nodes, episode_tuples, strict=True):
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs) existing_override = list(canonical_nodes.values()) if canonical_nodes else None
node_uuid_map: dict[str, EntityNode] = { resolved_nodes, uuid_map, duplicates = await resolve_extracted_nodes(
node.uuid: node for nodes in extracted_nodes for node in nodes clients,
} nodes,
episode,
previous_episodes,
entity_types,
existing_nodes_override=existing_override,
)
per_episode_uuid_maps.append(uuid_map)
episode_resolutions.append((episode.uuid, resolved_nodes))
for node in resolved_nodes:
canonical_nodes[node.uuid] = node
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
uuid_chains: dict[str, str] = {}
for uuid_map in per_episode_uuid_maps:
uuid_chains.update(uuid_map)
for duplicate_uuid, canonical_uuid in duplicate_pairs:
uuid_chains[duplicate_uuid] = canonical_uuid
def _resolve_uuid(uuid: str) -> str:
seen: set[str] = set()
current = uuid
while True:
target = uuid_chains.get(current)
if target is None or target == current:
return current if target is None else target
if current in seen:
return current
seen.add(current)
current = target
compressed_map: dict[str, str] = {uuid: _resolve_uuid(uuid) for uuid in uuid_chains}
for canonical_uuid in canonical_nodes:
compressed_map.setdefault(canonical_uuid, canonical_uuid)
nodes_by_episode: dict[str, list[EntityNode]] = {} nodes_by_episode: dict[str, list[EntityNode]] = {}
for i, nodes in enumerate(extracted_nodes): for episode_uuid, resolved_nodes in episode_resolutions:
episode = episode_tuples[i][0] deduped_nodes: list[EntityNode] = []
seen: set[str] = set()
for node in resolved_nodes:
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
if canonical_uuid in seen:
continue
seen.add(canonical_uuid)
canonical_node = canonical_nodes.get(canonical_uuid, node)
deduped_nodes.append(canonical_node)
nodes_by_episode[episode.uuid] = [ nodes_by_episode[episode_uuid] = deduped_nodes
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
]
return nodes_by_episode, compressed_map return nodes_by_episode, compressed_map

View file

@ -20,7 +20,7 @@ import math
import re import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from hashlib import blake2b from hashlib import blake2b
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -164,6 +164,7 @@ class DedupResolutionState:
resolved_nodes: list[EntityNode | None] resolved_nodes: list[EntityNode | None]
uuid_map: dict[str, str] uuid_map: dict[str, str]
unresolved_indices: list[int] unresolved_indices: list[int]
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes: def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
@ -213,6 +214,8 @@ def _resolve_with_similarity(
match = existing_matches[0] match = existing_matches[0]
state.resolved_nodes[idx] = match state.resolved_nodes[idx] = match
state.uuid_map[node.uuid] = match.uuid state.uuid_map[node.uuid] = match.uuid
if match.uuid != node.uuid:
state.duplicate_pairs.append((node, match))
continue continue
if len(existing_matches) > 1: if len(existing_matches) > 1:
state.unresolved_indices.append(idx) state.unresolved_indices.append(idx)
@ -236,6 +239,8 @@ def _resolve_with_similarity(
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD: if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
state.resolved_nodes[idx] = best_candidate state.resolved_nodes[idx] = best_candidate
state.uuid_map[node.uuid] = best_candidate.uuid state.uuid_map[node.uuid] = best_candidate.uuid
if best_candidate.uuid != node.uuid:
state.duplicate_pairs.append((node, best_candidate))
continue continue
state.unresolved_indices.append(idx) state.unresolved_indices.append(idx)

View file

@ -306,6 +306,8 @@ async def _resolve_with_llm(
state.resolved_nodes[original_index] = resolved_node state.resolved_nodes[original_index] = resolved_node
state.uuid_map[extracted_node.uuid] = resolved_node.uuid state.uuid_map[extracted_node.uuid] = resolved_node.uuid
if resolved_node.uuid != extracted_node.uuid:
state.duplicate_pairs.append((extracted_node, resolved_node))
async def resolve_extracted_nodes( async def resolve_extracted_nodes(
@ -332,7 +334,6 @@ async def resolve_extracted_nodes(
uuid_map={}, uuid_map={},
unresolved_indices=[], unresolved_indices=[],
) )
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
_resolve_with_similarity(extracted_nodes, indexes, state) _resolve_with_similarity(extracted_nodes, indexes, state)
@ -359,7 +360,7 @@ async def resolve_extracted_nodes(
new_node_duplicates: list[ new_node_duplicates: list[
tuple[EntityNode, EntityNode] tuple[EntityNode, EntityNode]
] = await filter_existing_duplicate_of_edges(driver, node_duplicates) ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
return ( return (
[node for node in state.resolved_nodes if node is not None], [node for node in state.resolved_nodes if node is not None],

View file

@ -0,0 +1,87 @@
from collections import deque
from unittest.mock import MagicMock
import pytest
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
from graphiti_core.utils.datetime_utils import utc_now
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
return EpisodicNode(
name=f'episode-{uuid_suffix}',
group_id=group_id,
labels=[],
source=EpisodeType.message,
content='content',
source_description='test',
created_at=utc_now(),
valid_at=utc_now(),
)
def _make_clients() -> GraphitiClients:
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
return GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
ensure_ascii=False,
)
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('1')
episode_two = _make_episode('2')
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
canonical = extracted_one
call_queue = deque()
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
call_queue.append(existing_nodes_override)
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
assert existing_nodes_override is not None
assert existing_nodes_override[0] is canonical
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert len(call_queue) == 2
assert call_queue[0] is None
assert list(call_queue[1]) == [canonical]
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(extracted_two.uuid) == canonical.uuid

View file

@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state():
assert state.resolved_nodes[0].uuid == candidate.uuid assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.unresolved_indices == [] assert state.unresolved_indices == []
assert state.duplicate_pairs == [(extracted, candidate)]
def test_resolve_with_similarity_low_entropy_defers_resolution(): def test_resolve_with_similarity_low_entropy_defers_resolution():
@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution():
assert state.resolved_nodes[0] is None assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0] assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm(): def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
assert state.resolved_nodes[0] is None assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0] assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
@pytest.mark.asyncio @pytest.mark.asyncio
@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
assert state.uuid_map[extracted.uuid] == candidate.uuid assert state.uuid_map[extracted.uuid] == candidate.uuid
assert captured_context['existing_nodes'][0]['idx'] == 0 assert captured_context['existing_nodes'][0]['idx'] == 0
assert isinstance(captured_context['existing_nodes'], list) assert isinstance(captured_context['existing_nodes'], list)
assert state.duplicate_pairs == [(extracted, candidate)]