bulk utils update (#727)

* bulk utils update

* remove unused imports

* edge model type guard
This commit is contained in:
Preston Rasmussen 2025-07-15 11:42:08 -04:00 committed by GitHub
parent 5821776512
commit 62df6624d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 51 additions and 65 deletions

View file

@ -78,7 +78,7 @@ async def main(use_bulk: bool = False):
group_id = str(uuid4())
raw_episodes: list[RawEpisode] = []
for i, message in enumerate(messages[3:7]):
for i, message in enumerate(messages[3:10]):
raw_episodes.append(
RawEpisode(
name=f'Message {i}',

View file

@ -263,26 +263,14 @@ async def dedupe_nodes_bulk(
)
# Collect all duplicate pairs sorted by uuid
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = []
duplicate_pairs: list[tuple[str, str]] = []
for _, _, duplicates in bulk_node_resolutions:
for duplicate in duplicates:
n, m = duplicate
if n.uuid < m.uuid:
duplicate_pairs.append((n, m))
else:
duplicate_pairs.append((m, n))
# Build full deduplication map
duplicate_map: dict[str, str] = {}
for value, key in duplicate_pairs:
if key.uuid in duplicate_map:
existing_value = duplicate_map[key.uuid]
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
else:
duplicate_map[key.uuid] = value.uuid
duplicate_pairs.append((n.uuid, m.uuid))
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
node_uuid_map: dict[str, EntityNode] = {
node.uuid: node for nodes in extracted_nodes for node in nodes
@ -357,26 +345,14 @@ async def dedupe_edges_bulk(
]
)
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = []
duplicate_pairs: list[tuple[str, str]] = []
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
episode, edge, candidates = dedupe_tuples[i]
for duplicate in duplicates:
if edge.uuid < duplicate.uuid:
duplicate_pairs.append((edge, duplicate))
else:
duplicate_pairs.append((duplicate, edge))
# Build full deduplication map
duplicate_map: dict[str, str] = {}
for value, key in duplicate_pairs:
if key.uuid in duplicate_map:
existing_value = duplicate_map[key.uuid]
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
else:
duplicate_map[key.uuid] = value.uuid
duplicate_pairs.append((edge.uuid, duplicate.uuid))
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
edge_uuid_map: dict[str, EntityEdge] = {
edge.uuid: edge for edges in extracted_edges for edge in edges
@ -393,36 +369,44 @@ async def dedupe_edges_bulk(
return edges_by_episode
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
compressed_map = {}
class UnionFind:
def __init__(self, elements):
# start each element in its own set
self.parent = {e: e for e in elements}
def find_min_uuid(start: str) -> str:
path = []
visited = set()
curr = start
def find(self, x):
# pathcompression
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
while curr in uuid_map and curr not in visited:
visited.add(curr)
path.append(curr)
curr = uuid_map[curr]
def union(self, a, b):
ra, rb = self.find(a), self.find(b)
if ra == rb:
return
# attach the lexicographically larger root under the smaller
if ra < rb:
self.parent[rb] = ra
else:
self.parent[ra] = rb
# Also include the last resolved value (could be outside the map)
path.append(curr)
# Resolve to lex smallest UUID in the path
min_uuid = min(path)
def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
"""
all_ids: iterable of all entity IDs (strings)
duplicate_pairs: iterable of (id1, id2) pairs
returns: dict mapping each id -> lexicographically smallest id in its duplicate set
"""
all_uuids = set()
for pair in duplicate_pairs:
all_uuids.add(pair[0])
all_uuids.add(pair[1])
# Assign all UUIDs in the path to the min_uuid
for node in path:
compressed_map[node] = min_uuid
return min_uuid
for key in uuid_map:
if key not in compressed_map:
find_min_uuid(key)
return compressed_map
uf = UnionFind(all_uuids)
for a, b in duplicate_pairs:
uf.union(a, b)
# ensure full pathcompression before mapping
return {uuid: uf.find(uuid) for uuid in all_uuids}
E = typing.TypeVar('E', bound=Edge)

View file

@ -444,14 +444,14 @@ async def resolve_extracted_edge(
}
edge_model = edge_types.get(fact_type)
if edge_model is not None and len(edge_model.model_fields) != 0:
edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
response_model=edge_model, # type: ignore
model_size=ModelSize.small,
)
edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
response_model=edge_model, # type: ignore
model_size=ModelSize.small,
)
resolved_edge.attributes = edge_attributes_response
resolved_edge.attributes = edge_attributes_response
end = time()
logger.debug(

View file

@ -277,10 +277,12 @@ async def resolve_extracted_nodes(
uuid_map[extracted_node.uuid] = resolved_node.uuid
duplicates: list[int] = resolution.get('duplicates', [])
if duplicate_idx not in duplicates and duplicate_idx > -1:
duplicates.append(duplicate_idx)
for idx in duplicates:
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
node_duplicates.append((resolved_node, existing_node))
node_duplicates.append((extracted_node, existing_node))
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.17.2"
version = "0.17.3"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },