This commit is contained in:
Vasilije 2026-01-20 21:30:11 +01:00 committed by GitHub
commit 93b6ae027a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 9 deletions

View file

@ -129,15 +129,20 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return [data["embedding"] for data in response["data"]]
else:
async with embedding_rate_limiter_context_manager():
embedding_kwargs = {
"model": self.model,
"input": text,
"api_key": self.api_key,
"api_base": self.endpoint,
"api_version": self.api_version,
}
# Pass through target embedding dimensions when supported
if self.dimensions is not None:
embedding_kwargs["dimensions"] = self.dimensions
# Ensure each attempt does not hang indefinitely
response = await asyncio.wait_for(
litellm.aembedding(
model=self.model,
input=text,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
),
litellm.aembedding(**embedding_kwargs),
timeout=30.0,
)

View file

@ -57,7 +57,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
dimensions: Optional[int] = 1024,
max_completion_tokens: int = 512,
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
endpoint: Optional[str] = "http://localhost:11434/api/embed",
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
batch_size: int = 100,
):
@ -93,6 +93,10 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
if self.mock:
return [[0.0] * self.dimensions for _ in text]
# Handle case when a single string is passed instead of a list
if not isinstance(text, list):
text = [text]
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
return embeddings
@ -107,7 +111,12 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
"""
Internal method to call the Ollama embeddings endpoint for a single prompt.
"""
payload = {"model": self.model, "prompt": prompt, "input": prompt}
payload = {
"model": self.model,
"prompt": prompt,
"input": prompt,
"dimensions": self.dimensions,
}
headers = {}
api_key = os.getenv("LLM_API_KEY")
@ -124,6 +133,8 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
data = await response.json()
if "embeddings" in data:
return data["embeddings"][0]
if "embedding" in data:
return data["embedding"]
else:
return data["data"][0]["embedding"]