From da71d118db3b83d38bb4bcf870b6c3864c03254f Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Sat, 20 Sep 2025 09:00:04 -0400 Subject: [PATCH] Embedding fix (#917) * embedding fix * pre3 * fixedmake format --- graphiti_core/utils/bulk_utils.py | 26 +++++++++++++++++++------- pyproject.toml | 2 +- uv.lock | 2 +- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 78397e87..181cda74 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -136,12 +136,14 @@ async def add_nodes_and_edges_bulk_tx( entity_data: dict[str, Any] = { 'uuid': node.uuid, 'name': node.name, - 'name_embedding': node.name_embedding, 'group_id': node.group_id, 'summary': node.summary, 'created_at': node.created_at, } + if not bool(driver.aoss_client): + entity_data['name_embedding'] = node.name_embedding + entity_data['labels'] = list(set(node.labels + ['Entity'])) if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {} @@ -161,7 +163,6 @@ async def add_nodes_and_edges_bulk_tx( 'target_node_uuid': edge.target_node_uuid, 'name': edge.name, 'fact': edge.fact, - 'fact_embedding': edge.fact_embedding, 'group_id': edge.group_id, 'episodes': edge.episodes, 'created_at': edge.created_at, @@ -170,6 +171,9 @@ async def add_nodes_and_edges_bulk_tx( 'invalid_at': edge.invalid_at, } + if not bool(driver.aoss_client): + edge_data['fact_embedding'] = edge.fact_embedding + if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {} edge_data['attributes'] = json.dumps(attributes) @@ -195,21 +199,29 @@ async def add_nodes_and_edges_bulk_tx( else: await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes) await tx.run( - get_entity_node_save_bulk_query(driver.provider, nodes), + get_entity_node_save_bulk_query( + driver.provider, nodes, has_aoss=bool(driver.aoss_client) + ), nodes=nodes, - has_aoss=bool(driver.aoss_client), ) await tx.run( get_episodic_edge_save_bulk_query(driver.provider), episodic_edges=[edge.model_dump() for edge in episodic_edges], ) await tx.run( - get_entity_edge_save_bulk_query(driver.provider), + get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)), entity_edges=edges, - has_aoss=bool(driver.aoss_client), ) - if driver.aoss_client: + if bool(driver.aoss_client): + for node_data, entity_node in zip(nodes, entity_nodes, strict=True): + if node_data.get('uuid') == entity_node.uuid: + node_data['name_embedding'] = entity_node.name_embedding + + for edge_data, entity_edge in zip(edges, entity_edges, strict=True): + if edge_data.get('uuid') == entity_edge.uuid: + edge_data['fact_embedding'] = entity_edge.fact_embedding + await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes) await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges) diff --git a/pyproject.toml b/pyproject.toml index c86f90e4..c77d9ca8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.21.0pre2" +version = "0.21.0pre3" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, diff --git a/uv.lock b/uv.lock index bad253b8..0b75961a 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.21.0rc2" +version = "0.21.0rc3" source = { editable = "." } dependencies = [ { name = "diskcache" },