Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929)
* Simplify deduplication process in bulk_utils by reusing canonical nodes. * Update dedup_helpers to store duplicate pairs during resolution. * Modify node_operations to append duplicate pairs when resolving nodes. * Add tests to verify deduplication behavior and ensure correct state updates.
This commit is contained in:
parent
1e56019027
commit
18550c84ac
5 changed files with 155 additions and 75 deletions
|
|
@ -43,7 +43,7 @@ from graphiti_core.models.nodes.node_db_queries import (
|
||||||
get_entity_node_save_bulk_query,
|
get_entity_node_save_bulk_query,
|
||||||
get_episode_node_save_bulk_query,
|
get_episode_node_save_bulk_query,
|
||||||
)
|
)
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
extract_edges,
|
extract_edges,
|
||||||
|
|
@ -266,83 +266,66 @@ async def dedupe_nodes_bulk(
|
||||||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||||
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
||||||
embedder = clients.embedder
|
canonical_nodes: dict[str, EntityNode] = {}
|
||||||
min_score = 0.8
|
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
|
||||||
|
per_episode_uuid_maps: list[dict[str, str]] = []
|
||||||
# generate embeddings
|
|
||||||
await semaphore_gather(
|
|
||||||
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find similar results
|
|
||||||
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
|
|
||||||
for i, nodes_i in enumerate(extracted_nodes):
|
|
||||||
existing_nodes: list[EntityNode] = []
|
|
||||||
for j, nodes_j in enumerate(extracted_nodes):
|
|
||||||
if i == j:
|
|
||||||
continue
|
|
||||||
existing_nodes += nodes_j
|
|
||||||
|
|
||||||
candidates_i: list[EntityNode] = []
|
|
||||||
for node in nodes_i:
|
|
||||||
for existing_node in existing_nodes:
|
|
||||||
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
||||||
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
||||||
node_words = set(node.name.lower().split())
|
|
||||||
existing_node_words = set(existing_node.name.lower().split())
|
|
||||||
has_overlap = not node_words.isdisjoint(existing_node_words)
|
|
||||||
if has_overlap:
|
|
||||||
candidates_i.append(existing_node)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check for semantic similarity even if there is no overlap
|
|
||||||
similarity = np.dot(
|
|
||||||
normalize_l2(node.name_embedding or []),
|
|
||||||
normalize_l2(existing_node.name_embedding or []),
|
|
||||||
)
|
|
||||||
if similarity >= min_score:
|
|
||||||
candidates_i.append(existing_node)
|
|
||||||
|
|
||||||
dedupe_tuples.append((nodes_i, candidates_i))
|
|
||||||
|
|
||||||
# Determine Node Resolutions
|
|
||||||
bulk_node_resolutions: list[
|
|
||||||
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
|
||||||
] = await semaphore_gather(
|
|
||||||
*[
|
|
||||||
resolve_extracted_nodes(
|
|
||||||
clients,
|
|
||||||
dedupe_tuple[0],
|
|
||||||
episode_tuples[i][0],
|
|
||||||
episode_tuples[i][1],
|
|
||||||
entity_types,
|
|
||||||
existing_nodes_override=dedupe_tuples[i][1],
|
|
||||||
)
|
|
||||||
for i, dedupe_tuple in enumerate(dedupe_tuples)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect all duplicate pairs sorted by uuid
|
|
||||||
duplicate_pairs: list[tuple[str, str]] = []
|
duplicate_pairs: list[tuple[str, str]] = []
|
||||||
for _, _, duplicates in bulk_node_resolutions:
|
|
||||||
for duplicate in duplicates:
|
|
||||||
n, m = duplicate
|
|
||||||
duplicate_pairs.append((n.uuid, m.uuid))
|
|
||||||
|
|
||||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
for nodes, (episode, previous_episodes) in zip(extracted_nodes, episode_tuples, strict=True):
|
||||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
existing_override = list(canonical_nodes.values()) if canonical_nodes else None
|
||||||
|
|
||||||
node_uuid_map: dict[str, EntityNode] = {
|
resolved_nodes, uuid_map, duplicates = await resolve_extracted_nodes(
|
||||||
node.uuid: node for nodes in extracted_nodes for node in nodes
|
clients,
|
||||||
}
|
nodes,
|
||||||
|
episode,
|
||||||
|
previous_episodes,
|
||||||
|
entity_types,
|
||||||
|
existing_nodes_override=existing_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_episode_uuid_maps.append(uuid_map)
|
||||||
|
episode_resolutions.append((episode.uuid, resolved_nodes))
|
||||||
|
|
||||||
|
for node in resolved_nodes:
|
||||||
|
canonical_nodes[node.uuid] = node
|
||||||
|
|
||||||
|
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
|
||||||
|
|
||||||
|
uuid_chains: dict[str, str] = {}
|
||||||
|
for uuid_map in per_episode_uuid_maps:
|
||||||
|
uuid_chains.update(uuid_map)
|
||||||
|
for duplicate_uuid, canonical_uuid in duplicate_pairs:
|
||||||
|
uuid_chains[duplicate_uuid] = canonical_uuid
|
||||||
|
|
||||||
|
def _resolve_uuid(uuid: str) -> str:
|
||||||
|
seen: set[str] = set()
|
||||||
|
current = uuid
|
||||||
|
while True:
|
||||||
|
target = uuid_chains.get(current)
|
||||||
|
if target is None or target == current:
|
||||||
|
return current if target is None else target
|
||||||
|
if current in seen:
|
||||||
|
return current
|
||||||
|
seen.add(current)
|
||||||
|
current = target
|
||||||
|
|
||||||
|
compressed_map: dict[str, str] = {uuid: _resolve_uuid(uuid) for uuid in uuid_chains}
|
||||||
|
for canonical_uuid in canonical_nodes:
|
||||||
|
compressed_map.setdefault(canonical_uuid, canonical_uuid)
|
||||||
|
|
||||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||||
for i, nodes in enumerate(extracted_nodes):
|
for episode_uuid, resolved_nodes in episode_resolutions:
|
||||||
episode = episode_tuples[i][0]
|
deduped_nodes: list[EntityNode] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for node in resolved_nodes:
|
||||||
|
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
|
||||||
|
if canonical_uuid in seen:
|
||||||
|
continue
|
||||||
|
seen.add(canonical_uuid)
|
||||||
|
canonical_node = canonical_nodes.get(canonical_uuid, node)
|
||||||
|
deduped_nodes.append(canonical_node)
|
||||||
|
|
||||||
nodes_by_episode[episode.uuid] = [
|
nodes_by_episode[episode_uuid] = deduped_nodes
|
||||||
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
return nodes_by_episode, compressed_map
|
return nodes_by_episode, compressed_map
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ import math
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from hashlib import blake2b
|
from hashlib import blake2b
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
@ -164,6 +164,7 @@ class DedupResolutionState:
|
||||||
resolved_nodes: list[EntityNode | None]
|
resolved_nodes: list[EntityNode | None]
|
||||||
uuid_map: dict[str, str]
|
uuid_map: dict[str, str]
|
||||||
unresolved_indices: list[int]
|
unresolved_indices: list[int]
|
||||||
|
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
|
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
|
||||||
|
|
@ -213,6 +214,8 @@ def _resolve_with_similarity(
|
||||||
match = existing_matches[0]
|
match = existing_matches[0]
|
||||||
state.resolved_nodes[idx] = match
|
state.resolved_nodes[idx] = match
|
||||||
state.uuid_map[node.uuid] = match.uuid
|
state.uuid_map[node.uuid] = match.uuid
|
||||||
|
if match.uuid != node.uuid:
|
||||||
|
state.duplicate_pairs.append((node, match))
|
||||||
continue
|
continue
|
||||||
if len(existing_matches) > 1:
|
if len(existing_matches) > 1:
|
||||||
state.unresolved_indices.append(idx)
|
state.unresolved_indices.append(idx)
|
||||||
|
|
@ -236,6 +239,8 @@ def _resolve_with_similarity(
|
||||||
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
|
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
|
||||||
state.resolved_nodes[idx] = best_candidate
|
state.resolved_nodes[idx] = best_candidate
|
||||||
state.uuid_map[node.uuid] = best_candidate.uuid
|
state.uuid_map[node.uuid] = best_candidate.uuid
|
||||||
|
if best_candidate.uuid != node.uuid:
|
||||||
|
state.duplicate_pairs.append((node, best_candidate))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
state.unresolved_indices.append(idx)
|
state.unresolved_indices.append(idx)
|
||||||
|
|
|
||||||
|
|
@ -306,6 +306,8 @@ async def _resolve_with_llm(
|
||||||
|
|
||||||
state.resolved_nodes[original_index] = resolved_node
|
state.resolved_nodes[original_index] = resolved_node
|
||||||
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||||
|
if resolved_node.uuid != extracted_node.uuid:
|
||||||
|
state.duplicate_pairs.append((extracted_node, resolved_node))
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_nodes(
|
async def resolve_extracted_nodes(
|
||||||
|
|
@ -332,7 +334,6 @@ async def resolve_extracted_nodes(
|
||||||
uuid_map={},
|
uuid_map={},
|
||||||
unresolved_indices=[],
|
unresolved_indices=[],
|
||||||
)
|
)
|
||||||
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
|
||||||
|
|
||||||
_resolve_with_similarity(extracted_nodes, indexes, state)
|
_resolve_with_similarity(extracted_nodes, indexes, state)
|
||||||
|
|
||||||
|
|
@ -359,7 +360,7 @@ async def resolve_extracted_nodes(
|
||||||
|
|
||||||
new_node_duplicates: list[
|
new_node_duplicates: list[
|
||||||
tuple[EntityNode, EntityNode]
|
tuple[EntityNode, EntityNode]
|
||||||
] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
|
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
[node for node in state.resolved_nodes if node is not None],
|
[node for node in state.resolved_nodes if node is not None],
|
||||||
|
|
|
||||||
87
tests/utils/maintenance/test_bulk_utils.py
Normal file
87
tests/utils/maintenance/test_bulk_utils.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
from collections import deque
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
|
from graphiti_core.utils import bulk_utils
|
||||||
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
|
||||||
|
|
||||||
|
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
|
||||||
|
return EpisodicNode(
|
||||||
|
name=f'episode-{uuid_suffix}',
|
||||||
|
group_id=group_id,
|
||||||
|
labels=[],
|
||||||
|
source=EpisodeType.message,
|
||||||
|
content='content',
|
||||||
|
source_description='test',
|
||||||
|
created_at=utc_now(),
|
||||||
|
valid_at=utc_now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_clients() -> GraphitiClients:
|
||||||
|
driver = MagicMock()
|
||||||
|
embedder = MagicMock()
|
||||||
|
cross_encoder = MagicMock()
|
||||||
|
llm_client = MagicMock()
|
||||||
|
|
||||||
|
return GraphitiClients.model_construct( # bypass validation to allow test doubles
|
||||||
|
driver=driver,
|
||||||
|
embedder=embedder,
|
||||||
|
cross_encoder=cross_encoder,
|
||||||
|
llm_client=llm_client,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
|
||||||
|
clients = _make_clients()
|
||||||
|
|
||||||
|
episode_one = _make_episode('1')
|
||||||
|
episode_two = _make_episode('2')
|
||||||
|
|
||||||
|
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
|
||||||
|
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
|
||||||
|
|
||||||
|
canonical = extracted_one
|
||||||
|
|
||||||
|
call_queue = deque()
|
||||||
|
|
||||||
|
async def fake_resolve(
|
||||||
|
clients_arg,
|
||||||
|
nodes_arg,
|
||||||
|
episode_arg,
|
||||||
|
previous_episodes_arg,
|
||||||
|
entity_types_arg,
|
||||||
|
existing_nodes_override=None,
|
||||||
|
):
|
||||||
|
call_queue.append(existing_nodes_override)
|
||||||
|
|
||||||
|
if nodes_arg == [extracted_one]:
|
||||||
|
return [canonical], {canonical.uuid: canonical.uuid}, []
|
||||||
|
|
||||||
|
assert nodes_arg == [extracted_two]
|
||||||
|
assert existing_nodes_override is not None
|
||||||
|
assert existing_nodes_override[0] is canonical
|
||||||
|
|
||||||
|
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
|
||||||
|
|
||||||
|
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
|
||||||
|
|
||||||
|
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||||
|
clients,
|
||||||
|
[[extracted_one], [extracted_two]],
|
||||||
|
[(episode_one, []), (episode_two, [])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(call_queue) == 2
|
||||||
|
assert call_queue[0] is None
|
||||||
|
assert list(call_queue[1]) == [canonical]
|
||||||
|
|
||||||
|
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||||
|
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||||
|
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
|
||||||
|
|
@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state():
|
||||||
assert state.resolved_nodes[0].uuid == candidate.uuid
|
assert state.resolved_nodes[0].uuid == candidate.uuid
|
||||||
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
||||||
assert state.unresolved_indices == []
|
assert state.unresolved_indices == []
|
||||||
|
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_with_similarity_low_entropy_defers_resolution():
|
def test_resolve_with_similarity_low_entropy_defers_resolution():
|
||||||
|
|
@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution():
|
||||||
|
|
||||||
assert state.resolved_nodes[0] is None
|
assert state.resolved_nodes[0] is None
|
||||||
assert state.unresolved_indices == [0]
|
assert state.unresolved_indices == [0]
|
||||||
|
assert state.duplicate_pairs == []
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
|
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
|
||||||
|
|
@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
|
||||||
|
|
||||||
assert state.resolved_nodes[0] is None
|
assert state.resolved_nodes[0] is None
|
||||||
assert state.unresolved_indices == [0]
|
assert state.unresolved_indices == [0]
|
||||||
|
assert state.duplicate_pairs == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
|
||||||
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
||||||
assert captured_context['existing_nodes'][0]['idx'] == 0
|
assert captured_context['existing_nodes'][0]['idx'] == 0
|
||||||
assert isinstance(captured_context['existing_nodes'], list)
|
assert isinstance(captured_context['existing_nodes'], list)
|
||||||
|
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue