diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index b5d780d60..31e49a676 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -3,6 +3,7 @@ import logging import math from typing import List, Optional import litellm +import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine litellm.set_verbose = False @@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): api_version: str model: str dimensions: int + mock:bool def __init__( self, @@ -28,6 +30,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): self.api_version = api_version self.model = model self.dimensions = dimensions + self.mock = os.getenv("MOCK_EMBEDDING", False).lower() in ("true", "1", "yes") MAX_RETRIES = 5 retry_count = 0 @@ -38,17 +41,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): await asyncio.sleep(wait_time) try: - response = await litellm.aembedding( - self.model, - input = text, - api_key = self.api_key, - api_base = self.endpoint, - api_version = self.api_version - ) + if self.mock: + response = { + "data": [{"embedding": [0.0] * self.dimensions} for _ in text] + } - self.retry_count = 0 + self.retry_count = 0 - return [data["embedding"] for data in response.data] + return [data["embedding"] for data in response["data"]] + else: + response = await litellm.aembedding( + self.model, + input = text, + api_key = self.api_key, + api_base = self.endpoint, + api_version = self.api_version + ) + + self.retry_count = 0 + + return [data["embedding"] for data in response.data] except litellm.exceptions.ContextWindowExceededError as error: if isinstance(text, list):