refactor: add support for Ollama embedding size definition
This commit is contained in:
parent
3a6bb778e2
commit
9d73b493c8
1 changed files with 11 additions and 2 deletions
|
|
@ -57,7 +57,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
|
||||
dimensions: Optional[int] = 1024,
|
||||
max_completion_tokens: int = 512,
|
||||
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
||||
endpoint: Optional[str] = "http://localhost:11434/api/embed",
|
||||
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
||||
batch_size: int = 100,
|
||||
):
|
||||
|
|
@ -93,6 +93,10 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
if self.mock:
|
||||
return [[0.0] * self.dimensions for _ in text]
|
||||
|
||||
# Handle case when a single string is passed instead of a list
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
||||
return embeddings
|
||||
|
||||
|
|
@ -107,7 +111,12 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
"""
|
||||
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
||||
"""
|
||||
payload = {"model": self.model, "prompt": prompt, "input": prompt}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"input": prompt,
|
||||
"dimensions": self.dimensions,
|
||||
}
|
||||
|
||||
headers = {}
|
||||
api_key = os.getenv("LLM_API_KEY")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue