diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 2b326a1c..d515cd15 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -26,7 +26,7 @@ from pydantic import BaseModel, Field from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import parse_db_date -from graphiti_core.llm_client.config import EMBEDDING_DIM +from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM from graphiti_core.nodes import Node logger = logging.getLogger(__name__) @@ -171,7 +171,7 @@ class EntityEdge(Edge): default=None, description='datetime of when the fact stopped being true' ) - async def generate_embedding(self, embedder, model='text-embedding-3-small'): + async def generate_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL): start = time() text = self.fact.replace('\n', ' ') diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 540c9ff1..0037c74e 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -315,7 +315,7 @@ class Graphiti: # Calculate Embeddings await asyncio.gather( - *[node.generate_name_embedding(embedder) for node in extracted_nodes] + *[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes] ) # Resolve extracted nodes with nodes already in the graph and extract facts @@ -343,7 +343,7 @@ class Graphiti: # calculate embeddings await asyncio.gather( *[ - edge.generate_embedding(embedder) + edge.generate_embedding(embedder, self.llm_client.embedding_model) for edge in extracted_edges_with_resolved_pointers ] ) @@ -517,8 +517,8 @@ class Graphiti: # Generate embeddings await asyncio.gather( - *[node.generate_name_embedding(embedder) for node in extracted_nodes], - *[edge.generate_embedding(embedder) for edge in extracted_edges], + *[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in extracted_nodes], + *[edge.generate_embedding(embedder, self.llm_client.embedding_model) for edge in extracted_edges], ) # Dedupe extracted nodes, compress extracted edges @@ -568,7 +568,7 @@ class Graphiti: community_nodes, community_edges = await build_communities(self.driver, self.llm_client) - await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes]) + await asyncio.gather(*[node.generate_name_embedding(embedder, self.llm_client.embedding_model) for node in community_nodes]) await asyncio.gather(*[node.save(self.driver) for node in community_nodes]) await asyncio.gather(*[edge.save(self.driver) for edge in community_edges]) @@ -615,6 +615,7 @@ class Graphiti: EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE ) search_config.limit = num_results + search_config.embedding_model = self.llm_client.embedding_model edges = ( await search( diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index fe5e9177..9da42ce4 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -54,6 +54,7 @@ class LLMClient(ABC): self.max_tokens = config.max_tokens self.cache_enabled = cache self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory + self.embedding_model = config.embedding_model @abstractmethod def get_embedder(self) -> typing.Any: diff --git a/graphiti_core/llm_client/config.py b/graphiti_core/llm_client/config.py index cedfe7c8..26274771 100644 --- a/graphiti_core/llm_client/config.py +++ b/graphiti_core/llm_client/config.py @@ -17,7 +17,7 @@ limitations under the License. EMBEDDING_DIM = 1024 DEFAULT_MAX_TOKENS = 16384 DEFAULT_TEMPERATURE = 0 - +DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small' class LLMConfig: """ @@ -33,6 +33,7 @@ class LLMConfig: api_key: str | None = None, model: str | None = None, base_url: str | None = None, + embedding_model: str | None = DEFAULT_EMBEDDING_MODEL, temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS, ): @@ -50,9 +51,15 @@ class LLMConfig: base_url (str, optional): The base URL of the LLM API service. Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint. This can be changed if using a different provider or a custom endpoint. + embedding_model (str, optional): The specific embedding model. + Defaults to "text-embedding-3-small", which appears to be a custom model name. + Common values might include "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002". """ self.base_url = base_url self.api_key = api_key self.model = model self.temperature = temperature self.max_tokens = max_tokens + if not embedding_model: + embedding_model = DEFAULT_EMBEDDING_MODEL + self.embedding_model = embedding_model diff --git a/graphiti_core/llm_client/utils.py b/graphiti_core/llm_client/utils.py index d98b49b9..c26a746d 100644 --- a/graphiti_core/llm_client/utils.py +++ b/graphiti_core/llm_client/utils.py @@ -18,13 +18,13 @@ import logging import typing from time import time -from graphiti_core.llm_client.config import EMBEDDING_DIM +from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM logger = logging.getLogger(__name__) async def generate_embedding( - embedder: typing.Any, text: str, model: str = 'text-embedding-3-small' + embedder: typing.Any, text: str, model: str = DEFAULT_EMBEDDING_MODEL ): start = time() diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 828a7ebb..a2d13bdd 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -26,7 +26,7 @@ from neo4j import AsyncDriver from pydantic import BaseModel, Field from graphiti_core.errors import NodeNotFoundError -from graphiti_core.llm_client.config import EMBEDDING_DIM +from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM logger = logging.getLogger(__name__) @@ -212,7 +212,7 @@ class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) - async def generate_name_embedding(self, embedder, model='text-embedding-3-small'): + async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL): start = time() text = self.name.replace('\n', ' ') embedding = (await embedder.create(input=[text], model=model)).data[0].embedding @@ -323,7 +323,7 @@ class CommunityNode(Node): return result - async def generate_name_embedding(self, embedder, model='text-embedding-3-small'): + async def generate_name_embedding(self, embedder, model=DEFAULT_EMBEDDING_MODEL): start = time() text = self.name.replace('\n', ' ') embedding = (await embedder.create(input=[text], model=model)).data[0].embedding diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 862ececd..9a5fa311 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -22,7 +22,7 @@ from neo4j import AsyncDriver from graphiti_core.edges import EntityEdge from graphiti_core.errors import SearchRerankerError -from graphiti_core.llm_client.config import EMBEDDING_DIM +from graphiti_core.llm_client.config import DEFAULT_EMBEDDING_MODEL, EMBEDDING_DIM from graphiti_core.nodes import CommunityNode, EntityNode from graphiti_core.search.search_config import ( DEFAULT_SEARCH_LIMIT, @@ -67,21 +67,21 @@ async def search( group_ids = group_ids if group_ids else None edges = ( await edge_search( - driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit + driver, embedder, query, group_ids, config.edge_config, center_node_uuid, config.limit, config.embedding_model ) if config.edge_config is not None else [] ) nodes = ( await node_search( - driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit + driver, embedder, query, group_ids, config.node_config, center_node_uuid, config.limit, config.embedding_model ) if config.node_config is not None else [] ) communities = ( await community_search( - driver, embedder, query, group_ids, config.community_config, config.limit + driver, embedder, query, group_ids, config.community_config, config.limit, config.embedding_model ) if config.community_config is not None else [] @@ -108,6 +108,7 @@ async def edge_search( config: EdgeSearchConfig, center_node_uuid: str | None = None, limit=DEFAULT_SEARCH_LIMIT, + embedding_model: str | None = None, ) -> list[EntityEdge]: search_results: list[list[EntityEdge]] = [] @@ -117,7 +118,7 @@ async def edge_search( if EdgeSearchMethod.cosine_similarity in config.search_methods: search_vector = ( - (await embedder.create(input=[query], model='text-embedding-3-small')) + (await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL)) .data[0] .embedding[:EMBEDDING_DIM] ) @@ -173,6 +174,7 @@ async def node_search( config: NodeSearchConfig, center_node_uuid: str | None = None, limit=DEFAULT_SEARCH_LIMIT, + embedding_model: str | None = None, ) -> list[EntityNode]: search_results: list[list[EntityNode]] = [] @@ -182,7 +184,7 @@ async def node_search( if NodeSearchMethod.cosine_similarity in config.search_methods: search_vector = ( - (await embedder.create(input=[query], model='text-embedding-3-small')) + (await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL)) .data[0] .embedding[:EMBEDDING_DIM] ) @@ -222,6 +224,7 @@ async def community_search( group_ids: list[str] | None, config: CommunitySearchConfig, limit=DEFAULT_SEARCH_LIMIT, + embedding_model: str | None = None, ) -> list[CommunityNode]: search_results: list[list[CommunityNode]] = [] @@ -231,7 +234,7 @@ async def community_search( if CommunitySearchMethod.cosine_similarity in config.search_methods: search_vector = ( - (await embedder.create(input=[query], model='text-embedding-3-small')) + (await embedder.create(input=[query], model=embedding_model or DEFAULT_EMBEDDING_MODEL)) .data[0] .embedding[:EMBEDDING_DIM] ) diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index ceb644b9..e0c1127f 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -74,6 +74,7 @@ class SearchConfig(BaseModel): edge_config: EdgeSearchConfig | None = Field(default=None) node_config: NodeSearchConfig | None = Field(default=None) community_config: CommunitySearchConfig | None = Field(default=None) + embedding_model: str | None = Field(default=None) limit: int = Field(default=DEFAULT_SEARCH_LIMIT) diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index fa3046a2..88dba26f 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -305,6 +305,6 @@ async def update_community( community_edge = (build_community_edges([entity], community, datetime.now()))[0] await community_edge.save(driver) - await community.generate_name_embedding(embedder) + await community.generate_name_embedding(embedder, llm_client.embedding_model) await community.save(driver) diff --git a/server/graph_service/config.py b/server/graph_service/config.py index f3082a5f..af4eca84 100644 --- a/server/graph_service/config.py +++ b/server/graph_service/config.py @@ -10,6 +10,7 @@ class Settings(BaseSettings): openai_api_key: str openai_base_url: str | None = Field(None) model_name: str | None = Field(None) + embedding_model_name: str | None = Field(None) neo4j_uri: str neo4j_user: str neo4j_password: str diff --git a/server/graph_service/zep_graphiti.py b/server/graph_service/zep_graphiti.py index 4a901d61..66d74cfe 100644 --- a/server/graph_service/zep_graphiti.py +++ b/server/graph_service/zep_graphiti.py @@ -25,7 +25,7 @@ class ZepGraphiti(Graphiti): group_id=group_id, summary=summary, ) - await new_node.generate_name_embedding(self.llm_client.get_embedder()) + await new_node.generate_name_embedding(self.llm_client.get_embedder(), self.llm_client.embedding_model) await new_node.save(self.driver) return new_node @@ -83,6 +83,8 @@ async def get_graphiti(settings: ZepEnvDep): client.llm_client.config.api_key = settings.openai_api_key if settings.model_name is not None: client.llm_client.model = settings.model_name + if settings.embedding_model_name is not None: + client.llm_client.embedding_model = settings.embedding_model_name try: yield client finally: diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 5ab04541..eec33923 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -146,7 +146,7 @@ async def test_graph_integration(): invalid_at=now, ) - await entity_edge.generate_embedding(embedder) + await entity_edge.generate_embedding(embedder, client.llm_client.embedding_model) nodes = [episode, alice_node, bob_node] edges = [episodic_edge_1, episodic_edge_2, entity_edge]