diff --git a/README.md b/README.md index 4deeaccf..16f36d7f 100644 --- a/README.md +++ b/README.md @@ -523,6 +523,8 @@ reranker, leveraging Gemini's log probabilities feature to rank passage relevanc Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs. +**Note:** Use `OpenAIGenericClient` (not `OpenAIClient`) for Ollama and other OpenAI-compatible providers like LM Studio. The `OpenAIGenericClient` is optimized for local models with a higher default max token limit (16K vs 8K) and full support for structured outputs. + Install the models: ```bash diff --git a/graphiti_core/llm_client/openai_generic_client.py b/graphiti_core/llm_client/openai_generic_client.py index c2ee9691..af6e138b 100644 --- a/graphiti_core/llm_client/openai_generic_client.py +++ b/graphiti_core/llm_client/openai_generic_client.py @@ -17,7 +17,7 @@ limitations under the License. import json import logging import typing -from typing import ClassVar +from typing import Any, ClassVar import openai from openai import AsyncOpenAI @@ -59,15 +59,20 @@ class OpenAIGenericClient(LLMClient): MAX_RETRIES: ClassVar[int] = 2 def __init__( - self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None + self, + config: LLMConfig | None = None, + cache: bool = False, + client: typing.Any = None, + max_tokens: int = 16384, ): """ - Initialize the OpenAIClient with the provided configuration, cache setting, and client. + Initialize the OpenAIGenericClient with the provided configuration, cache setting, and client. Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. cache (bool): Whether to use caching for responses. Defaults to False. client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. + max_tokens (int): The maximum number of tokens to generate. Defaults to 16384 (16K) for better compatibility with local models. """ # removed caching to simplify the `generate_response` override @@ -79,6 +84,9 @@ class OpenAIGenericClient(LLMClient): super().__init__(config, cache) + # Override max_tokens to support higher limits for local models + self.max_tokens = max_tokens + if client is None: self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) else: @@ -99,12 +107,25 @@ class OpenAIGenericClient(LLMClient): elif m.role == 'system': openai_messages.append({'role': 'system', 'content': m.content}) try: + # Prepare response format + response_format: dict[str, Any] = {'type': 'json_object'} + if response_model is not None: + schema_name = getattr(response_model, '__name__', 'structured_response') + json_schema = response_model.model_json_schema() + response_format = { + 'type': 'json_schema', + 'json_schema': { + 'name': schema_name, + 'schema': json_schema, + }, + } + response = await self.client.chat.completions.create( model=self.model or DEFAULT_MODEL, messages=openai_messages, temperature=self.temperature, max_tokens=self.max_tokens, - response_format={'type': 'json_object'}, + response_format=response_format, # type: ignore[arg-type] ) result = response.choices[0].message.content or '' return json.loads(result) @@ -126,14 +147,6 @@ class OpenAIGenericClient(LLMClient): if max_tokens is None: max_tokens = self.max_tokens - if response_model is not None: - serialized_model = json.dumps(response_model.model_json_schema()) - messages[ - -1 - ].content += ( - f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}' - ) - # Add multilingual extraction instructions messages[0].content += get_extraction_language_instruction(group_id)