update add triple to always have embeddings (#803)

update
This commit is contained in:
Preston Rasmussen 2025-08-04 16:33:07 -04:00 committed by GitHub
parent 0eb98647a9
commit 6f3f369be4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -26,7 +26,12 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.edges import (
CommunityEdge,
EntityEdge,
EpisodicEdge,
create_entity_edge_embeddings,
)
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import (
@ -36,7 +41,13 @@ from graphiti_core.helpers import (
validate_group_id,
)
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
create_entity_node_embeddings,
)
from graphiti_core.search.search import SearchConfig, search
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
from graphiti_core.search.search_config_recipes import (
@ -984,7 +995,7 @@ class Graphiti:
if edge.fact_embedding is None:
await edge.generate_embedding(self.embedder)
resolved_nodes, uuid_map, _ = await resolve_extracted_nodes(
nodes, uuid_map, _ = await resolve_extracted_nodes(
self.clients,
[source_node, target_node],
)
@ -1012,9 +1023,12 @@ class Graphiti:
),
)
await add_nodes_and_edges_bulk(
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
)
edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
await create_entity_edge_embeddings(self.embedder, edges)
await create_entity_node_embeddings(self.embedder, nodes)
await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
async def remove_episode(self, episode_uuid: str):
# Find the episode to be deleted