make bulk save more robust (#461)
* make bulk save more robust * updates
This commit is contained in:
parent
a5f1f03372
commit
89c4ee8cad
3 changed files with 18 additions and 4 deletions
|
|
@ -397,7 +397,7 @@ class Graphiti:
|
||||||
episode.content = ''
|
episode.content = ''
|
||||||
|
|
||||||
await add_nodes_and_edges_bulk(
|
await add_nodes_and_edges_bulk(
|
||||||
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges
|
self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update any communities
|
# Update any communities
|
||||||
|
|
@ -693,7 +693,7 @@ class Graphiti:
|
||||||
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
||||||
|
|
||||||
await add_nodes_and_edges_bulk(
|
await add_nodes_and_edges_bulk(
|
||||||
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
|
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
|
||||||
)
|
)
|
||||||
|
|
||||||
async def remove_episode(self, episode_uuid: str):
|
async def remove_episode(self, episode_uuid: str):
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from pydantic import BaseModel
|
||||||
from typing_extensions import Any
|
from typing_extensions import Any
|
||||||
|
|
||||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||||
|
from graphiti_core.embedder import EmbedderClient
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
|
@ -95,10 +96,16 @@ async def add_nodes_and_edges_bulk(
|
||||||
episodic_edges: list[EpisodicEdge],
|
episodic_edges: list[EpisodicEdge],
|
||||||
entity_nodes: list[EntityNode],
|
entity_nodes: list[EntityNode],
|
||||||
entity_edges: list[EntityEdge],
|
entity_edges: list[EntityEdge],
|
||||||
|
embedder: EmbedderClient,
|
||||||
):
|
):
|
||||||
async with driver.session(database=DEFAULT_DATABASE) as session:
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
||||||
await session.execute_write(
|
await session.execute_write(
|
||||||
add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges
|
add_nodes_and_edges_bulk_tx,
|
||||||
|
episodic_nodes,
|
||||||
|
episodic_edges,
|
||||||
|
entity_nodes,
|
||||||
|
entity_edges,
|
||||||
|
embedder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,12 +115,15 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
episodic_edges: list[EpisodicEdge],
|
episodic_edges: list[EpisodicEdge],
|
||||||
entity_nodes: list[EntityNode],
|
entity_nodes: list[EntityNode],
|
||||||
entity_edges: list[EntityEdge],
|
entity_edges: list[EntityEdge],
|
||||||
|
embedder: EmbedderClient,
|
||||||
):
|
):
|
||||||
episodes = [dict(episode) for episode in episodic_nodes]
|
episodes = [dict(episode) for episode in episodic_nodes]
|
||||||
for episode in episodes:
|
for episode in episodes:
|
||||||
episode['source'] = str(episode['source'].value)
|
episode['source'] = str(episode['source'].value)
|
||||||
nodes: list[dict[str, Any]] = []
|
nodes: list[dict[str, Any]] = []
|
||||||
for node in entity_nodes:
|
for node in entity_nodes:
|
||||||
|
if node.name_embedding is None:
|
||||||
|
await node.generate_name_embedding(embedder)
|
||||||
entity_data: dict[str, Any] = {
|
entity_data: dict[str, Any] = {
|
||||||
'uuid': node.uuid,
|
'uuid': node.uuid,
|
||||||
'name': node.name,
|
'name': node.name,
|
||||||
|
|
@ -127,6 +137,10 @@ async def add_nodes_and_edges_bulk_tx(
|
||||||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||||
nodes.append(entity_data)
|
nodes.append(entity_data)
|
||||||
|
|
||||||
|
for edge in entity_edges:
|
||||||
|
if edge.fact_embedding is None:
|
||||||
|
await edge.generate_name_fact(embedder)
|
||||||
|
|
||||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
||||||
await tx.run(
|
await tx.run(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.11.6pre2"
|
version = "0.11.6pre3"
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue