cognee/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py
Boris d192d1fe20
chore: remove unused dependencies and make some optional (#661)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin
2025-03-25 10:19:52 +01:00

99 lines
3.5 KiB
Python

import asyncio
import aiohttp
import logging
from typing import List, Optional
import os
import aiohttp.http_exceptions
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
logger = logging.getLogger("OllamaEmbeddingEngine")
class OllamaEmbeddingEngine(EmbeddingEngine):
model: str
dimensions: int
max_tokens: int
endpoint: str
mock: bool
huggingface_tokenizer_name: str
MAX_RETRIES = 5
def __init__(
self,
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
dimensions: Optional[int] = 1024,
max_tokens: int = 512,
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
):
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.endpoint = endpoint
self.huggingface_tokenizer_name = huggingface_tokenizer
self.tokenizer = self.get_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
async def embed_text(self, text: List[str]) -> List[List[float]]:
"""
Given a list of text prompts, returns a list of embedding vectors.
"""
if self.mock:
return [[0.0] * self.dimensions for _ in text]
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
return embeddings
async def _get_embedding(self, prompt: str) -> List[float]:
"""
Internal method to call the Ollama embeddings endpoint for a single prompt.
"""
payload = {
"model": self.model,
"prompt": prompt,
}
headers = {}
api_key = os.getenv("LLM_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
retries = 0
while retries < self.MAX_RETRIES:
try:
async with aiohttp.ClientSession() as session:
async with session.post(
self.endpoint, json=payload, headers=headers, timeout=60.0
) as response:
data = await response.json()
return data["embedding"]
except aiohttp.http_exceptions.HttpBadRequest as e:
logger.error(f"HTTP error on attempt {retries + 1}: {e}")
retries += 1
await asyncio.sleep(min(2**retries, 60))
except Exception as e:
logger.error(f"Error on attempt {retries + 1}: {e}")
retries += 1
await asyncio.sleep(min(2**retries, 60))
raise EmbeddingException(
f"Failed to embed text using model {self.model} after {self.MAX_RETRIES} retries"
)
def get_vector_size(self) -> int:
return self.dimensions
def get_tokenizer(self):
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
tokenizer = HuggingFaceTokenizer(
model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens
)
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
return tokenizer