make bulk save more robust (#461)

* make bulk save more robust

* updates
This commit is contained in:
Preston Rasmussen 2025-05-08 15:34:13 -04:00 committed by GitHub
parent a5f1f03372
commit 89c4ee8cad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 18 additions and 4 deletions

View file

@ -397,7 +397,7 @@ class Graphiti:
episode.content = ''
await add_nodes_and_edges_bulk(
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
)
# Update any communities
@ -693,7 +693,7 @@ class Graphiti:
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
await add_nodes_and_edges_bulk(
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
)
async def remove_episode(self, episode_uuid: str):

View file

@ -26,6 +26,7 @@ from pydantic import BaseModel
from typing_extensions import Any
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient
@ -95,10 +96,16 @@ async def add_nodes_and_edges_bulk(
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
):
async with driver.session(database=DEFAULT_DATABASE) as session:
await session.execute_write(
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
add_nodes_and_edges_bulk_tx,
episodic_nodes,
episodic_edges,
entity_nodes,
entity_edges,
embedder,
)
@ -108,12 +115,15 @@ async def add_nodes_and_edges_bulk_tx(
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
):
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
nodes: list[dict[str, Any]] = []
for node in entity_nodes:
if node.name_embedding is None:
await node.generate_name_embedding(embedder)
entity_data: dict[str, Any] = {
'uuid': node.uuid,
'name': node.name,
@ -127,6 +137,10 @@ async def add_nodes_and_edges_bulk_tx(
entity_data['labels'] = list(set(node.labels + ['Entity']))
nodes.append(entity_data)
for edge in entity_edges:
if edge.fact_embedding is None:
await edge.generate_name_fact(embedder)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
await tx.run(

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.11.6pre2"
version = "0.11.6pre3"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },