Feature/cog 1358 local ollama model support for cognee (#555)

<!-- .github/pull_request_template.md -->

This PR contains the ollama specific llm adapter together with the
embedding engine.

Tested with the following models:

`LLM_API_KEY="ollama"
llm_model = "llama3.1:8b"
LLM_PROVIDER = "ollama"
llm_endpoint = "http://localhost:11434/v1"
EMBEDDING_PROVIDER="ollama"
EMBEDDING_MODEL="avr/sfr-embedding-mistral:latest"
EMBEDDING_ENDPOINT="http://localhost:11434/api/embeddings"
EMBEDDING_DIMENSIONS=4096
HUGGINGFACE_TOKENIZER="Salesforce/SFR-Embedding-Mistral"`

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced a new embedding option that leverages an external provider
for asynchronous text processing.
- Added enhanced language model integration using a dedicated adapter to
improve interaction quality.

- **Enhancements**
  - Expanded configuration settings to include a new tokenizer option.
- Updated provider selection logic to incorporate the additional
embedding and language model features.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
Co-authored-by: vasilije <vas.markovic@gmail.com>
This commit is contained in:
hajdul88 2025-02-19 02:54:04 +01:00 committed by GitHub
parent e98d51aac9
commit 0bcaf5c477
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 158 additions and 2 deletions

View file

@ -0,0 +1,101 @@
import asyncio
import httpx
import logging
from typing import List, Optional
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
logger = logging.getLogger("OllamaEmbeddingEngine")
class OllamaEmbeddingEngine(EmbeddingEngine):
model: str
dimensions: int
max_tokens: int
endpoint: str
mock: bool
huggingface_tokenizer_name: str
MAX_RETRIES = 5
def __init__(
self,
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
dimensions: Optional[int] = 1024,
max_tokens: int = 512,
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
):
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.endpoint = endpoint
self.huggingface_tokenizer_name = huggingface_tokenizer
self.tokenizer = self.get_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
async def embed_text(self, text: List[str]) -> List[List[float]]:
"""
Given a list of text prompts, returns a list of embedding vectors.
"""
if self.mock:
return [[0.0] * self.dimensions for _ in text]
embeddings = []
async with httpx.AsyncClient() as client:
for prompt in text:
embedding = await self._get_embedding(client, prompt)
embeddings.append(embedding)
return embeddings
async def _get_embedding(self, client: httpx.AsyncClient, prompt: str) -> List[float]:
"""
Internal method to call the Ollama embeddings endpoint for a single prompt.
"""
payload = {
"model": self.model,
"prompt": prompt,
}
headers = {}
api_key = os.getenv("LLM_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
retries = 0
while retries < self.MAX_RETRIES:
try:
response = await client.post(
self.endpoint, json=payload, headers=headers, timeout=60.0
)
response.raise_for_status()
data = response.json()
return data["embedding"]
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error on attempt {retries + 1}: {e}")
retries += 1
await asyncio.sleep(min(2**retries, 60))
except Exception as e:
logger.error(f"Error on attempt {retries + 1}: {e}")
retries += 1
await asyncio.sleep(min(2**retries, 60))
raise EmbeddingException(
f"Failed to embed text using model {self.model} after {self.MAX_RETRIES} retries"
)
def get_vector_size(self) -> int:
return self.dimensions
def get_tokenizer(self):
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
tokenizer = HuggingFaceTokenizer(
model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens
)
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
return tokenizer

View file

@ -11,6 +11,7 @@ class EmbeddingConfig(BaseSettings):
embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None
embedding_max_tokens: Optional[int] = 8191
huggingface_tokenizer: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -16,10 +16,19 @@ def get_embedding_engine() -> EmbeddingEngine:
max_tokens=config.embedding_max_tokens,
)
if config.embedding_provider == "ollama":
from .OllamaEmbeddingEngine import OllamaEmbeddingEngine
return OllamaEmbeddingEngine(
model=config.embedding_model,
dimensions=config.embedding_dimensions,
max_tokens=config.embedding_max_tokens,
huggingface_tokenizer=config.huggingface_tokenizer,
)
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
return LiteLLMEmbeddingEngine(
# If OpenAI API is used for embeddings, litellm needs only the api_key.
provider=config.embedding_provider,
api_key=config.embedding_api_key or llm_config.llm_api_key,
endpoint=config.embedding_endpoint,

View file

@ -4,6 +4,7 @@ from enum import Enum
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm import get_llm_config
from cognee.infrastructure.llm.ollama.adapter import OllamaAPIAdapter
# Define an Enum for LLM Providers
@ -52,7 +53,7 @@ def get_llm_client():
from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(
return OllamaAPIAdapter(
llm_config.llm_endpoint,
llm_config.llm_api_key,
llm_config.llm_model,

View file

@ -0,0 +1,44 @@
from typing import Type
from pydantic import BaseModel
import instructor
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config
from openai import OpenAI
class OllamaAPIAdapter(LLMInterface):
"""Adapter for a Generic API LLM provider using instructor with an OpenAI backend."""
def __init__(self, endpoint: str, api_key: str, model: str, name: str, max_tokens: int):
self.name = name
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_tokens = max_tokens
self.aclient = instructor.from_openai(
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate a structured output from the LLM using the provided text and system prompt."""
response = self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"Use the given format to extract information from the following input: {text_input}",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
response_model=response_model,
)
return response