fixedmake format

This commit is contained in:
prestonrasmussen 2025-09-20 08:56:16 -04:00
parent e0ddcdc03c
commit 2ecf5ad87b

View file

@ -136,12 +136,14 @@ async def add_nodes_and_edges_bulk_tx(
entity_data: dict[str, Any] = {
'uuid': node.uuid,
'name': node.name,
'name_embedding': node.name_embedding,
'group_id': node.group_id,
'summary': node.summary,
'created_at': node.created_at,
}
if not bool(driver.aoss_client):
entity_data['name_embedding'] = node.name_embedding
entity_data['labels'] = list(set(node.labels + ['Entity']))
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
@ -161,7 +163,6 @@ async def add_nodes_and_edges_bulk_tx(
'target_node_uuid': edge.target_node_uuid,
'name': edge.name,
'fact': edge.fact,
'fact_embedding': edge.fact_embedding,
'group_id': edge.group_id,
'episodes': edge.episodes,
'created_at': edge.created_at,
@ -170,6 +171,9 @@ async def add_nodes_and_edges_bulk_tx(
'invalid_at': edge.invalid_at,
}
if not bool(driver.aoss_client):
edge_data['fact_embedding'] = edge.fact_embedding
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes)
@ -209,7 +213,15 @@ async def add_nodes_and_edges_bulk_tx(
entity_edges=edges,
)
if driver.aoss_client:
if bool(driver.aoss_client):
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
if node_data.get('uuid') == entity_node.uuid:
node_data['name_embedding'] = entity_node.name_embedding
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
if edge_data.get('uuid') == entity_edge.uuid:
edge_data['fact_embedding'] = entity_edge.fact_embedding
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)