Merge 7762fdef0c into 5e593dd096
This commit is contained in:
commit
ac67fd1c3f
7 changed files with 428 additions and 0 deletions
|
|
@ -1,5 +1,7 @@
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
|
|
||||||
|
OLLAMA_HOST=http://localhost:11434
|
||||||
|
|
||||||
# Neo4j database connection
|
# Neo4j database connection
|
||||||
NEO4J_URI=
|
NEO4J_URI=
|
||||||
NEO4J_PORT=
|
NEO4J_PORT=
|
||||||
|
|
|
||||||
|
|
@ -213,6 +213,9 @@ pip install graphiti-core[groq]
|
||||||
# Install with Google Gemini support
|
# Install with Google Gemini support
|
||||||
pip install graphiti-core[google-genai]
|
pip install graphiti-core[google-genai]
|
||||||
|
|
||||||
|
# Install with Ollama support
|
||||||
|
pip install graphiti-core[ollama]
|
||||||
|
|
||||||
# Install with multiple providers
|
# Install with multiple providers
|
||||||
pip install graphiti-core[anthropic,groq,google-genai]
|
pip install graphiti-core[anthropic,groq,google-genai]
|
||||||
|
|
||||||
|
|
|
||||||
176
graphiti_core/embedder/ollama.py
Normal file
176
graphiti_core/embedder/ollama.py
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ollama import AsyncClient
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from ollama import AsyncClient
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
'ollama is required for OllamaEmbedder. '
|
||||||
|
'Install it with: pip install graphiti-core[ollama]'
|
||||||
|
) from None
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from .client import EmbedderClient, EmbedderConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_EMBEDDING_MODEL = 'bge-m3:567m'
|
||||||
|
DEFAULT_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
class OllamaEmbedderConfig(EmbedderConfig):
|
||||||
|
embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
|
||||||
|
api_key: str | None = None
|
||||||
|
base_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbedder(EmbedderClient):
|
||||||
|
"""
|
||||||
|
Ollama Embedder Client
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: OllamaEmbedderConfig | None = None,
|
||||||
|
client: AsyncClient | None = None,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
):
|
||||||
|
if config is None:
|
||||||
|
config = OllamaEmbedderConfig()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
# AsyncClient doesn't necessarily accept api_key; pass host via headers if needed
|
||||||
|
try:
|
||||||
|
host = config.base_url.rstrip('/v1') if config.base_url else None
|
||||||
|
self.client = AsyncClient(host=host)
|
||||||
|
except TypeError as e:
|
||||||
|
logger.warning(f"Error creating AsyncClient: {e}")
|
||||||
|
self.client = AsyncClient()
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
if batch_size is None:
|
||||||
|
self.batch_size = DEFAULT_BATCH_SIZE
|
||||||
|
else:
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
async def create(self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]) -> list[float]:
|
||||||
|
"""Create a single embedding for the input using Ollama.
|
||||||
|
|
||||||
|
Ollama's embed endpoint accepts either a single string or list of strings.
|
||||||
|
We normalize to a single-item list and return the first embedding vector.
|
||||||
|
"""
|
||||||
|
# Ollama's embed returns an object with 'embedding' or similar fields
|
||||||
|
try:
|
||||||
|
# Support call with client.embed for async client
|
||||||
|
result = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=input_data) # type: ignore[arg-type]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Ollama embed error: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Extract embedding and coerce to list[float]
|
||||||
|
values: list[float] | None = None
|
||||||
|
|
||||||
|
if hasattr(result, 'embeddings'):
|
||||||
|
emb = result.embeddings
|
||||||
|
if isinstance(emb, list) and len(emb) > 0:
|
||||||
|
values = emb[0] if isinstance(emb[0], list | tuple) else emb # type: ignore
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
if 'embedding' in result and isinstance(result['embedding'], list | tuple):
|
||||||
|
values = list(result['embedding']) # type: ignore
|
||||||
|
elif 'embeddings' in result and isinstance(result['embeddings'], list) and len(result['embeddings']) > 0:
|
||||||
|
first = result['embeddings'][0]
|
||||||
|
if isinstance(first, dict) and 'embedding' in first and isinstance(first['embedding'], list | tuple):
|
||||||
|
values = list(first['embedding'])
|
||||||
|
elif isinstance(first, list | tuple):
|
||||||
|
values = list(first)
|
||||||
|
|
||||||
|
# If result itself is a list (some clients return list for single input)
|
||||||
|
if values is None and isinstance(result, list | tuple):
|
||||||
|
# assume it's already the embedding vector
|
||||||
|
values = list(result) # type: ignore
|
||||||
|
if not values:
|
||||||
|
raise ValueError('No embeddings returned from Ollama API in create()')
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||||
|
if not input_data_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
all_embeddings: list[list[float]] = []
|
||||||
|
|
||||||
|
for i in range(0, len(input_data_list), self.batch_size):
|
||||||
|
batch = input_data_list[i : i + self.batch_size]
|
||||||
|
try:
|
||||||
|
result = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=batch)
|
||||||
|
|
||||||
|
# result may be dict with 'embeddings' list or single 'embedding'
|
||||||
|
if isinstance(result, dict) and 'embeddings' in result:
|
||||||
|
for emb in result['embeddings']:
|
||||||
|
if isinstance(emb, dict) and 'embedding' in emb and isinstance(emb['embedding'], list | tuple):
|
||||||
|
all_embeddings.append(list(emb['embedding']))
|
||||||
|
elif isinstance(emb, list | tuple):
|
||||||
|
all_embeddings.append(list(emb))
|
||||||
|
else:
|
||||||
|
# unexpected shape
|
||||||
|
raise ValueError('Unexpected embedding shape in batch result')
|
||||||
|
else:
|
||||||
|
# Fallback: maybe result itself is a list of vectors
|
||||||
|
if isinstance(result, list):
|
||||||
|
all_embeddings.extend(result)
|
||||||
|
else:
|
||||||
|
# Single embedding returned for whole batch; if so, duplicate per item
|
||||||
|
embedding = None
|
||||||
|
if isinstance(result, dict) and 'embedding' in result:
|
||||||
|
embedding = result['embedding']
|
||||||
|
if embedding is None:
|
||||||
|
raise ValueError('No embeddings returned')
|
||||||
|
for _ in batch:
|
||||||
|
all_embeddings.append(embedding)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Batch embedding failed for batch {i // self.batch_size + 1}, falling back to individual processing: {e}')
|
||||||
|
for item in batch:
|
||||||
|
try:
|
||||||
|
single = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=item)
|
||||||
|
emb = None
|
||||||
|
if hasattr(result, 'embeddings'):
|
||||||
|
_emb = result.embeddings
|
||||||
|
if isinstance(_emb, list) and len(_emb) > 0:
|
||||||
|
emb = _emb[0] if isinstance(_emb[0], list | tuple) else _emb # type: ignore
|
||||||
|
elif isinstance(single, dict) and 'embedding' in single:
|
||||||
|
emb = single['embedding']
|
||||||
|
elif isinstance(single, dict) and 'embeddings' in single:
|
||||||
|
emb = single['embeddings']
|
||||||
|
elif isinstance(single, list | tuple):
|
||||||
|
emb = single[0] if single else None # type: ignore
|
||||||
|
if not emb:
|
||||||
|
raise ValueError('No embeddings returned from Ollama API')
|
||||||
|
all_embeddings.append(emb) # type: ignore
|
||||||
|
except Exception as individual_error:
|
||||||
|
logger.error(f'Failed to embed individual item: {individual_error}')
|
||||||
|
raise individual_error
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
|
|
@ -275,6 +275,8 @@ class Graphiti:
|
||||||
return 'gemini'
|
return 'gemini'
|
||||||
elif 'groq' in class_name:
|
elif 'groq' in class_name:
|
||||||
return 'groq'
|
return 'groq'
|
||||||
|
elif 'ollama' in class_name:
|
||||||
|
return 'ollama'
|
||||||
# Database providers
|
# Database providers
|
||||||
elif 'neo4j' in class_name:
|
elif 'neo4j' in class_name:
|
||||||
return 'neo4j'
|
return 'neo4j'
|
||||||
|
|
|
||||||
148
graphiti_core/llm_client/ollama_client.py
Normal file
148
graphiti_core/llm_client/ollama_client.py
Normal file
|
|
@ -0,0 +1,148 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ollama import AsyncClient
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from ollama import AsyncClient
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
'ollama is required for OllamaClient. Install it with: pip install graphiti-core[ollama]'
|
||||||
|
) from None
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..prompts.models import Message
|
||||||
|
from .client import LLMClient
|
||||||
|
from .config import LLMConfig, ModelSize
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MODEL = 'qwen3:4b'
|
||||||
|
DEFAULT_MAX_TOKENS = 8192
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaClient(LLMClient):
|
||||||
|
"""Ollama async client wrapper for Graphiti.
|
||||||
|
|
||||||
|
This client expects the `ollama` python package to be installed. It uses the
|
||||||
|
AsyncClient.chat(...) API to generate chat responses. The response content
|
||||||
|
is expected to be JSON which will be parsed and returned as a dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any | None = None):
|
||||||
|
if config is None:
|
||||||
|
config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS)
|
||||||
|
elif config.max_tokens is None:
|
||||||
|
config.max_tokens = DEFAULT_MAX_TOKENS
|
||||||
|
super().__init__(config, cache)
|
||||||
|
|
||||||
|
# Allow injecting a preconfigured AsyncClient for testing
|
||||||
|
if client is None:
|
||||||
|
# AsyncClient accepts host and other httpx args; pass api_key/base_url when available
|
||||||
|
try:
|
||||||
|
host = config.base_url.rstrip('/v1') if config.base_url else None
|
||||||
|
self.client = AsyncClient(host=host)
|
||||||
|
except TypeError as e:
|
||||||
|
logger.warning(f"Error creating AsyncClient: {e}")
|
||||||
|
self.client = AsyncClient()
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def _generate_response(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
|
) -> dict[str, typing.Any]:
|
||||||
|
msgs: list[dict[str, str]] = []
|
||||||
|
for m in messages:
|
||||||
|
if m.role == 'user':
|
||||||
|
msgs.append({'role': 'user', 'content': m.content})
|
||||||
|
elif m.role == 'system':
|
||||||
|
msgs.append({'role': 'system', 'content': m.content})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare options
|
||||||
|
options: dict[str, typing.Any] = {}
|
||||||
|
if max_tokens is not None:
|
||||||
|
options['max_tokens'] = max_tokens
|
||||||
|
if self.temperature is not None:
|
||||||
|
options['temperature'] = self.temperature
|
||||||
|
|
||||||
|
# If a response_model is provided, try to get its JSON schema for format
|
||||||
|
schema = None
|
||||||
|
if response_model is not None:
|
||||||
|
try:
|
||||||
|
schema = response_model.model_json_schema()
|
||||||
|
except Exception:
|
||||||
|
schema = None
|
||||||
|
response = await self.client.chat(
|
||||||
|
model=self.model or DEFAULT_MODEL,
|
||||||
|
messages=msgs,
|
||||||
|
stream=False,
|
||||||
|
format=schema,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract content
|
||||||
|
content: str | None = None
|
||||||
|
if isinstance(response, dict) and 'message' in response and isinstance(response['message'], dict):
|
||||||
|
content = response['message'].get('content')
|
||||||
|
else:
|
||||||
|
# Some clients return objects with a .message attribute instead of dicts
|
||||||
|
msg = getattr(response, 'message', None)
|
||||||
|
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
content = msg.get('content')
|
||||||
|
elif msg is not None:
|
||||||
|
content = getattr(msg, 'content', None)
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
# fallback to string
|
||||||
|
content = str(response)
|
||||||
|
|
||||||
|
# If structured response requested, validate with pydantic model
|
||||||
|
if response_model is not None:
|
||||||
|
# Use pydantic v2 model validate json method
|
||||||
|
try:
|
||||||
|
validated = response_model.model_validate_json(content)
|
||||||
|
# return model as dict
|
||||||
|
return validated.model_dump() # type: ignore[attr-defined]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Failed to validate response with response_model: {e}')
|
||||||
|
# fallthrough to try json loads
|
||||||
|
|
||||||
|
# Try parse JSON otherwise
|
||||||
|
try:
|
||||||
|
return json.loads(content)
|
||||||
|
except Exception:
|
||||||
|
return {'text': content}
|
||||||
|
except Exception as e:
|
||||||
|
# map obvious ollama rate limit / response errors to RateLimitError when possible
|
||||||
|
err_name = e.__class__.__name__
|
||||||
|
status_code = getattr(e, 'status_code', None) or getattr(e, 'status', None)
|
||||||
|
if err_name in ('RequestError', 'ResponseError') and status_code == 429:
|
||||||
|
raise RateLimitError from e
|
||||||
|
logger.error(f'Error in generating LLM response (ollama): {e}')
|
||||||
|
raise
|
||||||
|
|
@ -35,6 +35,7 @@ voyageai = ["voyageai>=0.2.3"]
|
||||||
neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"]
|
neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"]
|
||||||
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
sentence-transformers = ["sentence-transformers>=3.2.1"]
|
||||||
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
|
neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"]
|
||||||
|
ollama = ["ollama>=0.5.3"]
|
||||||
tracing = ["opentelemetry-api>=1.20.0", "opentelemetry-sdk>=1.20.0"]
|
tracing = ["opentelemetry-api>=1.20.0", "opentelemetry-sdk>=1.20.0"]
|
||||||
dev = [
|
dev = [
|
||||||
"pyright>=1.1.404",
|
"pyright>=1.1.404",
|
||||||
|
|
@ -56,6 +57,7 @@ dev = [
|
||||||
"sentence-transformers>=3.2.1",
|
"sentence-transformers>=3.2.1",
|
||||||
"transformers>=4.45.2",
|
"transformers>=4.45.2",
|
||||||
"voyageai>=0.2.3",
|
"voyageai>=0.2.3",
|
||||||
|
"ollama>=0.5.3",
|
||||||
"pytest>=8.3.3",
|
"pytest>=8.3.3",
|
||||||
"pytest-asyncio>=0.24.0",
|
"pytest-asyncio>=0.24.0",
|
||||||
"pytest-xdist>=3.6.1",
|
"pytest-xdist>=3.6.1",
|
||||||
|
|
|
||||||
95
tests/llm_client/test_ollama_client.py
Normal file
95
tests/llm_client/test_ollama_client.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Running tests: pytest -xvs tests/llm_client/test_ollama_client.py
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from graphiti_core.llm_client.ollama_client import OllamaClient
|
||||||
|
from graphiti_core.prompts.models import Message
|
||||||
|
|
||||||
|
# Skip tests if no Ollama API/key or local server available
|
||||||
|
# pytestmark = pytest.mark.skipif(
|
||||||
|
# 'OLLAMA_HOST' not in os.environ,
|
||||||
|
# reason='Ollama API/host not available',
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# Rename to avoid pytest collection as a test class
|
||||||
|
class SimpleResponseModel(BaseModel):
|
||||||
|
message: str = Field(..., description='A message from the model')
|
||||||
|
|
||||||
|
|
||||||
|
class Pet(BaseModel):
|
||||||
|
name: str
|
||||||
|
animal: str
|
||||||
|
age: int
|
||||||
|
color: str | None
|
||||||
|
favorite_toy: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class PetList(BaseModel):
|
||||||
|
pets: list[Pet]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_generate_simple_response():
|
||||||
|
client = OllamaClient()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content="Respond with a JSON object containing a 'message' field with value 'Hello, world!'",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.generate_response(messages, response_model=SimpleResponseModel)
|
||||||
|
assert isinstance(response, dict)
|
||||||
|
assert 'message' in response
|
||||||
|
assert response['message'] == 'Hello, world!'
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f'Test skipped due to Ollama API error: {str(e)}')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_structured_output_with_pydantic():
|
||||||
|
client = OllamaClient()
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
Message(
|
||||||
|
role='user',
|
||||||
|
content='''
|
||||||
|
I have two pets.
|
||||||
|
A cat named Luna who is 5 years old and loves playing with yarn. She has grey fur.
|
||||||
|
I also have a 2 year old black cat named Loki who loves tennis balls.
|
||||||
|
''',
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.generate_response(messages, response_model=PetList)
|
||||||
|
assert isinstance(response, dict)
|
||||||
|
assert 'pets' in response
|
||||||
|
assert isinstance(response['pets'], list)
|
||||||
|
assert len(response['pets']) >= 1
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f'Test skipped due to Ollama API error: {str(e)}')
|
||||||
|
|
||||||
Loading…
Add table
Reference in a new issue