From 152deb930df20020d9321f9c6040beddb5caaf6f Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 24 Sep 2025 21:16:08 -0700 Subject: [PATCH] implement deduplication helpers and integrate with node operations --- .../utils/maintenance/dedup_helpers.py | 253 +++++++++++++ .../utils/maintenance/node_operations.py | 337 +++++++++++------- tests/test_edge_int.py | 1 + tests/test_node_int.py | 2 + .../utils/maintenance/test_node_operations.py | 320 +++++++++++++++++ 5 files changed, 794 insertions(+), 119 deletions(-) create mode 100644 graphiti_core/utils/maintenance/dedup_helpers.py create mode 100644 tests/utils/maintenance/test_node_operations.py diff --git a/graphiti_core/utils/maintenance/dedup_helpers.py b/graphiti_core/utils/maintenance/dedup_helpers.py new file mode 100644 index 00000000..c5ee8024 --- /dev/null +++ b/graphiti_core/utils/maintenance/dedup_helpers.py @@ -0,0 +1,253 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import math +import re +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from functools import lru_cache +from hashlib import blake2b +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from graphiti_core.nodes import EntityNode + +_NAME_ENTROPY_THRESHOLD = 1.5 +_MIN_NAME_LENGTH = 6 +_MIN_TOKEN_COUNT = 2 +_FUZZY_JACCARD_THRESHOLD = 0.9 +_MINHASH_PERMUTATIONS = 32 +_MINHASH_BAND_SIZE = 4 + + +def _normalize_name_exact(name: str) -> str: + """Lowercase text and collapse whitespace so equal names map to the same key.""" + normalized = re.sub(r'[\s]+', ' ', name.lower()) + return normalized.strip() + + +def _normalize_name_for_fuzzy(name: str) -> str: + """Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles.""" + normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_name_exact(name)) + normalized = normalized.strip() + return re.sub(r'[\s]+', ' ', normalized) + + +def _name_entropy(normalized_name: str) -> float: + """Approximate text specificity using Shannon entropy over characters. + + We strip spaces, count how often each character appears, and sum + probability * -log2(probability). Short or repetitive names yield low + entropy, which signals we should defer resolution to the LLM instead of + trusting fuzzy similarity. + """ + if not normalized_name: + return 0.0 + + counts: dict[str, int] = {} + for char in normalized_name.replace(' ', ''): + counts[char] = counts.get(char, 0) + 1 + + total = sum(counts.values()) + if total == 0: + return 0.0 + + entropy = 0.0 + for count in counts.values(): + probability = count / total + entropy -= probability * math.log2(probability) + + return entropy + + +def _has_high_entropy(normalized_name: str) -> bool: + """Filter out very short or low-entropy names that are unreliable for fuzzy matching.""" + token_count = len(normalized_name.split()) + if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT: + return False + + return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD + + +def _shingles(normalized_name: str) -> set[str]: + """Create 3-gram shingles from the normalized name for MinHash calculations.""" + cleaned = normalized_name.replace(' ', '') + if len(cleaned) < 2: + return {cleaned} if cleaned else set() + + return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)} + + +def _hash_shingle(shingle: str, seed: int) -> int: + """Generate a deterministic 64-bit hash for a shingle given the permutation seed.""" + digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8) + return int.from_bytes(digest.digest(), 'big') + + +def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]: + """Compute the MinHash signature for the shingle set across predefined permutations.""" + if not shingles: + return tuple() + + seeds = range(_MINHASH_PERMUTATIONS) + signature: list[int] = [] + for seed in seeds: + min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles) + signature.append(min_hash) + + return tuple(signature) + + +def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]: + """Split the MinHash signature into fixed-size bands for locality-sensitive hashing.""" + signature_list = list(signature) + if not signature_list: + return [] + + bands: list[tuple[int, ...]] = [] + for start in range(0, len(signature_list), _MINHASH_BAND_SIZE): + band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE]) + if len(band) == _MINHASH_BAND_SIZE: + bands.append(band) + return bands + + +def _jaccard_similarity(a: set[str], b: set[str]) -> float: + """Return the Jaccard similarity between two shingle sets, handling empty edge cases.""" + if not a and not b: + return 1.0 + if not a or not b: + return 0.0 + + intersection = len(a.intersection(b)) + union = len(a.union(b)) + return intersection / union if union else 0.0 + + +@lru_cache(maxsize=512) +def _cached_shingles(name: str) -> set[str]: + """Cache shingle sets per normalized name to avoid recomputation within a worker.""" + return _shingles(name) + + +@dataclass +class DedupCandidateIndexes: + """Precomputed lookup structures that drive entity deduplication heuristics.""" + + existing_nodes: list[EntityNode] + normalized_existing: defaultdict[str, list[EntityNode]] + shingles_by_candidate: dict[str, set[str]] + lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] + + +@dataclass +class DedupResolutionState: + """Mutable resolution bookkeeping shared across deterministic and LLM passes.""" + + resolved_nodes: list[EntityNode | None] + uuid_map: dict[str, str] + unresolved_indices: list[int] + + +def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes: + """Precompute exact and fuzzy lookup structures once per dedupe run.""" + normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list) + shingles_by_candidate: dict[str, set[str]] = {} + lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list) + + for candidate in existing_nodes: + normalized = _normalize_name_exact(candidate.name) + normalized_existing[normalized].append(candidate) + + shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name)) + shingles_by_candidate[candidate.uuid] = shingles + + signature = _minhash_signature(shingles) + for band_index, band in enumerate(_lsh_bands(signature)): + lsh_buckets[(band_index, band)].append(candidate.uuid) + + return DedupCandidateIndexes( + existing_nodes=existing_nodes, + normalized_existing=normalized_existing, + shingles_by_candidate=shingles_by_candidate, + lsh_buckets=lsh_buckets, + ) + + +def _resolve_with_similarity( + extracted_nodes: list[EntityNode], + indexes: DedupCandidateIndexes, + state: DedupResolutionState, +) -> None: + """Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons.""" + for idx, node in enumerate(extracted_nodes): + normalized_exact = _normalize_name_exact(node.name) + normalized_fuzzy = _normalize_name_for_fuzzy(node.name) + + if not _has_high_entropy(normalized_fuzzy): + state.unresolved_indices.append(idx) + continue + + existing_matches = indexes.normalized_existing.get(normalized_exact, []) + if len(existing_matches) == 1: + match = existing_matches[0] + state.resolved_nodes[idx] = match + state.uuid_map[node.uuid] = match.uuid + continue + + shingles = _cached_shingles(normalized_fuzzy) + signature = _minhash_signature(shingles) + candidate_ids: set[str] = set() + for band_index, band in enumerate(_lsh_bands(signature)): + candidate_ids.update(indexes.lsh_buckets.get((band_index, band), [])) + + best_candidate: EntityNode | None = None + best_score = 0.0 + for candidate_id in candidate_ids: + candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set()) + score = _jaccard_similarity(shingles, candidate_shingles) + if score > best_score: + best_score = score + best_candidate = next( + (cand for cand in indexes.existing_nodes if cand.uuid == candidate_id), + None, + ) + + if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD: + state.resolved_nodes[idx] = best_candidate + state.uuid_map[node.uuid] = best_candidate.uuid + continue + + state.unresolved_indices.append(idx) + + +__all__ = [ + 'DedupCandidateIndexes', + 'DedupResolutionState', + '_normalize_name_exact', + '_normalize_name_for_fuzzy', + '_has_high_entropy', + '_minhash_signature', + '_lsh_bands', + '_jaccard_similarity', + '_cached_shingles', + '_FUZZY_JACCARD_THRESHOLD', + '_build_candidate_indexes', + '_resolve_with_similarity', +] diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index e58d8750..94773557 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -24,7 +24,12 @@ from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import ModelSize -from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings +from graphiti_core.nodes import ( + EntityNode, + EpisodeType, + EpisodicNode, + create_entity_node_embeddings, +) from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions from graphiti_core.prompts.extract_nodes import ( @@ -38,7 +43,15 @@ from graphiti_core.search.search_config import SearchResults from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF from graphiti_core.search.search_filters import SearchFilters from graphiti_core.utils.datetime_utils import utc_now -from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges +from graphiti_core.utils.maintenance.dedup_helpers import ( + DedupCandidateIndexes, + DedupResolutionState, + _build_candidate_indexes, + _resolve_with_similarity, +) +from graphiti_core.utils.maintenance.edge_operations import ( + filter_existing_duplicate_of_edges, +) logger = logging.getLogger(__name__) @@ -52,16 +65,16 @@ async def extract_nodes_reflexion( ) -> list[str]: # Prepare context for LLM context = { - 'episode_content': episode.content, - 'previous_episodes': [ep.content for ep in previous_episodes], - 'extracted_entities': node_names, - 'ensure_ascii': ensure_ascii, + "episode_content": episode.content, + "previous_episodes": [ep.content for ep in previous_episodes], + "extracted_entities": node_names, + "ensure_ascii": ensure_ascii, } llm_response = await llm_client.generate_response( prompt_library.extract_nodes.reflexion(context), MissedEntities ) - missed_entities = llm_response.get('missed_entities', []) + missed_entities = llm_response.get("missed_entities", []) return missed_entities @@ -76,24 +89,24 @@ async def extract_nodes( start = time() llm_client = clients.llm_client llm_response = {} - custom_prompt = '' + custom_prompt = "" entities_missed = True reflexion_iterations = 0 entity_types_context = [ { - 'entity_type_id': 0, - 'entity_type_name': 'Entity', - 'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.', + "entity_type_id": 0, + "entity_type_name": "Entity", + "entity_type_description": "Default entity classification. Use this entity type if the entity is not one of the other listed types.", } ] entity_types_context += ( [ { - 'entity_type_id': i + 1, - 'entity_type_name': type_name, - 'entity_type_description': type_model.__doc__, + "entity_type_id": i + 1, + "entity_type_name": type_name, + "entity_type_description": type_model.__doc__, } for i, (type_name, type_model) in enumerate(entity_types.items()) ] @@ -102,13 +115,13 @@ async def extract_nodes( ) context = { - 'episode_content': episode.content, - 'episode_timestamp': episode.valid_at.isoformat(), - 'previous_episodes': [ep.content for ep in previous_episodes], - 'custom_prompt': custom_prompt, - 'entity_types': entity_types_context, - 'source_description': episode.source_description, - 'ensure_ascii': clients.ensure_ascii, + "episode_content": episode.content, + "episode_timestamp": episode.valid_at.isoformat(), + "previous_episodes": [ep.content for ep in previous_episodes], + "custom_prompt": custom_prompt, + "entity_types": entity_types_context, + "source_description": episode.source_description, + "ensure_ascii": clients.ensure_ascii, } while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS: @@ -119,11 +132,13 @@ async def extract_nodes( ) elif episode.source == EpisodeType.text: llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities + prompt_library.extract_nodes.extract_text(context), + response_model=ExtractedEntities, ) elif episode.source == EpisodeType.json: llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities + prompt_library.extract_nodes.extract_json(context), + response_model=ExtractedEntities, ) response_object = ExtractedEntities(**llm_response) @@ -142,56 +157,57 @@ async def extract_nodes( entities_missed = len(missing_entities) != 0 - custom_prompt = 'Make sure that the following entities are extracted: ' + custom_prompt = "Make sure that the following entities are extracted: " for entity in missing_entities: - custom_prompt += f'\n{entity},' + custom_prompt += f"\n{entity}," - filtered_extracted_entities = [entity for entity in extracted_entities if entity.name.strip()] + filtered_extracted_entities = [ + entity for entity in extracted_entities if entity.name.strip() + ] end = time() - logger.debug(f'Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms') + logger.debug( + f"Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms" + ) # Convert the extracted data into EntityNode objects extracted_nodes = [] for extracted_entity in filtered_extracted_entities: type_id = extracted_entity.entity_type_id if 0 <= type_id < len(entity_types_context): - entity_type_name = entity_types_context[extracted_entity.entity_type_id].get( - 'entity_type_name' - ) + entity_type_name = entity_types_context[ + extracted_entity.entity_type_id + ].get("entity_type_name") else: - entity_type_name = 'Entity' + entity_type_name = "Entity" # Check if this entity type should be excluded if excluded_entity_types and entity_type_name in excluded_entity_types: - logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"') + logger.debug( + f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"' + ) continue - labels: list[str] = list({'Entity', str(entity_type_name)}) + labels: list[str] = list({"Entity", str(entity_type_name)}) new_node = EntityNode( name=extracted_entity.name, group_id=episode.group_id, labels=labels, - summary='', + summary="", created_at=utc_now(), ) extracted_nodes.append(new_node) - logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') + logger.debug(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") - logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') + logger.debug(f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}") return extracted_nodes -async def resolve_extracted_nodes( +async def _collect_candidate_nodes( clients: GraphitiClients, extracted_nodes: list[EntityNode], - episode: EpisodicNode | None = None, - previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, type[BaseModel]] | None = None, - existing_nodes_override: list[EntityNode] | None = None, -) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: - llm_client = clients.llm_client - driver = clients.driver - + existing_nodes_override: list[EntityNode] | None, +) -> list[EntityNode]: + """Search per extracted name and return unique candidates with overrides honored in order.""" search_results: list[SearchResults] = await semaphore_gather( *[ search( @@ -205,54 +221,79 @@ async def resolve_extracted_nodes( ] ) - candidate_nodes: list[EntityNode] = ( - [node for result in search_results for node in result.nodes] - if existing_nodes_override is None - else existing_nodes_override + candidate_nodes: list[EntityNode] = [ + node for result in search_results for node in result.nodes + ] + + if existing_nodes_override is not None: + candidate_nodes.extend(existing_nodes_override) + + seen_candidate_uuids: set[str] = set() + ordered_candidates: list[EntityNode] = [] + for candidate in candidate_nodes: + if candidate.uuid in seen_candidate_uuids: + continue + seen_candidate_uuids.add(candidate.uuid) + ordered_candidates.append(candidate) + + return ordered_candidates + + +async def _resolve_with_llm( + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + indexes: DedupCandidateIndexes, + state: DedupResolutionState, + ensure_ascii: bool, + episode: EpisodicNode | None, + previous_episodes: list[EpisodicNode] | None, + entity_types: dict[str, type[BaseModel]] | None, +) -> None: + """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.""" + if not state.unresolved_indices: + return + + entity_types_dict: dict[str, type[BaseModel]] = ( + entity_types if entity_types is not None else {} ) - existing_nodes_dict: dict[str, EntityNode] = {node.uuid: node for node in candidate_nodes} + llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices] - existing_nodes: list[EntityNode] = list(existing_nodes_dict.values()) - - existing_nodes_context = ( - [ - { - **{ - 'idx': i, - 'name': candidate.name, - 'entity_types': candidate.labels, - }, - **candidate.attributes, - } - for i, candidate in enumerate(existing_nodes) - ], - ) - - entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {} - - # Prepare context for LLM extracted_nodes_context = [ { - 'id': i, - 'name': node.name, - 'entity_type': node.labels, - 'entity_type_description': entity_types_dict.get( - next((item for item in node.labels if item != 'Entity'), '') + "id": i, + "name": node.name, + "entity_type": node.labels, + "entity_type_description": entity_types_dict.get( + next((item for item in node.labels if item != "Entity"), "") ).__doc__ - or 'Default Entity Type', + or "Default Entity Type", } - for i, node in enumerate(extracted_nodes) + for i, node in enumerate(llm_extracted_nodes) + ] + + existing_nodes_context = [ + { + **{ + "idx": i, + "name": candidate.name, + "entity_types": candidate.labels, + }, + **candidate.attributes, + } + for i, candidate in enumerate(indexes.existing_nodes) ] context = { - 'extracted_nodes': extracted_nodes_context, - 'existing_nodes': existing_nodes_context, - 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [], - 'ensure_ascii': clients.ensure_ascii, + "extracted_nodes": extracted_nodes_context, + "existing_nodes": existing_nodes_context, + "episode_content": episode.content if episode is not None else "", + "previous_episodes": ( + [ep.content for ep in previous_episodes] + if previous_episodes is not None + else [] + ), + "ensure_ascii": ensure_ascii, } llm_response = await llm_client.generate_response( @@ -260,35 +301,85 @@ async def resolve_extracted_nodes( response_model=NodeResolutions, ) - node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions + node_resolutions: list[NodeDuplicate] = NodeResolutions( + **llm_response + ).entity_resolutions - resolved_nodes: list[EntityNode] = [] - uuid_map: dict[str, str] = {} - node_duplicates: list[tuple[EntityNode, EntityNode]] = [] for resolution in node_resolutions: - resolution_id: int = resolution.id + relative_id: int = resolution.id duplicate_idx: int = resolution.duplicate_idx - extracted_node = extracted_nodes[resolution_id] + original_index = state.unresolved_indices[relative_id] + extracted_node = extracted_nodes[original_index] resolved_node = ( - existing_nodes[duplicate_idx] - if 0 <= duplicate_idx < len(existing_nodes) + indexes.existing_nodes[duplicate_idx] + if 0 <= duplicate_idx < len(indexes.existing_nodes) else extracted_node ) - # resolved_node.name = resolution.get('name') + state.resolved_nodes[original_index] = resolved_node + state.uuid_map[extracted_node.uuid] = resolved_node.uuid - resolved_nodes.append(resolved_node) - uuid_map[extracted_node.uuid] = resolved_node.uuid - logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') +async def resolve_extracted_nodes( + clients: GraphitiClients, + extracted_nodes: list[EntityNode], + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, + existing_nodes_override: list[EntityNode] | None = None, +) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: + """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt.""" + llm_client = clients.llm_client + driver = clients.driver + existing_nodes = await _collect_candidate_nodes( + clients, + extracted_nodes, + existing_nodes_override, + ) - new_node_duplicates: list[ - tuple[EntityNode, EntityNode] - ] = await filter_existing_duplicate_of_edges(driver, node_duplicates) + indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes) - return resolved_nodes, uuid_map, new_node_duplicates + state = DedupResolutionState( + resolved_nodes=[None] * len(extracted_nodes), + uuid_map={}, + unresolved_indices=[], + ) + node_duplicates: list[tuple[EntityNode, EntityNode]] = [] + + _resolve_with_similarity(extracted_nodes, indexes, state) + + await _resolve_with_llm( + llm_client, + extracted_nodes, + indexes, + state, + clients.ensure_ascii, + episode, + previous_episodes, + entity_types, + ) + + for idx, node in enumerate(extracted_nodes): + if state.resolved_nodes[idx] is None: + state.resolved_nodes[idx] = node + state.uuid_map[node.uuid] = node.uuid + + logger.debug( + "Resolved nodes: %s", + [(node.name, node.uuid) for node in state.resolved_nodes if node is not None], + ) + + new_node_duplicates: list[tuple[EntityNode, EntityNode]] = ( + await filter_existing_duplicate_of_edges(driver, node_duplicates) + ) + + return ( + [node for node in state.resolved_nodes if node is not None], + state.uuid_map, + new_node_duplicates, + ) async def extract_attributes_from_nodes( @@ -307,9 +398,13 @@ async def extract_attributes_from_nodes( node, episode, previous_episodes, - entity_types.get(next((item for item in node.labels if item != 'Entity'), '')) - if entity_types is not None - else None, + ( + entity_types.get( + next((item for item in node.labels if item != "Entity"), "") + ) + if entity_types is not None + else None + ), clients.ensure_ascii, ) for node in nodes @@ -330,28 +425,32 @@ async def extract_attributes_from_node( ensure_ascii: bool = False, ) -> EntityNode: node_context: dict[str, Any] = { - 'name': node.name, - 'summary': node.summary, - 'entity_types': node.labels, - 'attributes': node.attributes, + "name": node.name, + "summary": node.summary, + "entity_types": node.labels, + "attributes": node.attributes, } attributes_context: dict[str, Any] = { - 'node': node_context, - 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [], - 'ensure_ascii': ensure_ascii, + "node": node_context, + "episode_content": episode.content if episode is not None else "", + "previous_episodes": ( + [ep.content for ep in previous_episodes] + if previous_episodes is not None + else [] + ), + "ensure_ascii": ensure_ascii, } summary_context: dict[str, Any] = { - 'node': node_context, - 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [], - 'ensure_ascii': ensure_ascii, + "node": node_context, + "episode_content": episode.content if episode is not None else "", + "previous_episodes": ( + [ep.content for ep in previous_episodes] + if previous_episodes is not None + else [] + ), + "ensure_ascii": ensure_ascii, } has_entity_attributes: bool = bool( @@ -379,7 +478,7 @@ async def extract_attributes_from_node( if has_entity_attributes and entity_type is not None: entity_type(**llm_response) - node.summary = summary_response.get('summary', '') + node.summary = summary_response.get("summary", "") node_attributes = {key: value for key, value in llm_response.items()} node.attributes.update(node_attributes) diff --git a/tests/test_edge_int.py b/tests/test_edge_int.py index 15555d72..b028e452 100644 --- a/tests/test_edge_int.py +++ b/tests/test_edge_int.py @@ -26,6 +26,7 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, Episodic from tests.helpers_test import get_edge_count, get_node_count, group_id pytest_plugins = ('pytest_asyncio',) +pytestmark = pytest.mark.integration def setup_logging(): diff --git a/tests/test_node_int.py b/tests/test_node_int.py index 7e73b856..bdb8add8 100644 --- a/tests/test_node_int.py +++ b/tests/test_node_int.py @@ -33,6 +33,8 @@ from tests.helpers_test import ( group_id, ) +pytestmark = pytest.mark.integration + created_at = datetime.now() deleted_at = created_at + timedelta(days=3) valid_at = created_at + timedelta(days=1) diff --git a/tests/utils/maintenance/test_node_operations.py b/tests/utils/maintenance/test_node_operations.py new file mode 100644 index 00000000..58a8f27a --- /dev/null +++ b/tests/utils/maintenance/test_node_operations.py @@ -0,0 +1,320 @@ +from collections import defaultdict +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from graphiti_core.graphiti_types import GraphitiClients +from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.search.search_config import SearchResults +from graphiti_core.utils.datetime_utils import utc_now +from graphiti_core.utils.maintenance.dedup_helpers import ( + DedupCandidateIndexes, + DedupResolutionState, + _build_candidate_indexes, + _cached_shingles, + _has_high_entropy, + _hash_shingle, + _jaccard_similarity, + _lsh_bands, + _minhash_signature, + _name_entropy, + _normalize_name_exact, + _normalize_name_for_fuzzy, + _resolve_with_similarity, + _shingles, +) +from graphiti_core.utils.maintenance.node_operations import ( + _collect_candidate_nodes, + _resolve_with_llm, + resolve_extracted_nodes, +) + + +def _make_clients(): + driver = MagicMock() + embedder = MagicMock() + cross_encoder = MagicMock() + llm_client = MagicMock() + llm_generate = AsyncMock() + llm_client.generate_response = llm_generate + + clients = GraphitiClients.model_construct( # bypass validation to allow test doubles + driver=driver, + embedder=embedder, + cross_encoder=cross_encoder, + llm_client=llm_client, + ensure_ascii=False, + ) + + return clients, llm_generate + + +def _make_episode(group_id: str = 'group'): + return EpisodicNode( + name='episode', + group_id=group_id, + source=EpisodeType.message, + source_description='test', + content='content', + valid_at=utc_now(), + ) + + +@pytest.mark.asyncio +async def test_resolve_nodes_exact_match_skips_llm(monkeypatch): + clients, llm_generate = _make_clients() + + candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity']) + extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity']) + + async def fake_search(*_, **__): + return SearchResults(nodes=[candidate]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.search', + fake_search, + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges', + AsyncMock(return_value=[]), + ) + + resolved, uuid_map, _ = await resolve_extracted_nodes( + clients, + [extracted], + episode=_make_episode(), + previous_episodes=[], + ) + + assert resolved[0].uuid == candidate.uuid + assert uuid_map[extracted.uuid] == candidate.uuid + llm_generate.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch): + clients, llm_generate = _make_clients() + llm_generate.return_value = { + 'entity_resolutions': [ + { + 'id': 0, + 'duplicate_idx': -1, + 'name': 'Joe', + 'duplicates': [], + } + ] + } + + extracted = EntityNode(name='Joe', group_id='group', labels=['Entity']) + + async def fake_search(*_, **__): + return SearchResults(nodes=[]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.search', + fake_search, + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges', + AsyncMock(return_value=[]), + ) + + resolved, uuid_map, _ = await resolve_extracted_nodes( + clients, + [extracted], + episode=_make_episode(), + previous_episodes=[], + ) + + assert resolved[0].uuid == extracted.uuid + assert uuid_map[extracted.uuid] == extracted.uuid + llm_generate.assert_awaited() + + +@pytest.mark.asyncio +async def test_resolve_nodes_fuzzy_match(monkeypatch): + clients, llm_generate = _make_clients() + + candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity']) + extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity']) + + async def fake_search(*_, **__): + return SearchResults(nodes=[candidate]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.search', + fake_search, + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges', + AsyncMock(return_value=[]), + ) + + resolved, uuid_map, _ = await resolve_extracted_nodes( + clients, + [extracted], + episode=_make_episode(), + previous_episodes=[], + ) + + assert resolved[0].uuid == candidate.uuid + assert uuid_map[extracted.uuid] == candidate.uuid + llm_generate.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch): + clients, _ = _make_clients() + + candidate = EntityNode(name='Alice', group_id='group', labels=['Entity']) + override_duplicate = EntityNode( + uuid=candidate.uuid, + name='Alice Alt', + group_id='group', + labels=['Entity'], + ) + extracted = EntityNode(name='Alice', group_id='group', labels=['Entity']) + + search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate])) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.search', + search_mock, + ) + + result = await _collect_candidate_nodes( + clients, + [extracted], + existing_nodes_override=[override_duplicate], + ) + + assert len(result) == 1 + assert result[0].uuid == candidate.uuid + search_mock.assert_awaited() + + +def test_build_candidate_indexes_populates_structures(): + candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([candidate]) + + normalized_key = candidate.name.lower() + assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid + assert candidate.uuid in indexes.shingles_by_candidate + assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values()) + + +def test_normalize_helpers(): + assert _normalize_name_exact(' Alice Smith ') == 'alice smith' + assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith' + + +def test_name_entropy_variants(): + assert _name_entropy('alice') > _name_entropy('aaaaa') + assert _name_entropy('') == 0.0 + + +def test_has_high_entropy_rules(): + assert _has_high_entropy('meaningful name') is True + assert _has_high_entropy('aa') is False + + +def test_shingles_and_cache(): + raw = 'alice' + shingle_set = _shingles(raw) + assert shingle_set == {'ali', 'lic', 'ice'} + assert _cached_shingles(raw) == shingle_set + assert _cached_shingles(raw) is _cached_shingles(raw) + + +def test_hash_minhash_and_lsh(): + shingles = {'abc', 'bcd', 'cde'} + signature = _minhash_signature(shingles) + assert len(signature) == 32 + bands = _lsh_bands(signature) + assert all(len(band) == 4 for band in bands) + hashed = {_hash_shingle(s, 0) for s in shingles} + assert len(hashed) == len(shingles) + + +def test_jaccard_similarity_edges(): + a = {'a', 'b'} + b = {'a', 'c'} + assert _jaccard_similarity(a, b) == pytest.approx(1 / 3) + assert _jaccard_similarity(set(), set()) == 1.0 + assert _jaccard_similarity(a, set()) == 0.0 + + +def test_resolve_with_similarity_exact_match_updates_state(): + candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity']) + extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([candidate]) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[]) + + _resolve_with_similarity([extracted], indexes, state) + + assert state.resolved_nodes[0].uuid == candidate.uuid + assert state.uuid_map[extracted.uuid] == candidate.uuid + assert state.unresolved_indices == [] + + +def test_resolve_with_similarity_low_entropy_defers_resolution(): + extracted = EntityNode(name='Bob', group_id='group', labels=['Entity']) + indexes = DedupCandidateIndexes([], defaultdict(list), {}, defaultdict(list)) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[]) + + _resolve_with_similarity([extracted], indexes, state) + + assert state.resolved_nodes[0] is None + assert state.unresolved_indices == [0] + + +@pytest.mark.asyncio +async def test_resolve_with_llm_updates_unresolved(monkeypatch): + extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity']) + candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([candidate]) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0]) + + captured_context = {} + + def fake_prompt_nodes(context): + captured_context.update(context) + return ['prompt'] + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes', + fake_prompt_nodes, + ) + + async def fake_generate_response(*_, **__): + return { + 'entity_resolutions': [ + { + 'id': 0, + 'duplicate_idx': 0, + 'name': 'Dizzy Gillespie', + 'duplicates': [0], + } + ] + } + + llm_client = MagicMock() + llm_client.generate_response = AsyncMock(side_effect=fake_generate_response) + + await _resolve_with_llm( + llm_client, + [extracted], + indexes, + state, + ensure_ascii=False, + episode=_make_episode(), + previous_episodes=[], + entity_types=None, + ) + + assert state.resolved_nodes[0].uuid == candidate.uuid + assert state.uuid_map[extracted.uuid] == candidate.uuid + assert captured_context['existing_nodes'][0]['idx'] == 0 + assert isinstance(captured_context['existing_nodes'], list)