From a47cfe7272f228349e0e2d1d5639ebbbd6d1e663 Mon Sep 17 00:00:00 2001 From: vasilije Date: Sun, 28 Dec 2025 21:37:22 +0100 Subject: [PATCH 1/3] added dimensions in a simple way --- .../embeddings/LiteLLMEmbeddingEngine.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 12de57617..3b8d356fd 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -111,13 +111,18 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): return [data["embedding"] for data in response["data"]] else: async with embedding_rate_limiter_context_manager(): - response = await litellm.aembedding( - model=self.model, - input=text, - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - ) + embedding_kwargs = { + "model": self.model, + "input": text, + "api_key": self.api_key, + "api_base": self.endpoint, + "api_version": self.api_version, + } + # Pass through target embedding dimensions when supported + if self.dimensions is not None: + embedding_kwargs["dimensions"] = self.dimensions + + response = await litellm.aembedding(**embedding_kwargs) return [data["embedding"] for data in response.data] From 3a6bb778e2a386ca2be4ffc85f437086e786f123 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 20 Jan 2026 20:50:59 +0100 Subject: [PATCH 2/3] refactor: Handle embedding case in data --- .../databases/vector/embeddings/OllamaEmbeddingEngine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py index 1d5e7fbfe..74d33a76c 100644 --- a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py @@ -124,6 +124,8 @@ class OllamaEmbeddingEngine(EmbeddingEngine): data = await response.json() if "embeddings" in data: return data["embeddings"][0] + if "embedding" in data: + return data["embedding"] else: return data["data"][0]["embedding"] From 9d73b493c8cd26251c1eb4407b5ad49a32baa8b7 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 20 Jan 2026 21:29:29 +0100 Subject: [PATCH 3/3] refactor: add support for Ollama embedding size definition --- .../vector/embeddings/OllamaEmbeddingEngine.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py index 74d33a76c..6afd056de 100644 --- a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py @@ -57,7 +57,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine): model: Optional[str] = "avr/sfr-embedding-mistral:latest", dimensions: Optional[int] = 1024, max_completion_tokens: int = 512, - endpoint: Optional[str] = "http://localhost:11434/api/embeddings", + endpoint: Optional[str] = "http://localhost:11434/api/embed", huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral", batch_size: int = 100, ): @@ -93,6 +93,10 @@ class OllamaEmbeddingEngine(EmbeddingEngine): if self.mock: return [[0.0] * self.dimensions for _ in text] + # Handle case when a single string is passed instead of a list + if not isinstance(text, list): + text = [text] + embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text]) return embeddings @@ -107,7 +111,12 @@ class OllamaEmbeddingEngine(EmbeddingEngine): """ Internal method to call the Ollama embeddings endpoint for a single prompt. """ - payload = {"model": self.model, "prompt": prompt, "input": prompt} + payload = { + "model": self.model, + "prompt": prompt, + "input": prompt, + "dimensions": self.dimensions, + } headers = {} api_key = os.getenv("LLM_API_KEY")