Updates to the configs
This commit is contained in:
parent
df32ba622e
commit
6c739f3357
2 changed files with 25 additions and 19 deletions
|
|
@ -1,53 +1,59 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
|
||||
import instructor
|
||||
from typing import List, Optional
|
||||
from openai import AsyncOpenAI
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
from cognee.config import Config
|
||||
from cognee.root_dir import get_absolute_path
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from litellm import aembedding
|
||||
import litellm
|
||||
|
||||
litellm.set_verbose = True
|
||||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||
config = get_embedding_config()
|
||||
|
||||
class DefaultEmbeddingEngine(EmbeddingEngine):
|
||||
embedding_model: str
|
||||
embedding_dimensions: int
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Optional[str],
|
||||
embedding_dimensions: Optional[int],
|
||||
):
|
||||
self.embedding_model = embedding_model
|
||||
self.embedding_dimensions = embedding_dimensions
|
||||
|
||||
async def embed_text(self, text: List[str]) -> List[float]:
|
||||
embedding_model = TextEmbedding(model_name = config.embedding_model, cache_dir = get_absolute_path("cache/embeddings"))
|
||||
embedding_model = TextEmbedding(model_name = self.embedding_model, cache_dir = get_absolute_path("cache/embeddings"))
|
||||
embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
|
||||
|
||||
return embeddings_list
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
return config.embedding_dimensions
|
||||
return self.embedding_dimensions
|
||||
|
||||
|
||||
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||
embedding_model: str
|
||||
embedding_dimensions: int
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Optional[str],
|
||||
embedding_dimensions: Optional[int],
|
||||
):
|
||||
self.embedding_model = embedding_model
|
||||
self.embedding_dimensions = embedding_dimensions
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
async def get_embedding(text_):
|
||||
response = await aembedding(config.litellm_embedding_model, input=text_)
|
||||
response = await aembedding(self.embedding_model, input=text_)
|
||||
return response.data[0]['embedding']
|
||||
|
||||
tasks = [get_embedding(text_) for text_ in text]
|
||||
result = await asyncio.gather(*tasks)
|
||||
return result
|
||||
|
||||
# embedding = response.data[0].embedding
|
||||
# # embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
|
||||
# print("response", type(response.data[0]['embedding']))
|
||||
# print("response", response.data[0])
|
||||
# return [response.data[0]['embedding']]
|
||||
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
return config.litellm_embedding_dimensions
|
||||
return self.embedding_dimensions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class EmbeddingConfig(BaseSettings):
|
|||
openai_embedding_dimensions: int = 3072
|
||||
litellm_embedding_model: str = "text-embedding-3-large"
|
||||
litellm_embedding_dimensions: int = 3072
|
||||
embedding_engine:object = DefaultEmbeddingEngine()
|
||||
embedding_engine:object = DefaultEmbeddingEngine(embedding_model=openai_embedding_model, embedding_dimensions=openai_embedding_dimensions)
|
||||
|
||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue