bulk utils update (#727)
* bulk utils update * remove unused imports * edge model type guard
This commit is contained in:
parent
5821776512
commit
62df6624d4
5 changed files with 51 additions and 65 deletions
|
|
@ -78,7 +78,7 @@ async def main(use_bulk: bool = False):
|
|||
group_id = str(uuid4())
|
||||
|
||||
raw_episodes: list[RawEpisode] = []
|
||||
for i, message in enumerate(messages[3:7]):
|
||||
for i, message in enumerate(messages[3:10]):
|
||||
raw_episodes.append(
|
||||
RawEpisode(
|
||||
name=f'Message {i}',
|
||||
|
|
|
|||
|
|
@ -263,26 +263,14 @@ async def dedupe_nodes_bulk(
|
|||
)
|
||||
|
||||
# Collect all duplicate pairs sorted by uuid
|
||||
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = []
|
||||
duplicate_pairs: list[tuple[str, str]] = []
|
||||
for _, _, duplicates in bulk_node_resolutions:
|
||||
for duplicate in duplicates:
|
||||
n, m = duplicate
|
||||
if n.uuid < m.uuid:
|
||||
duplicate_pairs.append((n, m))
|
||||
else:
|
||||
duplicate_pairs.append((m, n))
|
||||
|
||||
# Build full deduplication map
|
||||
duplicate_map: dict[str, str] = {}
|
||||
for value, key in duplicate_pairs:
|
||||
if key.uuid in duplicate_map:
|
||||
existing_value = duplicate_map[key.uuid]
|
||||
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
||||
else:
|
||||
duplicate_map[key.uuid] = value.uuid
|
||||
duplicate_pairs.append((n.uuid, m.uuid))
|
||||
|
||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
||||
|
||||
node_uuid_map: dict[str, EntityNode] = {
|
||||
node.uuid: node for nodes in extracted_nodes for node in nodes
|
||||
|
|
@ -357,26 +345,14 @@ async def dedupe_edges_bulk(
|
|||
]
|
||||
)
|
||||
|
||||
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = []
|
||||
duplicate_pairs: list[tuple[str, str]] = []
|
||||
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
||||
episode, edge, candidates = dedupe_tuples[i]
|
||||
for duplicate in duplicates:
|
||||
if edge.uuid < duplicate.uuid:
|
||||
duplicate_pairs.append((edge, duplicate))
|
||||
else:
|
||||
duplicate_pairs.append((duplicate, edge))
|
||||
|
||||
# Build full deduplication map
|
||||
duplicate_map: dict[str, str] = {}
|
||||
for value, key in duplicate_pairs:
|
||||
if key.uuid in duplicate_map:
|
||||
existing_value = duplicate_map[key.uuid]
|
||||
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
||||
else:
|
||||
duplicate_map[key.uuid] = value.uuid
|
||||
duplicate_pairs.append((edge.uuid, duplicate.uuid))
|
||||
|
||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
||||
|
||||
edge_uuid_map: dict[str, EntityEdge] = {
|
||||
edge.uuid: edge for edges in extracted_edges for edge in edges
|
||||
|
|
@ -393,36 +369,44 @@ async def dedupe_edges_bulk(
|
|||
return edges_by_episode
|
||||
|
||||
|
||||
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
||||
compressed_map = {}
|
||||
class UnionFind:
|
||||
def __init__(self, elements):
|
||||
# start each element in its own set
|
||||
self.parent = {e: e for e in elements}
|
||||
|
||||
def find_min_uuid(start: str) -> str:
|
||||
path = []
|
||||
visited = set()
|
||||
curr = start
|
||||
def find(self, x):
|
||||
# path‐compression
|
||||
if self.parent[x] != x:
|
||||
self.parent[x] = self.find(self.parent[x])
|
||||
return self.parent[x]
|
||||
|
||||
while curr in uuid_map and curr not in visited:
|
||||
visited.add(curr)
|
||||
path.append(curr)
|
||||
curr = uuid_map[curr]
|
||||
def union(self, a, b):
|
||||
ra, rb = self.find(a), self.find(b)
|
||||
if ra == rb:
|
||||
return
|
||||
# attach the lexicographically larger root under the smaller
|
||||
if ra < rb:
|
||||
self.parent[rb] = ra
|
||||
else:
|
||||
self.parent[ra] = rb
|
||||
|
||||
# Also include the last resolved value (could be outside the map)
|
||||
path.append(curr)
|
||||
|
||||
# Resolve to lex smallest UUID in the path
|
||||
min_uuid = min(path)
|
||||
def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
|
||||
"""
|
||||
all_ids: iterable of all entity IDs (strings)
|
||||
duplicate_pairs: iterable of (id1, id2) pairs
|
||||
returns: dict mapping each id -> lexicographically smallest id in its duplicate set
|
||||
"""
|
||||
all_uuids = set()
|
||||
for pair in duplicate_pairs:
|
||||
all_uuids.add(pair[0])
|
||||
all_uuids.add(pair[1])
|
||||
|
||||
# Assign all UUIDs in the path to the min_uuid
|
||||
for node in path:
|
||||
compressed_map[node] = min_uuid
|
||||
|
||||
return min_uuid
|
||||
|
||||
for key in uuid_map:
|
||||
if key not in compressed_map:
|
||||
find_min_uuid(key)
|
||||
|
||||
return compressed_map
|
||||
uf = UnionFind(all_uuids)
|
||||
for a, b in duplicate_pairs:
|
||||
uf.union(a, b)
|
||||
# ensure full path‐compression before mapping
|
||||
return {uuid: uf.find(uuid) for uuid in all_uuids}
|
||||
|
||||
|
||||
E = typing.TypeVar('E', bound=Edge)
|
||||
|
|
|
|||
|
|
@ -444,14 +444,14 @@ async def resolve_extracted_edge(
|
|||
}
|
||||
|
||||
edge_model = edge_types.get(fact_type)
|
||||
if edge_model is not None and len(edge_model.model_fields) != 0:
|
||||
edge_attributes_response = await llm_client.generate_response(
|
||||
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
||||
response_model=edge_model, # type: ignore
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
|
||||
edge_attributes_response = await llm_client.generate_response(
|
||||
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
||||
response_model=edge_model, # type: ignore
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
|
||||
resolved_edge.attributes = edge_attributes_response
|
||||
resolved_edge.attributes = edge_attributes_response
|
||||
|
||||
end = time()
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -277,10 +277,12 @@ async def resolve_extracted_nodes(
|
|||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||
|
||||
duplicates: list[int] = resolution.get('duplicates', [])
|
||||
if duplicate_idx not in duplicates and duplicate_idx > -1:
|
||||
duplicates.append(duplicate_idx)
|
||||
for idx in duplicates:
|
||||
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
||||
|
||||
node_duplicates.append((resolved_node, existing_node))
|
||||
node_duplicates.append((extracted_node, existing_node))
|
||||
|
||||
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.17.2"
|
||||
version = "0.17.3"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue