Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929)
* Simplify deduplication process in bulk_utils by reusing canonical nodes. * Update dedup_helpers to store duplicate pairs during resolution. * Modify node_operations to append duplicate pairs when resolving nodes. * Add tests to verify deduplication behavior and ensure correct state updates.
This commit is contained in:
parent
1e56019027
commit
18550c84ac
5 changed files with 155 additions and 75 deletions
|
|
@ -43,7 +43,7 @@ from graphiti_core.models.nodes.node_db_queries import (
|
|||
get_entity_node_save_bulk_query,
|
||||
get_episode_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
|
|
@ -266,83 +266,66 @@ async def dedupe_nodes_bulk(
|
|||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
||||
embedder = clients.embedder
|
||||
min_score = 0.8
|
||||
|
||||
# generate embeddings
|
||||
await semaphore_gather(
|
||||
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
|
||||
)
|
||||
|
||||
# Find similar results
|
||||
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
|
||||
for i, nodes_i in enumerate(extracted_nodes):
|
||||
existing_nodes: list[EntityNode] = []
|
||||
for j, nodes_j in enumerate(extracted_nodes):
|
||||
if i == j:
|
||||
continue
|
||||
existing_nodes += nodes_j
|
||||
|
||||
candidates_i: list[EntityNode] = []
|
||||
for node in nodes_i:
|
||||
for existing_node in existing_nodes:
|
||||
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
||||
# This approach will cast a wider net than BM25, which is ideal for this use case
|
||||
node_words = set(node.name.lower().split())
|
||||
existing_node_words = set(existing_node.name.lower().split())
|
||||
has_overlap = not node_words.isdisjoint(existing_node_words)
|
||||
if has_overlap:
|
||||
candidates_i.append(existing_node)
|
||||
continue
|
||||
|
||||
# Check for semantic similarity even if there is no overlap
|
||||
similarity = np.dot(
|
||||
normalize_l2(node.name_embedding or []),
|
||||
normalize_l2(existing_node.name_embedding or []),
|
||||
)
|
||||
if similarity >= min_score:
|
||||
candidates_i.append(existing_node)
|
||||
|
||||
dedupe_tuples.append((nodes_i, candidates_i))
|
||||
|
||||
# Determine Node Resolutions
|
||||
bulk_node_resolutions: list[
|
||||
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
||||
] = await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_nodes(
|
||||
clients,
|
||||
dedupe_tuple[0],
|
||||
episode_tuples[i][0],
|
||||
episode_tuples[i][1],
|
||||
entity_types,
|
||||
existing_nodes_override=dedupe_tuples[i][1],
|
||||
)
|
||||
for i, dedupe_tuple in enumerate(dedupe_tuples)
|
||||
]
|
||||
)
|
||||
|
||||
# Collect all duplicate pairs sorted by uuid
|
||||
canonical_nodes: dict[str, EntityNode] = {}
|
||||
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
|
||||
per_episode_uuid_maps: list[dict[str, str]] = []
|
||||
duplicate_pairs: list[tuple[str, str]] = []
|
||||
for _, _, duplicates in bulk_node_resolutions:
|
||||
for duplicate in duplicates:
|
||||
n, m = duplicate
|
||||
duplicate_pairs.append((n.uuid, m.uuid))
|
||||
|
||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
||||
for nodes, (episode, previous_episodes) in zip(extracted_nodes, episode_tuples, strict=True):
|
||||
existing_override = list(canonical_nodes.values()) if canonical_nodes else None
|
||||
|
||||
node_uuid_map: dict[str, EntityNode] = {
|
||||
node.uuid: node for nodes in extracted_nodes for node in nodes
|
||||
}
|
||||
resolved_nodes, uuid_map, duplicates = await resolve_extracted_nodes(
|
||||
clients,
|
||||
nodes,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
existing_nodes_override=existing_override,
|
||||
)
|
||||
|
||||
per_episode_uuid_maps.append(uuid_map)
|
||||
episode_resolutions.append((episode.uuid, resolved_nodes))
|
||||
|
||||
for node in resolved_nodes:
|
||||
canonical_nodes[node.uuid] = node
|
||||
|
||||
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
|
||||
|
||||
uuid_chains: dict[str, str] = {}
|
||||
for uuid_map in per_episode_uuid_maps:
|
||||
uuid_chains.update(uuid_map)
|
||||
for duplicate_uuid, canonical_uuid in duplicate_pairs:
|
||||
uuid_chains[duplicate_uuid] = canonical_uuid
|
||||
|
||||
def _resolve_uuid(uuid: str) -> str:
|
||||
seen: set[str] = set()
|
||||
current = uuid
|
||||
while True:
|
||||
target = uuid_chains.get(current)
|
||||
if target is None or target == current:
|
||||
return current if target is None else target
|
||||
if current in seen:
|
||||
return current
|
||||
seen.add(current)
|
||||
current = target
|
||||
|
||||
compressed_map: dict[str, str] = {uuid: _resolve_uuid(uuid) for uuid in uuid_chains}
|
||||
for canonical_uuid in canonical_nodes:
|
||||
compressed_map.setdefault(canonical_uuid, canonical_uuid)
|
||||
|
||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||
for i, nodes in enumerate(extracted_nodes):
|
||||
episode = episode_tuples[i][0]
|
||||
for episode_uuid, resolved_nodes in episode_resolutions:
|
||||
deduped_nodes: list[EntityNode] = []
|
||||
seen: set[str] = set()
|
||||
for node in resolved_nodes:
|
||||
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
|
||||
if canonical_uuid in seen:
|
||||
continue
|
||||
seen.add(canonical_uuid)
|
||||
canonical_node = canonical_nodes.get(canonical_uuid, node)
|
||||
deduped_nodes.append(canonical_node)
|
||||
|
||||
nodes_by_episode[episode.uuid] = [
|
||||
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
||||
]
|
||||
nodes_by_episode[episode_uuid] = deduped_nodes
|
||||
|
||||
return nodes_by_episode, compressed_map
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import math
|
|||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from hashlib import blake2b
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -164,6 +164,7 @@ class DedupResolutionState:
|
|||
resolved_nodes: list[EntityNode | None]
|
||||
uuid_map: dict[str, str]
|
||||
unresolved_indices: list[int]
|
||||
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
|
||||
|
|
@ -213,6 +214,8 @@ def _resolve_with_similarity(
|
|||
match = existing_matches[0]
|
||||
state.resolved_nodes[idx] = match
|
||||
state.uuid_map[node.uuid] = match.uuid
|
||||
if match.uuid != node.uuid:
|
||||
state.duplicate_pairs.append((node, match))
|
||||
continue
|
||||
if len(existing_matches) > 1:
|
||||
state.unresolved_indices.append(idx)
|
||||
|
|
@ -236,6 +239,8 @@ def _resolve_with_similarity(
|
|||
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
|
||||
state.resolved_nodes[idx] = best_candidate
|
||||
state.uuid_map[node.uuid] = best_candidate.uuid
|
||||
if best_candidate.uuid != node.uuid:
|
||||
state.duplicate_pairs.append((node, best_candidate))
|
||||
continue
|
||||
|
||||
state.unresolved_indices.append(idx)
|
||||
|
|
|
|||
|
|
@ -306,6 +306,8 @@ async def _resolve_with_llm(
|
|||
|
||||
state.resolved_nodes[original_index] = resolved_node
|
||||
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||
if resolved_node.uuid != extracted_node.uuid:
|
||||
state.duplicate_pairs.append((extracted_node, resolved_node))
|
||||
|
||||
|
||||
async def resolve_extracted_nodes(
|
||||
|
|
@ -332,7 +334,6 @@ async def resolve_extracted_nodes(
|
|||
uuid_map={},
|
||||
unresolved_indices=[],
|
||||
)
|
||||
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
||||
|
||||
_resolve_with_similarity(extracted_nodes, indexes, state)
|
||||
|
||||
|
|
@ -359,7 +360,7 @@ async def resolve_extracted_nodes(
|
|||
|
||||
new_node_duplicates: list[
|
||||
tuple[EntityNode, EntityNode]
|
||||
] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
|
||||
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
|
||||
|
||||
return (
|
||||
[node for node in state.resolved_nodes if node is not None],
|
||||
|
|
|
|||
87
tests/utils/maintenance/test_bulk_utils.py
Normal file
87
tests/utils/maintenance/test_bulk_utils.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
from collections import deque
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils import bulk_utils
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
|
||||
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
|
||||
return EpisodicNode(
|
||||
name=f'episode-{uuid_suffix}',
|
||||
group_id=group_id,
|
||||
labels=[],
|
||||
source=EpisodeType.message,
|
||||
content='content',
|
||||
source_description='test',
|
||||
created_at=utc_now(),
|
||||
valid_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
def _make_clients() -> GraphitiClients:
|
||||
driver = MagicMock()
|
||||
embedder = MagicMock()
|
||||
cross_encoder = MagicMock()
|
||||
llm_client = MagicMock()
|
||||
|
||||
return GraphitiClients.model_construct( # bypass validation to allow test doubles
|
||||
driver=driver,
|
||||
embedder=embedder,
|
||||
cross_encoder=cross_encoder,
|
||||
llm_client=llm_client,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
|
||||
clients = _make_clients()
|
||||
|
||||
episode_one = _make_episode('1')
|
||||
episode_two = _make_episode('2')
|
||||
|
||||
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
|
||||
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
|
||||
|
||||
canonical = extracted_one
|
||||
|
||||
call_queue = deque()
|
||||
|
||||
async def fake_resolve(
|
||||
clients_arg,
|
||||
nodes_arg,
|
||||
episode_arg,
|
||||
previous_episodes_arg,
|
||||
entity_types_arg,
|
||||
existing_nodes_override=None,
|
||||
):
|
||||
call_queue.append(existing_nodes_override)
|
||||
|
||||
if nodes_arg == [extracted_one]:
|
||||
return [canonical], {canonical.uuid: canonical.uuid}, []
|
||||
|
||||
assert nodes_arg == [extracted_two]
|
||||
assert existing_nodes_override is not None
|
||||
assert existing_nodes_override[0] is canonical
|
||||
|
||||
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
|
||||
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
|
||||
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[[extracted_one], [extracted_two]],
|
||||
[(episode_one, []), (episode_two, [])],
|
||||
)
|
||||
|
||||
assert len(call_queue) == 2
|
||||
assert call_queue[0] is None
|
||||
assert list(call_queue[1]) == [canonical]
|
||||
|
||||
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
|
||||
|
|
@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state():
|
|||
assert state.resolved_nodes[0].uuid == candidate.uuid
|
||||
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
||||
assert state.unresolved_indices == []
|
||||
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||
|
||||
|
||||
def test_resolve_with_similarity_low_entropy_defers_resolution():
|
||||
|
|
@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution():
|
|||
|
||||
assert state.resolved_nodes[0] is None
|
||||
assert state.unresolved_indices == [0]
|
||||
assert state.duplicate_pairs == []
|
||||
|
||||
|
||||
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
|
||||
|
|
@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
|
|||
|
||||
assert state.resolved_nodes[0] is None
|
||||
assert state.unresolved_indices == [0]
|
||||
assert state.duplicate_pairs == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
|
|||
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
||||
assert captured_context['existing_nodes'][0]['idx'] == 0
|
||||
assert isinstance(captured_context['existing_nodes'], list)
|
||||
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue