Remove unnecessary nesting in embed_text and add DummyEmbeddingEngine

This commit is contained in:
Leon Luithlen 2024-11-28 15:42:20 +01:00
parent 5c9fd44680
commit 3e1949d895
2 changed files with 21 additions and 17 deletions

View file

@ -28,24 +28,19 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.dimensions = dimensions
async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_):
try:
response = await litellm.aembedding(
self.model,
input = text_,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
except litellm.exceptions.BadRequestError as error:
logger.error("Error embedding text: %s", str(error))
raise error
try:
response = await litellm.aembedding(
self.model,
input = text,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
except litellm.exceptions.BadRequestError as error:
logger.error("Error embedding text: %s", str(error))
raise error
return [data["embedding"] for data in response.data]
# tasks = [get_embedding(text_) for text_ in text]
result = await get_embedding(text)
return result
return [data["embedding"] for data in response.data]
def get_vector_size(self) -> int:
return self.dimensions

View file

@ -0,0 +1,9 @@
import numpy as np
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
class DummyEmbeddingEngine(EmbeddingEngine):
async def embed_text(self, text: list[str]) -> list[list[float]]:
return(list(list(np.random.randn(3072))))
def get_vector_size(self) -> int:
return(3072)