From 7c469e8e2b45f24dd86f986f52117396695a44ef Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Thu, 25 Sep 2025 07:13:19 -0700 Subject: [PATCH] Improve node deduplication w/ deterministic matching, LLM fallbacks (#929) * add repository guidelines and project structure documentation * update neo4j image version and modify test command to disable specific databases * implement deduplication helpers and integrate with node operations * refactor string formatting to use single quotes in node operations * enhance deduplication helpers with UUID indexing and update resolution logic * implement exact fact matching (#931) --- AGENTS.md | 21 ++ Makefile | 4 +- docker-compose.test.yml | 2 +- .../utils/maintenance/dedup_helpers.py | 257 +++++++++++++ .../utils/maintenance/edge_operations.py | 14 + .../utils/maintenance/node_operations.py | 199 +++++++--- tests/test_edge_int.py | 1 + tests/test_node_int.py | 2 + .../utils/maintenance/test_edge_operations.py | 50 +++ .../utils/maintenance/test_node_operations.py | 341 ++++++++++++++++++ 10 files changed, 828 insertions(+), 63 deletions(-) create mode 100644 AGENTS.md create mode 100644 graphiti_core/utils/maintenance/dedup_helpers.py create mode 100644 tests/utils/maintenance/test_node_operations.py diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..cc88527e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,21 @@ +# Repository Guidelines + +## Project Structure & Module Organization +Graphiti's core library lives under `graphiti_core/`, split into domain modules such as `nodes.py`, `edges.py`, `models/`, and `search/` for retrieval pipelines. Service adapters and API glue reside in `server/graph_service/`, while the MCP integration lives in `mcp_server/`. Shared assets and collateral sit in `images/` and `examples/`. Tests cover the package via `tests/`, with configuration in `conftest.py`, `pytest.ini`, and Docker compose files for optional services. Tooling manifests live at the repo root, including `pyproject.toml`, `Makefile`, and deployment compose files. + +## Build, Test, and Development Commands +- `uv sync --extra dev`: install the dev environment declared in `pyproject.toml`. +- `make format`: run `ruff` to sort imports and apply the canonical formatter. +- `make lint`: execute `ruff` plus `pyright` type checks against `graphiti_core`. +- `make test`: run the full `pytest` suite (`uv run pytest`). +- `uv run pytest tests/path/test_file.py`: target a specific module or test selection. +- `docker-compose -f docker-compose.test.yml up`: provision local graph/search dependencies for integration flows. + +## Coding Style & Naming Conventions +Python code uses 4-space indentation, 100-character lines, and prefers single quotes as configured in `pyproject.toml`. Modules, files, and functions stay snake_case; Pydantic models in `graphiti_core/models` use PascalCase with explicit type hints. Keep side-effectful code inside drivers or adapters (`graphiti_core/driver`, `graphiti_core/utils`) and rely on pure helpers elsewhere. Run `make format` before committing to normalize imports and docstring formatting. + +## Testing Guidelines +Author tests alongside features under `tests/`, naming files `test_.py` and functions `test_`. Use `@pytest.mark.integration` for database-reliant scenarios so CI can gate them. Reproduce regressions with a failing test first and validate fixes via `uv run pytest -k "pattern"`. Start required backing services through `docker-compose.test.yml` when running integration suites locally. + +## Commit & Pull Request Guidelines +Commits use an imperative, present-tense summary (for example, `add async cache invalidation`) optionally suffixed with the PR number as seen in history (`(#927)`). Squash fixups and keep unrelated changes isolated. Pull requests should include: a concise description, linked tracking issue, notes about schema or API impacts, and screenshots or logs when behavior changes. Confirm `make lint` and `make test` pass locally, and update docs or examples when public interfaces shift. diff --git a/Makefile b/Makefile index de6e6f53..c2a9c8d8 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ lint: # Run tests test: - $(PYTEST) + DISABLE_FALKORDB=1 DISABLE_KUZU=1 DISABLE_NEPTUNE=1 $(PYTEST) -m "not integration" # Run format, lint, and test -check: format lint test \ No newline at end of file +check: format lint test diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 8ed07501..0bcd2cc6 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -25,7 +25,7 @@ services: - PORT=8000 neo4j: - image: neo4j:5.22.0 + image: neo4j:5.26.2 ports: - "7474:7474" - "${NEO4J_PORT}:${NEO4J_PORT}" diff --git a/graphiti_core/utils/maintenance/dedup_helpers.py b/graphiti_core/utils/maintenance/dedup_helpers.py new file mode 100644 index 00000000..4916331e --- /dev/null +++ b/graphiti_core/utils/maintenance/dedup_helpers.py @@ -0,0 +1,257 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import math +import re +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from functools import lru_cache +from hashlib import blake2b +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from graphiti_core.nodes import EntityNode + +_NAME_ENTROPY_THRESHOLD = 1.5 +_MIN_NAME_LENGTH = 6 +_MIN_TOKEN_COUNT = 2 +_FUZZY_JACCARD_THRESHOLD = 0.9 +_MINHASH_PERMUTATIONS = 32 +_MINHASH_BAND_SIZE = 4 + + +def _normalize_string_exact(name: str) -> str: + """Lowercase text and collapse whitespace so equal names map to the same key.""" + normalized = re.sub(r'[\s]+', ' ', name.lower()) + return normalized.strip() + + +def _normalize_name_for_fuzzy(name: str) -> str: + """Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles.""" + normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name)) + normalized = normalized.strip() + return re.sub(r'[\s]+', ' ', normalized) + + +def _name_entropy(normalized_name: str) -> float: + """Approximate text specificity using Shannon entropy over characters. + + We strip spaces, count how often each character appears, and sum + probability * -log2(probability). Short or repetitive names yield low + entropy, which signals we should defer resolution to the LLM instead of + trusting fuzzy similarity. + """ + if not normalized_name: + return 0.0 + + counts: dict[str, int] = {} + for char in normalized_name.replace(' ', ''): + counts[char] = counts.get(char, 0) + 1 + + total = sum(counts.values()) + if total == 0: + return 0.0 + + entropy = 0.0 + for count in counts.values(): + probability = count / total + entropy -= probability * math.log2(probability) + + return entropy + + +def _has_high_entropy(normalized_name: str) -> bool: + """Filter out very short or low-entropy names that are unreliable for fuzzy matching.""" + token_count = len(normalized_name.split()) + if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT: + return False + + return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD + + +def _shingles(normalized_name: str) -> set[str]: + """Create 3-gram shingles from the normalized name for MinHash calculations.""" + cleaned = normalized_name.replace(' ', '') + if len(cleaned) < 2: + return {cleaned} if cleaned else set() + + return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)} + + +def _hash_shingle(shingle: str, seed: int) -> int: + """Generate a deterministic 64-bit hash for a shingle given the permutation seed.""" + digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8) + return int.from_bytes(digest.digest(), 'big') + + +def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]: + """Compute the MinHash signature for the shingle set across predefined permutations.""" + if not shingles: + return tuple() + + seeds = range(_MINHASH_PERMUTATIONS) + signature: list[int] = [] + for seed in seeds: + min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles) + signature.append(min_hash) + + return tuple(signature) + + +def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]: + """Split the MinHash signature into fixed-size bands for locality-sensitive hashing.""" + signature_list = list(signature) + if not signature_list: + return [] + + bands: list[tuple[int, ...]] = [] + for start in range(0, len(signature_list), _MINHASH_BAND_SIZE): + band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE]) + if len(band) == _MINHASH_BAND_SIZE: + bands.append(band) + return bands + + +def _jaccard_similarity(a: set[str], b: set[str]) -> float: + """Return the Jaccard similarity between two shingle sets, handling empty edge cases.""" + if not a and not b: + return 1.0 + if not a or not b: + return 0.0 + + intersection = len(a.intersection(b)) + union = len(a.union(b)) + return intersection / union if union else 0.0 + + +@lru_cache(maxsize=512) +def _cached_shingles(name: str) -> set[str]: + """Cache shingle sets per normalized name to avoid recomputation within a worker.""" + return _shingles(name) + + +@dataclass +class DedupCandidateIndexes: + """Precomputed lookup structures that drive entity deduplication heuristics.""" + + existing_nodes: list[EntityNode] + nodes_by_uuid: dict[str, EntityNode] + normalized_existing: defaultdict[str, list[EntityNode]] + shingles_by_candidate: dict[str, set[str]] + lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] + + +@dataclass +class DedupResolutionState: + """Mutable resolution bookkeeping shared across deterministic and LLM passes.""" + + resolved_nodes: list[EntityNode | None] + uuid_map: dict[str, str] + unresolved_indices: list[int] + + +def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes: + """Precompute exact and fuzzy lookup structures once per dedupe run.""" + normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list) + nodes_by_uuid: dict[str, EntityNode] = {} + shingles_by_candidate: dict[str, set[str]] = {} + lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list) + + for candidate in existing_nodes: + normalized = _normalize_string_exact(candidate.name) + normalized_existing[normalized].append(candidate) + nodes_by_uuid[candidate.uuid] = candidate + + shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name)) + shingles_by_candidate[candidate.uuid] = shingles + + signature = _minhash_signature(shingles) + for band_index, band in enumerate(_lsh_bands(signature)): + lsh_buckets[(band_index, band)].append(candidate.uuid) + + return DedupCandidateIndexes( + existing_nodes=existing_nodes, + nodes_by_uuid=nodes_by_uuid, + normalized_existing=normalized_existing, + shingles_by_candidate=shingles_by_candidate, + lsh_buckets=lsh_buckets, + ) + + +def _resolve_with_similarity( + extracted_nodes: list[EntityNode], + indexes: DedupCandidateIndexes, + state: DedupResolutionState, +) -> None: + """Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons.""" + for idx, node in enumerate(extracted_nodes): + normalized_exact = _normalize_string_exact(node.name) + normalized_fuzzy = _normalize_name_for_fuzzy(node.name) + + if not _has_high_entropy(normalized_fuzzy): + state.unresolved_indices.append(idx) + continue + + existing_matches = indexes.normalized_existing.get(normalized_exact, []) + if len(existing_matches) == 1: + match = existing_matches[0] + state.resolved_nodes[idx] = match + state.uuid_map[node.uuid] = match.uuid + continue + if len(existing_matches) > 1: + state.unresolved_indices.append(idx) + continue + + shingles = _cached_shingles(normalized_fuzzy) + signature = _minhash_signature(shingles) + candidate_ids: set[str] = set() + for band_index, band in enumerate(_lsh_bands(signature)): + candidate_ids.update(indexes.lsh_buckets.get((band_index, band), [])) + + best_candidate: EntityNode | None = None + best_score = 0.0 + for candidate_id in candidate_ids: + candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set()) + score = _jaccard_similarity(shingles, candidate_shingles) + if score > best_score: + best_score = score + best_candidate = indexes.nodes_by_uuid.get(candidate_id) + + if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD: + state.resolved_nodes[idx] = best_candidate + state.uuid_map[node.uuid] = best_candidate.uuid + continue + + state.unresolved_indices.append(idx) + + +__all__ = [ + 'DedupCandidateIndexes', + 'DedupResolutionState', + '_normalize_string_exact', + '_normalize_name_for_fuzzy', + '_has_high_entropy', + '_minhash_signature', + '_lsh_bands', + '_jaccard_similarity', + '_cached_shingles', + '_FUZZY_JACCARD_THRESHOLD', + '_build_candidate_indexes', + '_resolve_with_similarity', +] diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 259c1db3..4069a0bd 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -41,6 +41,7 @@ from graphiti_core.search.search_config import SearchResults from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF from graphiti_core.search.search_filters import SearchFilters from graphiti_core.utils.datetime_utils import ensure_utc, utc_now +from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact logger = logging.getLogger(__name__) @@ -397,6 +398,19 @@ async def resolve_extracted_edge( if len(related_edges) == 0 and len(existing_edges) == 0: return extracted_edge, [], [] + # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge. + normalized_fact = _normalize_string_exact(extracted_edge.fact) + for edge in related_edges: + if ( + edge.source_node_uuid == extracted_edge.source_node_uuid + and edge.target_node_uuid == extracted_edge.target_node_uuid + and _normalize_string_exact(edge.fact) == normalized_fact + ): + resolved = edge + if episode is not None and episode.uuid not in resolved.episodes: + resolved.episodes.append(episode.uuid) + return resolved, [], [] + start = time() # Prepare context for LLM diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index e58d8750..693609d8 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -24,7 +24,12 @@ from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import ModelSize -from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings +from graphiti_core.nodes import ( + EntityNode, + EpisodeType, + EpisodicNode, + create_entity_node_embeddings, +) from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions from graphiti_core.prompts.extract_nodes import ( @@ -38,7 +43,15 @@ from graphiti_core.search.search_config import SearchResults from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF from graphiti_core.search.search_filters import SearchFilters from graphiti_core.utils.datetime_utils import utc_now -from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges +from graphiti_core.utils.maintenance.dedup_helpers import ( + DedupCandidateIndexes, + DedupResolutionState, + _build_candidate_indexes, + _resolve_with_similarity, +) +from graphiti_core.utils.maintenance.edge_operations import ( + filter_existing_duplicate_of_edges, +) logger = logging.getLogger(__name__) @@ -119,11 +132,13 @@ async def extract_nodes( ) elif episode.source == EpisodeType.text: llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities + prompt_library.extract_nodes.extract_text(context), + response_model=ExtractedEntities, ) elif episode.source == EpisodeType.json: llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities + prompt_library.extract_nodes.extract_json(context), + response_model=ExtractedEntities, ) response_object = ExtractedEntities(**llm_response) @@ -181,17 +196,12 @@ async def extract_nodes( return extracted_nodes -async def resolve_extracted_nodes( +async def _collect_candidate_nodes( clients: GraphitiClients, extracted_nodes: list[EntityNode], - episode: EpisodicNode | None = None, - previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, type[BaseModel]] | None = None, - existing_nodes_override: list[EntityNode] | None = None, -) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: - llm_client = clients.llm_client - driver = clients.driver - + existing_nodes_override: list[EntityNode] | None, +) -> list[EntityNode]: + """Search per extracted name and return unique candidates with overrides honored in order.""" search_results: list[SearchResults] = await semaphore_gather( *[ search( @@ -205,33 +215,40 @@ async def resolve_extracted_nodes( ] ) - candidate_nodes: list[EntityNode] = ( - [node for result in search_results for node in result.nodes] - if existing_nodes_override is None - else existing_nodes_override - ) + candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes] - existing_nodes_dict: dict[str, EntityNode] = {node.uuid: node for node in candidate_nodes} + if existing_nodes_override is not None: + candidate_nodes.extend(existing_nodes_override) - existing_nodes: list[EntityNode] = list(existing_nodes_dict.values()) + seen_candidate_uuids: set[str] = set() + ordered_candidates: list[EntityNode] = [] + for candidate in candidate_nodes: + if candidate.uuid in seen_candidate_uuids: + continue + seen_candidate_uuids.add(candidate.uuid) + ordered_candidates.append(candidate) - existing_nodes_context = ( - [ - { - **{ - 'idx': i, - 'name': candidate.name, - 'entity_types': candidate.labels, - }, - **candidate.attributes, - } - for i, candidate in enumerate(existing_nodes) - ], - ) + return ordered_candidates + + +async def _resolve_with_llm( + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + indexes: DedupCandidateIndexes, + state: DedupResolutionState, + ensure_ascii: bool, + episode: EpisodicNode | None, + previous_episodes: list[EpisodicNode] | None, + entity_types: dict[str, type[BaseModel]] | None, +) -> None: + """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.""" + if not state.unresolved_indices: + return entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {} - # Prepare context for LLM + llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices] + extracted_nodes_context = [ { 'id': i, @@ -242,17 +259,29 @@ async def resolve_extracted_nodes( ).__doc__ or 'Default Entity Type', } - for i, node in enumerate(extracted_nodes) + for i, node in enumerate(llm_extracted_nodes) + ] + + existing_nodes_context = [ + { + **{ + 'idx': i, + 'name': candidate.name, + 'entity_types': candidate.labels, + }, + **candidate.attributes, + } + for i, candidate in enumerate(indexes.existing_nodes) ] context = { 'extracted_nodes': extracted_nodes_context, 'existing_nodes': existing_nodes_context, 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [], - 'ensure_ascii': clients.ensure_ascii, + 'previous_episodes': ( + [ep.content for ep in previous_episodes] if previous_episodes is not None else [] + ), + 'ensure_ascii': ensure_ascii, } llm_response = await llm_client.generate_response( @@ -262,33 +291,81 @@ async def resolve_extracted_nodes( node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions - resolved_nodes: list[EntityNode] = [] - uuid_map: dict[str, str] = {} - node_duplicates: list[tuple[EntityNode, EntityNode]] = [] for resolution in node_resolutions: - resolution_id: int = resolution.id + relative_id: int = resolution.id duplicate_idx: int = resolution.duplicate_idx - extracted_node = extracted_nodes[resolution_id] + original_index = state.unresolved_indices[relative_id] + extracted_node = extracted_nodes[original_index] resolved_node = ( - existing_nodes[duplicate_idx] - if 0 <= duplicate_idx < len(existing_nodes) + indexes.existing_nodes[duplicate_idx] + if 0 <= duplicate_idx < len(indexes.existing_nodes) else extracted_node ) - # resolved_node.name = resolution.get('name') + state.resolved_nodes[original_index] = resolved_node + state.uuid_map[extracted_node.uuid] = resolved_node.uuid - resolved_nodes.append(resolved_node) - uuid_map[extracted_node.uuid] = resolved_node.uuid - logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') +async def resolve_extracted_nodes( + clients: GraphitiClients, + extracted_nodes: list[EntityNode], + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, + existing_nodes_override: list[EntityNode] | None = None, +) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]: + """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt.""" + llm_client = clients.llm_client + driver = clients.driver + existing_nodes = await _collect_candidate_nodes( + clients, + extracted_nodes, + existing_nodes_override, + ) + + indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes) + + state = DedupResolutionState( + resolved_nodes=[None] * len(extracted_nodes), + uuid_map={}, + unresolved_indices=[], + ) + node_duplicates: list[tuple[EntityNode, EntityNode]] = [] + + _resolve_with_similarity(extracted_nodes, indexes, state) + + await _resolve_with_llm( + llm_client, + extracted_nodes, + indexes, + state, + clients.ensure_ascii, + episode, + previous_episodes, + entity_types, + ) + + for idx, node in enumerate(extracted_nodes): + if state.resolved_nodes[idx] is None: + state.resolved_nodes[idx] = node + state.uuid_map[node.uuid] = node.uuid + + logger.debug( + 'Resolved nodes: %s', + [(node.name, node.uuid) for node in state.resolved_nodes if node is not None], + ) new_node_duplicates: list[ tuple[EntityNode, EntityNode] ] = await filter_existing_duplicate_of_edges(driver, node_duplicates) - return resolved_nodes, uuid_map, new_node_duplicates + return ( + [node for node in state.resolved_nodes if node is not None], + state.uuid_map, + new_node_duplicates, + ) async def extract_attributes_from_nodes( @@ -307,9 +384,11 @@ async def extract_attributes_from_nodes( node, episode, previous_episodes, - entity_types.get(next((item for item in node.labels if item != 'Entity'), '')) - if entity_types is not None - else None, + ( + entity_types.get(next((item for item in node.labels if item != 'Entity'), '')) + if entity_types is not None + else None + ), clients.ensure_ascii, ) for node in nodes @@ -339,18 +418,18 @@ async def extract_attributes_from_node( attributes_context: dict[str, Any] = { 'node': node_context, 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [], + 'previous_episodes': ( + [ep.content for ep in previous_episodes] if previous_episodes is not None else [] + ), 'ensure_ascii': ensure_ascii, } summary_context: dict[str, Any] = { 'node': node_context, 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [], + 'previous_episodes': ( + [ep.content for ep in previous_episodes] if previous_episodes is not None else [] + ), 'ensure_ascii': ensure_ascii, } diff --git a/tests/test_edge_int.py b/tests/test_edge_int.py index 15555d72..b028e452 100644 --- a/tests/test_edge_int.py +++ b/tests/test_edge_int.py @@ -26,6 +26,7 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, Episodic from tests.helpers_test import get_edge_count, get_node_count, group_id pytest_plugins = ('pytest_asyncio',) +pytestmark = pytest.mark.integration def setup_logging(): diff --git a/tests/test_node_int.py b/tests/test_node_int.py index 7e73b856..bdb8add8 100644 --- a/tests/test_node_int.py +++ b/tests/test_node_int.py @@ -33,6 +33,8 @@ from tests.helpers_test import ( group_id, ) +pytestmark = pytest.mark.integration + created_at = datetime.now() deleted_at = created_at + timedelta(days=3) valid_at = created_at + timedelta(days=1) diff --git a/tests/utils/maintenance/test_edge_operations.py b/tests/utils/maintenance/test_edge_operations.py index cdb1de9f..3d5e0433 100644 --- a/tests/utils/maintenance/test_edge_operations.py +++ b/tests/utils/maintenance/test_edge_operations.py @@ -5,6 +5,7 @@ import pytest from graphiti_core.edges import EntityEdge from graphiti_core.nodes import EpisodicNode +from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge @pytest.fixture @@ -92,3 +93,52 @@ def mock_previous_episodes(): # Run the tests if __name__ == '__main__': pytest.main([__file__]) + + +@pytest.mark.asyncio +async def test_resolve_extracted_edge_exact_fact_short_circuit( + mock_llm_client, + mock_existing_edges, + mock_current_episode, +): + extracted = EntityEdge( + source_node_uuid='source_uuid', + target_node_uuid='target_uuid', + name='test_edge', + group_id='group_1', + fact='Related fact', + episodes=['episode_1'], + created_at=datetime.now(timezone.utc), + valid_at=None, + invalid_at=None, + ) + + related_edges = [ + EntityEdge( + source_node_uuid='source_uuid', + target_node_uuid='target_uuid', + name='related_edge', + group_id='group_1', + fact=' related FACT ', + episodes=['episode_2'], + created_at=datetime.now(timezone.utc) - timedelta(days=1), + valid_at=None, + invalid_at=None, + ) + ] + + resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge( + mock_llm_client, + extracted, + related_edges, + mock_existing_edges, + mock_current_episode, + edge_types=None, + ensure_ascii=True, + ) + + assert resolved_edge is related_edges[0] + assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1 + assert duplicate_edges == [] + assert invalidated == [] + mock_llm_client.generate_response.assert_not_called() diff --git a/tests/utils/maintenance/test_node_operations.py b/tests/utils/maintenance/test_node_operations.py new file mode 100644 index 00000000..a7250559 --- /dev/null +++ b/tests/utils/maintenance/test_node_operations.py @@ -0,0 +1,341 @@ +from collections import defaultdict +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from graphiti_core.graphiti_types import GraphitiClients +from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.search.search_config import SearchResults +from graphiti_core.utils.datetime_utils import utc_now +from graphiti_core.utils.maintenance.dedup_helpers import ( + DedupCandidateIndexes, + DedupResolutionState, + _build_candidate_indexes, + _cached_shingles, + _has_high_entropy, + _hash_shingle, + _jaccard_similarity, + _lsh_bands, + _minhash_signature, + _name_entropy, + _normalize_name_for_fuzzy, + _normalize_string_exact, + _resolve_with_similarity, + _shingles, +) +from graphiti_core.utils.maintenance.node_operations import ( + _collect_candidate_nodes, + _resolve_with_llm, + resolve_extracted_nodes, +) + + +def _make_clients(): + driver = MagicMock() + embedder = MagicMock() + cross_encoder = MagicMock() + llm_client = MagicMock() + llm_generate = AsyncMock() + llm_client.generate_response = llm_generate + + clients = GraphitiClients.model_construct( # bypass validation to allow test doubles + driver=driver, + embedder=embedder, + cross_encoder=cross_encoder, + llm_client=llm_client, + ensure_ascii=False, + ) + + return clients, llm_generate + + +def _make_episode(group_id: str = 'group'): + return EpisodicNode( + name='episode', + group_id=group_id, + source=EpisodeType.message, + source_description='test', + content='content', + valid_at=utc_now(), + ) + + +@pytest.mark.asyncio +async def test_resolve_nodes_exact_match_skips_llm(monkeypatch): + clients, llm_generate = _make_clients() + + candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity']) + extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity']) + + async def fake_search(*_, **__): + return SearchResults(nodes=[candidate]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.search', + fake_search, + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges', + AsyncMock(return_value=[]), + ) + + resolved, uuid_map, _ = await resolve_extracted_nodes( + clients, + [extracted], + episode=_make_episode(), + previous_episodes=[], + ) + + assert resolved[0].uuid == candidate.uuid + assert uuid_map[extracted.uuid] == candidate.uuid + llm_generate.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch): + clients, llm_generate = _make_clients() + llm_generate.return_value = { + 'entity_resolutions': [ + { + 'id': 0, + 'duplicate_idx': -1, + 'name': 'Joe', + 'duplicates': [], + } + ] + } + + extracted = EntityNode(name='Joe', group_id='group', labels=['Entity']) + + async def fake_search(*_, **__): + return SearchResults(nodes=[]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.search', + fake_search, + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges', + AsyncMock(return_value=[]), + ) + + resolved, uuid_map, _ = await resolve_extracted_nodes( + clients, + [extracted], + episode=_make_episode(), + previous_episodes=[], + ) + + assert resolved[0].uuid == extracted.uuid + assert uuid_map[extracted.uuid] == extracted.uuid + llm_generate.assert_awaited() + + +@pytest.mark.asyncio +async def test_resolve_nodes_fuzzy_match(monkeypatch): + clients, llm_generate = _make_clients() + + candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity']) + extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity']) + + async def fake_search(*_, **__): + return SearchResults(nodes=[candidate]) + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.search', + fake_search, + ) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges', + AsyncMock(return_value=[]), + ) + + resolved, uuid_map, _ = await resolve_extracted_nodes( + clients, + [extracted], + episode=_make_episode(), + previous_episodes=[], + ) + + assert resolved[0].uuid == candidate.uuid + assert uuid_map[extracted.uuid] == candidate.uuid + llm_generate.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch): + clients, _ = _make_clients() + + candidate = EntityNode(name='Alice', group_id='group', labels=['Entity']) + override_duplicate = EntityNode( + uuid=candidate.uuid, + name='Alice Alt', + group_id='group', + labels=['Entity'], + ) + extracted = EntityNode(name='Alice', group_id='group', labels=['Entity']) + + search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate])) + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.search', + search_mock, + ) + + result = await _collect_candidate_nodes( + clients, + [extracted], + existing_nodes_override=[override_duplicate], + ) + + assert len(result) == 1 + assert result[0].uuid == candidate.uuid + search_mock.assert_awaited() + + +def test_build_candidate_indexes_populates_structures(): + candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([candidate]) + + normalized_key = candidate.name.lower() + assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid + assert indexes.nodes_by_uuid[candidate.uuid] is candidate + assert candidate.uuid in indexes.shingles_by_candidate + assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values()) + + +def test_normalize_helpers(): + assert _normalize_string_exact(' Alice Smith ') == 'alice smith' + assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith' + + +def test_name_entropy_variants(): + assert _name_entropy('alice') > _name_entropy('aaaaa') + assert _name_entropy('') == 0.0 + + +def test_has_high_entropy_rules(): + assert _has_high_entropy('meaningful name') is True + assert _has_high_entropy('aa') is False + + +def test_shingles_and_cache(): + raw = 'alice' + shingle_set = _shingles(raw) + assert shingle_set == {'ali', 'lic', 'ice'} + assert _cached_shingles(raw) == shingle_set + assert _cached_shingles(raw) is _cached_shingles(raw) + + +def test_hash_minhash_and_lsh(): + shingles = {'abc', 'bcd', 'cde'} + signature = _minhash_signature(shingles) + assert len(signature) == 32 + bands = _lsh_bands(signature) + assert all(len(band) == 4 for band in bands) + hashed = {_hash_shingle(s, 0) for s in shingles} + assert len(hashed) == len(shingles) + + +def test_jaccard_similarity_edges(): + a = {'a', 'b'} + b = {'a', 'c'} + assert _jaccard_similarity(a, b) == pytest.approx(1 / 3) + assert _jaccard_similarity(set(), set()) == 1.0 + assert _jaccard_similarity(a, set()) == 0.0 + + +def test_resolve_with_similarity_exact_match_updates_state(): + candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity']) + extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([candidate]) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[]) + + _resolve_with_similarity([extracted], indexes, state) + + assert state.resolved_nodes[0].uuid == candidate.uuid + assert state.uuid_map[extracted.uuid] == candidate.uuid + assert state.unresolved_indices == [] + + +def test_resolve_with_similarity_low_entropy_defers_resolution(): + extracted = EntityNode(name='Bob', group_id='group', labels=['Entity']) + indexes = DedupCandidateIndexes( + existing_nodes=[], + nodes_by_uuid={}, + normalized_existing=defaultdict(list), + shingles_by_candidate={}, + lsh_buckets=defaultdict(list), + ) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[]) + + _resolve_with_similarity([extracted], indexes, state) + + assert state.resolved_nodes[0] is None + assert state.unresolved_indices == [0] + + +def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm(): + candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity']) + candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity']) + extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([candidate1, candidate2]) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[]) + + _resolve_with_similarity([extracted], indexes, state) + + assert state.resolved_nodes[0] is None + assert state.unresolved_indices == [0] + + +@pytest.mark.asyncio +async def test_resolve_with_llm_updates_unresolved(monkeypatch): + extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity']) + candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity']) + + indexes = _build_candidate_indexes([candidate]) + state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0]) + + captured_context = {} + + def fake_prompt_nodes(context): + captured_context.update(context) + return ['prompt'] + + monkeypatch.setattr( + 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes', + fake_prompt_nodes, + ) + + async def fake_generate_response(*_, **__): + return { + 'entity_resolutions': [ + { + 'id': 0, + 'duplicate_idx': 0, + 'name': 'Dizzy Gillespie', + 'duplicates': [0], + } + ] + } + + llm_client = MagicMock() + llm_client.generate_response = AsyncMock(side_effect=fake_generate_response) + + await _resolve_with_llm( + llm_client, + [extracted], + indexes, + state, + ensure_ascii=False, + episode=_make_episode(), + previous_episodes=[], + entity_types=None, + ) + + assert state.resolved_nodes[0].uuid == candidate.uuid + assert state.uuid_map[extracted.uuid] == candidate.uuid + assert captured_context['existing_nodes'][0]['idx'] == 0 + assert isinstance(captured_context['existing_nodes'], list)