add embedder/ollama and llm_client/ollama_client
This commit is contained in:
parent
ce1ae30569
commit
c64c128f06
7 changed files with 423 additions and 0 deletions
|
|
@ -1,5 +1,7 @@
|
|||
OPENAI_API_KEY=
|
||||
|
||||
OLLAMA_HOST=http://localhost:11434
|
||||
|
||||
# Neo4j database connection
|
||||
NEO4J_URI=
|
||||
NEO4J_PORT=
|
||||
|
|
|
|||
|
|
@ -196,6 +196,9 @@ pip install graphiti-core[groq]
|
|||
# Install with Google Gemini support
|
||||
pip install graphiti-core[google-genai]
|
||||
|
||||
# Install with Ollama support
|
||||
pip install graphiti-core[ollama]
|
||||
|
||||
# Install with multiple providers
|
||||
pip install graphiti-core[anthropic,groq,google-genai]
|
||||
|
||||
|
|
|
|||
172
graphiti_core/embedder/ollama.py
Normal file
172
graphiti_core/embedder/ollama.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ollama
|
||||
from ollama import AsyncClient
|
||||
else:
|
||||
try:
|
||||
import ollama
|
||||
from ollama import AsyncClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'ollama is required for OllamaEmbedder. '
|
||||
'Install it with: pip install graphiti-core[ollama]'
|
||||
) from None
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .client import EmbedderClient, EmbedderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_BASE_URL = "http://localhost:11434"
|
||||
DEFAULT_EMBEDDING_MODEL = 'bge-m3:567m'
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
|
||||
class OllamaEmbedderConfig(EmbedderConfig):
|
||||
embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
|
||||
api_key: str | None = None
|
||||
base_url: str = Field(default=DEFAULT_BASE_URL)
|
||||
|
||||
|
||||
class OllamaEmbedder(EmbedderClient):
|
||||
"""
|
||||
Ollama Embedder Client
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config: OllamaEmbedderConfig | None = None,
|
||||
client: AsyncClient | None = None,
|
||||
batch_size: int | None = None,
|
||||
):
|
||||
if config is None:
|
||||
config = OllamaEmbedderConfig()
|
||||
|
||||
self.config = config
|
||||
|
||||
if client is None:
|
||||
# AsyncClient doesn't necessarily accept api_key; pass host via headers if needed
|
||||
try:
|
||||
self.client = AsyncClient(api_key=config.api_key, host=config.base_url)
|
||||
except TypeError:
|
||||
self.client = AsyncClient()
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
if batch_size is None:
|
||||
self.batch_size = DEFAULT_BATCH_SIZE
|
||||
else:
|
||||
self.batch_size = batch_size
|
||||
|
||||
async def create(self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]) -> list[float]:
|
||||
"""Create a single embedding for the input using Ollama.
|
||||
|
||||
Ollama's embed endpoint accepts either a single string or list of strings.
|
||||
We normalize to a single-item list and return the first embedding vector.
|
||||
"""
|
||||
# Ollama's embed returns an object with 'embedding' or similar fields
|
||||
try:
|
||||
# Support call with client.embed for async client
|
||||
result = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=input_data) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.error(f'Ollama embed error: {e}')
|
||||
raise
|
||||
|
||||
# Extract embedding and coerce to list[float]
|
||||
values: list[float] | None = None
|
||||
|
||||
if hasattr(result, 'embedding'):
|
||||
cand = getattr(result, 'embedding')
|
||||
if isinstance(cand, (list, tuple)):
|
||||
values = list(cand) # type: ignore
|
||||
elif isinstance(result, dict):
|
||||
if 'embedding' in result and isinstance(result['embedding'], (list, tuple)):
|
||||
values = list(result['embedding']) # type: ignore
|
||||
elif 'embeddings' in result and isinstance(result['embeddings'], list) and len(result['embeddings']) > 0:
|
||||
first = result['embeddings'][0]
|
||||
if isinstance(first, dict) and 'embedding' in first and isinstance(first['embedding'], (list, tuple)):
|
||||
values = list(first['embedding'])
|
||||
elif isinstance(first, (list, tuple)):
|
||||
values = list(first)
|
||||
|
||||
# If result itself is a list (some clients return list for single input)
|
||||
if values is None and isinstance(result, (list, tuple)):
|
||||
# assume it's already the embedding vector
|
||||
values = list(result) # type: ignore
|
||||
|
||||
if not values:
|
||||
raise ValueError('No embeddings returned from Ollama API in create()')
|
||||
|
||||
return values
|
||||
|
||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
if not input_data_list:
|
||||
return []
|
||||
|
||||
all_embeddings: list[list[float]] = []
|
||||
|
||||
for i in range(0, len(input_data_list), self.batch_size):
|
||||
batch = input_data_list[i : i + self.batch_size]
|
||||
try:
|
||||
result = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=batch)
|
||||
|
||||
# result may be dict with 'embeddings' list or single 'embedding'
|
||||
if isinstance(result, dict) and 'embeddings' in result:
|
||||
for emb in result['embeddings']:
|
||||
if isinstance(emb, dict) and 'embedding' in emb and isinstance(emb['embedding'], (list, tuple)):
|
||||
all_embeddings.append(list(emb['embedding']))
|
||||
elif isinstance(emb, (list, tuple)):
|
||||
all_embeddings.append(list(emb))
|
||||
else:
|
||||
# unexpected shape
|
||||
raise ValueError('Unexpected embedding shape in batch result')
|
||||
else:
|
||||
# Fallback: maybe result itself is a list of vectors
|
||||
if isinstance(result, list):
|
||||
all_embeddings.extend(result)
|
||||
else:
|
||||
# Single embedding returned for whole batch; if so, duplicate per item
|
||||
embedding = None
|
||||
if isinstance(result, dict) and 'embedding' in result:
|
||||
embedding = result['embedding']
|
||||
if embedding is None:
|
||||
raise ValueError('No embeddings returned')
|
||||
for _ in batch:
|
||||
all_embeddings.append(embedding)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Batch embedding failed for batch {i // self.batch_size + 1}, falling back to individual processing: {e}')
|
||||
for item in batch:
|
||||
try:
|
||||
single = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=item)
|
||||
emb = None
|
||||
if isinstance(single, dict) and 'embedding' in single:
|
||||
emb = single['embedding']
|
||||
elif isinstance(single, (list, tuple)):
|
||||
emb = single[0] if single else None # type: ignore
|
||||
if not emb:
|
||||
raise ValueError('No embeddings returned from Ollama API')
|
||||
all_embeddings.append(emb)
|
||||
except Exception as individual_error:
|
||||
logger.error(f'Failed to embed individual item: {individual_error}')
|
||||
raise individual_error
|
||||
|
||||
return all_embeddings
|
||||
|
|
@ -271,6 +271,8 @@ class Graphiti:
|
|||
return 'gemini'
|
||||
elif 'groq' in class_name:
|
||||
return 'groq'
|
||||
elif 'ollama' in class_name:
|
||||
return 'ollama'
|
||||
# Database providers
|
||||
elif 'neo4j' in class_name:
|
||||
return 'neo4j'
|
||||
|
|
|
|||
145
graphiti_core/llm_client/ollama_client.py
Normal file
145
graphiti_core/llm_client/ollama_client.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ollama import AsyncClient
|
||||
else:
|
||||
try:
|
||||
from ollama import AsyncClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'ollama is required for OllamaClient. Install it with: pip install graphiti-core[ollama]'
|
||||
) from None
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig, ModelSize
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'qwen3:4b'
|
||||
DEFAULT_MAX_TOKENS = 8192
|
||||
|
||||
|
||||
class OllamaClient(LLMClient):
|
||||
"""Ollama async client wrapper for Graphiti.
|
||||
|
||||
This client expects the `ollama` python package to be installed. It uses the
|
||||
AsyncClient.chat(...) API to generate chat responses. The response content
|
||||
is expected to be JSON which will be parsed and returned as a dict.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any | None = None):
|
||||
if config is None:
|
||||
config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS)
|
||||
elif config.max_tokens is None:
|
||||
config.max_tokens = DEFAULT_MAX_TOKENS
|
||||
super().__init__(config, cache)
|
||||
|
||||
# Allow injecting a preconfigured AsyncClient for testing
|
||||
if client is None:
|
||||
# AsyncClient accepts host and other httpx args; pass api_key/base_url when available
|
||||
try:
|
||||
self.client = AsyncClient(api_key=config.api_key, host=config.base_url)
|
||||
except TypeError:
|
||||
# Fallback if AsyncClient signature differs
|
||||
self.client = AsyncClient()
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
async def _generate_response(
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
) -> dict[str, typing.Any]:
|
||||
msgs: list[dict[str, str]] = []
|
||||
for m in messages:
|
||||
if m.role == 'user':
|
||||
msgs.append({'role': 'user', 'content': m.content})
|
||||
elif m.role == 'system':
|
||||
msgs.append({'role': 'system', 'content': m.content})
|
||||
|
||||
try:
|
||||
# Prepare options
|
||||
options: dict[str, typing.Any] = {}
|
||||
if max_tokens is not None:
|
||||
options['max_tokens'] = max_tokens
|
||||
if self.temperature is not None:
|
||||
options['temperature'] = self.temperature
|
||||
|
||||
# If a response_model is provided, try to get its JSON schema for format
|
||||
schema = None
|
||||
if response_model is not None:
|
||||
try:
|
||||
schema = response_model.model_json_schema()
|
||||
except Exception:
|
||||
schema = None
|
||||
response = await self.client.chat(
|
||||
model=self.model or DEFAULT_MODEL,
|
||||
messages=msgs,
|
||||
stream=False,
|
||||
format=schema,
|
||||
options=options,
|
||||
)
|
||||
|
||||
# Extract content
|
||||
content: str | None = None
|
||||
if isinstance(response, dict) and 'message' in response and isinstance(response['message'], dict):
|
||||
content = response['message'].get('content')
|
||||
elif hasattr(response, 'message') and getattr(response, 'message') is not None:
|
||||
msg = getattr(response, 'message')
|
||||
if isinstance(msg, dict):
|
||||
content = msg.get('content')
|
||||
else:
|
||||
content = getattr(msg, 'content', None)
|
||||
|
||||
if content is None:
|
||||
# fallback to string
|
||||
content = str(response)
|
||||
|
||||
# If structured response requested, validate with pydantic model
|
||||
if response_model is not None:
|
||||
# Use pydantic v2 model validate json method
|
||||
try:
|
||||
validated = response_model.model_validate_json(content)
|
||||
# return model as dict
|
||||
return validated.model_dump() # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to validate response with response_model: {e}')
|
||||
# fallthrough to try json loads
|
||||
|
||||
# Try parse JSON otherwise
|
||||
try:
|
||||
return json.loads(content)
|
||||
except Exception:
|
||||
return {'text': content}
|
||||
except Exception as e:
|
||||
# map obvious ollama rate limit / response errors to RateLimitError when possible
|
||||
err_name = e.__class__.__name__
|
||||
status_code = getattr(e, 'status_code', None) or getattr(e, 'status', None)
|
||||
if err_name in ('RequestError', 'ResponseError') and status_code == 429:
|
||||
raise RateLimitError from e
|
||||
logger.error(f'Error in generating LLM response (ollama): {e}')
|
||||
raise
|
||||
|
|
@ -34,6 +34,7 @@ falkordb = ["falkordb>=1.1.2,<2.0.0"]
|
|||
voyageai = ["voyageai>=0.2.3"]
|
||||
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
||||
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
|
||||
ollama = ["ollama>=0.5.3"]
|
||||
dev = [
|
||||
"pyright>=1.1.404",
|
||||
"groq>=0.2.0",
|
||||
|
|
@ -51,6 +52,7 @@ dev = [
|
|||
"sentence-transformers>=3.2.1",
|
||||
"transformers>=4.45.2",
|
||||
"voyageai>=0.2.3",
|
||||
"ollama>=0.5.3",
|
||||
"pytest>=8.3.3",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-xdist>=3.6.1",
|
||||
|
|
|
|||
97
tests/llm_client/test_ollama_client.py
Normal file
97
tests/llm_client/test_ollama_client.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Running tests: pytest -xvs tests/llm_client/test_ollama_client.py
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core.llm_client.ollama_client import OllamaClient
|
||||
from graphiti_core.prompts.models import Message
|
||||
|
||||
|
||||
# Skip tests if no Ollama API/key or local server available
|
||||
# pytestmark = pytest.mark.skipif(
|
||||
# 'OLLAMA_HOST' not in os.environ,
|
||||
# reason='Ollama API/host not available',
|
||||
# )
|
||||
|
||||
|
||||
# Rename to avoid pytest collection as a test class
|
||||
class SimpleResponseModel(BaseModel):
|
||||
message: str = Field(..., description='A message from the model')
|
||||
|
||||
|
||||
class Pet(BaseModel):
|
||||
name: str
|
||||
animal: str
|
||||
age: int
|
||||
color: str | None
|
||||
favorite_toy: str | None
|
||||
|
||||
|
||||
class PetList(BaseModel):
|
||||
pets: list[Pet]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_generate_simple_response():
|
||||
client = OllamaClient()
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role='user',
|
||||
content="Respond with a JSON object containing a 'message' field with value 'Hello, world!'",
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
response = await client.generate_response(messages, response_model=SimpleResponseModel)
|
||||
assert isinstance(response, dict)
|
||||
assert 'message' in response
|
||||
assert response['message'] == 'Hello, world!'
|
||||
except Exception as e:
|
||||
pytest.skip(f'Test skipped due to Ollama API error: {str(e)}')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_structured_output_with_pydantic():
|
||||
client = OllamaClient()
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role='user',
|
||||
content='''
|
||||
I have two pets.
|
||||
A cat named Luna who is 5 years old and loves playing with yarn. She has grey fur.
|
||||
I also have a 2 year old black cat named Loki who loves tennis balls.
|
||||
''',
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
response = await client.generate_response(messages, response_model=PetList)
|
||||
assert isinstance(response, dict)
|
||||
assert 'pets' in response
|
||||
assert isinstance(response['pets'], list)
|
||||
assert len(response['pets']) >= 1
|
||||
except Exception as e:
|
||||
pytest.skip(f'Test skipped due to Ollama API error: {str(e)}')
|
||||
|
||||
Loading…
Add table
Reference in a new issue