diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index ded7cc0e..330e960d 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -45,6 +45,12 @@ from graphiti_core.models.nodes.node_db_queries import ( ) from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings +from graphiti_core.utils.maintenance.dedup_helpers import ( + DedupResolutionState, + _build_candidate_indexes, + _normalize_string_exact, + _resolve_with_similarity, +) from graphiti_core.utils.maintenance.edge_operations import ( extract_edges, resolve_extracted_edge, @@ -266,52 +272,88 @@ async def dedupe_nodes_bulk( episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]], entity_types: dict[str, type[BaseModel]] | None = None, ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]: - canonical_nodes: dict[str, EntityNode] = {} + """Resolve entity duplicates across an in-memory batch using a two-pass strategy. + + 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is + reconciled against the live graph just like the non-batch flow. + 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch + duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers + can apply to edges and persistence. + """ + + first_pass_results = await semaphore_gather( + *[ + resolve_extracted_nodes( + clients, + nodes, + episode_tuples[i][0], + episode_tuples[i][1], + entity_types, + ) + for i, nodes in enumerate(extracted_nodes) + ] + ) + episode_resolutions: list[tuple[str, list[EntityNode]]] = [] per_episode_uuid_maps: list[dict[str, str]] = [] duplicate_pairs: list[tuple[str, str]] = [] - for nodes, (episode, previous_episodes) in zip(extracted_nodes, episode_tuples, strict=True): - existing_override = list(canonical_nodes.values()) if canonical_nodes else None - - resolved_nodes, uuid_map, duplicates = await resolve_extracted_nodes( - clients, - nodes, - episode, - previous_episodes, - entity_types, - existing_nodes_override=existing_override, - ) - - per_episode_uuid_maps.append(uuid_map) + for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip( + first_pass_results, episode_tuples, strict=True + ): episode_resolutions.append((episode.uuid, resolved_nodes)) - - for node in resolved_nodes: - canonical_nodes[node.uuid] = node - + per_episode_uuid_maps.append(uuid_map) duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates) - uuid_chains: dict[str, str] = {} + canonical_nodes: dict[str, EntityNode] = {} + for _, resolved_nodes in episode_resolutions: + for node in resolved_nodes: + if not canonical_nodes: + canonical_nodes[node.uuid] = node + continue + + existing_candidates = list(canonical_nodes.values()) + normalized = _normalize_string_exact(node.name) + exact_match = next( + ( + candidate + for candidate in existing_candidates + if _normalize_string_exact(candidate.name) == normalized + ), + None, + ) + if exact_match is not None: + if exact_match.uuid != node.uuid: + duplicate_pairs.append((node.uuid, exact_match.uuid)) + continue + + indexes = _build_candidate_indexes(existing_candidates) + state = DedupResolutionState( + resolved_nodes=[None], + uuid_map={}, + unresolved_indices=[], + ) + _resolve_with_similarity([node], indexes, state) + + resolved = state.resolved_nodes[0] + if resolved is None: + canonical_nodes[node.uuid] = node + continue + + canonical_uuid = resolved.uuid + canonical_nodes.setdefault(canonical_uuid, resolved) + if canonical_uuid != node.uuid: + duplicate_pairs.append((node.uuid, canonical_uuid)) + + union_pairs: list[tuple[str, str]] = [] for uuid_map in per_episode_uuid_maps: - uuid_chains.update(uuid_map) - for duplicate_uuid, canonical_uuid in duplicate_pairs: - uuid_chains[duplicate_uuid] = canonical_uuid + union_pairs.extend(uuid_map.items()) + union_pairs.extend(duplicate_pairs) - def _resolve_uuid(uuid: str) -> str: - seen: set[str] = set() - current = uuid - while True: - target = uuid_chains.get(current) - if target is None or target == current: - return current if target is None else target - if current in seen: - return current - seen.add(current) - current = target - - compressed_map: dict[str, str] = {uuid: _resolve_uuid(uuid) for uuid in uuid_chains} - for canonical_uuid in canonical_nodes: - compressed_map.setdefault(canonical_uuid, canonical_uuid) + compressed_map: dict[str, str] = compress_uuid_map(union_pairs) + for source_uuid, target_uuid in union_pairs: + canonical_uuid = compressed_map.get(target_uuid, target_uuid) + compressed_map[source_uuid] = canonical_uuid nodes_by_episode: dict[str, list[EntityNode]] = {} for episode_uuid, resolved_nodes in episode_resolutions: diff --git a/tests/utils/maintenance/test_bulk_utils.py b/tests/utils/maintenance/test_bulk_utils.py index 83c1eb19..f803d471 100644 --- a/tests/utils/maintenance/test_bulk_utils.py +++ b/tests/utils/maintenance/test_bulk_utils.py @@ -65,8 +65,7 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch): return [canonical], {canonical.uuid: canonical.uuid}, [] assert nodes_arg == [extracted_two] - assert existing_nodes_override is not None - assert existing_nodes_override[0] is canonical + assert existing_nodes_override is None return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)] @@ -80,7 +79,7 @@ async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch): assert len(call_queue) == 2 assert call_queue[0] is None - assert list(call_queue[1]) == [canonical] + assert call_queue[1] is None assert nodes_by_episode[episode_one.uuid] == [canonical] assert nodes_by_episode[episode_two.uuid] == [canonical]