parent
3ea6f9f9a8
commit
da71d118db
3 changed files with 21 additions and 9 deletions
|
|
@ -136,12 +136,14 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
entity_data: dict[str, Any] = {
|
||||
'uuid': node.uuid,
|
||||
'name': node.name,
|
||||
'name_embedding': node.name_embedding,
|
||||
'group_id': node.group_id,
|
||||
'summary': node.summary,
|
||||
'created_at': node.created_at,
|
||||
}
|
||||
|
||||
if not bool(driver.aoss_client):
|
||||
entity_data['name_embedding'] = node.name_embedding
|
||||
|
||||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
||||
|
|
@ -161,7 +163,6 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
'target_node_uuid': edge.target_node_uuid,
|
||||
'name': edge.name,
|
||||
'fact': edge.fact,
|
||||
'fact_embedding': edge.fact_embedding,
|
||||
'group_id': edge.group_id,
|
||||
'episodes': edge.episodes,
|
||||
'created_at': edge.created_at,
|
||||
|
|
@ -170,6 +171,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
'invalid_at': edge.invalid_at,
|
||||
}
|
||||
|
||||
if not bool(driver.aoss_client):
|
||||
edge_data['fact_embedding'] = edge.fact_embedding
|
||||
|
||||
if driver.provider == GraphProvider.KUZU:
|
||||
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
||||
edge_data['attributes'] = json.dumps(attributes)
|
||||
|
|
@ -195,21 +199,29 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
else:
|
||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||
await tx.run(
|
||||
get_entity_node_save_bulk_query(driver.provider, nodes),
|
||||
get_entity_node_save_bulk_query(
|
||||
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
|
||||
),
|
||||
nodes=nodes,
|
||||
has_aoss=bool(driver.aoss_client),
|
||||
)
|
||||
await tx.run(
|
||||
get_episodic_edge_save_bulk_query(driver.provider),
|
||||
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
||||
)
|
||||
await tx.run(
|
||||
get_entity_edge_save_bulk_query(driver.provider),
|
||||
get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
||||
entity_edges=edges,
|
||||
has_aoss=bool(driver.aoss_client),
|
||||
)
|
||||
|
||||
if driver.aoss_client:
|
||||
if bool(driver.aoss_client):
|
||||
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
|
||||
if node_data.get('uuid') == entity_node.uuid:
|
||||
node_data['name_embedding'] = entity_node.name_embedding
|
||||
|
||||
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
|
||||
if edge_data.get('uuid') == entity_edge.uuid:
|
||||
edge_data['fact_embedding'] = entity_edge.fact_embedding
|
||||
|
||||
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
||||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.21.0pre2"
|
||||
version = "0.21.0pre3"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.21.0rc2"
|
||||
version = "0.21.0rc3"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue