Refactor dimensions handling in LiteLLMEmbeddingEngine
Remove dimensions initialization from constructor and add validation for dimensions if provided.
This commit is contained in:
parent
7028140c89
commit
76bfb3ac3e
1 changed files with 9 additions and 2 deletions
|
|
@ -68,7 +68,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.api_version = api_version
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.retry_count = 0
|
||||
|
|
@ -78,6 +77,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
if isinstance(enable_mocking, bool):
|
||||
enable_mocking = str(enable_mocking).lower()
|
||||
self.mock = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
if dimensions is not None:
|
||||
if not isinstance(dimensions, int) or dimensions <= 0:
|
||||
raise ValueError("dimensions must be a positive integer")
|
||||
self.dimensions = dimensions
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
|
|
@ -111,13 +115,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
return [data["embedding"] for data in response["data"]]
|
||||
else:
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
kwargs = {}
|
||||
if self.dimensions is not None:
|
||||
kwargs['dimensions'] = self.dimensions
|
||||
response = await litellm.aembedding(
|
||||
model=self.model,
|
||||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
dimensions=self.dimensions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue