From 6f3f369be4c75f3c3d0b98c73bf613cdee72373c Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Mon, 4 Aug 2025 16:33:07 -0400 Subject: [PATCH] update add triple to always have embeddings (#803) update --- graphiti_core/graphiti.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 03b1e402..fe23f1fb 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -26,7 +26,12 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver -from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge +from graphiti_core.edges import ( + CommunityEdge, + EntityEdge, + EpisodicEdge, + create_entity_edge_embeddings, +) from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import ( @@ -36,7 +41,13 @@ from graphiti_core.helpers import ( validate_group_id, ) from graphiti_core.llm_client import LLMClient, OpenAIClient -from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode +from graphiti_core.nodes import ( + CommunityNode, + EntityNode, + EpisodeType, + EpisodicNode, + create_entity_node_embeddings, +) from graphiti_core.search.search import SearchConfig, search from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults from graphiti_core.search.search_config_recipes import ( @@ -984,7 +995,7 @@ class Graphiti: if edge.fact_embedding is None: await edge.generate_embedding(self.embedder) - resolved_nodes, uuid_map, _ = await resolve_extracted_nodes( + nodes, uuid_map, _ = await resolve_extracted_nodes( self.clients, [source_node, target_node], ) @@ -1012,9 +1023,12 @@ class Graphiti: ), ) - await add_nodes_and_edges_bulk( - self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder - ) + edges: list[EntityEdge] = [resolved_edge] + invalidated_edges + + await create_entity_edge_embeddings(self.embedder, edges) + await create_entity_node_embeddings(self.embedder, nodes) + + await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder) async def remove_episode(self, episode_uuid: str): # Find the episode to be deleted