Embedding fix (#917)

* embedding fix

* pre3

* fixedmake format
This commit is contained in:
Preston Rasmussen 2025-09-20 09:00:04 -04:00 committed by GitHub
parent 3ea6f9f9a8
commit da71d118db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 21 additions and 9 deletions

View file

@ -136,12 +136,14 @@ async def add_nodes_and_edges_bulk_tx(
entity_data: dict[str, Any] = {
'uuid': node.uuid,
'name': node.name,
'name_embedding': node.name_embedding,
'group_id': node.group_id,
'summary': node.summary,
'created_at': node.created_at,
}
if not bool(driver.aoss_client):
entity_data['name_embedding'] = node.name_embedding
entity_data['labels'] = list(set(node.labels + ['Entity']))
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
@ -161,7 +163,6 @@ async def add_nodes_and_edges_bulk_tx(
'target_node_uuid': edge.target_node_uuid,
'name': edge.name,
'fact': edge.fact,
'fact_embedding': edge.fact_embedding,
'group_id': edge.group_id,
'episodes': edge.episodes,
'created_at': edge.created_at,
@ -170,6 +171,9 @@ async def add_nodes_and_edges_bulk_tx(
'invalid_at': edge.invalid_at,
}
if not bool(driver.aoss_client):
edge_data['fact_embedding'] = edge.fact_embedding
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes)
@ -195,21 +199,29 @@ async def add_nodes_and_edges_bulk_tx(
else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(
get_entity_node_save_bulk_query(driver.provider, nodes),
get_entity_node_save_bulk_query(
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
),
nodes=nodes,
has_aoss=bool(driver.aoss_client),
)
await tx.run(
get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges],
)
await tx.run(
get_entity_edge_save_bulk_query(driver.provider),
get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
entity_edges=edges,
has_aoss=bool(driver.aoss_client),
)
if driver.aoss_client:
if bool(driver.aoss_client):
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
if node_data.get('uuid') == entity_node.uuid:
node_data['name_embedding'] = entity_node.name_embedding
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
if edge_data.get('uuid') == entity_edge.uuid:
edge_data['fact_embedding'] = entity_edge.fact_embedding
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.21.0pre2"
version = "0.21.0pre3"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.21.0rc2"
version = "0.21.0rc3"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },