Bulk embed (#403)
* add batch embeddings * bulk edge and node embeddings * update embeddings during add_episode * Update graphiti_core/embedder/client.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mypy --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
9ad5c4b4df
commit
0b94e0e603
8 changed files with 53 additions and 11 deletions
|
|
@ -321,8 +321,8 @@ class EntityEdge(Edge):
|
||||||
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
|
||||||
query: LiteralString = (
|
query: LiteralString = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
||||||
"""
|
"""
|
||||||
+ ENTITY_EDGE_RETURN
|
+ ENTITY_EDGE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -468,3 +468,9 @@ def get_community_edge_from_record(record: Any):
|
||||||
target_node_uuid=record['target_node_uuid'],
|
target_node_uuid=record['target_node_uuid'],
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
||||||
|
fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
|
||||||
|
for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
|
||||||
|
edge.fact_embedding = fact_embedding
|
||||||
|
|
|
||||||
|
|
@ -32,3 +32,6 @@ class EmbedderClient(ABC):
|
||||||
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
||||||
|
|
@ -66,3 +66,13 @@ class GeminiEmbedder(EmbedderClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
return result.embeddings[0].values
|
return result.embeddings[0].values
|
||||||
|
|
||||||
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||||
|
# Generate embeddings
|
||||||
|
result = await self.client.aio.models.embed_content(
|
||||||
|
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
|
||||||
|
contents=input_data_list,
|
||||||
|
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
return [embedding.values for embedding in result.embeddings]
|
||||||
|
|
|
||||||
|
|
@ -58,3 +58,9 @@ class OpenAIEmbedder(EmbedderClient):
|
||||||
input=input_data, model=self.config.embedding_model
|
input=input_data, model=self.config.embedding_model
|
||||||
)
|
)
|
||||||
return result.data[0].embedding[: self.config.embedding_dim]
|
return result.data[0].embedding[: self.config.embedding_dim]
|
||||||
|
|
||||||
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||||
|
result = await self.client.embeddings.create(
|
||||||
|
input=input_data_list, model=self.config.embedding_model
|
||||||
|
)
|
||||||
|
return [embedding.embedding[: self.config.embedding_dim] for embedding in result.data]
|
||||||
|
|
|
||||||
|
|
@ -56,3 +56,10 @@ class VoyageAIEmbedder(EmbedderClient):
|
||||||
|
|
||||||
result = await self.client.embed(input_list, model=self.config.embedding_model)
|
result = await self.client.embed(input_list, model=self.config.embedding_model)
|
||||||
return [float(x) for x in result.embeddings[0][: self.config.embedding_dim]]
|
return [float(x) for x in result.embeddings[0][: self.config.embedding_dim]]
|
||||||
|
|
||||||
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||||
|
result = await self.client.embed(input_data_list, model=self.config.embedding_model)
|
||||||
|
return [
|
||||||
|
[float(x) for x in embedding[: self.config.embedding_dim]]
|
||||||
|
for embedding in result.embeddings
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -332,8 +332,8 @@ class EntityNode(Node):
|
||||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||||
query = (
|
query = (
|
||||||
"""
|
"""
|
||||||
MATCH (n:Entity {uuid: $uuid})
|
MATCH (n:Entity {uuid: $uuid})
|
||||||
"""
|
"""
|
||||||
+ ENTITY_NODE_RETURN
|
+ ENTITY_NODE_RETURN
|
||||||
)
|
)
|
||||||
records, _, _ = await driver.execute_query(
|
records, _, _ = await driver.execute_query(
|
||||||
|
|
@ -560,3 +560,9 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
||||||
created_at=record['created_at'].to_native(),
|
created_at=record['created_at'].to_native(),
|
||||||
summary=record['summary'],
|
summary=record['summary'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
||||||
|
name_embeddings = await embedder.create_batch([node.name for node in nodes])
|
||||||
|
for node, name_embedding in zip(nodes, name_embeddings, strict=True):
|
||||||
|
node.name_embedding = name_embedding
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,12 @@ import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
from graphiti_core.edges import (
|
||||||
|
CommunityEdge,
|
||||||
|
EntityEdge,
|
||||||
|
EpisodicEdge,
|
||||||
|
create_entity_edge_embeddings,
|
||||||
|
)
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
|
@ -152,8 +157,7 @@ async def extract_edges(
|
||||||
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate embeddings
|
await create_entity_edge_embeddings(embedder, edges)
|
||||||
await semaphore_gather(*[edge.generate_embedding(embedder) for edge in edges])
|
|
||||||
|
|
||||||
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
||||||
|
|
||||||
|
|
@ -214,7 +218,7 @@ async def resolve_extracted_edges(
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
|
|
||||||
related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges(
|
related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges(
|
||||||
driver, extracted_edges, SearchFilters(), 0.8
|
driver, extracted_edges, SearchFilters()
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from pydantic import BaseModel
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
from graphiti_core.graphiti_types import GraphitiClients
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||||
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
||||||
|
|
@ -211,7 +211,7 @@ async def extract_nodes(
|
||||||
extracted_nodes.append(new_node)
|
extracted_nodes.append(new_node)
|
||||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||||
|
|
||||||
await semaphore_gather(*[node.generate_name_embedding(embedder) for node in extracted_nodes])
|
await create_entity_node_embeddings(embedder, extracted_nodes)
|
||||||
|
|
||||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||||
return extracted_nodes
|
return extracted_nodes
|
||||||
|
|
@ -279,7 +279,7 @@ async def resolve_extracted_nodes(
|
||||||
|
|
||||||
# Find relevant nodes already in the graph
|
# Find relevant nodes already in the graph
|
||||||
existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
|
existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
|
||||||
driver, extracted_nodes, SearchFilters(), 0.8
|
driver, extracted_nodes, SearchFilters()
|
||||||
)
|
)
|
||||||
|
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue