diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 3b0cf73c..4153a684 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -397,7 +397,7 @@ class Graphiti: episode.content = '' await add_nodes_and_edges_bulk( - self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges + self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder ) # Update any communities @@ -693,7 +693,7 @@ class Graphiti: invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges) await add_nodes_and_edges_bulk( - self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges + self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder ) async def remove_episode(self, episode_uuid: str): diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index aafb54b5..5f253d30 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -26,6 +26,7 @@ from pydantic import BaseModel from typing_extensions import Any from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge +from graphiti_core.embedder import EmbedderClient from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient @@ -95,10 +96,16 @@ async def add_nodes_and_edges_bulk( episodic_edges: list[EpisodicEdge], entity_nodes: list[EntityNode], entity_edges: list[EntityEdge], + embedder: EmbedderClient, ): async with driver.session(database=DEFAULT_DATABASE) as session: await session.execute_write( - add_nodes_and_edges_bulk_tx, episodic_nodes, episodic_edges, entity_nodes, entity_edges + add_nodes_and_edges_bulk_tx, + episodic_nodes, + episodic_edges, + entity_nodes, + entity_edges, + embedder, ) @@ -108,12 +115,15 @@ async def add_nodes_and_edges_bulk_tx( episodic_edges: list[EpisodicEdge], entity_nodes: list[EntityNode], entity_edges: list[EntityEdge], + embedder: EmbedderClient, ): episodes = [dict(episode) for episode in episodic_nodes] for episode in episodes: episode['source'] = str(episode['source'].value) nodes: list[dict[str, Any]] = [] for node in entity_nodes: + if node.name_embedding is None: + await node.generate_name_embedding(embedder) entity_data: dict[str, Any] = { 'uuid': node.uuid, 'name': node.name, @@ -127,6 +137,10 @@ async def add_nodes_and_edges_bulk_tx( entity_data['labels'] = list(set(node.labels + ['Entity'])) nodes.append(entity_data) + for edge in entity_edges: + if edge.fact_embedding is None: + await edge.generate_name_fact(embedder) + await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes) await tx.run( diff --git a/pyproject.toml b/pyproject.toml index ded5d92f..7ae136c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.11.6pre2" +version = "0.11.6pre3" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },