diff --git a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py index b25536d52..943351729 100644 --- a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py @@ -1,53 +1,59 @@ import asyncio -from typing import List - -import instructor +from typing import List, Optional from openai import AsyncOpenAI from fastembed import TextEmbedding -from cognee.config import Config from cognee.root_dir import get_absolute_path from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from litellm import aembedding import litellm litellm.set_verbose = True -from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config -config = get_embedding_config() class DefaultEmbeddingEngine(EmbeddingEngine): + embedding_model: str + embedding_dimensions: int + def __init__( + self, + embedding_model: Optional[str], + embedding_dimensions: Optional[int], + ): + self.embedding_model = embedding_model + self.embedding_dimensions = embedding_dimensions + async def embed_text(self, text: List[str]) -> List[float]: - embedding_model = TextEmbedding(model_name = config.embedding_model, cache_dir = get_absolute_path("cache/embeddings")) + embedding_model = TextEmbedding(model_name = self.embedding_model, cache_dir = get_absolute_path("cache/embeddings")) embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text))) return embeddings_list def get_vector_size(self) -> int: - return config.embedding_dimensions + return self.embedding_dimensions class LiteLLMEmbeddingEngine(EmbeddingEngine): + embedding_model: str + embedding_dimensions: int + def __init__( + self, + embedding_model: Optional[str], + embedding_dimensions: Optional[int], + ): + self.embedding_model = embedding_model + self.embedding_dimensions = embedding_dimensions import asyncio from typing import List async def embed_text(self, text: List[str]) -> List[List[float]]: async def get_embedding(text_): - response = await aembedding(config.litellm_embedding_model, input=text_) + response = await aembedding(self.embedding_model, input=text_) return response.data[0]['embedding'] tasks = [get_embedding(text_) for text_ in text] result = await asyncio.gather(*tasks) return result - - # embedding = response.data[0].embedding - # # embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text))) - # print("response", type(response.data[0]['embedding'])) - # print("response", response.data[0]) - # return [response.data[0]['embedding']] - - def get_vector_size(self) -> int: - return config.litellm_embedding_dimensions + return self.embedding_dimensions if __name__ == "__main__": diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 7a953847b..a6750c767 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -9,7 +9,7 @@ class EmbeddingConfig(BaseSettings): openai_embedding_dimensions: int = 3072 litellm_embedding_model: str = "text-embedding-3-large" litellm_embedding_dimensions: int = 3072 - embedding_engine:object = DefaultEmbeddingEngine() + embedding_engine:object = DefaultEmbeddingEngine(embedding_model=openai_embedding_model, embedding_dimensions=openai_embedding_dimensions) model_config = SettingsConfigDict(env_file = ".env", extra = "allow")