cognee/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py

64 lines
2.2 KiB
Python

import asyncio
from typing import List
import instructor
from openai import AsyncOpenAI
from fastembed import TextEmbedding
from cognee.config import Config
from cognee.root_dir import get_absolute_path
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from litellm import aembedding
import litellm
litellm.set_verbose = True
config = Config()
config.load()
class DefaultEmbeddingEngine(EmbeddingEngine):
async def embed_text(self, text: List[str]) -> List[float]:
embedding_model = TextEmbedding(model_name = config.embedding_model, cache_dir = get_absolute_path("cache/embeddings"))
embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
return embeddings_list
def get_vector_size(self) -> int:
return config.embedding_dimensions
class LiteLLMEmbeddingEngine(EmbeddingEngine):
import asyncio
from typing import List
async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_):
response = await aembedding(config.litellm_embedding_model, input=text_)
return response.data[0]['embedding']
tasks = [get_embedding(text_) for text_ in text]
result = await asyncio.gather(*tasks)
return result
# embedding = response.data[0].embedding
# # embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text)))
# print("response", type(response.data[0]['embedding']))
# print("response", response.data[0])
# return [response.data[0]['embedding']]
def get_vector_size(self) -> int:
return config.litellm_embedding_dimensions
if __name__ == "__main__":
async def gg():
openai_embedding_engine = LiteLLMEmbeddingEngine()
# print(openai_embedding_engine.embed_text(["Hello, how are you?"]))
# print(openai_embedding_engine.get_vector_size())
# default_embedding_engine = DefaultEmbeddingEngine()
sds = await openai_embedding_engine.embed_text(["Hello, sadasdas are you?"])
print(sds)
# print(default_embedding_engine.get_vector_size())
asyncio.run(gg())