Refactor dimensions handling in LiteLLMEmbeddingEngine

Remove dimensions initialization from constructor and add validation for dimensions if provided.
This commit is contained in:
Stony 2026-01-05 15:07:17 +08:00 committed by GitHub
parent 7028140c89
commit 76bfb3ac3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -68,7 +68,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.api_version = api_version self.api_version = api_version
self.provider = provider self.provider = provider
self.model = model self.model = model
self.dimensions = dimensions
self.max_completion_tokens = max_completion_tokens self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer() self.tokenizer = self.get_tokenizer()
self.retry_count = 0 self.retry_count = 0
@ -79,6 +78,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower() enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes") self.mock = enable_mocking in ("true", "1", "yes")
if dimensions is not None:
if not isinstance(dimensions, int) or dimensions <= 0:
raise ValueError("dimensions must be a positive integer")
self.dimensions = dimensions
@retry( @retry(
stop=stop_after_delay(128), stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128), wait=wait_exponential_jitter(2, 128),
@ -111,13 +115,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return [data["embedding"] for data in response["data"]] return [data["embedding"] for data in response["data"]]
else: else:
async with embedding_rate_limiter_context_manager(): async with embedding_rate_limiter_context_manager():
kwargs = {}
if self.dimensions is not None:
kwargs['dimensions'] = self.dimensions
response = await litellm.aembedding( response = await litellm.aembedding(
model=self.model, model=self.model,
input=text, input=text,
api_key=self.api_key, api_key=self.api_key,
api_base=self.endpoint, api_base=self.endpoint,
api_version=self.api_version, api_version=self.api_version,
dimensions=self.dimensions, **kwargs,
) )
return [data["embedding"] for data in response.data] return [data["embedding"] for data in response.data]