From 18550c84ac6dcaf9593cc8f5cafb039fcc08930a Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:32:03 -0700 Subject: [PATCH] Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929) * Simplify deduplication process in bulk_utils by reusing canonical nodes. * Update dedup_helpers to store duplicate pairs during resolution. * Modify node_operations to append duplicate pairs when resolving nodes. * Add tests to verify deduplication behavior and ensure correct state updates. --- graphiti_core/utils/bulk_utils.py | 127 ++++++++---------- .../utils/maintenance/dedup_helpers.py | 7 +- .../utils/maintenance/node_operations.py | 5 +- tests/utils/maintenance/test_bulk_utils.py | 87 ++++++++++++ .../utils/maintenance/test_node_operations.py | 4 + 5 files changed, 155 insertions(+), 75 deletions(-) create mode 100644 tests/utils/maintenance/test_bulk_utils.py diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 181cda74..ded7cc0e 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -43,7 +43,7 @@ from graphiti_core.models.nodes.node_db_queries import ( get_entity_node_save_bulk_query, get_episode_node_save_bulk_query, ) -from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings +from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings from graphiti_core.utils.maintenance.edge_operations import ( extract_edges, @@ -266,83 +266,66 @@ async def dedupe_nodes_bulk( episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], entity_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]: - embedder = clients.embedder - min_score = 0.8 - - # generate embeddings - await semaphore_gather( - *[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes] - ) - - # Find similar results - dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = [] - for i, nodes_i in enumerate(extracted_nodes): - existing_nodes: list[EntityNode] = [] - for j, nodes_j in enumerate(extracted_nodes): - if i == j: - continue - existing_nodes += nodes_j - - candidates_i: list[EntityNode] = [] - for node in nodes_i: - for existing_node in existing_nodes: - # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices) - # This approach will cast a wider net than BM25, which is ideal for this use case - node_words = set(node.name.lower().split()) - existing_node_words = set(existing_node.name.lower().split()) - has_overlap = not node_words.isdisjoint(existing_node_words) - if has_overlap: - candidates_i.append(existing_node) - continue - - # Check for semantic similarity even if there is no overlap - similarity = np.dot( - normalize_l2(node.name_embedding or []), - normalize_l2(existing_node.name_embedding or []), - ) - if similarity >= min_score: - candidates_i.append(existing_node) - - dedupe_tuples.append((nodes_i, candidates_i)) - - # Determine Node Resolutions - bulk_node_resolutions: list[ - tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]] - ] = await semaphore_gather( - *[ - resolve_extracted_nodes( - clients, - dedupe_tuple[0], - episode_tuples[i][0], - episode_tuples[i][1], - entity_types, - existing_nodes_override=dedupe_tuples[i][1], - ) - for i, dedupe_tuple in enumerate(dedupe_tuples) - ] - ) - - # Collect all duplicate pairs sorted by uuid + canonical_nodes: dict[str, EntityNode] = {} + episode_resolutions: list[tuple[str, list[EntityNode]]] = [] + per_episode_uuid_maps: list[dict[str, str]] = [] duplicate_pairs: list[tuple[str, str]] = [] - for _, _, duplicates in bulk_node_resolutions: - for duplicate in duplicates: - n, m = duplicate - duplicate_pairs.append((n.uuid, m.uuid)) - # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid) - compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs) + for nodes, (episode, previous_episodes) in zip(extracted_nodes, episode_tuples, strict=True): + existing_override = list(canonical_nodes.values()) if canonical_nodes else None - node_uuid_map: dict[str, EntityNode] = { - node.uuid: node for nodes in extracted_nodes for node in nodes - } + resolved_nodes, uuid_map, duplicates = await resolve_extracted_nodes( + clients, + nodes, + episode, + previous_episodes, + entity_types, + existing_nodes_override=existing_override, + ) + + per_episode_uuid_maps.append(uuid_map) + episode_resolutions.append((episode.uuid, resolved_nodes)) + + for node in resolved_nodes: + canonical_nodes[node.uuid] = node + + duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates) + + uuid_chains: dict[str, str] = {} + for uuid_map in per_episode_uuid_maps: + uuid_chains.update(uuid_map) + for duplicate_uuid, canonical_uuid in duplicate_pairs: + uuid_chains[duplicate_uuid] = canonical_uuid + + def _resolve_uuid(uuid: str) -> str: + seen: set[str] = set() + current = uuid + while True: + target = uuid_chains.get(current) + if target is None or target == current: + return current if target is None else target + if current in seen: + return current + seen.add(current) + current = target + + compressed_map: dict[str, str] = {uuid: _resolve_uuid(uuid) for uuid in uuid_chains} + for canonical_uuid in canonical_nodes: + compressed_map.setdefault(canonical_uuid, canonical_uuid) nodes_by_episode: dict[str, list[EntityNode]] = {} - for i, nodes in enumerate(extracted_nodes): - episode = episode_tuples[i][0] + for episode_uuid, resolved_nodes in episode_resolutions: + deduped_nodes: list[EntityNode] = [] + seen: set[str] = set() + for node in resolved_nodes: + canonical_uuid = compressed_map.get(node.uuid, node.uuid) + if canonical_uuid in seen: + continue + seen.add(canonical_uuid) + canonical_node = canonical_nodes.get(canonical_uuid, node) + deduped_nodes.append(canonical_node) - nodes_by_episode[episode.uuid] = [ - node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes - ] + nodes_by_episode[episode_uuid] = deduped_nodes return nodes_by_episode, compressed_map diff --git a/graphiti_core/utils/maintenance/dedup_helpers.py b/graphiti_core/utils/maintenance/dedup_helpers.py index 4916331e..b8ce68b8 100644 --- a/graphiti_core/utils/maintenance/dedup_helpers.py +++ b/graphiti_core/utils/maintenance/dedup_helpers.py @@ -20,7 +20,7 @@ import math import re from collections import defaultdict from collections.abc import Iterable -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache from hashlib import blake2b from typing import TYPE_CHECKING @@ -164,6 +164,7 @@ class DedupResolutionState: resolved_nodes: list[EntityNode | None] uuid_map: dict[str, str] unresolved_indices: list[int] + duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list) def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes: @@ -213,6 +214,8 @@ def _resolve_with_similarity( match = existing_matches[0] state.resolved_nodes[idx] = match state.uuid_map[node.uuid] = match.uuid + if match.uuid != node.uuid: + state.duplicate_pairs.append((node, match)) continue if len(existing_matches) > 1: state.unresolved_indices.append(idx) @@ -236,6 +239,8 @@ def _resolve_with_similarity( if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD: state.resolved_nodes[idx] = best_candidate state.uuid_map[node.uuid] = best_candidate.uuid + if best_candidate.uuid != node.uuid: + state.duplicate_pairs.append((node, best_candidate)) continue state.unresolved_indices.append(idx) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 693609d8..26bc23d9 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -306,6 +306,8 @@ async def _resolve_with_llm( state.resolved_nodes[original_index] = resolved_node state.uuid_map[extracted_node.uuid] = resolved_node.uuid + if resolved_node.uuid != extracted_node.uuid: + state.duplicate_pairs.append((extracted_node, resolved_node)) async def resolve_extracted_nodes( @@ -332,7 +334,6 @@ async def resolve_extracted_nodes( uuid_map={}, unresolved_indices=[], ) - node_duplicates: list[tuple[EntityNode, EntityNode]] = [] _resolve_with_similarity(extracted_nodes, indexes, state) @@ -359,7 +360,7 @@ async def resolve_extracted_nodes( new_node_duplicates: list[ tuple[EntityNode, EntityNode] - ] = await filter_existing_duplicate_of_edges(driver, node_duplicates) + ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs) return ( [node for node in state.resolved_nodes if node is not None], diff --git a/tests/utils/maintenance/test_bulk_utils.py b/tests/utils/maintenance/test_bulk_utils.py new file mode 100644 index 00000000..83c1eb19 --- /dev/null +++ b/tests/utils/maintenance/test_bulk_utils.py @@ -0,0 +1,87 @@ +from collections import deque +from unittest.mock import MagicMock + +import pytest + +from graphiti_core.graphiti_types import GraphitiClients +from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.utils import bulk_utils +from graphiti_core.utils.datetime_utils import utc_now + + +def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode: + return EpisodicNode( + name=f'episode-{uuid_suffix}', + group_id=group_id, + labels=[], + source=EpisodeType.message, + content='content', + source_description='test', + created_at=utc_now(), + valid_at=utc_now(), + ) + + +def _make_clients() -> GraphitiClients: + driver = MagicMock() + embedder = MagicMock() + cross_encoder = MagicMock() + llm_client = MagicMock() + + return GraphitiClients.model_construct( # bypass validation to allow test doubles + driver=driver, + embedder=embedder, + cross_encoder=cross_encoder, + llm_client=llm_client, + ensure_ascii=False, + ) + + +@pytest.mark.asyncio +async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch): + clients = _make_clients() + + episode_one = _make_episode('1') + episode_two = _make_episode('2') + + extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity']) + extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity']) + + canonical = extracted_one + + call_queue = deque() + + async def fake_resolve( + clients_arg, + nodes_arg, + episode_arg, + previous_episodes_arg, + entity_types_arg, + existing_nodes_override=None, + ): + call_queue.append(existing_nodes_override) + + if nodes_arg == [extracted_one]: + return [canonical], {canonical.uuid: canonical.uuid}, [] + + assert nodes_arg == [extracted_two] + assert existing_nodes_override is not None + assert existing_nodes_override[0] is canonical + + return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)] + + monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve) + + nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk( + clients, + [[extracted_one], [extracted_two]], + [(episode_one, []), (episode_two, [])], + ) + + assert len(call_queue) == 2 + assert call_queue[0] is None + assert list(call_queue[1]) == [canonical] + + assert nodes_by_episode[episode_one.uuid] == [canonical] + assert nodes_by_episode[episode_two.uuid] == [canonical] + assert compressed_map.get(extracted_two.uuid) == canonical.uuid diff --git a/tests/utils/maintenance/test_node_operations.py b/tests/utils/maintenance/test_node_operations.py index a7250559..b2fc30b0 100644 --- a/tests/utils/maintenance/test_node_operations.py +++ b/tests/utils/maintenance/test_node_operations.py @@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state(): assert state.resolved_nodes[0].uuid == candidate.uuid assert state.uuid_map[extracted.uuid] == candidate.uuid assert state.unresolved_indices == [] + assert state.duplicate_pairs == [(extracted, candidate)] def test_resolve_with_similarity_low_entropy_defers_resolution(): @@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution(): assert state.resolved_nodes[0] is None assert state.unresolved_indices == [0] + assert state.duplicate_pairs == [] def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm(): @@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm(): assert state.resolved_nodes[0] is None assert state.unresolved_indices == [0] + assert state.duplicate_pairs == [] @pytest.mark.asyncio @@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch): assert state.uuid_map[extracted.uuid] == candidate.uuid assert captured_context['existing_nodes'][0]['idx'] == 0 assert isinstance(captured_context['existing_nodes'], list) + assert state.duplicate_pairs == [(extracted, candidate)]