Use OpenAI structured output API for response validation (#1061)
* Use OpenAI structured output API for response validation Replace prompt-based schema injection with native json_schema response format. This improves token efficiency and reliability by having OpenAI enforce the schema directly instead of embedding it in the prompt message. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add type ignore for response_format to fix pyright error * Increase OpenAIGenericClient max_tokens to 16K and update docs - Set default max_tokens to 16384 (16K) for OpenAIGenericClient to better support local models - Add documentation note clarifying OpenAIGenericClient should be used for Ollama and LM Studio - Previous default was 8192 (8K) * Refactor max_tokens override to use constructor parameter pattern - Add max_tokens parameter to __init__ with 16K default - Override self.max_tokens after super().__init__() instead of mutating config - Consistent with OpenAIBaseClient and AnthropicClient patterns - Avoids unintended config mutation side effects --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d4a92772ec
commit
90d7757c17
2 changed files with 27 additions and 12 deletions
|
|
@ -523,6 +523,8 @@ reranker, leveraging Gemini's log probabilities feature to rank passage relevanc
|
|||
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal
|
||||
for privacy-focused applications or when you want to avoid API costs.
|
||||
|
||||
**Note:** Use `OpenAIGenericClient` (not `OpenAIClient`) for Ollama and other OpenAI-compatible providers like LM Studio. The `OpenAIGenericClient` is optimized for local models with a higher default max token limit (16K vs 8K) and full support for structured outputs.
|
||||
|
||||
Install the models:
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||
import json
|
||||
import logging
|
||||
import typing
|
||||
from typing import ClassVar
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
|
|
@ -59,15 +59,20 @@ class OpenAIGenericClient(LLMClient):
|
|||
MAX_RETRIES: ClassVar[int] = 2
|
||||
|
||||
def __init__(
|
||||
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
cache: bool = False,
|
||||
client: typing.Any = None,
|
||||
max_tokens: int = 16384,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
||||
Initialize the OpenAIGenericClient with the provided configuration, cache setting, and client.
|
||||
|
||||
Args:
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||
max_tokens (int): The maximum number of tokens to generate. Defaults to 16384 (16K) for better compatibility with local models.
|
||||
|
||||
"""
|
||||
# removed caching to simplify the `generate_response` override
|
||||
|
|
@ -79,6 +84,9 @@ class OpenAIGenericClient(LLMClient):
|
|||
|
||||
super().__init__(config, cache)
|
||||
|
||||
# Override max_tokens to support higher limits for local models
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
if client is None:
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
else:
|
||||
|
|
@ -99,12 +107,25 @@ class OpenAIGenericClient(LLMClient):
|
|||
elif m.role == 'system':
|
||||
openai_messages.append({'role': 'system', 'content': m.content})
|
||||
try:
|
||||
# Prepare response format
|
||||
response_format: dict[str, Any] = {'type': 'json_object'}
|
||||
if response_model is not None:
|
||||
schema_name = getattr(response_model, '__name__', 'structured_response')
|
||||
json_schema = response_model.model_json_schema()
|
||||
response_format = {
|
||||
'type': 'json_schema',
|
||||
'json_schema': {
|
||||
'name': schema_name,
|
||||
'schema': json_schema,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model or DEFAULT_MODEL,
|
||||
messages=openai_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
response_format={'type': 'json_object'},
|
||||
response_format=response_format, # type: ignore[arg-type]
|
||||
)
|
||||
result = response.choices[0].message.content or ''
|
||||
return json.loads(result)
|
||||
|
|
@ -126,14 +147,6 @@ class OpenAIGenericClient(LLMClient):
|
|||
if max_tokens is None:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
if response_model is not None:
|
||||
serialized_model = json.dumps(response_model.model_json_schema())
|
||||
messages[
|
||||
-1
|
||||
].content += (
|
||||
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
||||
)
|
||||
|
||||
# Add multilingual extraction instructions
|
||||
messages[0].content += get_extraction_language_instruction(group_id)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue