reveret to concurrent dedup with fanout and then reconcilation
This commit is contained in:
parent
18550c84ac
commit
ecab825684
2 changed files with 81 additions and 40 deletions
|
|
@ -45,6 +45,12 @@ from graphiti_core.models.nodes.node_db_queries import (
|
|||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||||
from graphiti_core.utils.maintenance.dedup_helpers import (
|
||||
DedupResolutionState,
|
||||
_build_candidate_indexes,
|
||||
_normalize_string_exact,
|
||||
_resolve_with_similarity,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
resolve_extracted_edge,
|
||||
|
|
@ -266,52 +272,88 @@ async def dedupe_nodes_bulk(
|
|||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||||
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
||||
canonical_nodes: dict[str, EntityNode] = {}
|
||||
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
|
||||
|
||||
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
|
||||
reconciled against the live graph just like the non-batch flow.
|
||||
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
|
||||
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
|
||||
can apply to edges and persistence.
|
||||
"""
|
||||
|
||||
first_pass_results = await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_nodes(
|
||||
clients,
|
||||
nodes,
|
||||
episode_tuples[i][0],
|
||||
episode_tuples[i][1],
|
||||
entity_types,
|
||||
)
|
||||
for i, nodes in enumerate(extracted_nodes)
|
||||
]
|
||||
)
|
||||
|
||||
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
|
||||
per_episode_uuid_maps: list[dict[str, str]] = []
|
||||
duplicate_pairs: list[tuple[str, str]] = []
|
||||
|
||||
for nodes, (episode, previous_episodes) in zip(extracted_nodes, episode_tuples, strict=True):
|
||||
existing_override = list(canonical_nodes.values()) if canonical_nodes else None
|
||||
|
||||
resolved_nodes, uuid_map, duplicates = await resolve_extracted_nodes(
|
||||
clients,
|
||||
nodes,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
existing_nodes_override=existing_override,
|
||||
)
|
||||
|
||||
per_episode_uuid_maps.append(uuid_map)
|
||||
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
|
||||
first_pass_results, episode_tuples, strict=True
|
||||
):
|
||||
episode_resolutions.append((episode.uuid, resolved_nodes))
|
||||
|
||||
for node in resolved_nodes:
|
||||
canonical_nodes[node.uuid] = node
|
||||
|
||||
per_episode_uuid_maps.append(uuid_map)
|
||||
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
|
||||
|
||||
uuid_chains: dict[str, str] = {}
|
||||
canonical_nodes: dict[str, EntityNode] = {}
|
||||
for _, resolved_nodes in episode_resolutions:
|
||||
for node in resolved_nodes:
|
||||
if not canonical_nodes:
|
||||
canonical_nodes[node.uuid] = node
|
||||
continue
|
||||
|
||||
existing_candidates = list(canonical_nodes.values())
|
||||
normalized = _normalize_string_exact(node.name)
|
||||
exact_match = next(
|
||||
(
|
||||
candidate
|
||||
for candidate in existing_candidates
|
||||
if _normalize_string_exact(candidate.name) == normalized
|
||||
),
|
||||
None,
|
||||
)
|
||||
if exact_match is not None:
|
||||
if exact_match.uuid != node.uuid:
|
||||
duplicate_pairs.append((node.uuid, exact_match.uuid))
|
||||
continue
|
||||
|
||||
indexes = _build_candidate_indexes(existing_candidates)
|
||||
state = DedupResolutionState(
|
||||
resolved_nodes=[None],
|
||||
uuid_map={},
|
||||
unresolved_indices=[],
|
||||
)
|
||||
_resolve_with_similarity([node], indexes, state)
|
||||
|
||||
resolved = state.resolved_nodes[0]
|
||||
if resolved is None:
|
||||
canonical_nodes[node.uuid] = node
|
||||
continue
|
||||
|
||||
canonical_uuid = resolved.uuid
|
||||
canonical_nodes.setdefault(canonical_uuid, resolved)
|
||||
if canonical_uuid != node.uuid:
|
||||
duplicate_pairs.append((node.uuid, canonical_uuid))
|
||||
|
||||
union_pairs: list[tuple[str, str]] = []
|
||||
for uuid_map in per_episode_uuid_maps:
|
||||
uuid_chains.update(uuid_map)
|
||||
for duplicate_uuid, canonical_uuid in duplicate_pairs:
|
||||
uuid_chains[duplicate_uuid] = canonical_uuid
|
||||
union_pairs.extend(uuid_map.items())
|
||||
union_pairs.extend(duplicate_pairs)
|
||||
|
||||
def _resolve_uuid(uuid: str) -> str:
|
||||
seen: set[str] = set()
|
||||
current = uuid
|
||||
while True:
|
||||
target = uuid_chains.get(current)
|
||||
if target is None or target == current:
|
||||
return current if target is None else target
|
||||
if current in seen:
|
||||
return current
|
||||
seen.add(current)
|
||||
current = target
|
||||
|
||||
compressed_map: dict[str, str] = {uuid: _resolve_uuid(uuid) for uuid in uuid_chains}
|
||||
for canonical_uuid in canonical_nodes:
|
||||
compressed_map.setdefault(canonical_uuid, canonical_uuid)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(union_pairs)
|
||||
for source_uuid, target_uuid in union_pairs:
|
||||
canonical_uuid = compressed_map.get(target_uuid, target_uuid)
|
||||
compressed_map[source_uuid] = canonical_uuid
|
||||
|
||||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||||
for episode_uuid, resolved_nodes in episode_resolutions:
|
||||
|
|
|
|||
|
|
@ -65,8 +65,7 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
|
|||
return [canonical], {canonical.uuid: canonical.uuid}, []
|
||||
|
||||
assert nodes_arg == [extracted_two]
|
||||
assert existing_nodes_override is not None
|
||||
assert existing_nodes_override[0] is canonical
|
||||
assert existing_nodes_override is None
|
||||
|
||||
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
|
||||
|
||||
|
|
@ -80,7 +79,7 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
|
|||
|
||||
assert len(call_queue) == 2
|
||||
assert call_queue[0] is None
|
||||
assert list(call_queue[1]) == [canonical]
|
||||
assert call_queue[1] is None
|
||||
|
||||
assert nodes_by_episode[episode_one.uuid] == [canonical]
|
||||
assert nodes_by_episode[episode_two.uuid] == [canonical]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue