diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 93e3dd26a..3a08b7eaf 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -68,7 +68,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): self.api_version = api_version self.provider = provider self.model = model - self.dimensions = dimensions self.max_completion_tokens = max_completion_tokens self.tokenizer = self.get_tokenizer() self.retry_count = 0 @@ -78,6 +77,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): if isinstance(enable_mocking, bool): enable_mocking = str(enable_mocking).lower() self.mock = enable_mocking in ("true", "1", "yes") + + if dimensions is not None: + if not isinstance(dimensions, int) or dimensions <= 0: + raise ValueError("dimensions must be a positive integer") + self.dimensions = dimensions @retry( stop=stop_after_delay(128), @@ -111,13 +115,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): return [data["embedding"] for data in response["data"]] else: async with embedding_rate_limiter_context_manager(): + kwargs = {} + if self.dimensions is not None: + kwargs['dimensions'] = self.dimensions response = await litellm.aembedding( model=self.model, input=text, api_key=self.api_key, api_base=self.endpoint, api_version=self.api_version, - dimensions=self.dimensions, + **kwargs, ) return [data["embedding"] for data in response.data]