feat: Dedicated embedder interface (#159)

* feat: Add Embedder interface and implement openai embedder

* feat: Add voyage ai embedder
This commit is contained in:
Pavlo Paliychuk 2024-09-27 12:47:04 -04:00 committed by GitHub
parent 790c37de38
commit a7148d6260
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 182 additions and 102 deletions

View file

@ -24,9 +24,9 @@ from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
from graphiti_core.nodes import Node
logger = logging.getLogger(__name__)
@ -171,17 +171,16 @@ class EntityEdge(Edge):
default=None, description='datetime of when the fact stopped being true'
)
async def generate_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
async def generate_embedding(self, embedder: EmbedderClient):
start = time()
text = self.fact.replace('\n', ' ')
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.fact_embedding = embedding[:EMBEDDING_DIM]
self.fact_embedding = await embedder.create(input=[text])
end = time()
logger.info(f'embedded {text} in {end - start} ms')
return embedding
return self.fact_embedding
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(

View file

@ -0,0 +1,4 @@
from .client import EmbedderClient
from .openai import OpenAIEmbedder, OpenAIEmbedderConfig
__all__ = ['EmbedderClient', 'OpenAIEmbedder', 'OpenAIEmbedderConfig']

View file

@ -0,0 +1,34 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABC, abstractmethod
from typing import Iterable, List, Literal
from pydantic import BaseModel, Field
EMBEDDING_DIM = 1024
class EmbedderConfig(BaseModel):
embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
class EmbedderClient(ABC):
@abstractmethod
async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
pass

View file

@ -0,0 +1,48 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Iterable, List
from openai import AsyncOpenAI
from openai.types import EmbeddingModel
from .client import EmbedderClient, EmbedderConfig
DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
class OpenAIEmbedderConfig(EmbedderConfig):
embedding_model: EmbeddingModel | str = DEFAULT_EMBEDDING_MODEL
api_key: str | None = None
base_url: str | None = None
class OpenAIEmbedder(EmbedderClient):
"""
OpenAI Embedder Client
"""
def __init__(self, config: OpenAIEmbedderConfig | None = None):
if config is None:
config = OpenAIEmbedderConfig()
self.config = config
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
return result.data[0].embedding[: self.config.embedding_dim]

View file

@ -0,0 +1,47 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Iterable, List
import voyageai # type: ignore
from pydantic import Field
from .client import EmbedderClient, EmbedderConfig
DEFAULT_EMBEDDING_MODEL = 'voyage-3'
class VoyageAIEmbedderConfig(EmbedderConfig):
embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
api_key: str | None = None
class VoyageAIEmbedder(EmbedderClient):
"""
VoyageAI Embedder Client
"""
def __init__(self, config: VoyageAIEmbedderConfig | None = None):
if config is None:
config = VoyageAIEmbedderConfig()
self.config = config
self.client = voyageai.AsyncClient(api_key=config.api_key)
async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embed(input, model=self.config.embedding_model)
return result.embeddings[0][: self.config.embedding_dim]

View file

@ -23,6 +23,7 @@ from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
@ -83,6 +84,7 @@ class Graphiti:
user: str,
password: str,
llm_client: LLMClient | None = None,
embedder: EmbedderClient | None = None,
store_raw_episode_content: bool = True,
):
"""
@ -128,6 +130,10 @@ class Graphiti:
self.llm_client = llm_client
else:
self.llm_client = OpenAIClient()
if embedder:
self.embedder = embedder
else:
self.embedder = OpenAIEmbedder()
async def close(self):
"""
@ -290,7 +296,6 @@ class Graphiti:
start = time()
entity_edges: list[EntityEdge] = []
embedder = self.llm_client.get_embedder()
now = datetime.now()
previous_episodes = await self.retrieve_episodes(
@ -318,7 +323,7 @@ class Graphiti:
# Calculate Embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes]
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
)
# Resolve extracted nodes with nodes already in the graph and extract facts
@ -346,7 +351,7 @@ class Graphiti:
# calculate embeddings
await asyncio.gather(
*[
edge.generate_embedding(embedder, self.llm_client.embedding_model)
edge.generate_embedding(self.embedder)
for edge in extracted_edges_with_resolved_pointers
]
)
@ -439,7 +444,7 @@ class Graphiti:
if update_communities:
await asyncio.gather(
*[
update_community(self.driver, self.llm_client, embedder, node)
update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes
]
)
@ -488,7 +493,6 @@ class Graphiti:
"""
try:
start = time()
embedder = self.llm_client.get_embedder()
now = datetime.now()
episodes = [
@ -520,8 +524,8 @@ class Graphiti:
# Generate embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes],
*[edge.generate_embedding(embedder, self.llm_client.embedding_model) for edge in extracted_edges],
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
)
# Dedupe extracted nodes, compress extracted edges
@ -564,14 +568,14 @@ class Graphiti:
raise e
async def build_communities(self):
embedder = self.llm_client.get_embedder()
# Clear existing communities
await remove_communities(self.driver)
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
await asyncio.gather(*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in community_nodes])
await asyncio.gather(
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
)
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
@ -618,12 +622,11 @@ class Graphiti:
EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
)
search_config.limit = num_results
search_config.embedding_model = self.llm_client.embedding_model
edges = (
await search(
self.driver,
self.llm_client.get_embedder(),
self.embedder,
query,
group_ids,
search_config,
@ -640,9 +643,7 @@ class Graphiti:
group_ids: list[str] | None = None,
center_node_uuid: str | None = None,
) -> SearchResults:
return await search(
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
)
return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
async def get_nodes_by_query(
self,
@ -687,14 +688,15 @@ class Graphiti:
to each individual search method before results are combined and deduplicated.
If not specified, a default limit (defined in the search functions) will be used.
"""
embedder = self.llm_client.get_embedder()
search_config = (
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
)
search_config.limit = limit
nodes = (
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
await search(
self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
)
).nodes
return nodes

View file

@ -20,7 +20,6 @@ import typing
import anthropic
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
from ..prompts.models import Message
from .client import LLMClient
@ -47,10 +46,6 @@ class AnthropicClient(LLMClient):
max_retries=1,
)
def get_embedder(self) -> typing.Any:
openai_client = AsyncOpenAI()
return openai_client.embeddings
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [

View file

@ -54,11 +54,6 @@ class LLMClient(ABC):
self.max_tokens = config.max_tokens
self.cache_enabled = cache
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
self.embedding_model = config.embedding_model
@abstractmethod
def get_embedder(self) -> typing.Any:
pass
@retry(
stop=stop_after_attempt(4),

View file

@ -14,10 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
EMBEDDING_DIM = 1024
DEFAULT_MAX_TOKENS = 16384
DEFAULT_TEMPERATURE = 0
DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
class LLMConfig:
"""
@ -33,7 +32,6 @@ class LLMConfig:
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
embedding_model: str | None = DEFAULT_EMBEDDING_MODEL,
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: int = DEFAULT_MAX_TOKENS,
):
@ -51,15 +49,9 @@ class LLMConfig:
base_url (str, optional): The base URL of the LLM API service.
Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint.
This can be changed if using a different provider or a custom endpoint.
embedding_model (str, optional): The specific embedding model.
Defaults to openai "text-embedding-3-small" model.
We currently only support openai embedding models, such as "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002".
"""
self.base_url = base_url
self.api_key = api_key
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
if not embedding_model:
embedding_model = DEFAULT_EMBEDDING_MODEL
self.embedding_model = embedding_model

View file

@ -21,7 +21,6 @@ import typing
import groq
from groq import AsyncGroq
from groq.types.chat import ChatCompletionMessageParam
from openai import AsyncOpenAI
from ..prompts.models import Message
from .client import LLMClient
@ -44,10 +43,6 @@ class GroqClient(LLMClient):
self.client = AsyncGroq(api_key=config.api_key)
def get_embedder(self) -> typing.Any:
openai_client = AsyncOpenAI()
return openai_client.embeddings
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
msgs: list[ChatCompletionMessageParam] = []
for m in messages:

View file

@ -49,9 +49,6 @@ class OpenAIClient(LLMClient):
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
get_embedder() -> typing.Any:
Returns the embedder from the OpenAI client.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
@ -78,9 +75,6 @@ class OpenAIClient(LLMClient):
else:
self.client = client
def get_embedder(self) -> typing.Any:
return self.client.embeddings
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:

View file

@ -15,22 +15,18 @@ limitations under the License.
"""
import logging
import typing
from time import time
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
from graphiti_core.embedder.client import EmbedderClient
logger = logging.getLogger(__name__)
async def generate_embedding(
embedder: typing.Any, text: str, model: str = DEFAULT_EMBEDDING_MODEL
):
async def generate_embedding(embedder: EmbedderClient, text: str):
start = time()
text = text.replace('\n', ' ')
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
embedding = embedding[:EMBEDDING_DIM]
embedding = await embedder.create(input=[text])
end = time()
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')

View file

@ -25,8 +25,8 @@ from uuid import uuid4
from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
logger = logging.getLogger(__name__)
@ -212,15 +212,14 @@ class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
text = self.name.replace('\n', ' ')
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]
self.name_embedding = await embedder.create(input=[text])
end = time()
logger.info(f'embedded {text} in {end - start} ms')
return embedding
return self.name_embedding
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
@ -323,15 +322,14 @@ class CommunityNode(Node):
return result
async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL):
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
text = self.name.replace('\n', ' ')
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]
self.name_embedding = await embedder.create(input=[text])
end = time()
logger.info(f'embedded {text} in {end - start} ms')
return embedding
return self.name_embedding
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):

View file

@ -22,8 +22,8 @@ from time import time
from neo4j import AsyncDriver
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError
from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
async def search(
driver: AsyncDriver,
embedder,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
@ -75,7 +75,6 @@ async def search(
config.edge_config,
center_node_uuid,
config.limit,
config.embedding_model,
),
node_search(
driver,
@ -85,7 +84,6 @@ async def search(
config.node_config,
center_node_uuid,
config.limit,
config.embedding_model,
),
community_search(
driver,
@ -94,7 +92,6 @@ async def search(
group_ids,
config.community_config,
config.limit,
config.embedding_model,
),
)
@ -113,13 +110,12 @@ async def search(
async def edge_search(
driver: AsyncDriver,
embedder,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[EntityEdge]:
if config is None:
return []
@ -131,11 +127,7 @@ async def edge_search(
search_results.append(text_search)
if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
.data[0]
.embedding[:EMBEDDING_DIM]
)
search_vector = await embedder.create(input=[query])
similarity_search = await edge_similarity_search(
driver, search_vector, None, None, group_ids, 2 * limit
@ -182,13 +174,12 @@ async def edge_search(
async def node_search(
driver: AsyncDriver,
embedder,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: NodeSearchConfig | None,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[EntityNode]:
if config is None:
return []
@ -200,11 +191,7 @@ async def node_search(
search_results.append(text_search)
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
.data[0]
.embedding[:EMBEDDING_DIM]
)
search_vector = await embedder.create(input=[query])
similarity_search = await node_similarity_search(
driver, search_vector, group_ids, 2 * limit
@ -236,12 +223,11 @@ async def node_search(
async def community_search(
driver: AsyncDriver,
embedder,
embedder: EmbedderClient,
query: str,
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
embedding_model: str | None = None,
) -> list[CommunityNode]:
if config is None:
return []
@ -253,11 +239,7 @@ async def community_search(
search_results.append(text_search)
if CommunitySearchMethod.cosine_similarity in config.search_methods:
search_vector = (
(await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL))
.data[0]
.embedding[:EMBEDDING_DIM]
)
search_vector = await embedder.create(input=[query])
similarity_search = await community_similarity_search(
driver, search_vector, group_ids, 2 * limit

View file

@ -74,7 +74,6 @@ class SearchConfig(BaseModel):
edge_config: EdgeSearchConfig | None = Field(default=None)
node_config: NodeSearchConfig | None = Field(default=None)
community_config: CommunitySearchConfig | None = Field(default=None)
embedding_model: str | None = Field(default=None)
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)

View file

@ -7,6 +7,7 @@ from neo4j import AsyncDriver
from pydantic import BaseModel
from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.prompts import prompt_library
@ -288,7 +289,7 @@ async def determine_entity_community(
async def update_community(
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
driver: AsyncDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
):
community, is_new = await determine_entity_community(driver, entity)
@ -305,6 +306,6 @@ async def update_community(
community_edge = (build_community_edges([entity], community, datetime.now()))[0]
await community_edge.save(driver)
await community.generate_name_embedding(embedder, llm_client.embedding_model)
await community.generate_name_embedding(embedder)
await community.save(driver)

View file

@ -25,7 +25,7 @@ class ZepGraphiti(Graphiti):
group_id=group_id,
summary=summary,
)
await new_node.generate_name_embedding(self.llm_client.get_embedder(), self.llm_client.embedding_model)
await new_node.generate_name_embedding(self.embedder)
await new_node.save(self.driver)
return new_node
@ -83,8 +83,7 @@ async def get_graphiti(settings: ZepEnvDep):
client.llm_client.config.api_key = settings.openai_api_key
if settings.model_name is not None:
client.llm_client.model = settings.model_name
if settings.embedding_model_name is not None:
client.llm_client.embedding_model = settings.embedding_model_name
try:
yield client
finally:

View file

@ -101,7 +101,7 @@ async def test_graphiti_init():
@pytest.mark.asyncio
async def test_graph_integration():
client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
embedder = client.llm_client.get_embedder()
embedder = client.embedder
driver = client.driver
now = datetime.now()
@ -145,7 +145,7 @@ async def test_graph_integration():
invalid_at=now,
)
await entity_edge.generate_embedding(embedder, client.llm_client.embedding_model)
await entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]