From 2ecf5ad87b580eade36386b6db6f9c8e32e1ec6f Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Sat, 20 Sep 2025 08:56:16 -0400 Subject: [PATCH] fixedmake format --- graphiti_core/utils/bulk_utils.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index ae3ad5f4..181cda74 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -136,12 +136,14 @@ async def add_nodes_and_edges_bulk_tx( entity_data: dict[str, Any] = { 'uuid': node.uuid, 'name': node.name, - 'name_embedding': node.name_embedding, 'group_id': node.group_id, 'summary': node.summary, 'created_at': node.created_at, } + if not bool(driver.aoss_client): + entity_data['name_embedding'] = node.name_embedding + entity_data['labels'] = list(set(node.labels + ['Entity'])) if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {} @@ -161,7 +163,6 @@ async def add_nodes_and_edges_bulk_tx( 'target_node_uuid': edge.target_node_uuid, 'name': edge.name, 'fact': edge.fact, - 'fact_embedding': edge.fact_embedding, 'group_id': edge.group_id, 'episodes': edge.episodes, 'created_at': edge.created_at, @@ -170,6 +171,9 @@ async def add_nodes_and_edges_bulk_tx( 'invalid_at': edge.invalid_at, } + if not bool(driver.aoss_client): + edge_data['fact_embedding'] = edge.fact_embedding + if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {} edge_data['attributes'] = json.dumps(attributes) @@ -209,7 +213,15 @@ async def add_nodes_and_edges_bulk_tx( entity_edges=edges, ) - if driver.aoss_client: + if bool(driver.aoss_client): + for node_data, entity_node in zip(nodes, entity_nodes, strict=True): + if node_data.get('uuid') == entity_node.uuid: + node_data['name_embedding'] = entity_node.name_embedding + + for edge_data, entity_edge in zip(edges, entity_edges, strict=True): + if edge_data.get('uuid') == entity_edge.uuid: + edge_data['fact_embedding'] = entity_edge.fact_embedding + await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes) await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)