bulk utils update (#727)
* bulk utils update * remove unused imports * edge model type guard
This commit is contained in:
parent
5821776512
commit
62df6624d4
5 changed files with 51 additions and 65 deletions
|
|
@ -78,7 +78,7 @@ async def main(use_bulk: bool = False):
|
||||||
group_id = str(uuid4())
|
group_id = str(uuid4())
|
||||||
|
|
||||||
raw_episodes: list[RawEpisode] = []
|
raw_episodes: list[RawEpisode] = []
|
||||||
for i, message in enumerate(messages[3:7]):
|
for i, message in enumerate(messages[3:10]):
|
||||||
raw_episodes.append(
|
raw_episodes.append(
|
||||||
RawEpisode(
|
RawEpisode(
|
||||||
name=f'Message {i}',
|
name=f'Message {i}',
|
||||||
|
|
|
||||||
|
|
@ -263,26 +263,14 @@ async def dedupe_nodes_bulk(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect all duplicate pairs sorted by uuid
|
# Collect all duplicate pairs sorted by uuid
|
||||||
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = []
|
duplicate_pairs: list[tuple[str, str]] = []
|
||||||
for _, _, duplicates in bulk_node_resolutions:
|
for _, _, duplicates in bulk_node_resolutions:
|
||||||
for duplicate in duplicates:
|
for duplicate in duplicates:
|
||||||
n, m = duplicate
|
n, m = duplicate
|
||||||
if n.uuid < m.uuid:
|
duplicate_pairs.append((n.uuid, m.uuid))
|
||||||
duplicate_pairs.append((n, m))
|
|
||||||
else:
|
|
||||||
duplicate_pairs.append((m, n))
|
|
||||||
|
|
||||||
# Build full deduplication map
|
|
||||||
duplicate_map: dict[str, str] = {}
|
|
||||||
for value, key in duplicate_pairs:
|
|
||||||
if key.uuid in duplicate_map:
|
|
||||||
existing_value = duplicate_map[key.uuid]
|
|
||||||
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
|
||||||
else:
|
|
||||||
duplicate_map[key.uuid] = value.uuid
|
|
||||||
|
|
||||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
||||||
|
|
||||||
node_uuid_map: dict[str, EntityNode] = {
|
node_uuid_map: dict[str, EntityNode] = {
|
||||||
node.uuid: node for nodes in extracted_nodes for node in nodes
|
node.uuid: node for nodes in extracted_nodes for node in nodes
|
||||||
|
|
@ -357,26 +345,14 @@ async def dedupe_edges_bulk(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = []
|
duplicate_pairs: list[tuple[str, str]] = []
|
||||||
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
||||||
episode, edge, candidates = dedupe_tuples[i]
|
episode, edge, candidates = dedupe_tuples[i]
|
||||||
for duplicate in duplicates:
|
for duplicate in duplicates:
|
||||||
if edge.uuid < duplicate.uuid:
|
duplicate_pairs.append((edge.uuid, duplicate.uuid))
|
||||||
duplicate_pairs.append((edge, duplicate))
|
|
||||||
else:
|
|
||||||
duplicate_pairs.append((duplicate, edge))
|
|
||||||
|
|
||||||
# Build full deduplication map
|
|
||||||
duplicate_map: dict[str, str] = {}
|
|
||||||
for value, key in duplicate_pairs:
|
|
||||||
if key.uuid in duplicate_map:
|
|
||||||
existing_value = duplicate_map[key.uuid]
|
|
||||||
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
|
||||||
else:
|
|
||||||
duplicate_map[key.uuid] = value.uuid
|
|
||||||
|
|
||||||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||||||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
||||||
|
|
||||||
edge_uuid_map: dict[str, EntityEdge] = {
|
edge_uuid_map: dict[str, EntityEdge] = {
|
||||||
edge.uuid: edge for edges in extracted_edges for edge in edges
|
edge.uuid: edge for edges in extracted_edges for edge in edges
|
||||||
|
|
@ -393,36 +369,44 @@ async def dedupe_edges_bulk(
|
||||||
return edges_by_episode
|
return edges_by_episode
|
||||||
|
|
||||||
|
|
||||||
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
class UnionFind:
|
||||||
compressed_map = {}
|
def __init__(self, elements):
|
||||||
|
# start each element in its own set
|
||||||
|
self.parent = {e: e for e in elements}
|
||||||
|
|
||||||
def find_min_uuid(start: str) -> str:
|
def find(self, x):
|
||||||
path = []
|
# path‐compression
|
||||||
visited = set()
|
if self.parent[x] != x:
|
||||||
curr = start
|
self.parent[x] = self.find(self.parent[x])
|
||||||
|
return self.parent[x]
|
||||||
|
|
||||||
while curr in uuid_map and curr not in visited:
|
def union(self, a, b):
|
||||||
visited.add(curr)
|
ra, rb = self.find(a), self.find(b)
|
||||||
path.append(curr)
|
if ra == rb:
|
||||||
curr = uuid_map[curr]
|
return
|
||||||
|
# attach the lexicographically larger root under the smaller
|
||||||
|
if ra < rb:
|
||||||
|
self.parent[rb] = ra
|
||||||
|
else:
|
||||||
|
self.parent[ra] = rb
|
||||||
|
|
||||||
# Also include the last resolved value (could be outside the map)
|
|
||||||
path.append(curr)
|
|
||||||
|
|
||||||
# Resolve to lex smallest UUID in the path
|
def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
|
||||||
min_uuid = min(path)
|
"""
|
||||||
|
all_ids: iterable of all entity IDs (strings)
|
||||||
|
duplicate_pairs: iterable of (id1, id2) pairs
|
||||||
|
returns: dict mapping each id -> lexicographically smallest id in its duplicate set
|
||||||
|
"""
|
||||||
|
all_uuids = set()
|
||||||
|
for pair in duplicate_pairs:
|
||||||
|
all_uuids.add(pair[0])
|
||||||
|
all_uuids.add(pair[1])
|
||||||
|
|
||||||
# Assign all UUIDs in the path to the min_uuid
|
uf = UnionFind(all_uuids)
|
||||||
for node in path:
|
for a, b in duplicate_pairs:
|
||||||
compressed_map[node] = min_uuid
|
uf.union(a, b)
|
||||||
|
# ensure full path‐compression before mapping
|
||||||
return min_uuid
|
return {uuid: uf.find(uuid) for uuid in all_uuids}
|
||||||
|
|
||||||
for key in uuid_map:
|
|
||||||
if key not in compressed_map:
|
|
||||||
find_min_uuid(key)
|
|
||||||
|
|
||||||
return compressed_map
|
|
||||||
|
|
||||||
|
|
||||||
E = typing.TypeVar('E', bound=Edge)
|
E = typing.TypeVar('E', bound=Edge)
|
||||||
|
|
|
||||||
|
|
@ -444,14 +444,14 @@ async def resolve_extracted_edge(
|
||||||
}
|
}
|
||||||
|
|
||||||
edge_model = edge_types.get(fact_type)
|
edge_model = edge_types.get(fact_type)
|
||||||
|
if edge_model is not None and len(edge_model.model_fields) != 0:
|
||||||
|
edge_attributes_response = await llm_client.generate_response(
|
||||||
|
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
||||||
|
response_model=edge_model, # type: ignore
|
||||||
|
model_size=ModelSize.small,
|
||||||
|
)
|
||||||
|
|
||||||
edge_attributes_response = await llm_client.generate_response(
|
resolved_edge.attributes = edge_attributes_response
|
||||||
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
||||||
response_model=edge_model, # type: ignore
|
|
||||||
model_size=ModelSize.small,
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved_edge.attributes = edge_attributes_response
|
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
||||||
|
|
@ -277,10 +277,12 @@ async def resolve_extracted_nodes(
|
||||||
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
||||||
|
|
||||||
duplicates: list[int] = resolution.get('duplicates', [])
|
duplicates: list[int] = resolution.get('duplicates', [])
|
||||||
|
if duplicate_idx not in duplicates and duplicate_idx > -1:
|
||||||
|
duplicates.append(duplicate_idx)
|
||||||
for idx in duplicates:
|
for idx in duplicates:
|
||||||
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
existing_node = existing_nodes[idx] if idx < len(existing_nodes) else resolved_node
|
||||||
|
|
||||||
node_duplicates.append((resolved_node, existing_node))
|
node_duplicates.append((extracted_node, existing_node))
|
||||||
|
|
||||||
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
||||||
|
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -746,7 +746,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.17.2"
|
version = "0.17.3"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue