Merge 38f0ab865b into 2ef347f8fa
This commit is contained in:
commit
93b6ae027a
2 changed files with 25 additions and 9 deletions
|
|
@ -129,15 +129,20 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
return [data["embedding"] for data in response["data"]]
|
return [data["embedding"] for data in response["data"]]
|
||||||
else:
|
else:
|
||||||
async with embedding_rate_limiter_context_manager():
|
async with embedding_rate_limiter_context_manager():
|
||||||
|
embedding_kwargs = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": text,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"api_base": self.endpoint,
|
||||||
|
"api_version": self.api_version,
|
||||||
|
}
|
||||||
|
# Pass through target embedding dimensions when supported
|
||||||
|
if self.dimensions is not None:
|
||||||
|
embedding_kwargs["dimensions"] = self.dimensions
|
||||||
|
|
||||||
# Ensure each attempt does not hang indefinitely
|
# Ensure each attempt does not hang indefinitely
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
litellm.aembedding(
|
litellm.aembedding(**embedding_kwargs),
|
||||||
model=self.model,
|
|
||||||
input=text,
|
|
||||||
api_key=self.api_key,
|
|
||||||
api_base=self.endpoint,
|
|
||||||
api_version=self.api_version,
|
|
||||||
),
|
|
||||||
timeout=30.0,
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
|
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
|
||||||
dimensions: Optional[int] = 1024,
|
dimensions: Optional[int] = 1024,
|
||||||
max_completion_tokens: int = 512,
|
max_completion_tokens: int = 512,
|
||||||
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
endpoint: Optional[str] = "http://localhost:11434/api/embed",
|
||||||
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
||||||
batch_size: int = 100,
|
batch_size: int = 100,
|
||||||
):
|
):
|
||||||
|
|
@ -93,6 +93,10 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
return [[0.0] * self.dimensions for _ in text]
|
return [[0.0] * self.dimensions for _ in text]
|
||||||
|
|
||||||
|
# Handle case when a single string is passed instead of a list
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
@ -107,7 +111,12 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
"""
|
"""
|
||||||
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
||||||
"""
|
"""
|
||||||
payload = {"model": self.model, "prompt": prompt, "input": prompt}
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"input": prompt,
|
||||||
|
"dimensions": self.dimensions,
|
||||||
|
}
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
api_key = os.getenv("LLM_API_KEY")
|
api_key = os.getenv("LLM_API_KEY")
|
||||||
|
|
@ -124,6 +133,8 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
if "embeddings" in data:
|
if "embeddings" in data:
|
||||||
return data["embeddings"][0]
|
return data["embeddings"][0]
|
||||||
|
if "embedding" in data:
|
||||||
|
return data["embedding"]
|
||||||
else:
|
else:
|
||||||
return data["data"][0]["embedding"]
|
return data["data"][0]["embedding"]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue