diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 558b11538..8049e763f 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -129,15 +129,20 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): return [data["embedding"] for data in response["data"]] else: async with embedding_rate_limiter_context_manager(): + embedding_kwargs = { + "model": self.model, + "input": text, + "api_key": self.api_key, + "api_base": self.endpoint, + "api_version": self.api_version, + } + # Pass through target embedding dimensions when supported + if self.dimensions is not None: + embedding_kwargs["dimensions"] = self.dimensions + # Ensure each attempt does not hang indefinitely response = await asyncio.wait_for( - litellm.aembedding( - model=self.model, - input=text, - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - ), + litellm.aembedding(**embedding_kwargs), timeout=30.0, ) diff --git a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py index 1d5e7fbfe..6afd056de 100644 --- a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py @@ -57,7 +57,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine): model: Optional[str] = "avr/sfr-embedding-mistral:latest", dimensions: Optional[int] = 1024, max_completion_tokens: int = 512, - endpoint: Optional[str] = "http://localhost:11434/api/embeddings", + endpoint: Optional[str] = "http://localhost:11434/api/embed", huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral", batch_size: int = 100, ): @@ -93,6 +93,10 @@ class OllamaEmbeddingEngine(EmbeddingEngine): if self.mock: return [[0.0] * self.dimensions for _ in text] + # Handle case when a single string is passed instead of a list + if not isinstance(text, list): + text = [text] + embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text]) return embeddings @@ -107,7 +111,12 @@ class OllamaEmbeddingEngine(EmbeddingEngine): """ Internal method to call the Ollama embeddings endpoint for a single prompt. """ - payload = {"model": self.model, "prompt": prompt, "input": prompt} + payload = { + "model": self.model, + "prompt": prompt, + "input": prompt, + "dimensions": self.dimensions, + } headers = {} api_key = os.getenv("LLM_API_KEY") @@ -124,6 +133,8 @@ class OllamaEmbeddingEngine(EmbeddingEngine): data = await response.json() if "embeddings" in data: return data["embeddings"][0] + if "embedding" in data: + return data["embedding"] else: return data["data"][0]["embedding"]