diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 03b1e402..fe23f1fb 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -26,7 +26,12 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.driver.driver import GraphDriver from graphiti_core.driver.neo4j_driver import Neo4jDriver -from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge +from graphiti_core.edges import ( + CommunityEdge, + EntityEdge, + EpisodicEdge, + create_entity_edge_embeddings, +) from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import ( @@ -36,7 +41,13 @@ from graphiti_core.helpers import ( validate_group_id, ) from graphiti_core.llm_client import LLMClient, OpenAIClient -from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode +from graphiti_core.nodes import ( + CommunityNode, + EntityNode, + EpisodeType, + EpisodicNode, + create_entity_node_embeddings, +) from graphiti_core.search.search import SearchConfig, search from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults from graphiti_core.search.search_config_recipes import ( @@ -984,7 +995,7 @@ class Graphiti: if edge.fact_embedding is None: await edge.generate_embedding(self.embedder) - resolved_nodes, uuid_map, _ = await resolve_extracted_nodes( + nodes, uuid_map, _ = await resolve_extracted_nodes( self.clients, [source_node, target_node], ) @@ -1012,9 +1023,12 @@ class Graphiti: ), ) - await add_nodes_and_edges_bulk( - self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder - ) + edges: list[EntityEdge] = [resolved_edge] + invalidated_edges + + await create_entity_edge_embeddings(self.embedder, edges) + await create_entity_node_embeddings(self.embedder, nodes) + + await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder) async def remove_episode(self, episode_uuid: str): # Find the episode to be deleted