make bulk save more robust (#461)
* make bulk save more robust * updates
This commit is contained in:
parent
a5f1f03372
commit
89c4ee8cad
3 changed files with 18 additions and 4 deletions
|
|
@ -397,7 +397,7 @@ class Graphiti:
|
|||
episode.content = ''
|
||||
|
||||
await add_nodes_and_edges_bulk(
|
||||
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges
|
||||
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
||||
)
|
||||
|
||||
# Update any communities
|
||||
|
|
@ -693,7 +693,7 @@ class Graphiti:
|
|||
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
||||
|
||||
await add_nodes_and_edges_bulk(
|
||||
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
|
||||
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
|
||||
)
|
||||
|
||||
async def remove_episode(self, episode_uuid: str):
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from pydantic import BaseModel
|
|||
from typing_extensions import Any
|
||||
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.embedder import EmbedderClient
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
|
|
@ -95,10 +96,16 @@ async def add_nodes_and_edges_bulk(
|
|||
episodic_edges: list[EpisodicEdge],
|
||||
entity_nodes: list[EntityNode],
|
||||
entity_edges: list[EntityEdge],
|
||||
embedder: EmbedderClient,
|
||||
):
|
||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||
await session.execute_write(
|
||||
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
|
||||
add_nodes_and_edges_bulk_tx,
|
||||
episodic_nodes,
|
||||
episodic_edges,
|
||||
entity_nodes,
|
||||
entity_edges,
|
||||
embedder,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -108,12 +115,15 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
episodic_edges: list[EpisodicEdge],
|
||||
entity_nodes: list[EntityNode],
|
||||
entity_edges: list[EntityEdge],
|
||||
embedder: EmbedderClient,
|
||||
):
|
||||
episodes = [dict(episode) for episode in episodic_nodes]
|
||||
for episode in episodes:
|
||||
episode['source'] = str(episode['source'].value)
|
||||
nodes: list[dict[str, Any]] = []
|
||||
for node in entity_nodes:
|
||||
if node.name_embedding is None:
|
||||
await node.generate_name_embedding(embedder)
|
||||
entity_data: dict[str, Any] = {
|
||||
'uuid': node.uuid,
|
||||
'name': node.name,
|
||||
|
|
@ -127,6 +137,10 @@ async def add_nodes_and_edges_bulk_tx(
|
|||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||
nodes.append(entity_data)
|
||||
|
||||
for edge in entity_edges:
|
||||
if edge.fact_embedding is None:
|
||||
await edge.generate_name_fact(embedder)
|
||||
|
||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
||||
await tx.run(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.11.6pre2"
|
||||
version = "0.11.6pre3"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue