From 62df6624d43d0e07e6ef7f2d34ccf3acf5dac775 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:42:08 -0400 Subject: [PATCH] bulk utils update (#727) * bulk utils update * remove unused imports * edge model type guard --- examples/podcast/podcast_runner.py | 2 +- graphiti_core/utils/bulk_utils.py | 94 ++++++++----------- .../utils/maintenance/edge_operations.py | 14 +-- .../utils/maintenance/node_operations.py | 4 +- uv.lock | 2 +- 5 files changed, 51 insertions(+), 65 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index c813a8bb..e4573a70 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -78,7 +78,7 @@ async def main(use_bulk: bool = False): group_id = str(uuid4()) raw_episodes: list[RawEpisode] = [] - for i, message in enumerate(messages[3:7]): + for i, message in enumerate(messages[3:10]): raw_episodes.append( RawEpisode( name=f'Message {i}', diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 9987e044..c10c070d 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -263,26 +263,14 @@ async def dedupe_nodes_bulk( ) # Collect all duplicate pairs sorted by uuid - duplicate_pairs: list[tuple[EntityNode, EntityNode]] = [] + duplicate_pairs: list[tuple[str, str]] = [] for _, _, duplicates in bulk_node_resolutions: for duplicate in duplicates: n, m = duplicate - if n.uuid < m.uuid: - duplicate_pairs.append((n, m)) - else: - duplicate_pairs.append((m, n)) - - # Build full deduplication map - duplicate_map: dict[str, str] = {} - for value, key in duplicate_pairs: - if key.uuid in duplicate_map: - existing_value = duplicate_map[key.uuid] - duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value - else: - duplicate_map[key.uuid] = value.uuid + duplicate_pairs.append((n.uuid, m.uuid)) # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid) - compressed_map: dict[str, str] = compress_uuid_map(duplicate_map) + compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs) node_uuid_map: dict[str, EntityNode] = { node.uuid: node for nodes in extracted_nodes for node in nodes @@ -357,26 +345,14 @@ async def dedupe_edges_bulk( ] ) - duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = [] + duplicate_pairs: list[tuple[str, str]] = [] for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions): episode, edge, candidates = dedupe_tuples[i] for duplicate in duplicates: - if edge.uuid < duplicate.uuid: - duplicate_pairs.append((edge, duplicate)) - else: - duplicate_pairs.append((duplicate, edge)) - - # Build full deduplication map - duplicate_map: dict[str, str] = {} - for value, key in duplicate_pairs: - if key.uuid in duplicate_map: - existing_value = duplicate_map[key.uuid] - duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value - else: - duplicate_map[key.uuid] = value.uuid + duplicate_pairs.append((edge.uuid, duplicate.uuid)) # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid) - compressed_map: dict[str, str] = compress_uuid_map(duplicate_map) + compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs) edge_uuid_map: dict[str, EntityEdge] = { edge.uuid: edge for edges in extracted_edges for edge in edges @@ -393,36 +369,44 @@ async def dedupe_edges_bulk( return edges_by_episode -def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]: - compressed_map = {} +class UnionFind: + def __init__(self, elements): + # start each element in its own set + self.parent = {e: e for e in elements} - def find_min_uuid(start: str) -> str: - path = [] - visited = set() - curr = start + def find(self, x): + # path‐compression + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] - while curr in uuid_map and curr not in visited: - visited.add(curr) - path.append(curr) - curr = uuid_map[curr] + def union(self, a, b): + ra, rb = self.find(a), self.find(b) + if ra == rb: + return + # attach the lexicographically larger root under the smaller + if ra < rb: + self.parent[rb] = ra + else: + self.parent[ra] = rb - # Also include the last resolved value (could be outside the map) - path.append(curr) - # Resolve to lex smallest UUID in the path - min_uuid = min(path) +def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]: + """ + all_ids: iterable of all entity IDs (strings) + duplicate_pairs: iterable of (id1, id2) pairs + returns: dict mapping each id -> lexicographically smallest id in its duplicate set + """ + all_uuids = set() + for pair in duplicate_pairs: + all_uuids.add(pair[0]) + all_uuids.add(pair[1]) - # Assign all UUIDs in the path to the min_uuid - for node in path: - compressed_map[node] = min_uuid - - return min_uuid - - for key in uuid_map: - if key not in compressed_map: - find_min_uuid(key) - - return compressed_map + uf = UnionFind(all_uuids) + for a, b in duplicate_pairs: + uf.union(a, b) + # ensure full path‐compression before mapping + return {uuid: uf.find(uuid) for uuid in all_uuids} E = typing.TypeVar('E', bound=Edge) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 03415dce..bd030400 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -444,14 +444,14 @@ async def resolve_extracted_edge( } edge_model = edge_types.get(fact_type) + if edge_model is not None and len(edge_model.model_fields) != 0: + edge_attributes_response = await llm_client.generate_response( + prompt_library.extract_edges.extract_attributes(edge_attributes_context), + response_model=edge_model, # type: ignore + model_size=ModelSize.small, + ) - edge_attributes_response = await llm_client.generate_response( - prompt_library.extract_edges.extract_attributes(edge_attributes_context), - response_model=edge_model, # type: ignore - model_size=ModelSize.small, - ) - - resolved_edge.attributes = edge_attributes_response + resolved_edge.attributes = edge_attributes_response end = time() logger.debug( diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 8a15fbcb..1588d042 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -277,10 +277,12 @@ async def resolve_extracted_nodes( uuid_map[extracted_node.uuid] = resolved_node.uuid duplicates: list[int] = resolution.get('duplicates', []) + if duplicate_idx not in duplicates and duplicate_idx > -1: + duplicates.append(duplicate_idx) for idx in duplicates: existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node - node_duplicates.append((resolved_node, existing_node)) + node_duplicates.append((extracted_node, existing_node)) logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') diff --git a/uv.lock b/uv.lock index 77ec9b23..dff9047d 100644 --- a/uv.lock +++ b/uv.lock @@ -746,7 +746,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.17.2" +version = "0.17.3" source = { editable = "." } dependencies = [ { name = "diskcache" },