Remove unnecessary nesting in embed_text and add DummyEmbeddingEngine
This commit is contained in:
parent
5c9fd44680
commit
3e1949d895
2 changed files with 21 additions and 17 deletions
|
|
@ -28,24 +28,19 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
|
|
||||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||||
async def get_embedding(text_):
|
try:
|
||||||
try:
|
response = await litellm.aembedding(
|
||||||
response = await litellm.aembedding(
|
self.model,
|
||||||
self.model,
|
input = text,
|
||||||
input = text_,
|
api_key = self.api_key,
|
||||||
api_key = self.api_key,
|
api_base = self.endpoint,
|
||||||
api_base = self.endpoint,
|
api_version = self.api_version
|
||||||
api_version = self.api_version
|
)
|
||||||
)
|
except litellm.exceptions.BadRequestError as error:
|
||||||
except litellm.exceptions.BadRequestError as error:
|
logger.error("Error embedding text: %s", str(error))
|
||||||
logger.error("Error embedding text: %s", str(error))
|
raise error
|
||||||
raise error
|
|
||||||
|
|
||||||
return [data["embedding"] for data in response.data]
|
return [data["embedding"] for data in response.data]
|
||||||
|
|
||||||
# tasks = [get_embedding(text_) for text_ in text]
|
|
||||||
result = await get_embedding(text)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
def get_vector_size(self) -> int:
|
||||||
return self.dimensions
|
return self.dimensions
|
||||||
|
|
|
||||||
9
profiling/util/DummyEmbeddingEngine.py
Normal file
9
profiling/util/DummyEmbeddingEngine.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
import numpy as np
|
||||||
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
|
class DummyEmbeddingEngine(EmbeddingEngine):
|
||||||
|
async def embed_text(self, text: list[str]) -> list[list[float]]:
|
||||||
|
return(list(list(np.random.randn(3072))))
|
||||||
|
|
||||||
|
def get_vector_size(self) -> int:
|
||||||
|
return(3072)
|
||||||
Loading…
Add table
Reference in a new issue