parent
3ea6f9f9a8
commit
da71d118db
3 changed files with 21 additions and 9 deletions
|
|
@ -136,12 +136,14 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
entity_data: dict[str, Any] = {
|
entity_data: dict[str, Any] = {
|
||||||
'uuid': node.uuid,
|
'uuid': node.uuid,
|
||||||
'name': node.name,
|
'name': node.name,
|
||||||
'name_embedding': node.name_embedding,
|
|
||||||
'group_id': node.group_id,
|
'group_id': node.group_id,
|
||||||
'summary': node.summary,
|
'summary': node.summary,
|
||||||
'created_at': node.created_at,
|
'created_at': node.created_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not bool(driver.aoss_client):
|
||||||
|
entity_data['name_embedding'] = node.name_embedding
|
||||||
|
|
||||||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||||
if driver.provider == GraphProvider.KUZU:
|
if driver.provider == GraphProvider.KUZU:
|
||||||
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
||||||
|
|
@ -161,7 +163,6 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
'target_node_uuid': edge.target_node_uuid,
|
'target_node_uuid': edge.target_node_uuid,
|
||||||
'name': edge.name,
|
'name': edge.name,
|
||||||
'fact': edge.fact,
|
'fact': edge.fact,
|
||||||
'fact_embedding': edge.fact_embedding,
|
|
||||||
'group_id': edge.group_id,
|
'group_id': edge.group_id,
|
||||||
'episodes': edge.episodes,
|
'episodes': edge.episodes,
|
||||||
'created_at': edge.created_at,
|
'created_at': edge.created_at,
|
||||||
|
|
@ -170,6 +171,9 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
'invalid_at': edge.invalid_at,
|
'invalid_at': edge.invalid_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not bool(driver.aoss_client):
|
||||||
|
edge_data['fact_embedding'] = edge.fact_embedding
|
||||||
|
|
||||||
if driver.provider == GraphProvider.KUZU:
|
if driver.provider == GraphProvider.KUZU:
|
||||||
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
||||||
edge_data['attributes'] = json.dumps(attributes)
|
edge_data['attributes'] = json.dumps(attributes)
|
||||||
|
|
@ -195,21 +199,29 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
else:
|
else:
|
||||||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
get_entity_node_save_bulk_query(driver.provider, nodes),
|
get_entity_node_save_bulk_query(
|
||||||
|
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
|
||||||
|
),
|
||||||
nodes=nodes,
|
nodes=nodes,
|
||||||
has_aoss=bool(driver.aoss_client),
|
|
||||||
)
|
)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
get_episodic_edge_save_bulk_query(driver.provider),
|
get_episodic_edge_save_bulk_query(driver.provider),
|
||||||
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
||||||
)
|
)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
get_entity_edge_save_bulk_query(driver.provider),
|
get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
||||||
entity_edges=edges,
|
entity_edges=edges,
|
||||||
has_aoss=bool(driver.aoss_client),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if driver.aoss_client:
|
if bool(driver.aoss_client):
|
||||||
|
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
|
||||||
|
if node_data.get('uuid') == entity_node.uuid:
|
||||||
|
node_data['name_embedding'] = entity_node.name_embedding
|
||||||
|
|
||||||
|
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
|
||||||
|
if edge_data.get('uuid') == entity_edge.uuid:
|
||||||
|
edge_data['fact_embedding'] = entity_edge.fact_embedding
|
||||||
|
|
||||||
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||||||
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
||||||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.21.0pre2"
|
version = "0.21.0pre3"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -783,7 +783,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.21.0rc2"
|
version = "0.21.0rc3"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue