From 3e1949d895f1450a3b8a436eb9491c11e60daad8 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Thu, 28 Nov 2024 15:42:20 +0100 Subject: [PATCH] Remove unnecessary nesting in embed_text and add DummyEmbeddingEngine --- .../embeddings/LiteLLMEmbeddingEngine.py | 29 ++++++++----------- profiling/util/DummyEmbeddingEngine.py | 9 ++++++ 2 files changed, 21 insertions(+), 17 deletions(-) create mode 100644 profiling/util/DummyEmbeddingEngine.py diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index de30640e5..edc8eb57f 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -28,24 +28,19 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): self.dimensions = dimensions async def embed_text(self, text: List[str]) -> List[List[float]]: - async def get_embedding(text_): - try: - response = await litellm.aembedding( - self.model, - input = text_, - api_key = self.api_key, - api_base = self.endpoint, - api_version = self.api_version - ) - except litellm.exceptions.BadRequestError as error: - logger.error("Error embedding text: %s", str(error)) - raise error + try: + response = await litellm.aembedding( + self.model, + input = text, + api_key = self.api_key, + api_base = self.endpoint, + api_version = self.api_version + ) + except litellm.exceptions.BadRequestError as error: + logger.error("Error embedding text: %s", str(error)) + raise error - return [data["embedding"] for data in response.data] - - # tasks = [get_embedding(text_) for text_ in text] - result = await get_embedding(text) - return result + return [data["embedding"] for data in response.data] def get_vector_size(self) -> int: return self.dimensions diff --git a/profiling/util/DummyEmbeddingEngine.py b/profiling/util/DummyEmbeddingEngine.py new file mode 100644 index 000000000..7f5b3e847 --- /dev/null +++ b/profiling/util/DummyEmbeddingEngine.py @@ -0,0 +1,9 @@ +import numpy as np +from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine + +class DummyEmbeddingEngine(EmbeddingEngine): + async def embed_text(self, text: list[str]) -> list[list[float]]: + return(list(list(np.random.randn(3072)))) + + def get_vector_size(self) -> int: + return(3072)