From 0b94e0e603209978c19b855d50fd7c0fcb4423a8 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Sat, 26 Apr 2025 22:09:12 -0400 Subject: [PATCH] Bulk embed (#403) * add batch embeddings * bulk edge and node embeddings * update embeddings during add_episode * Update graphiti_core/embedder/client.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- graphiti_core/edges.py | 10 ++++++++-- graphiti_core/embedder/client.py | 3 +++ graphiti_core/embedder/gemini.py | 10 ++++++++++ graphiti_core/embedder/openai.py | 6 ++++++ graphiti_core/embedder/voyage.py | 7 +++++++ graphiti_core/nodes.py | 10 ++++++++-- graphiti_core/utils/maintenance/edge_operations.py | 12 ++++++++---- graphiti_core/utils/maintenance/node_operations.py | 6 +++--- 8 files changed, 53 insertions(+), 11 deletions(-) diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 108c143e..e392a0a5 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -321,8 +321,8 @@ class EntityEdge(Edge): async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str): query: LiteralString = ( """ - MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) - """ + MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity) + """ + ENTITY_EDGE_RETURN ) records, _, _ = await driver.execute_query( @@ -468,3 +468,9 @@ def get_community_edge_from_record(record: Any): target_node_uuid=record['target_node_uuid'], created_at=record['created_at'].to_native(), ) + + +async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]): + fact_embeddings = await embedder.create_batch([edge.fact for edge in edges]) + for edge, fact_embedding in zip(edges, fact_embeddings, strict=True): + edge.fact_embedding = fact_embedding diff --git a/graphiti_core/embedder/client.py b/graphiti_core/embedder/client.py index 8b8a15f3..9ffc0653 100644 --- a/graphiti_core/embedder/client.py +++ b/graphiti_core/embedder/client.py @@ -32,3 +32,6 @@ class EmbedderClient(ABC): self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]] ) -> list[float]: pass + + async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: + raise NotImplementedError() diff --git a/graphiti_core/embedder/gemini.py b/graphiti_core/embedder/gemini.py index fc92f96e..a9a4e49b 100644 --- a/graphiti_core/embedder/gemini.py +++ b/graphiti_core/embedder/gemini.py @@ -66,3 +66,13 @@ class GeminiEmbedder(EmbedderClient): ) return result.embeddings[0].values + + async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: + # Generate embeddings + result = await self.client.aio.models.embed_content( + model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, + contents=input_data_list, + config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim), + ) + + return [embedding.values for embedding in result.embeddings] diff --git a/graphiti_core/embedder/openai.py b/graphiti_core/embedder/openai.py index 5cba71b2..ed290849 100644 --- a/graphiti_core/embedder/openai.py +++ b/graphiti_core/embedder/openai.py @@ -58,3 +58,9 @@ class OpenAIEmbedder(EmbedderClient): input=input_data, model=self.config.embedding_model ) return result.data[0].embedding[: self.config.embedding_dim] + + async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: + result = await self.client.embeddings.create( + input=input_data_list, model=self.config.embedding_model + ) + return [embedding.embedding[: self.config.embedding_dim] for embedding in result.data] diff --git a/graphiti_core/embedder/voyage.py b/graphiti_core/embedder/voyage.py index 8e043960..6a22e2e2 100644 --- a/graphiti_core/embedder/voyage.py +++ b/graphiti_core/embedder/voyage.py @@ -56,3 +56,10 @@ class VoyageAIEmbedder(EmbedderClient): result = await self.client.embed(input_list, model=self.config.embedding_model) return [float(x) for x in result.embeddings[0][: self.config.embedding_dim]] + + async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: + result = await self.client.embed(input_data_list, model=self.config.embedding_model) + return [ + [float(x) for x in embedding[: self.config.embedding_dim]] + for embedding in result.embeddings + ] diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index ed68958a..3e53584c 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -332,8 +332,8 @@ class EntityNode(Node): async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): query = ( """ - MATCH (n:Entity {uuid: $uuid}) - """ + MATCH (n:Entity {uuid: $uuid}) + """ + ENTITY_NODE_RETURN ) records, _, _ = await driver.execute_query( @@ -560,3 +560,9 @@ def get_community_node_from_record(record: Any) -> CommunityNode: created_at=record['created_at'].to_native(), summary=record['summary'], ) + + +async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]): + name_embeddings = await embedder.create_batch([node.name for node in nodes]) + for node, name_embedding in zip(nodes, name_embeddings, strict=True): + node.name_embedding = name_embedding diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 97cd1c21..1ce1e161 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -18,7 +18,12 @@ import logging from datetime import datetime from time import time -from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge +from graphiti_core.edges import ( + CommunityEdge, + EntityEdge, + EpisodicEdge, + create_entity_edge_embeddings, +) from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient @@ -152,8 +157,7 @@ async def extract_edges( f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})' ) - # calculate embeddings - await semaphore_gather(*[edge.generate_embedding(embedder) for edge in edges]) + await create_entity_edge_embeddings(embedder, edges) logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}') @@ -214,7 +218,7 @@ async def resolve_extracted_edges( llm_client = clients.llm_client related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges( - driver, extracted_edges, SearchFilters(), 0.8 + driver, extracted_edges, SearchFilters() ) logger.debug( diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 69233cc8..255aae22 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -25,7 +25,7 @@ from pydantic import BaseModel from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient -from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities @@ -211,7 +211,7 @@ async def extract_nodes( extracted_nodes.append(new_node) logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - await semaphore_gather(*[node.generate_name_embedding(embedder) for node in extracted_nodes]) + await create_entity_node_embeddings(embedder, extracted_nodes) logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') return extracted_nodes @@ -279,7 +279,7 @@ async def resolve_extracted_nodes( # Find relevant nodes already in the graph existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes( - driver, extracted_nodes, SearchFilters(), 0.8 + driver, extracted_nodes, SearchFilters() ) uuid_map: dict[str, str] = {}