Use OpenAI structured output API for response validation (#1061)

* Use OpenAI structured output API for response validation

Replace prompt-based schema injection with native json_schema response format. This improves token efficiency and reliability by having OpenAI enforce the schema directly instead of embedding it in the prompt message.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Add type ignore for response_format to fix pyright error

* Increase OpenAIGenericClient max_tokens to 16K and update docs

- Set default max_tokens to 16384 (16K) for OpenAIGenericClient to better support local models
- Add documentation note clarifying OpenAIGenericClient should be used for Ollama and LM Studio
- Previous default was 8192 (8K)

* Refactor max_tokens override to use constructor parameter pattern

- Add max_tokens parameter to __init__ with 16K default
- Override self.max_tokens after super().__init__() instead of mutating config
- Consistent with OpenAIBaseClient and AnthropicClient patterns
- Avoids unintended config mutation side effects

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Chalef 2025-11-11 06:53:37 -08:00 committed by GitHub
parent d4a92772ec
commit 90d7757c17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 12 deletions

View file

@ -523,6 +523,8 @@ reranker, leveraging Gemini's log probabilities feature to rank passage relevanc
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal
for privacy-focused applications or when you want to avoid API costs. for privacy-focused applications or when you want to avoid API costs.
**Note:** Use `OpenAIGenericClient` (not `OpenAIClient`) for Ollama and other OpenAI-compatible providers like LM Studio. The `OpenAIGenericClient` is optimized for local models with a higher default max token limit (16K vs 8K) and full support for structured outputs.
Install the models: Install the models:
```bash ```bash

View file

@ -17,7 +17,7 @@ limitations under the License.
import json import json
import logging import logging
import typing import typing
from typing import ClassVar from typing import Any, ClassVar
import openai import openai
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -59,15 +59,20 @@ class OpenAIGenericClient(LLMClient):
MAX_RETRIES: ClassVar[int] = 2 MAX_RETRIES: ClassVar[int] = 2
def __init__( def __init__(
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None self,
config: LLMConfig | None = None,
cache: bool = False,
client: typing.Any = None,
max_tokens: int = 16384,
): ):
""" """
Initialize the OpenAIClient with the provided configuration, cache setting, and client. Initialize the OpenAIGenericClient with the provided configuration, cache setting, and client.
Args: Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False. cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
max_tokens (int): The maximum number of tokens to generate. Defaults to 16384 (16K) for better compatibility with local models.
""" """
# removed caching to simplify the `generate_response` override # removed caching to simplify the `generate_response` override
@ -79,6 +84,9 @@ class OpenAIGenericClient(LLMClient):
super().__init__(config, cache) super().__init__(config, cache)
# Override max_tokens to support higher limits for local models
self.max_tokens = max_tokens
if client is None: if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else: else:
@ -99,12 +107,25 @@ class OpenAIGenericClient(LLMClient):
elif m.role == 'system': elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content}) openai_messages.append({'role': 'system', 'content': m.content})
try: try:
# Prepare response format
response_format: dict[str, Any] = {'type': 'json_object'}
if response_model is not None:
schema_name = getattr(response_model, '__name__', 'structured_response')
json_schema = response_model.model_json_schema()
response_format = {
'type': 'json_schema',
'json_schema': {
'name': schema_name,
'schema': json_schema,
},
}
response = await self.client.chat.completions.create( response = await self.client.chat.completions.create(
model=self.model or DEFAULT_MODEL, model=self.model or DEFAULT_MODEL,
messages=openai_messages, messages=openai_messages,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
response_format={'type': 'json_object'}, response_format=response_format, # type: ignore[arg-type]
) )
result = response.choices[0].message.content or '' result = response.choices[0].message.content or ''
return json.loads(result) return json.loads(result)
@ -126,14 +147,6 @@ class OpenAIGenericClient(LLMClient):
if max_tokens is None: if max_tokens is None:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if response_model is not None:
serialized_model = json.dumps(response_model.model_json_schema())
messages[
-1
].content += (
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
)
# Add multilingual extraction instructions # Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction(group_id) messages[0].content += get_extraction_language_instruction(group_id)