Bulk embed (#403)
* add batch embeddings * bulk edge and node embeddings * update embeddings during add_episode * Update graphiti_core/embedder/client.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
9ad5c4b4df
commit
0b94e0e603
8 changed files with 53 additions and 11 deletions
|
|
@ -321,8 +321,8 @@ class EntityEdge(Edge):
|
|||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||
"""
|
||||
+ ENTITY_EDGE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -468,3 +468,9 @@ def get_community_edge_from_record(record: Any):
|
|||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
)
|
||||
|
||||
|
||||
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
||||
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
|
||||
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
||||
edge.fact_embedding = fact_embedding
|
||||
|
|
|
|||
|
|
@ -32,3 +32,6 @@ class EmbedderClient(ABC):
|
|||
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
pass
|
||||
|
||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -66,3 +66,13 @@ class GeminiEmbedder(EmbedderClient):
|
|||
)
|
||||
|
||||
return result.embeddings[0].values
|
||||
|
||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
# Generate embeddings
|
||||
result = await self.client.aio.models.embed_content(
|
||||
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
||||
contents=input_data_list,
|
||||
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
||||
)
|
||||
|
||||
return [embedding.values for embedding in result.embeddings]
|
||||
|
|
|
|||
|
|
@ -58,3 +58,9 @@ class OpenAIEmbedder(EmbedderClient):
|
|||
input=input_data, model=self.config.embedding_model
|
||||
)
|
||||
return result.data[0].embedding[: self.config.embedding_dim]
|
||||
|
||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
result = await self.client.embeddings.create(
|
||||
input=input_data_list, model=self.config.embedding_model
|
||||
)
|
||||
return [embedding.embedding[: self.config.embedding_dim] for embedding in result.data]
|
||||
|
|
|
|||
|
|
@ -56,3 +56,10 @@ class VoyageAIEmbedder(EmbedderClient):
|
|||
|
||||
result = await self.client.embed(input_list, model=self.config.embedding_model)
|
||||
return [float(x) for x in result.embeddings[0][: self.config.embedding_dim]]
|
||||
|
||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
result = await self.client.embed(input_data_list, model=self.config.embedding_model)
|
||||
return [
|
||||
[float(x) for x in embedding[: self.config.embedding_dim]]
|
||||
for embedding in result.embeddings
|
||||
]
|
||||
|
|
|
|||
|
|
@ -332,8 +332,8 @@ class EntityNode(Node):
|
|||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
"""
|
||||
+ ENTITY_NODE_RETURN
|
||||
)
|
||||
records, _, _ = await driver.execute_query(
|
||||
|
|
@ -560,3 +560,9 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
|
||||
|
||||
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
||||
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
||||
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
||||
node.name_embedding = name_embedding
|
||||
|
|
|
|||
|
|
@ -18,7 +18,12 @@ import logging
|
|||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.edges import (
|
||||
CommunityEdge,
|
||||
EntityEdge,
|
||||
EpisodicEdge,
|
||||
create_entity_edge_embeddings,
|
||||
)
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
|
|
@ -152,8 +157,7 @@ async def extract_edges(
|
|||
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
||||
)
|
||||
|
||||
# calculate embeddings
|
||||
await semaphore_gather(*[edge.generate_embedding(embedder) for edge in edges])
|
||||
await create_entity_edge_embeddings(embedder, edges)
|
||||
|
||||
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
||||
|
||||
|
|
@ -214,7 +218,7 @@ async def resolve_extracted_edges(
|
|||
llm_client = clients.llm_client
|
||||
|
||||
related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges(
|
||||
driver, extracted_edges, SearchFilters(), 0.8
|
||||
driver, extracted_edges, SearchFilters()
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from pydantic import BaseModel
|
|||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
||||
|
|
@ -211,7 +211,7 @@ async def extract_nodes(
|
|||
extracted_nodes.append(new_node)
|
||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||
|
||||
await semaphore_gather(*[node.generate_name_embedding(embedder) for node in extracted_nodes])
|
||||
await create_entity_node_embeddings(embedder, extracted_nodes)
|
||||
|
||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
return extracted_nodes
|
||||
|
|
@ -279,7 +279,7 @@ async def resolve_extracted_nodes(
|
|||
|
||||
# Find relevant nodes already in the graph
|
||||
existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
|
||||
driver, extracted_nodes, SearchFilters(), 0.8
|
||||
driver, extracted_nodes, SearchFilters()
|
||||
)
|
||||
|
||||
uuid_map: dict[str, str] = {}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue