bulk utils update (#727)

* bulk utils update

* remove unused imports

* edge model type guard
This commit is contained in:
Preston Rasmussen 2025-07-15 11:42:08 -04:00 committed by GitHub
parent 5821776512
commit 62df6624d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 51 additions and 65 deletions

View file

@ -78,7 +78,7 @@ async def main(use_bulk: bool = False):
group_id = str(uuid4()) group_id = str(uuid4())
raw_episodes: list[RawEpisode] = [] raw_episodes: list[RawEpisode] = []
for i, message in enumerate(messages[3:7]): for i, message in enumerate(messages[3:10]):
raw_episodes.append( raw_episodes.append(
RawEpisode( RawEpisode(
name=f'Message {i}', name=f'Message {i}',

View file

@ -263,26 +263,14 @@ async def dedupe_nodes_bulk(
) )
# Collect all duplicate pairs sorted by uuid # Collect all duplicate pairs sorted by uuid
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = [] duplicate_pairs: list[tuple[str, str]] = []
for _, _, duplicates in bulk_node_resolutions: for _, _, duplicates in bulk_node_resolutions:
for duplicate in duplicates: for duplicate in duplicates:
n, m = duplicate n, m = duplicate
if n.uuid < m.uuid: duplicate_pairs.append((n.uuid, m.uuid))
duplicate_pairs.append((n, m))
else:
duplicate_pairs.append((m, n))
# Build full deduplication map
duplicate_map: dict[str, str] = {}
for value, key in duplicate_pairs:
if key.uuid in duplicate_map:
existing_value = duplicate_map[key.uuid]
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
else:
duplicate_map[key.uuid] = value.uuid
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid) # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map) compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
node_uuid_map: dict[str, EntityNode] = { node_uuid_map: dict[str, EntityNode] = {
node.uuid: node for nodes in extracted_nodes for node in nodes node.uuid: node for nodes in extracted_nodes for node in nodes
@ -357,26 +345,14 @@ async def dedupe_edges_bulk(
] ]
) )
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = [] duplicate_pairs: list[tuple[str, str]] = []
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions): for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
episode, edge, candidates = dedupe_tuples[i] episode, edge, candidates = dedupe_tuples[i]
for duplicate in duplicates: for duplicate in duplicates:
if edge.uuid < duplicate.uuid: duplicate_pairs.append((edge.uuid, duplicate.uuid))
duplicate_pairs.append((edge, duplicate))
else:
duplicate_pairs.append((duplicate, edge))
# Build full deduplication map
duplicate_map: dict[str, str] = {}
for value, key in duplicate_pairs:
if key.uuid in duplicate_map:
existing_value = duplicate_map[key.uuid]
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
else:
duplicate_map[key.uuid] = value.uuid
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid) # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map) compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
edge_uuid_map: dict[str, EntityEdge] = { edge_uuid_map: dict[str, EntityEdge] = {
edge.uuid: edge for edges in extracted_edges for edge in edges edge.uuid: edge for edges in extracted_edges for edge in edges
@ -393,36 +369,44 @@ async def dedupe_edges_bulk(
return edges_by_episode return edges_by_episode
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]: class UnionFind:
compressed_map = {} def __init__(self, elements):
# start each element in its own set
self.parent = {e: e for e in elements}
def find_min_uuid(start: str) -> str: def find(self, x):
path = [] # pathcompression
visited = set() if self.parent[x] != x:
curr = start self.parent[x] = self.find(self.parent[x])
return self.parent[x]
while curr in uuid_map and curr not in visited: def union(self, a, b):
visited.add(curr) ra, rb = self.find(a), self.find(b)
path.append(curr) if ra == rb:
curr = uuid_map[curr] return
# attach the lexicographically larger root under the smaller
if ra < rb:
self.parent[rb] = ra
else:
self.parent[ra] = rb
# Also include the last resolved value (could be outside the map)
path.append(curr)
# Resolve to lex smallest UUID in the path def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
min_uuid = min(path) """
all_ids: iterable of all entity IDs (strings)
duplicate_pairs: iterable of (id1, id2) pairs
returns: dict mapping each id -> lexicographically smallest id in its duplicate set
"""
all_uuids = set()
for pair in duplicate_pairs:
all_uuids.add(pair[0])
all_uuids.add(pair[1])
# Assign all UUIDs in the path to the min_uuid uf = UnionFind(all_uuids)
for node in path: for a, b in duplicate_pairs:
compressed_map[node] = min_uuid uf.union(a, b)
# ensure full pathcompression before mapping
return min_uuid return {uuid: uf.find(uuid) for uuid in all_uuids}
for key in uuid_map:
if key not in compressed_map:
find_min_uuid(key)
return compressed_map
E = typing.TypeVar('E', bound=Edge) E = typing.TypeVar('E', bound=Edge)

View file

@ -444,14 +444,14 @@ async def resolve_extracted_edge(
} }
edge_model = edge_types.get(fact_type) edge_model = edge_types.get(fact_type)
if edge_model is not None and len(edge_model.model_fields) != 0:
edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
response_model=edge_model, # type: ignore
model_size=ModelSize.small,
)
edge_attributes_response = await llm_client.generate_response( resolved_edge.attributes = edge_attributes_response
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
response_model=edge_model, # type: ignore
model_size=ModelSize.small,
)
resolved_edge.attributes = edge_attributes_response
end = time() end = time()
logger.debug( logger.debug(

View file

@ -277,10 +277,12 @@ async def resolve_extracted_nodes(
uuid_map[extracted_node.uuid] = resolved_node.uuid uuid_map[extracted_node.uuid] = resolved_node.uuid
duplicates: list[int] = resolution.get('duplicates', []) duplicates: list[int] = resolution.get('duplicates', [])
if duplicate_idx not in duplicates and duplicate_idx > -1:
duplicates.append(duplicate_idx)
for idx in duplicates: for idx in duplicates:
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
node_duplicates.append((resolved_node, existing_node)) node_duplicates.append((extracted_node, existing_node))
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')

2
uv.lock generated
View file

@ -746,7 +746,7 @@ wheels = [
[[package]] [[package]]
name = "graphiti-core" name = "graphiti-core"
version = "0.17.2" version = "0.17.3"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "diskcache" }, { name = "diskcache" },