Feature/cog 1358 local ollama model support for cognee (#555)
<!-- .github/pull_request_template.md --> This PR contains the ollama specific llm adapter together with the embedding engine. Tested with the following models: `LLM_API_KEY="ollama" llm_model = "llama3.1:8b" LLM_PROVIDER = "ollama" llm_endpoint = "http://localhost:11434/v1" EMBEDDING_PROVIDER="ollama" EMBEDDING_MODEL="avr/sfr-embedding-mistral:latest" EMBEDDING_ENDPOINT="http://localhost:11434/api/embeddings" EMBEDDING_DIMENSIONS=4096 HUGGINGFACE_TOKENIZER="Salesforce/SFR-Embedding-Mistral"` ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new embedding option that leverages an external provider for asynchronous text processing. - Added enhanced language model integration using a dedicated adapter to improve interaction quality. - **Enhancements** - Expanded configuration settings to include a new tokenizer option. - Updated provider selection logic to incorporate the additional embedding and language model features. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Co-authored-by: vasilije <vas.markovic@gmail.com>
This commit is contained in:
parent
e98d51aac9
commit
0bcaf5c477
6 changed files with 158 additions and 2 deletions
|
|
@ -0,0 +1,101 @@
|
|||
import asyncio
|
||||
import httpx
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
import os
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||
|
||||
logger = logging.getLogger("OllamaEmbeddingEngine")
|
||||
|
||||
|
||||
class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||
model: str
|
||||
dimensions: int
|
||||
max_tokens: int
|
||||
endpoint: str
|
||||
mock: bool
|
||||
huggingface_tokenizer_name: str
|
||||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = "avr/sfr-embedding-mistral:latest",
|
||||
dimensions: Optional[int] = 1024,
|
||||
max_tokens: int = 512,
|
||||
endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
|
||||
huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
|
||||
):
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_tokens = max_tokens
|
||||
self.endpoint = endpoint
|
||||
self.huggingface_tokenizer_name = huggingface_tokenizer
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
enable_mocking = str(enable_mocking).lower()
|
||||
self.mock = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Given a list of text prompts, returns a list of embedding vectors.
|
||||
"""
|
||||
if self.mock:
|
||||
return [[0.0] * self.dimensions for _ in text]
|
||||
|
||||
embeddings = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
for prompt in text:
|
||||
embedding = await self._get_embedding(client, prompt)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
async def _get_embedding(self, client: httpx.AsyncClient, prompt: str) -> List[float]:
|
||||
"""
|
||||
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
||||
"""
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
}
|
||||
headers = {}
|
||||
api_key = os.getenv("LLM_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
retries = 0
|
||||
while retries < self.MAX_RETRIES:
|
||||
try:
|
||||
response = await client.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["embedding"]
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error on attempt {retries + 1}: {e}")
|
||||
retries += 1
|
||||
await asyncio.sleep(min(2**retries, 60))
|
||||
except Exception as e:
|
||||
logger.error(f"Error on attempt {retries + 1}: {e}")
|
||||
retries += 1
|
||||
await asyncio.sleep(min(2**retries, 60))
|
||||
raise EmbeddingException(
|
||||
f"Failed to embed text using model {self.model} after {self.MAX_RETRIES} retries"
|
||||
)
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
return self.dimensions
|
||||
|
||||
def get_tokenizer(self):
|
||||
logger.debug("Loading HuggingfaceTokenizer for OllamaEmbeddingEngine...")
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
model=self.huggingface_tokenizer_name, max_tokens=self.max_tokens
|
||||
)
|
||||
logger.debug("Tokenizer loaded for OllamaEmbeddingEngine")
|
||||
return tokenizer
|
||||
|
|
@ -11,6 +11,7 @@ class EmbeddingConfig(BaseSettings):
|
|||
embedding_api_key: Optional[str] = None
|
||||
embedding_api_version: Optional[str] = None
|
||||
embedding_max_tokens: Optional[int] = 8191
|
||||
huggingface_tokenizer: Optional[str] = None
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,10 +16,19 @@ def get_embedding_engine() -> EmbeddingEngine:
|
|||
max_tokens=config.embedding_max_tokens,
|
||||
)
|
||||
|
||||
if config.embedding_provider == "ollama":
|
||||
from .OllamaEmbeddingEngine import OllamaEmbeddingEngine
|
||||
|
||||
return OllamaEmbeddingEngine(
|
||||
model=config.embedding_model,
|
||||
dimensions=config.embedding_dimensions,
|
||||
max_tokens=config.embedding_max_tokens,
|
||||
huggingface_tokenizer=config.huggingface_tokenizer,
|
||||
)
|
||||
|
||||
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||
|
||||
return LiteLLMEmbeddingEngine(
|
||||
# If OpenAI API is used for embeddings, litellm needs only the api_key.
|
||||
provider=config.embedding_provider,
|
||||
api_key=config.embedding_api_key or llm_config.llm_api_key,
|
||||
endpoint=config.embedding_endpoint,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from enum import Enum
|
|||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm import get_llm_config
|
||||
from cognee.infrastructure.llm.ollama.adapter import OllamaAPIAdapter
|
||||
|
||||
|
||||
# Define an Enum for LLM Providers
|
||||
|
|
@ -52,7 +53,7 @@ def get_llm_client():
|
|||
|
||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||
|
||||
return GenericAPIAdapter(
|
||||
return OllamaAPIAdapter(
|
||||
llm_config.llm_endpoint,
|
||||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
|
|
|
|||
0
cognee/infrastructure/llm/ollama/__init__.py
Normal file
0
cognee/infrastructure/llm/ollama/__init__.py
Normal file
44
cognee/infrastructure/llm/ollama/adapter.py
Normal file
44
cognee/infrastructure/llm/ollama/adapter.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Type
|
||||
from pydantic import BaseModel
|
||||
import instructor
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class OllamaAPIAdapter(LLMInterface):
|
||||
"""Adapter for a Generic API LLM provider using instructor with an OpenAI backend."""
|
||||
|
||||
def __init__(self, endpoint: str, api_key: str, model: str, name: str, max_tokens: int):
|
||||
self.name = name
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self.aclient = instructor.from_openai(
|
||||
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
||||
)
|
||||
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""Generate a structured output from the LLM using the provided text and system prompt."""
|
||||
|
||||
response = self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Use the given format to extract information from the following input: {text_input}",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return response
|
||||
Loading…
Add table
Reference in a new issue