support ollama
This commit is contained in:
parent
61e9da13e4
commit
03b3cf2b4e
4 changed files with 254 additions and 0 deletions
131
graphiti_core/embedder/ollama.py
Normal file
131
graphiti_core/embedder/ollama.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import Field
|
||||
|
||||
from .client import EmbedderClient, EmbedderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_EMBEDDING_MODEL = 'nomic-embed-text'
|
||||
DEFAULT_BASE_URL = 'http://localhost:11434'
|
||||
|
||||
|
||||
class OllamaEmbedderConfig(EmbedderConfig):
|
||||
embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
|
||||
base_url: str = Field(default=DEFAULT_BASE_URL)
|
||||
|
||||
|
||||
class OllamaEmbedder(EmbedderClient):
|
||||
"""
|
||||
Ollama Embedder Client
|
||||
|
||||
Uses Ollama's native API endpoint for embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OllamaEmbedderConfig | None = None):
|
||||
if config is None:
|
||||
config = OllamaEmbedderConfig()
|
||||
self.config = config
|
||||
self.base_url = config.base_url.rstrip('/')
|
||||
self.embed_url = f"{self.base_url}/api/embed"
|
||||
|
||||
async def create(
|
||||
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||
) -> list[float]:
|
||||
"""
|
||||
Create embeddings for the given input data using Ollama's embedding model.
|
||||
|
||||
Args:
|
||||
input_data: The input data to create embeddings for. Can be a string, list of strings,
|
||||
or an iterable of integers or iterables of integers.
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding vector.
|
||||
"""
|
||||
# Convert input to string if needed
|
||||
if isinstance(input_data, str):
|
||||
text_input = input_data
|
||||
elif isinstance(input_data, list) and len(input_data) > 0:
|
||||
if isinstance(input_data[0], str):
|
||||
# For list of strings, take the first one for single embedding
|
||||
text_input = input_data[0]
|
||||
else:
|
||||
# Convert other types to string
|
||||
text_input = str(input_data[0])
|
||||
else:
|
||||
text_input = str(input_data)
|
||||
|
||||
payload = {
|
||||
"model": self.config.embedding_model,
|
||||
"input": text_input
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.embed_url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
raise Exception(f"Ollama API error {response.status_code}: {error_text}")
|
||||
|
||||
result = response.json()
|
||||
|
||||
if "embeddings" not in result:
|
||||
raise Exception(f"No embeddings in response: {result}")
|
||||
|
||||
embeddings = result["embeddings"]
|
||||
if not embeddings or len(embeddings) == 0:
|
||||
raise Exception("Empty embeddings returned")
|
||||
|
||||
# Return the first embedding, truncated to the configured dimension
|
||||
embedding = embeddings[0]
|
||||
return embedding[: self.config.embedding_dim]
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error creating Ollama embedding: {e.response.status_code} - {e.response.text}")
|
||||
raise Exception(f"Ollama API error {e.response.status_code}: {e.response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Ollama embedding: {e}")
|
||||
raise
|
||||
|
||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Create batch embeddings using Ollama's embedding model.
|
||||
|
||||
Note: Ollama doesn't support batch embeddings natively, so we process them sequentially.
|
||||
"""
|
||||
embeddings = []
|
||||
|
||||
for text in input_data_list:
|
||||
try:
|
||||
embedding = await self.create(text)
|
||||
embeddings.append(embedding)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating embedding for text '{text[:50]}...': {e}")
|
||||
raise
|
||||
|
||||
return embeddings
|
||||
44
mcp_server/.env.example.gemini_ollama
Normal file
44
mcp_server/.env.example.gemini_ollama
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Graphiti MCP Server Environment Configuration
|
||||
|
||||
# Neo4j Database Configuration
|
||||
# These settings are used to connect to your Neo4j database
|
||||
NEO4J_URI=bolt://localhost:7687
|
||||
NEO4J_USER=neo4j
|
||||
NEO4J_PASSWORD=demodemo
|
||||
|
||||
# OpenAI API Configuration
|
||||
# Required for LLM operations
|
||||
OPENAI_API_KEY=your_gemini_api_key_here
|
||||
MODEL_NAME=gemini-2.5-flash
|
||||
SMALL_MODEL_NAME=gemini-2.5-flash
|
||||
|
||||
# Optional: Only needed for non-standard OpenAI endpoints
|
||||
OPENAI_BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
||||
|
||||
# Embedder Configuration
|
||||
# Optional: Separate API key and URL for embedder (falls back to OPENAI_API_KEY and OPENAI_BASE_URL if not set)
|
||||
# Note: OpenRouter does not support embeddings API, using Ollama as free alternative
|
||||
EMBEDDER_API_KEY=ollama
|
||||
EMBEDDER_BASE_URL=http://localhost:11434
|
||||
EMBEDDER_MODEL_NAME=nomic-embed-text
|
||||
EMBEDDER_DIMENSION=768
|
||||
|
||||
# Optional: Group ID for namespacing graph data
|
||||
# GROUP_ID=my_project
|
||||
|
||||
# Optional: Path configuration for Docker
|
||||
# PATH=/root/.local/bin:${PATH}
|
||||
|
||||
# Optional: Memory settings for Neo4j (used in Docker Compose)
|
||||
# NEO4J_server_memory_heap_initial__size=512m
|
||||
# NEO4J_server_memory_heap_max__size=1G
|
||||
# NEO4J_server_memory_pagecache_size=512m
|
||||
|
||||
# Azure OpenAI configuration
|
||||
# Optional: Only needed for Azure OpenAI endpoints
|
||||
# AZURE_OPENAI_ENDPOINT=your_azure_openai_endpoint_here
|
||||
# AZURE_OPENAI_API_VERSION=2025-01-01-preview
|
||||
# AZURE_OPENAI_DEPLOYMENT_NAME=gpt-4o-gpt-4o-mini-deployment
|
||||
# AZURE_OPENAI_EMBEDDING_API_VERSION=2023-05-15
|
||||
# AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME=text-embedding-3-large-deployment
|
||||
# AZURE_OPENAI_USE_MANAGED_IDENTITY=false
|
||||
43
mcp_server/.env.example.openrouter_ollama
Normal file
43
mcp_server/.env.example.openrouter_ollama
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# Graphiti MCP Server Environment Configuration
|
||||
|
||||
# Neo4j Database Configuration
|
||||
# These settings are used to connect to your Neo4j database
|
||||
NEO4J_URI=bolt://localhost:7687
|
||||
NEO4J_USER=neo4j
|
||||
NEO4J_PASSWORD=demodemo
|
||||
|
||||
# OpenAI API Configuration
|
||||
# Required for LLM operations
|
||||
OPENAI_API_KEY=your_open_router_api_key_here
|
||||
MODEL_NAME=gpt-4.1-mini
|
||||
|
||||
# Optional: Only needed for non-standard OpenAI endpoints
|
||||
OPENAI_BASE_URL=https://openrouter.ai/api/v1
|
||||
|
||||
# Embedder Configuration
|
||||
# Optional: Separate API key and URL for embedder (falls back to OPENAI_API_KEY and OPENAI_BASE_URL if not set)
|
||||
# Note: OpenRouter does not support embeddings API, using Ollama as free alternative
|
||||
EMBEDDER_API_KEY=ollama
|
||||
EMBEDDER_BASE_URL=http://localhost:11434
|
||||
EMBEDDER_MODEL_NAME=nomic-embed-text
|
||||
EMBEDDER_DIMENSION=768
|
||||
|
||||
# Optional: Group ID for namespacing graph data
|
||||
# GROUP_ID=my_project
|
||||
|
||||
# Optional: Path configuration for Docker
|
||||
# PATH=/root/.local/bin:${PATH}
|
||||
|
||||
# Optional: Memory settings for Neo4j (used in Docker Compose)
|
||||
# NEO4J_server_memory_heap_initial__size=512m
|
||||
# NEO4J_server_memory_heap_max__size=1G
|
||||
# NEO4J_server_memory_pagecache_size=512m
|
||||
|
||||
# Azure OpenAI configuration
|
||||
# Optional: Only needed for Azure OpenAI endpoints
|
||||
# AZURE_OPENAI_ENDPOINT=your_azure_openai_endpoint_here
|
||||
# AZURE_OPENAI_API_VERSION=2025-01-01-preview
|
||||
# AZURE_OPENAI_DEPLOYMENT_NAME=gpt-4o-gpt-4o-mini-deployment
|
||||
# AZURE_OPENAI_EMBEDDING_API_VERSION=2023-05-15
|
||||
# AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME=text-embedding-3-large-deployment
|
||||
# AZURE_OPENAI_USE_MANAGED_IDENTITY=false
|
||||
|
|
@ -23,6 +23,7 @@ from graphiti_core.edges import EntityEdge
|
|||
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
|
||||
from graphiti_core.embedder.client import EmbedderClient
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
from graphiti_core.embedder.ollama import OllamaEmbedder, OllamaEmbedderConfig
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
|
|
@ -354,6 +355,7 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
|
||||
model: str = DEFAULT_EMBEDDER_MODEL
|
||||
api_key: str | None = None
|
||||
provider: str = "openai" # "openai", "ollama", or "azure"
|
||||
azure_openai_endpoint: str | None = None
|
||||
azure_openai_deployment_name: str | None = None
|
||||
azure_openai_api_version: str | None = None
|
||||
|
|
@ -367,6 +369,16 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
||||
|
||||
# Get embedder-specific API key and base URL, fallback to general OpenAI settings
|
||||
api_key = os.environ.get('EMBEDDER_API_KEY') or os.environ.get('OPENAI_API_KEY')
|
||||
|
||||
# Detect provider based on configuration
|
||||
provider = "openai" # default
|
||||
if api_key and api_key.lower() == "ollama":
|
||||
provider = "ollama"
|
||||
|
||||
logger.info(f'GraphitiEmbedderConfig provider: {provider}')
|
||||
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
|
|
@ -408,6 +420,7 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
return cls(
|
||||
model=model,
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
def create_client(self) -> EmbedderClient | None:
|
||||
|
|
@ -439,6 +452,29 @@ class GraphitiEmbedderConfig(BaseModel):
|
|||
else:
|
||||
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
return None
|
||||
elif self.provider == "ollama":
|
||||
|
||||
base_url_env = os.environ.get('EMBEDDER_BASE_URL')
|
||||
base_url = base_url_env if base_url_env else 'http://localhost:11434'
|
||||
|
||||
model_env = os.environ.get('EMBEDDER_MODEL_NAME')
|
||||
model = model_env if model_env else 'nomic-embed-text'
|
||||
|
||||
# Get embedding dimension from environment
|
||||
embedding_dim_env = os.environ.get('EMBEDDER_DIMENSION')
|
||||
embedding_dim = int(embedding_dim_env) if embedding_dim_env else 768
|
||||
|
||||
logger.info(f'ollama model: {model}')
|
||||
logger.info(f'ollama base_url: {base_url}')
|
||||
logger.info(f'ollama embedding_dim: {embedding_dim}')
|
||||
|
||||
# Ollama API setup
|
||||
ollama_config = OllamaEmbedderConfig(
|
||||
embedding_model=model,
|
||||
base_url=base_url,
|
||||
embedding_dim=embedding_dim # nomic-embed-text default
|
||||
)
|
||||
return OllamaEmbedder(config=ollama_config)
|
||||
else:
|
||||
# OpenAI API setup
|
||||
if not self.api_key:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue