This commit is contained in:
Stony 2026-01-20 16:27:41 +00:00 committed by GitHub
commit 0ec98f1fc5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 74 additions and 6 deletions

View file

@ -49,7 +49,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_version: str
provider: str
model: str
dimensions: int
dimensions: Optional[int]
mock: bool
MAX_RETRIES = 5
@ -70,7 +70,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.api_version = api_version
self.provider = provider
self.model = model
self.dimensions = dimensions
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
self.retry_count = 0
@ -81,6 +80,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
if dimensions is not None:
if not isinstance(dimensions, int) or dimensions <= 0:
raise ValueError("dimensions must be a positive integer")
self.dimensions = dimensions
# Validate provided custom embedding endpoint early to avoid long hangs later
if self.endpoint:
try:
@ -125,18 +129,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
"""
try:
if self.mock:
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
dim = self.dimensions if self.dimensions is not None else 3072
response = {"data": [{"embedding": [0.0] * dim} for _ in text]}
return [data["embedding"] for data in response["data"]]
else:
async with embedding_rate_limiter_context_manager():
kwargs = {}
if self.dimensions is not None:
kwargs["dimensions"] = self.dimensions
# Ensure each attempt does not hang indefinitely
response = await asyncio.wait_for(
litellm.aembedding(
model=self.model,
input=text,
api_key=self.api_key,
api_key=self.api_key
if self.api_key and self.api_key.strip() != ""
else "EMPTY",
api_base=self.endpoint,
api_version=self.api_version,
**kwargs,
),
timeout=30.0,
)
@ -224,7 +236,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
- int: The size (dimensionality) of the embedding vectors.
"""
return self.dimensions
return self.dimensions if self.dimensions is not None else 3072
def get_batch_size(self) -> int:
"""
@ -280,4 +292,4 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer
return tokenizer

View file

@ -0,0 +1,56 @@
import os
from unittest.mock import patch
import pytest
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
@pytest.mark.asyncio
async def test_litellm_embedding_custom_dimensions():
"""
Test that LiteLLMEmbeddingEngine correctly respects the 'dimensions' parameter
in mock mode.
"""
# Force mock mode for this test
with patch.dict(os.environ, {"MOCK_EMBEDDING": "true"}):
custom_dim = 1024
engine = LiteLLMEmbeddingEngine(dimensions=custom_dim)
text = ["Hello world"]
embeddings = await engine.embed_text(text)
assert len(embeddings) == 1
assert len(embeddings[0]) == custom_dim, f"Expected dimension {custom_dim}, but got {len(embeddings[0])}"
@pytest.mark.asyncio
async def test_litellm_embedding_default_dimensions():
"""
Test that LiteLLMEmbeddingEngine uses the default dimension (3072)
when no dimension is provided.
"""
with patch.dict(os.environ, {"MOCK_EMBEDDING": "true"}):
engine = LiteLLMEmbeddingEngine(dimensions=None)
text = ["Hello world"]
embeddings = await engine.embed_text(text)
expected_default = 3072
assert len(embeddings) == 1
assert len(embeddings[0]) == expected_default, f"Expected default dimension {expected_default}, but got {len(embeddings[0])}"
@pytest.mark.asyncio
async def test_litellm_embedding_invalid_dimensions():
"""
Test that LiteLLMEmbeddingEngine raises ValueError for invalid dimensions.
"""
with pytest.raises(ValueError, match="dimensions must be a positive integer"):
LiteLLMEmbeddingEngine(dimensions=0)
with pytest.raises(ValueError, match="dimensions must be a positive integer"):
LiteLLMEmbeddingEngine(dimensions=-100)
with pytest.raises(ValueError, match="dimensions must be a positive integer"):
LiteLLMEmbeddingEngine(dimensions="1024") # type: ignore
with pytest.raises(ValueError, match="dimensions must be a positive integer"):
LiteLLMEmbeddingEngine(dimensions=1024.5) # type: ignore