diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index d515cd15..b5903f31 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -24,9 +24,9 @@ from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import parse_db_date -from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM from graphiti_core.nodes import Node logger = logging.getLogger(__name__) @@ -171,17 +171,16 @@ class EntityEdge(Edge): default=None, description='datetime of when the fact stopped being true' ) - async def generate_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL): + async def generate_embedding(self, embedder: EmbedderClient): start = time() text = self.fact.replace('\n', ' ') - embedding = (await embedder.create(input=[text], model=model)).data[0].embedding - self.fact_embedding = embedding[:EMBEDDING_DIM] + self.fact_embedding = await embedder.create(input=[text]) end = time() logger.info(f'embedded {text} in {end - start} ms') - return embedding + return self.fact_embedding async def save(self, driver: AsyncDriver): result = await driver.execute_query( diff --git a/graphiti_core/embedder/__init__.py b/graphiti_core/embedder/__init__.py new file mode 100644 index 00000000..9e952726 --- /dev/null +++ b/graphiti_core/embedder/__init__.py @@ -0,0 +1,4 @@ +from .client import EmbedderClient +from .openai import OpenAIEmbedder, OpenAIEmbedderConfig + +__all__ = ['EmbedderClient', 'OpenAIEmbedder', 'OpenAIEmbedderConfig'] diff --git a/graphiti_core/embedder/client.py b/graphiti_core/embedder/client.py new file mode 100644 index 00000000..950298e4 --- /dev/null +++ b/graphiti_core/embedder/client.py @@ -0,0 +1,34 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC, abstractmethod +from typing import Iterable, List, Literal + +from pydantic import BaseModel, Field + +EMBEDDING_DIM = 1024 + + +class EmbedderConfig(BaseModel): + embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True) + + +class EmbedderClient(ABC): + @abstractmethod + async def create( + self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] + ) -> list[float]: + pass diff --git a/graphiti_core/embedder/openai.py b/graphiti_core/embedder/openai.py new file mode 100644 index 00000000..a209dba1 --- /dev/null +++ b/graphiti_core/embedder/openai.py @@ -0,0 +1,48 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Iterable, List + +from openai import AsyncOpenAI +from openai.types import EmbeddingModel + +from .client import EmbedderClient, EmbedderConfig + +DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small' + + +class OpenAIEmbedderConfig(EmbedderConfig): + embedding_model: EmbeddingModel | str = DEFAULT_EMBEDDING_MODEL + api_key: str | None = None + base_url: str | None = None + + +class OpenAIEmbedder(EmbedderClient): + """ + OpenAI Embedder Client + """ + + def __init__(self, config: OpenAIEmbedderConfig | None = None): + if config is None: + config = OpenAIEmbedderConfig() + self.config = config + self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + + async def create( + self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] + ) -> list[float]: + result = await self.client.embeddings.create(input=input, model=self.config.embedding_model) + return result.data[0].embedding[: self.config.embedding_dim] diff --git a/graphiti_core/embedder/voyage.py b/graphiti_core/embedder/voyage.py new file mode 100644 index 00000000..f0fca309 --- /dev/null +++ b/graphiti_core/embedder/voyage.py @@ -0,0 +1,47 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Iterable, List + +import voyageai # type: ignore +from pydantic import Field + +from .client import EmbedderClient, EmbedderConfig + +DEFAULT_EMBEDDING_MODEL = 'voyage-3' + + +class VoyageAIEmbedderConfig(EmbedderConfig): + embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL) + api_key: str | None = None + + +class VoyageAIEmbedder(EmbedderClient): + """ + VoyageAI Embedder Client + """ + + def __init__(self, config: VoyageAIEmbedderConfig | None = None): + if config is None: + config = VoyageAIEmbedderConfig() + self.config = config + self.client = voyageai.AsyncClient(api_key=config.api_key) + + async def create( + self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]] + ) -> list[float]: + result = await self.client.embed(input, model=self.config.embedding_model) + return result.embeddings[0][: self.config.embedding_dim] diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 4b104698..5d0c6ba4 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -23,6 +23,7 @@ from dotenv import load_dotenv from neo4j import AsyncGraphDatabase from graphiti_core.edges import EntityEdge, EpisodicEdge +from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search import SearchConfig, search @@ -83,6 +84,7 @@ class Graphiti: user: str, password: str, llm_client: LLMClient | None = None, + embedder: EmbedderClient | None = None, store_raw_episode_content: bool = True, ): """ @@ -128,6 +130,10 @@ class Graphiti: self.llm_client = llm_client else: self.llm_client = OpenAIClient() + if embedder: + self.embedder = embedder + else: + self.embedder = OpenAIEmbedder() async def close(self): """ @@ -290,7 +296,6 @@ class Graphiti: start = time() entity_edges: list[EntityEdge] = [] - embedder = self.llm_client.get_embedder() now = datetime.now() previous_episodes = await self.retrieve_episodes( @@ -318,7 +323,7 @@ class Graphiti: # Calculate Embeddings await asyncio.gather( - *[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes] + *[node.generate_name_embedding(self.embedder) for node in extracted_nodes] ) # Resolve extracted nodes with nodes already in the graph and extract facts @@ -346,7 +351,7 @@ class Graphiti: # calculate embeddings await asyncio.gather( *[ - edge.generate_embedding(embedder, self.llm_client.embedding_model) + edge.generate_embedding(self.embedder) for edge in extracted_edges_with_resolved_pointers ] ) @@ -439,7 +444,7 @@ class Graphiti: if update_communities: await asyncio.gather( *[ - update_community(self.driver, self.llm_client, embedder, node) + update_community(self.driver, self.llm_client, self.embedder, node) for node in nodes ] ) @@ -488,7 +493,6 @@ class Graphiti: """ try: start = time() - embedder = self.llm_client.get_embedder() now = datetime.now() episodes = [ @@ -520,8 +524,8 @@ class Graphiti: # Generate embeddings await asyncio.gather( - *[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes], - *[edge.generate_embedding(embedder, self.llm_client.embedding_model) for edge in extracted_edges], + *[node.generate_name_embedding(self.embedder) for node in extracted_nodes], + *[edge.generate_embedding(self.embedder) for edge in extracted_edges], ) # Dedupe extracted nodes, compress extracted edges @@ -564,14 +568,14 @@ class Graphiti: raise e async def build_communities(self): - embedder = self.llm_client.get_embedder() - # Clear existing communities await remove_communities(self.driver) community_nodes, community_edges = await build_communities(self.driver, self.llm_client) - await asyncio.gather(*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in community_nodes]) + await asyncio.gather( + *[node.generate_name_embedding(self.embedder) for node in community_nodes] + ) await asyncio.gather(*[node.save(self.driver) for node in community_nodes]) await asyncio.gather(*[edge.save(self.driver) for edge in community_edges]) @@ -618,12 +622,11 @@ class Graphiti: EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE ) search_config.limit = num_results - search_config.embedding_model = self.llm_client.embedding_model edges = ( await search( self.driver, - self.llm_client.get_embedder(), + self.embedder, query, group_ids, search_config, @@ -640,9 +643,7 @@ class Graphiti: group_ids: list[str] | None = None, center_node_uuid: str | None = None, ) -> SearchResults: - return await search( - self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid - ) + return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid) async def get_nodes_by_query( self, @@ -687,14 +688,15 @@ class Graphiti: to each individual search method before results are combined and deduplicated. If not specified, a default limit (defined in the search functions) will be used. """ - embedder = self.llm_client.get_embedder() search_config = ( NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE ) search_config.limit = limit nodes = ( - await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid) + await search( + self.driver, self.embedder, query, group_ids, search_config, center_node_uuid + ) ).nodes return nodes diff --git a/graphiti_core/llm_client/anthropic_client.py b/graphiti_core/llm_client/anthropic_client.py index ee186f1e..ec6e88ff 100644 --- a/graphiti_core/llm_client/anthropic_client.py +++ b/graphiti_core/llm_client/anthropic_client.py @@ -20,7 +20,6 @@ import typing import anthropic from anthropic import AsyncAnthropic -from openai import AsyncOpenAI from ..prompts.models import Message from .client import LLMClient @@ -47,10 +46,6 @@ class AnthropicClient(LLMClient): max_retries=1, ) - def get_embedder(self) -> typing.Any: - openai_client = AsyncOpenAI() - return openai_client.embeddings - async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: system_message = messages[0] user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [ diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index 9da42ce4..7886c7f8 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -54,11 +54,6 @@ class LLMClient(ABC): self.max_tokens = config.max_tokens self.cache_enabled = cache self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory - self.embedding_model = config.embedding_model - - @abstractmethod - def get_embedder(self) -> typing.Any: - pass @retry( stop=stop_after_attempt(4), diff --git a/graphiti_core/llm_client/config.py b/graphiti_core/llm_client/config.py index 1f584245..5e346367 100644 --- a/graphiti_core/llm_client/config.py +++ b/graphiti_core/llm_client/config.py @@ -14,10 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. """ -EMBEDDING_DIM = 1024 DEFAULT_MAX_TOKENS = 16384 DEFAULT_TEMPERATURE = 0 -DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small' + class LLMConfig: """ @@ -33,7 +32,6 @@ class LLMConfig: api_key: str | None = None, model: str | None = None, base_url: str | None = None, - embedding_model: str | None = DEFAULT_EMBEDDING_MODEL, temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS, ): @@ -51,15 +49,9 @@ class LLMConfig: base_url (str, optional): The base URL of the LLM API service. Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint. This can be changed if using a different provider or a custom endpoint. - embedding_model (str, optional): The specific embedding model. - Defaults to openai "text-embedding-3-small" model. - We currently only support openai embedding models, such as "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002". """ self.base_url = base_url self.api_key = api_key self.model = model self.temperature = temperature self.max_tokens = max_tokens - if not embedding_model: - embedding_model = DEFAULT_EMBEDDING_MODEL - self.embedding_model = embedding_model diff --git a/graphiti_core/llm_client/groq_client.py b/graphiti_core/llm_client/groq_client.py index 673b8db1..9f59e621 100644 --- a/graphiti_core/llm_client/groq_client.py +++ b/graphiti_core/llm_client/groq_client.py @@ -21,7 +21,6 @@ import typing import groq from groq import AsyncGroq from groq.types.chat import ChatCompletionMessageParam -from openai import AsyncOpenAI from ..prompts.models import Message from .client import LLMClient @@ -44,10 +43,6 @@ class GroqClient(LLMClient): self.client = AsyncGroq(api_key=config.api_key) - def get_embedder(self) -> typing.Any: - openai_client = AsyncOpenAI() - return openai_client.embeddings - async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: msgs: list[ChatCompletionMessageParam] = [] for m in messages: diff --git a/graphiti_core/llm_client/openai_client.py b/graphiti_core/llm_client/openai_client.py index f459a3f4..957317cc 100644 --- a/graphiti_core/llm_client/openai_client.py +++ b/graphiti_core/llm_client/openai_client.py @@ -49,9 +49,6 @@ class OpenAIClient(LLMClient): __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None): Initializes the OpenAIClient with the provided configuration, cache setting, and client. - get_embedder() -> typing.Any: - Returns the embedder from the OpenAI client. - _generate_response(messages: list[Message]) -> dict[str, typing.Any]: Generates a response from the language model based on the provided messages. """ @@ -78,9 +75,6 @@ class OpenAIClient(LLMClient): else: self.client = client - def get_embedder(self) -> typing.Any: - return self.client.embeddings - async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: openai_messages: list[ChatCompletionMessageParam] = [] for m in messages: diff --git a/graphiti_core/llm_client/utils.py b/graphiti_core/llm_client/utils.py index c26a746d..2e367bad 100644 --- a/graphiti_core/llm_client/utils.py +++ b/graphiti_core/llm_client/utils.py @@ -15,22 +15,18 @@ limitations under the License. """ import logging -import typing from time import time -from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM +from graphiti_core.embedder.client import EmbedderClient logger = logging.getLogger(__name__) -async def generate_embedding( - embedder: typing.Any, text: str, model: str = DEFAULT_EMBEDDING_MODEL -): +async def generate_embedding(embedder: EmbedderClient, text: str): start = time() text = text.replace('\n', ' ') - embedding = (await embedder.create(input=[text], model=model)).data[0].embedding - embedding = embedding[:EMBEDDING_DIM] + embedding = await embedder.create(input=[text]) end = time() logger.debug(f'embedded text of length {len(text)} in {end - start} ms') diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index a2d13bdd..2635ae89 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -25,8 +25,8 @@ from uuid import uuid4 from neo4j import AsyncDriver from pydantic import BaseModel, Field +from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import NodeNotFoundError -from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM logger = logging.getLogger(__name__) @@ -212,15 +212,14 @@ class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) - async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL): + async def generate_name_embedding(self, embedder: EmbedderClient): start = time() text = self.name.replace('\n', ' ') - embedding = (await embedder.create(input=[text], model=model)).data[0].embedding - self.name_embedding = embedding[:EMBEDDING_DIM] + self.name_embedding = await embedder.create(input=[text]) end = time() logger.info(f'embedded {text} in {end - start} ms') - return embedding + return self.name_embedding async def save(self, driver: AsyncDriver): result = await driver.execute_query( @@ -323,15 +322,14 @@ class CommunityNode(Node): return result - async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL): + async def generate_name_embedding(self, embedder: EmbedderClient): start = time() text = self.name.replace('\n', ' ') - embedding = (await embedder.create(input=[text], model=model)).data[0].embedding - self.name_embedding = embedding[:EMBEDDING_DIM] + self.name_embedding = await embedder.create(input=[text]) end = time() logger.info(f'embedded {text} in {end - start} ms') - return embedding + return self.name_embedding @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 070efcb3..daf80f3e 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -22,8 +22,8 @@ from time import time from neo4j import AsyncDriver from graphiti_core.edges import EntityEdge +from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import SearchRerankerError -from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM from graphiti_core.nodes import CommunityNode, EntityNode from graphiti_core.search.search_config import ( DEFAULT_SEARCH_LIMIT, @@ -56,7 +56,7 @@ logger = logging.getLogger(__name__) async def search( driver: AsyncDriver, - embedder, + embedder: EmbedderClient, query: str, group_ids: list[str] | None, config: SearchConfig, @@ -75,7 +75,6 @@ async def search( config.edge_config, center_node_uuid, config.limit, - config.embedding_model, ), node_search( driver, @@ -85,7 +84,6 @@ async def search( config.node_config, center_node_uuid, config.limit, - config.embedding_model, ), community_search( driver, @@ -94,7 +92,6 @@ async def search( group_ids, config.community_config, config.limit, - config.embedding_model, ), ) @@ -113,13 +110,12 @@ async def search( async def edge_search( driver: AsyncDriver, - embedder, + embedder: EmbedderClient, query: str, group_ids: list[str] | None, config: EdgeSearchConfig | None, center_node_uuid: str | None = None, limit=DEFAULT_SEARCH_LIMIT, - embedding_model: str | None = None, ) -> list[EntityEdge]: if config is None: return [] @@ -131,11 +127,7 @@ async def edge_search( search_results.append(text_search) if EdgeSearchMethod.cosine_similarity in config.search_methods: - search_vector = ( - (await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL)) - .data[0] - .embedding[:EMBEDDING_DIM] - ) + search_vector = await embedder.create(input=[query]) similarity_search = await edge_similarity_search( driver, search_vector, None, None, group_ids, 2 * limit @@ -182,13 +174,12 @@ async def edge_search( async def node_search( driver: AsyncDriver, - embedder, + embedder: EmbedderClient, query: str, group_ids: list[str] | None, config: NodeSearchConfig | None, center_node_uuid: str | None = None, limit=DEFAULT_SEARCH_LIMIT, - embedding_model: str | None = None, ) -> list[EntityNode]: if config is None: return [] @@ -200,11 +191,7 @@ async def node_search( search_results.append(text_search) if NodeSearchMethod.cosine_similarity in config.search_methods: - search_vector = ( - (await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL)) - .data[0] - .embedding[:EMBEDDING_DIM] - ) + search_vector = await embedder.create(input=[query]) similarity_search = await node_similarity_search( driver, search_vector, group_ids, 2 * limit @@ -236,12 +223,11 @@ async def node_search( async def community_search( driver: AsyncDriver, - embedder, + embedder: EmbedderClient, query: str, group_ids: list[str] | None, config: CommunitySearchConfig | None, limit=DEFAULT_SEARCH_LIMIT, - embedding_model: str | None = None, ) -> list[CommunityNode]: if config is None: return [] @@ -253,11 +239,7 @@ async def community_search( search_results.append(text_search) if CommunitySearchMethod.cosine_similarity in config.search_methods: - search_vector = ( - (await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL)) - .data[0] - .embedding[:EMBEDDING_DIM] - ) + search_vector = await embedder.create(input=[query]) similarity_search = await community_similarity_search( driver, search_vector, group_ids, 2 * limit diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index e0c1127f..ceb644b9 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -74,7 +74,6 @@ class SearchConfig(BaseModel): edge_config: EdgeSearchConfig | None = Field(default=None) node_config: NodeSearchConfig | None = Field(default=None) community_config: CommunitySearchConfig | None = Field(default=None) - embedding_model: str | None = Field(default=None) limit: int = Field(default=DEFAULT_SEARCH_LIMIT) diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index 88dba26f..eb637617 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -7,6 +7,7 @@ from neo4j import AsyncDriver from pydantic import BaseModel from graphiti_core.edges import CommunityEdge +from graphiti_core.embedder import EmbedderClient from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record from graphiti_core.prompts import prompt_library @@ -288,7 +289,7 @@ async def determine_entity_community( async def update_community( - driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode + driver: AsyncDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode ): community, is_new = await determine_entity_community(driver, entity) @@ -305,6 +306,6 @@ async def update_community( community_edge = (build_community_edges([entity], community, datetime.now()))[0] await community_edge.save(driver) - await community.generate_name_embedding(embedder, llm_client.embedding_model) + await community.generate_name_embedding(embedder) await community.save(driver) diff --git a/server/graph_service/zep_graphiti.py b/server/graph_service/zep_graphiti.py index 66d74cfe..097c9f39 100644 --- a/server/graph_service/zep_graphiti.py +++ b/server/graph_service/zep_graphiti.py @@ -25,7 +25,7 @@ class ZepGraphiti(Graphiti): group_id=group_id, summary=summary, ) - await new_node.generate_name_embedding(self.llm_client.get_embedder(), self.llm_client.embedding_model) + await new_node.generate_name_embedding(self.embedder) await new_node.save(self.driver) return new_node @@ -83,8 +83,7 @@ async def get_graphiti(settings: ZepEnvDep): client.llm_client.config.api_key = settings.openai_api_key if settings.model_name is not None: client.llm_client.model = settings.model_name - if settings.embedding_model_name is not None: - client.llm_client.embedding_model = settings.embedding_model_name + try: yield client finally: diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index f9a94201..4c0a34b5 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -101,7 +101,7 @@ async def test_graphiti_init(): @pytest.mark.asyncio async def test_graph_integration(): client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - embedder = client.llm_client.get_embedder() + embedder = client.embedder driver = client.driver now = datetime.now() @@ -145,7 +145,7 @@ async def test_graph_integration(): invalid_at=now, ) - await entity_edge.generate_embedding(embedder, client.llm_client.embedding_model) + await entity_edge.generate_embedding(embedder) nodes = [episode, alice_node, bob_node] edges = [episodic_edge_1, episodic_edge_2, entity_edge]