refactor(embeddings): handle dimensions parameter in LiteLLMEmbeddingEngine
Signed-off-by: shijianglong <stonyme@vip.qq.com>
This commit is contained in:
parent
34c6652939
commit
a6f57bc8ad
2 changed files with 65 additions and 4 deletions
|
|
@ -30,6 +30,7 @@ from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
litellm.drop_params = True
|
||||||
logger = get_logger("LiteLLMEmbeddingEngine")
|
logger = get_logger("LiteLLMEmbeddingEngine")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -70,7 +71,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
self.api_version = api_version
|
self.api_version = api_version
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dimensions = dimensions
|
|
||||||
self.max_completion_tokens = max_completion_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
self.tokenizer = self.get_tokenizer()
|
self.tokenizer = self.get_tokenizer()
|
||||||
self.retry_count = 0
|
self.retry_count = 0
|
||||||
|
|
@ -81,6 +81,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
enable_mocking = str(enable_mocking).lower()
|
enable_mocking = str(enable_mocking).lower()
|
||||||
self.mock = enable_mocking in ("true", "1", "yes")
|
self.mock = enable_mocking in ("true", "1", "yes")
|
||||||
|
|
||||||
|
if dimensions is not None:
|
||||||
|
if not isinstance(dimensions, int) or dimensions <= 0:
|
||||||
|
raise ValueError("dimensions must be a positive integer")
|
||||||
|
self.dimensions = dimensions
|
||||||
|
|
||||||
# Validate provided custom embedding endpoint early to avoid long hangs later
|
# Validate provided custom embedding endpoint early to avoid long hangs later
|
||||||
if self.endpoint:
|
if self.endpoint:
|
||||||
try:
|
try:
|
||||||
|
|
@ -125,18 +130,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if self.mock:
|
if self.mock:
|
||||||
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
|
dim = self.dimensions if self.dimensions is not None else 3072
|
||||||
|
response = {"data": [{"embedding": [0.0] * dim} for _ in text]}
|
||||||
return [data["embedding"] for data in response["data"]]
|
return [data["embedding"] for data in response["data"]]
|
||||||
else:
|
else:
|
||||||
async with embedding_rate_limiter_context_manager():
|
async with embedding_rate_limiter_context_manager():
|
||||||
|
kwargs = {}
|
||||||
|
if self.dimensions is not None:
|
||||||
|
kwargs["dimensions"] = self.dimensions
|
||||||
|
|
||||||
# Ensure each attempt does not hang indefinitely
|
# Ensure each attempt does not hang indefinitely
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
litellm.aembedding(
|
litellm.aembedding(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
input=text,
|
input=text,
|
||||||
api_key=self.api_key if self.api_key and self.api_key.strip() != "" else "EMPTY",
|
api_key=self.api_key
|
||||||
|
if self.api_key and self.api_key.strip() != ""
|
||||||
|
else "EMPTY",
|
||||||
api_base=self.endpoint,
|
api_base=self.endpoint,
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
|
**kwargs,
|
||||||
),
|
),
|
||||||
timeout=30.0,
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
|
|
@ -224,7 +237,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
|
|
||||||
- int: The size (dimensionality) of the embedding vectors.
|
- int: The size (dimensionality) of the embedding vectors.
|
||||||
"""
|
"""
|
||||||
return self.dimensions
|
return self.dimensions if self.dimensions is not None else 3072
|
||||||
|
|
||||||
def get_batch_size(self) -> int:
|
def get_batch_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_embedding_custom_dimensions():
|
||||||
|
"""
|
||||||
|
Test that LiteLLMEmbeddingEngine correctly respects the 'dimensions' parameter
|
||||||
|
in mock mode.
|
||||||
|
"""
|
||||||
|
# Force mock mode for this test
|
||||||
|
with patch.dict(os.environ, {"MOCK_EMBEDDING": "true"}):
|
||||||
|
custom_dim = 1024
|
||||||
|
engine = LiteLLMEmbeddingEngine(dimensions=custom_dim)
|
||||||
|
|
||||||
|
text = ["Hello world"]
|
||||||
|
embeddings = await engine.embed_text(text)
|
||||||
|
|
||||||
|
assert len(embeddings) == 1
|
||||||
|
assert len(embeddings[0]) == custom_dim, f"Expected dimension {custom_dim}, but got {len(embeddings[0])}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_embedding_default_dimensions():
|
||||||
|
"""
|
||||||
|
Test that LiteLLMEmbeddingEngine uses the default dimension (3072)
|
||||||
|
when no dimension is provided.
|
||||||
|
"""
|
||||||
|
with patch.dict(os.environ, {"MOCK_EMBEDDING": "true"}):
|
||||||
|
engine = LiteLLMEmbeddingEngine(dimensions=None)
|
||||||
|
|
||||||
|
text = ["Hello world"]
|
||||||
|
embeddings = await engine.embed_text(text)
|
||||||
|
|
||||||
|
expected_default = 3072
|
||||||
|
assert len(embeddings) == 1
|
||||||
|
assert len(embeddings[0]) == expected_default, f"Expected default dimension {expected_default}, but got {len(embeddings[0])}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_litellm_embedding_invalid_dimensions():
|
||||||
|
"""
|
||||||
|
Test that LiteLLMEmbeddingEngine raises ValueError for invalid dimensions.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError, match="dimensions must be a positive integer"):
|
||||||
|
LiteLLMEmbeddingEngine(dimensions=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="dimensions must be a positive integer"):
|
||||||
|
LiteLLMEmbeddingEngine(dimensions=-100)
|
||||||
Loading…
Add table
Reference in a new issue