diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 28fa22b4f..e1bad3a8c 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -30,6 +30,7 @@ from cognee.infrastructure.llm.tokenizer.TikToken import ( from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager litellm.set_verbose = False +litellm.drop_params = True logger = get_logger("LiteLLMEmbeddingEngine") @@ -70,7 +71,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): self.api_version = api_version self.provider = provider self.model = model - self.dimensions = dimensions self.max_completion_tokens = max_completion_tokens self.tokenizer = self.get_tokenizer() self.retry_count = 0 @@ -81,6 +81,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): enable_mocking = str(enable_mocking).lower() self.mock = enable_mocking in ("true", "1", "yes") + if dimensions is not None: + if not isinstance(dimensions, int) or dimensions <= 0: + raise ValueError("dimensions must be a positive integer") + self.dimensions = dimensions + # Validate provided custom embedding endpoint early to avoid long hangs later if self.endpoint: try: @@ -125,18 +130,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): """ try: if self.mock: - response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]} + dim = self.dimensions if self.dimensions is not None else 3072 + response = {"data": [{"embedding": [0.0] * dim} for _ in text]} return [data["embedding"] for data in response["data"]] else: async with embedding_rate_limiter_context_manager(): + kwargs = {} + if self.dimensions is not None: + kwargs["dimensions"] = self.dimensions + # Ensure each attempt does not hang indefinitely response = await asyncio.wait_for( litellm.aembedding( model=self.model, input=text, - api_key=self.api_key if self.api_key and self.api_key.strip() != "" else "EMPTY", + api_key=self.api_key + if self.api_key and self.api_key.strip() != "" + else "EMPTY", api_base=self.endpoint, api_version=self.api_version, + **kwargs, ), timeout=30.0, ) @@ -224,7 +237,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): - int: The size (dimensionality) of the embedding vectors. """ - return self.dimensions + return self.dimensions if self.dimensions is not None else 3072 def get_batch_size(self) -> int: """ diff --git a/cognee/tests/unit/infrastructure/test_litellm_embedding_dimensions.py b/cognee/tests/unit/infrastructure/test_litellm_embedding_dimensions.py new file mode 100644 index 000000000..af52d27d6 --- /dev/null +++ b/cognee/tests/unit/infrastructure/test_litellm_embedding_dimensions.py @@ -0,0 +1,48 @@ +import pytest +import os +from unittest.mock import patch +from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine + +@pytest.mark.asyncio +async def test_litellm_embedding_custom_dimensions(): + """ + Test that LiteLLMEmbeddingEngine correctly respects the 'dimensions' parameter + in mock mode. + """ + # Force mock mode for this test + with patch.dict(os.environ, {"MOCK_EMBEDDING": "true"}): + custom_dim = 1024 + engine = LiteLLMEmbeddingEngine(dimensions=custom_dim) + + text = ["Hello world"] + embeddings = await engine.embed_text(text) + + assert len(embeddings) == 1 + assert len(embeddings[0]) == custom_dim, f"Expected dimension {custom_dim}, but got {len(embeddings[0])}" + +@pytest.mark.asyncio +async def test_litellm_embedding_default_dimensions(): + """ + Test that LiteLLMEmbeddingEngine uses the default dimension (3072) + when no dimension is provided. + """ + with patch.dict(os.environ, {"MOCK_EMBEDDING": "true"}): + engine = LiteLLMEmbeddingEngine(dimensions=None) + + text = ["Hello world"] + embeddings = await engine.embed_text(text) + + expected_default = 3072 + assert len(embeddings) == 1 + assert len(embeddings[0]) == expected_default, f"Expected default dimension {expected_default}, but got {len(embeddings[0])}" + +@pytest.mark.asyncio +async def test_litellm_embedding_invalid_dimensions(): + """ + Test that LiteLLMEmbeddingEngine raises ValueError for invalid dimensions. + """ + with pytest.raises(ValueError, match="dimensions must be a positive integer"): + LiteLLMEmbeddingEngine(dimensions=0) + + with pytest.raises(ValueError, match="dimensions must be a positive integer"): + LiteLLMEmbeddingEngine(dimensions=-100)