Remove unnecessary nesting in embed_text and add DummyEmbeddingEngine
This commit is contained in:
parent
5c9fd44680
commit
3e1949d895
2 changed files with 21 additions and 17 deletions
|
|
@ -28,24 +28,19 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.dimensions = dimensions
|
||||
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
async def get_embedding(text_):
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
input = text_,
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version
|
||||
)
|
||||
except litellm.exceptions.BadRequestError as error:
|
||||
logger.error("Error embedding text: %s", str(error))
|
||||
raise error
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
input = text,
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version
|
||||
)
|
||||
except litellm.exceptions.BadRequestError as error:
|
||||
logger.error("Error embedding text: %s", str(error))
|
||||
raise error
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
# tasks = [get_embedding(text_) for text_ in text]
|
||||
result = await get_embedding(text)
|
||||
return result
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
return self.dimensions
|
||||
|
|
|
|||
9
profiling/util/DummyEmbeddingEngine.py
Normal file
9
profiling/util/DummyEmbeddingEngine.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import numpy as np
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
class DummyEmbeddingEngine(EmbeddingEngine):
|
||||
async def embed_text(self, text: list[str]) -> list[list[float]]:
|
||||
return(list(list(np.random.randn(3072))))
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
return(3072)
|
||||
Loading…
Add table
Reference in a new issue