Refactor batch deduplication logic to enhance node resolution and track duplicate pairs (#929) (#936)
* Refactor deduplication logic to enhance node resolution and track duplicate pairs (#929) * Simplify deduplication process in bulk_utils by reusing canonical nodes. * Update dedup_helpers to store duplicate pairs during resolution. * Modify node_operations to append duplicate pairs when resolving nodes. * Add tests to verify deduplication behavior and ensure correct state updates. * reveret to concurrent dedup with fanout and then reconcilation * add performance note for deduplication loop in bulk_utils * enhance deduplication logic in bulk_utils to handle missing canonical nodes gracefully * Update graphiti_core/utils/bulk_utils.py Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> * refactor deduplication logic in bulk_utils to use directed union-find for canonical UUID resolution * implement _build_directed_uuid_map for efficient UUID resolution in bulk_utils * document directed union-find lookup in bulk_utils for clarity --------- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
This commit is contained in:
parent
1e56019027
commit
9aee3174bd
5 changed files with 371 additions and 63 deletions
|
|
@ -43,8 +43,14 @@ from graphiti_core.models.nodes.node_db_queries import (
|
|||
get_entity_node_save_bulk_query,
|
||||
get_episode_node_save_bulk_query,
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||||
from graphiti_core.utils.maintenance.dedup_helpers import (
|
||||
DedupResolutionState,
|
||||
_build_candidate_indexes,
|
||||
_normalize_string_exact,
|
||||
_resolve_with_similarity,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
resolve_extracted_edge,
|
||||
|
|
@ -63,6 +69,38 @@ logger = logging.getLogger(__name__)
|
|||
CHUNK_SIZE = 10
|
||||
|
||||
|
||||
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
|
||||
"""Collapse alias -> canonical chains while preserving direction.
|
||||
|
||||
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
|
||||
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
|
||||
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
|
||||
"""
|
||||
|
||||
parent: dict[str, str] = {}
|
||||
|
||||
def find(uuid: str) -> str:
|
||||
"""Directed union-find lookup using iterative path compression."""
|
||||
parent.setdefault(uuid, uuid)
|
||||
root = uuid
|
||||
while parent[root] != root:
|
||||
root = parent[root]
|
||||
|
||||
while parent[uuid] != root:
|
||||
next_uuid = parent[uuid]
|
||||
parent[uuid] = root
|
||||
uuid = next_uuid
|
||||
|
||||
return root
|
||||
|
||||
for source_uuid, target_uuid in pairs:
|
||||
parent.setdefault(source_uuid, source_uuid)
|
||||
parent.setdefault(target_uuid, target_uuid)
|
||||
parent[find(source_uuid)] = find(target_uuid)
|
||||
|
||||
return {uuid: find(uuid) for uuid in parent}
|
||||
|
||||
|
||||
class RawEpisode(BaseModel):
|
||||
name: str
|
||||
uuid: str | None = Field(default=None)
|
||||
|
|
@ -266,83 +304,111 @@ async def dedupe_nodes_bulk(
|
|||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
||||
embedder = clients.embedder
|
||||
min_score = 0.8
|
||||
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
|
||||
|
||||
# generate embeddings
|
||||
await semaphore_gather(
|
||||
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
|
||||
)
|
||||
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
|
||||
reconciled against the live graph just like the non-batch flow.
|
||||
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
|
||||
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
|
||||
can apply to edges and persistence.
|
||||
"""
|
||||
|
||||
# Find similar results
|
||||
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
|
||||
for i, nodes_i in enumerate(extracted_nodes):
|
||||
existing_nodes: list[EntityNode] = []
|
||||
for j, nodes_j in enumerate(extracted_nodes):
|
||||
if i == j:
|
||||
continue
|
||||
existing_nodes += nodes_j
|
||||
|
||||
candidates_i: list[EntityNode] = []
|
||||
for node in nodes_i:
|
||||
for existing_node in existing_nodes:
|
||||
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
||||
# This approach will cast a wider net than BM25, which is ideal for this use case
|
||||
node_words = set(node.name.lower().split())
|
||||
existing_node_words = set(existing_node.name.lower().split())
|
||||
has_overlap = not node_words.isdisjoint(existing_node_words)
|
||||
if has_overlap:
|
||||
candidates_i.append(existing_node)
|
||||
continue
|
||||
|
||||
# Check for semantic similarity even if there is no overlap
|
||||
similarity = np.dot(
|
||||
normalize_l2(node.name_embedding or []),
|
||||
normalize_l2(existing_node.name_embedding or []),
|
||||
)
|
||||
if similarity >= min_score:
|
||||
candidates_i.append(existing_node)
|
||||
|
||||
dedupe_tuples.append((nodes_i, candidates_i))
|
||||
|
||||
# Determine Node Resolutions
|
||||
bulk_node_resolutions: list[
|
||||
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
||||
] = await semaphore_gather(
|
||||
first_pass_results = await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_nodes(
|
||||
clients,
|
||||
dedupe_tuple[0],
|
||||
nodes,
|
||||
episode_tuples[i][0],
|
||||
episode_tuples[i][1],
|
||||
entity_types,
|
||||
existing_nodes_override=dedupe_tuples[i][1],
|
||||
)
|
||||
for i, dedupe_tuple in enumerate(dedupe_tuples)
|
||||
for i, nodes in enumerate(extracted_nodes)
|
||||
]
|
||||
)
|
||||
|
||||
# Collect all duplicate pairs sorted by uuid
|
||||
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
|
||||
per_episode_uuid_maps: list[dict[str, str]] = []
|
||||
duplicate_pairs: list[tuple[str, str]] = []
|
||||
for _, _, duplicates in bulk_node_resolutions:
|
||||
for duplicate in duplicates:
|
||||
n, m = duplicate
|
||||
duplicate_pairs.append((n.uuid, m.uuid))
|
||||
|
||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
||||
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
|
||||
first_pass_results, episode_tuples, strict=True
|
||||
):
|
||||
episode_resolutions.append((episode.uuid, resolved_nodes))
|
||||
per_episode_uuid_maps.append(uuid_map)
|
||||
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
|
||||
|
||||
node_uuid_map: dict[str, EntityNode] = {
|
||||
node.uuid: node for nodes in extracted_nodes for node in nodes
|
||||
}
|
||||
canonical_nodes: dict[str, EntityNode] = {}
|
||||
for _, resolved_nodes in episode_resolutions:
|
||||
for node in resolved_nodes:
|
||||
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
|
||||
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
|
||||
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
|
||||
# but if batches grow significantly we should switch to an incremental index or chunked
|
||||
# processing.
|
||||
if not canonical_nodes:
|
||||
canonical_nodes[node.uuid] = node
|
||||
continue
|
||||
|
||||
existing_candidates = list(canonical_nodes.values())
|
||||
normalized = _normalize_string_exact(node.name)
|
||||
exact_match = next(
|
||||
(
|
||||
candidate
|
||||
for candidate in existing_candidates
|
||||
if _normalize_string_exact(candidate.name) == normalized
|
||||
),
|
||||
None,
|
||||
)
|
||||
if exact_match is not None:
|
||||
if exact_match.uuid != node.uuid:
|
||||
duplicate_pairs.append((node.uuid, exact_match.uuid))
|
||||
continue
|
||||
|
||||
indexes = _build_candidate_indexes(existing_candidates)
|
||||
state = DedupResolutionState(
|
||||
resolved_nodes=[None],
|
||||
uuid_map={},
|
||||
unresolved_indices=[],
|
||||
)
|
||||
_resolve_with_similarity([node], indexes, state)
|
||||
|
||||
resolved = state.resolved_nodes[0]
|
||||
if resolved is None:
|
||||
canonical_nodes[node.uuid] = node
|
||||
continue
|
||||
|
||||
canonical_uuid = resolved.uuid
|
||||
canonical_nodes.setdefault(canonical_uuid, resolved)
|
||||
if canonical_uuid != node.uuid:
|
||||
duplicate_pairs.append((node.uuid, canonical_uuid))
|
||||
|
||||
union_pairs: list[tuple[str, str]] = []
|
||||
for uuid_map in per_episode_uuid_maps:
|
||||
union_pairs.extend(uuid_map.items())
|
||||
union_pairs.extend(duplicate_pairs)
|
||||
|
||||
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
|
||||
|
||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||
for i, nodes in enumerate(extracted_nodes):
|
||||
episode = episode_tuples[i][0]
|
||||
for episode_uuid, resolved_nodes in episode_resolutions:
|
||||
deduped_nodes: list[EntityNode] = []
|
||||
seen: set[str] = set()
|
||||
for node in resolved_nodes:
|
||||
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
|
||||
if canonical_uuid in seen:
|
||||
continue
|
||||
seen.add(canonical_uuid)
|
||||
canonical_node = canonical_nodes.get(canonical_uuid)
|
||||
if canonical_node is None:
|
||||
logger.error(
|
||||
'Canonical node %s missing during batch dedupe; falling back to %s',
|
||||
canonical_uuid,
|
||||
node.uuid,
|
||||
)
|
||||
canonical_node = node
|
||||
deduped_nodes.append(canonical_node)
|
||||
|
||||
nodes_by_episode[episode.uuid] = [
|
||||
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
||||
]
|
||||
nodes_by_episode[episode_uuid] = deduped_nodes
|
||||
|
||||
return nodes_by_episode, compressed_map
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ import math
|
|||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from hashlib import blake2b
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -164,6 +164,7 @@ class DedupResolutionState:
|
|||
resolved_nodes: list[EntityNode | None]
|
||||
uuid_map: dict[str, str]
|
||||
unresolved_indices: list[int]
|
||||
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
|
||||
|
|
@ -213,6 +214,8 @@ def _resolve_with_similarity(
|
|||
match = existing_matches[0]
|
||||
state.resolved_nodes[idx] = match
|
||||
state.uuid_map[node.uuid] = match.uuid
|
||||
if match.uuid != node.uuid:
|
||||
state.duplicate_pairs.append((node, match))
|
||||
continue
|
||||
if len(existing_matches) > 1:
|
||||
state.unresolved_indices.append(idx)
|
||||
|
|
@ -236,6 +239,8 @@ def _resolve_with_similarity(
|
|||
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
|
||||
state.resolved_nodes[idx] = best_candidate
|
||||
state.uuid_map[node.uuid] = best_candidate.uuid
|
||||
if best_candidate.uuid != node.uuid:
|
||||
state.duplicate_pairs.append((node, best_candidate))
|
||||
continue
|
||||
|
||||
state.unresolved_indices.append(idx)
|
||||
|
|
|
|||
|
|
@ -306,6 +306,8 @@ async def _resolve_with_llm(
|
|||
|
||||
state.resolved_nodes[original_index] = resolved_node
|
||||
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||
if resolved_node.uuid != extracted_node.uuid:
|
||||
state.duplicate_pairs.append((extracted_node, resolved_node))
|
||||
|
||||
|
||||
async def resolve_extracted_nodes(
|
||||
|
|
@ -332,7 +334,6 @@ async def resolve_extracted_nodes(
|
|||
uuid_map={},
|
||||
unresolved_indices=[],
|
||||
)
|
||||
node_duplicates: list[tuple[EntityNode, EntityNode]] = []
|
||||
|
||||
_resolve_with_similarity(extracted_nodes, indexes, state)
|
||||
|
||||
|
|
@ -359,7 +360,7 @@ async def resolve_extracted_nodes(
|
|||
|
||||
new_node_duplicates: list[
|
||||
tuple[EntityNode, EntityNode]
|
||||
] = await filter_existing_duplicate_of_edges(driver, node_duplicates)
|
||||
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
|
||||
|
||||
return (
|
||||
[node for node in state.resolved_nodes if node is not None],
|
||||
|
|
|
|||
232
tests/utils/maintenance/test_bulk_utils.py
Normal file
232
tests/utils/maintenance/test_bulk_utils.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
from collections import deque
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils import bulk_utils
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
|
||||
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
|
||||
return EpisodicNode(
|
||||
name=f'episode-{uuid_suffix}',
|
||||
group_id=group_id,
|
||||
labels=[],
|
||||
source=EpisodeType.message,
|
||||
content='content',
|
||||
source_description='test',
|
||||
created_at=utc_now(),
|
||||
valid_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
def _make_clients() -> GraphitiClients:
|
||||
driver = MagicMock()
|
||||
embedder = MagicMock()
|
||||
cross_encoder = MagicMock()
|
||||
llm_client = MagicMock()
|
||||
|
||||
return GraphitiClients.model_construct( # bypass validation to allow test doubles
|
||||
driver=driver,
|
||||
embedder=embedder,
|
||||
cross_encoder=cross_encoder,
|
||||
llm_client=llm_client,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
|
||||
clients = _make_clients()
|
||||
|
||||
episode_one = _make_episode('1')
|
||||
episode_two = _make_episode('2')
|
||||
|
||||
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
|
||||
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
|
||||
|
||||
canonical = extracted_one
|
||||
|
||||
call_queue = deque()
|
||||
|
||||
async def fake_resolve(
|
||||
clients_arg,
|
||||
nodes_arg,
|
||||
episode_arg,
|
||||
previous_episodes_arg,
|
||||
entity_types_arg,
|
||||
existing_nodes_override=None,
|
||||
):
|
||||
call_queue.append(existing_nodes_override)
|
||||
|
||||
if nodes_arg == [extracted_one]:
|
||||
return [canonical], {canonical.uuid: canonical.uuid}, []
|
||||
|
||||
assert nodes_arg == [extracted_two]
|
||||
assert existing_nodes_override is None
|
||||
|
||||
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
|
||||
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
|
||||
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[[extracted_one], [extracted_two]],
|
||||
[(episode_one, []), (episode_two, [])],
|
||||
)
|
||||
|
||||
assert len(call_queue) == 2
|
||||
assert call_queue[0] is None
|
||||
assert call_queue[1] is None
|
||||
|
||||
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
|
||||
clients = _make_clients()
|
||||
|
||||
resolve_mock = AsyncMock()
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
assert nodes_by_episode == {}
|
||||
assert compressed_map == {}
|
||||
resolve_mock.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
|
||||
clients = _make_clients()
|
||||
|
||||
episode = _make_episode('solo')
|
||||
extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
|
||||
|
||||
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[[extracted]],
|
||||
[(episode, [])],
|
||||
)
|
||||
|
||||
assert nodes_by_episode == {episode.uuid: [extracted]}
|
||||
assert compressed_map == {extracted.uuid: extracted.uuid}
|
||||
resolve_mock.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
|
||||
clients = _make_clients()
|
||||
|
||||
episode_one = _make_episode('one')
|
||||
episode_two = _make_episode('two')
|
||||
|
||||
extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
|
||||
extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
|
||||
|
||||
canonical = extracted_one
|
||||
alias = extracted_two
|
||||
|
||||
async def fake_resolve(
|
||||
clients_arg,
|
||||
nodes_arg,
|
||||
episode_arg,
|
||||
previous_episodes_arg,
|
||||
entity_types_arg,
|
||||
existing_nodes_override=None,
|
||||
):
|
||||
if nodes_arg == [extracted_one]:
|
||||
return [canonical], {canonical.uuid: canonical.uuid}, []
|
||||
assert nodes_arg == [extracted_two]
|
||||
return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
|
||||
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
|
||||
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[[extracted_one], [extracted_two]],
|
||||
[(episode_one, []), (episode_two, [])],
|
||||
)
|
||||
|
||||
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||
assert compressed_map.get(alias.uuid) == canonical.uuid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
|
||||
clients = _make_clients()
|
||||
|
||||
episode = _make_episode('missing')
|
||||
extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
|
||||
|
||||
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
|
||||
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
|
||||
|
||||
with caplog.at_level('WARNING'):
|
||||
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
|
||||
clients,
|
||||
[[extracted]],
|
||||
[(episode, [])],
|
||||
)
|
||||
|
||||
assert nodes_by_episode[episode.uuid] == [extracted]
|
||||
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
|
||||
assert any('Canonical node missing' in rec.message for rec in caplog.records)
|
||||
|
||||
|
||||
def test_build_directed_uuid_map_empty():
|
||||
assert bulk_utils._build_directed_uuid_map([]) == {}
|
||||
|
||||
|
||||
def test_build_directed_uuid_map_chain():
|
||||
mapping = bulk_utils._build_directed_uuid_map(
|
||||
[
|
||||
('a', 'b'),
|
||||
('b', 'c'),
|
||||
]
|
||||
)
|
||||
|
||||
assert mapping['a'] == 'c'
|
||||
assert mapping['b'] == 'c'
|
||||
assert mapping['c'] == 'c'
|
||||
|
||||
|
||||
def test_build_directed_uuid_map_preserves_direction():
|
||||
mapping = bulk_utils._build_directed_uuid_map(
|
||||
[
|
||||
('alias', 'canonical'),
|
||||
]
|
||||
)
|
||||
|
||||
assert mapping['alias'] == 'canonical'
|
||||
assert mapping['canonical'] == 'canonical'
|
||||
|
||||
|
||||
def test_resolve_edge_pointers_updates_sources():
|
||||
created_at = utc_now()
|
||||
edge = EntityEdge(
|
||||
name='knows',
|
||||
fact='fact',
|
||||
group_id='group',
|
||||
source_node_uuid='alias',
|
||||
target_node_uuid='target',
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
|
||||
|
||||
assert edge.source_node_uuid == 'canonical'
|
||||
assert edge.target_node_uuid == 'target'
|
||||
|
|
@ -257,6 +257,7 @@ def test_resolve_with_similarity_exact_match_updates_state():
|
|||
assert state.resolved_nodes[0].uuid == candidate.uuid
|
||||
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
||||
assert state.unresolved_indices == []
|
||||
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||
|
||||
|
||||
def test_resolve_with_similarity_low_entropy_defers_resolution():
|
||||
|
|
@ -274,6 +275,7 @@ def test_resolve_with_similarity_low_entropy_defers_resolution():
|
|||
|
||||
assert state.resolved_nodes[0] is None
|
||||
assert state.unresolved_indices == [0]
|
||||
assert state.duplicate_pairs == []
|
||||
|
||||
|
||||
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
|
||||
|
|
@ -288,6 +290,7 @@ def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
|
|||
|
||||
assert state.resolved_nodes[0] is None
|
||||
assert state.unresolved_indices == [0]
|
||||
assert state.duplicate_pairs == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -339,3 +342,4 @@ async def test_resolve_with_llm_updates_unresolved(monkeypatch):
|
|||
assert state.uuid_map[extracted.uuid] == candidate.uuid
|
||||
assert captured_context['existing_nodes'][0]['idx'] == 0
|
||||
assert isinstance(captured_context['existing_nodes'], list)
|
||||
assert state.duplicate_pairs == [(extracted, candidate)]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue