reveret to concurrent dedup with fanout and then reconcilation

This commit is contained in:
Daniel Chalef 2025-09-25 15:08:32 -07:00
parent 18550c84ac
commit ecab825684
2 changed files with 81 additions and 40 deletions

View file

@ -45,6 +45,12 @@ from graphiti_core.models.nodes.node_db_queries import (
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupResolutionState,
_build_candidate_indexes,
_normalize_string_exact,
_resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
resolve_extracted_edge,
@ -266,52 +272,88 @@ async def dedupe_nodes_bulk(
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
canonical_nodes: dict[str, EntityNode] = {}
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
reconciled against the live graph just like the non-batch flow.
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
can apply to edges and persistence.
"""
first_pass_results = await semaphore_gather(
*[
resolve_extracted_nodes(
clients,
nodes,
episode_tuples[i][0],
episode_tuples[i][1],
entity_types,
)
for i, nodes in enumerate(extracted_nodes)
]
)
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
per_episode_uuid_maps: list[dict[str, str]] = []
duplicate_pairs: list[tuple[str, str]] = []
for nodes, (episode, previous_episodes) in zip(extracted_nodes, episode_tuples, strict=True):
existing_override = list(canonical_nodes.values()) if canonical_nodes else None
resolved_nodes, uuid_map, duplicates = await resolve_extracted_nodes(
clients,
nodes,
episode,
previous_episodes,
entity_types,
existing_nodes_override=existing_override,
)
per_episode_uuid_maps.append(uuid_map)
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
first_pass_results, episode_tuples, strict=True
):
episode_resolutions.append((episode.uuid, resolved_nodes))
for node in resolved_nodes:
canonical_nodes[node.uuid] = node
per_episode_uuid_maps.append(uuid_map)
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
uuid_chains: dict[str, str] = {}
canonical_nodes: dict[str, EntityNode] = {}
for _, resolved_nodes in episode_resolutions:
for node in resolved_nodes:
if not canonical_nodes:
canonical_nodes[node.uuid] = node
continue
existing_candidates = list(canonical_nodes.values())
normalized = _normalize_string_exact(node.name)
exact_match = next(
(
candidate
for candidate in existing_candidates
if _normalize_string_exact(candidate.name) == normalized
),
None,
)
if exact_match is not None:
if exact_match.uuid != node.uuid:
duplicate_pairs.append((node.uuid, exact_match.uuid))
continue
indexes = _build_candidate_indexes(existing_candidates)
state = DedupResolutionState(
resolved_nodes=[None],
uuid_map={},
unresolved_indices=[],
)
_resolve_with_similarity([node], indexes, state)
resolved = state.resolved_nodes[0]
if resolved is None:
canonical_nodes[node.uuid] = node
continue
canonical_uuid = resolved.uuid
canonical_nodes.setdefault(canonical_uuid, resolved)
if canonical_uuid != node.uuid:
duplicate_pairs.append((node.uuid, canonical_uuid))
union_pairs: list[tuple[str, str]] = []
for uuid_map in per_episode_uuid_maps:
uuid_chains.update(uuid_map)
for duplicate_uuid, canonical_uuid in duplicate_pairs:
uuid_chains[duplicate_uuid] = canonical_uuid
union_pairs.extend(uuid_map.items())
union_pairs.extend(duplicate_pairs)
def _resolve_uuid(uuid: str) -> str:
seen: set[str] = set()
current = uuid
while True:
target = uuid_chains.get(current)
if target is None or target == current:
return current if target is None else target
if current in seen:
return current
seen.add(current)
current = target
compressed_map: dict[str, str] = {uuid: _resolve_uuid(uuid) for uuid in uuid_chains}
for canonical_uuid in canonical_nodes:
compressed_map.setdefault(canonical_uuid, canonical_uuid)
compressed_map: dict[str, str] = compress_uuid_map(union_pairs)
for source_uuid, target_uuid in union_pairs:
canonical_uuid = compressed_map.get(target_uuid, target_uuid)
compressed_map[source_uuid] = canonical_uuid
nodes_by_episode: dict[str, list[EntityNode]] = {}
for episode_uuid, resolved_nodes in episode_resolutions:

View file

@ -65,8 +65,7 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
assert existing_nodes_override is not None
assert existing_nodes_override[0] is canonical
assert existing_nodes_override is None
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
@ -80,7 +79,7 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
assert len(call_queue) == 2
assert call_queue[0] is None
assert list(call_queue[1]) == [canonical]
assert call_queue[1] is None
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]