Bulk embed (#403)

* add batch embeddings

* bulk edge and node embeddings

* update embeddings during add_episode

* Update graphiti_core/embedder/client.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* mypy

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-04-26 22:09:12 -04:00 committed by GitHub
parent 9ad5c4b4df
commit 0b94e0e603
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 53 additions and 11 deletions

View file

@ -321,8 +321,8 @@ class EntityEdge(Edge):
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
query: LiteralString = (
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
+ ENTITY_EDGE_RETURN
)
records, _, _ = await driver.execute_query(
@ -468,3 +468,9 @@ def get_community_edge_from_record(record: Any):
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
edge.fact_embedding = fact_embedding

View file

@ -32,3 +32,6 @@ class EmbedderClient(ABC):
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
pass
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
raise NotImplementedError()

View file

@ -66,3 +66,13 @@ class GeminiEmbedder(EmbedderClient):
)
return result.embeddings[0].values
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
# Generate embeddings
result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=input_data_list,
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
)
return [embedding.values for embedding in result.embeddings]

View file

@ -58,3 +58,9 @@ class OpenAIEmbedder(EmbedderClient):
input=input_data, model=self.config.embedding_model
)
return result.data[0].embedding[: self.config.embedding_dim]
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
result = await self.client.embeddings.create(
input=input_data_list, model=self.config.embedding_model
)
return [embedding.embedding[: self.config.embedding_dim] for embedding in result.data]

View file

@ -56,3 +56,10 @@ class VoyageAIEmbedder(EmbedderClient):
result = await self.client.embed(input_list, model=self.config.embedding_model)
return [float(x) for x in result.embeddings[0][: self.config.embedding_dim]]
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
result = await self.client.embed(input_data_list, model=self.config.embedding_model)
return [
[float(x) for x in embedding[: self.config.embedding_dim]]
for embedding in result.embeddings
]

View file

@ -332,8 +332,8 @@ class EntityNode(Node):
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
query = (
"""
MATCH (n:Entity {uuid: $uuid})
"""
MATCH (n:Entity {uuid: $uuid})
"""
+ ENTITY_NODE_RETURN
)
records, _, _ = await driver.execute_query(
@ -560,3 +560,9 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
name_embeddings = await embedder.create_batch([node.name for node in nodes])
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
node.name_embedding = name_embedding

View file

@ -18,7 +18,12 @@ import logging
from datetime import datetime
from time import time
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.edges import (
CommunityEdge,
EntityEdge,
EpisodicEdge,
create_entity_edge_embeddings,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
@ -152,8 +157,7 @@ async def extract_edges(
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
)
# calculate embeddings
await semaphore_gather(*[edge.generate_embedding(embedder) for edge in edges])
await create_entity_edge_embeddings(embedder, edges)
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
@ -214,7 +218,7 @@ async def resolve_extracted_edges(
llm_client = clients.llm_client
related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges(
driver, extracted_edges, SearchFilters(), 0.8
driver, extracted_edges, SearchFilters()
)
logger.debug(

View file

@ -25,7 +25,7 @@ from pydantic import BaseModel
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
@ -211,7 +211,7 @@ async def extract_nodes(
extracted_nodes.append(new_node)
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
await semaphore_gather(*[node.generate_name_embedding(embedder) for node in extracted_nodes])
await create_entity_node_embeddings(embedder, extracted_nodes)
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
return extracted_nodes
@ -279,7 +279,7 @@ async def resolve_extracted_nodes(
# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
driver, extracted_nodes, SearchFilters(), 0.8
driver, extracted_nodes, SearchFilters()
)
uuid_map: dict[str, str] = {}