Use OpenAI structured output API for response validation (#1061)
* Use OpenAI structured output API for response validation Replace prompt-based schema injection with native json_schema response format. This improves token efficiency and reliability by having OpenAI enforce the schema directly instead of embedding it in the prompt message. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add type ignore for response_format to fix pyright error * Increase OpenAIGenericClient max_tokens to 16K and update docs - Set default max_tokens to 16384 (16K) for OpenAIGenericClient to better support local models - Add documentation note clarifying OpenAIGenericClient should be used for Ollama and LM Studio - Previous default was 8192 (8K) * Refactor max_tokens override to use constructor parameter pattern - Add max_tokens parameter to __init__ with 16K default - Override self.max_tokens after super().__init__() instead of mutating config - Consistent with OpenAIBaseClient and AnthropicClient patterns - Avoids unintended config mutation side effects --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d4a92772ec
commit
90d7757c17
2 changed files with 27 additions and 12 deletions
|
|
@ -523,6 +523,8 @@ reranker, leveraging Gemini's log probabilities feature to rank passage relevanc
|
||||||
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal
|
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal
|
||||||
for privacy-focused applications or when you want to avoid API costs.
|
for privacy-focused applications or when you want to avoid API costs.
|
||||||
|
|
||||||
|
**Note:** Use `OpenAIGenericClient` (not `OpenAIClient`) for Ollama and other OpenAI-compatible providers like LM Studio. The `OpenAIGenericClient` is optimized for local models with a higher default max token limit (16K vs 8K) and full support for structured outputs.
|
||||||
|
|
||||||
Install the models:
|
Install the models:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
from typing import ClassVar
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
@ -59,15 +59,20 @@ class OpenAIGenericClient(LLMClient):
|
||||||
MAX_RETRIES: ClassVar[int] = 2
|
MAX_RETRIES: ClassVar[int] = 2
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
self,
|
||||||
|
config: LLMConfig | None = None,
|
||||||
|
cache: bool = False,
|
||||||
|
client: typing.Any = None,
|
||||||
|
max_tokens: int = 16384,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
Initialize the OpenAIGenericClient with the provided configuration, cache setting, and client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||||
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||||
|
max_tokens (int): The maximum number of tokens to generate. Defaults to 16384 (16K) for better compatibility with local models.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# removed caching to simplify the `generate_response` override
|
# removed caching to simplify the `generate_response` override
|
||||||
|
|
@ -79,6 +84,9 @@ class OpenAIGenericClient(LLMClient):
|
||||||
|
|
||||||
super().__init__(config, cache)
|
super().__init__(config, cache)
|
||||||
|
|
||||||
|
# Override max_tokens to support higher limits for local models
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
else:
|
else:
|
||||||
|
|
@ -99,12 +107,25 @@ class OpenAIGenericClient(LLMClient):
|
||||||
elif m.role == 'system':
|
elif m.role == 'system':
|
||||||
openai_messages.append({'role': 'system', 'content': m.content})
|
openai_messages.append({'role': 'system', 'content': m.content})
|
||||||
try:
|
try:
|
||||||
|
# Prepare response format
|
||||||
|
response_format: dict[str, Any] = {'type': 'json_object'}
|
||||||
|
if response_model is not None:
|
||||||
|
schema_name = getattr(response_model, '__name__', 'structured_response')
|
||||||
|
json_schema = response_model.model_json_schema()
|
||||||
|
response_format = {
|
||||||
|
'type': 'json_schema',
|
||||||
|
'json_schema': {
|
||||||
|
'name': schema_name,
|
||||||
|
'schema': json_schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
model=self.model or DEFAULT_MODEL,
|
model=self.model or DEFAULT_MODEL,
|
||||||
messages=openai_messages,
|
messages=openai_messages,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
response_format={'type': 'json_object'},
|
response_format=response_format, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
result = response.choices[0].message.content or ''
|
result = response.choices[0].message.content or ''
|
||||||
return json.loads(result)
|
return json.loads(result)
|
||||||
|
|
@ -126,14 +147,6 @@ class OpenAIGenericClient(LLMClient):
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
|
|
||||||
if response_model is not None:
|
|
||||||
serialized_model = json.dumps(response_model.model_json_schema())
|
|
||||||
messages[
|
|
||||||
-1
|
|
||||||
].content += (
|
|
||||||
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add multilingual extraction instructions
|
# Add multilingual extraction instructions
|
||||||
messages[0].content += get_extraction_language_instruction(group_id)
|
messages[0].content += get_extraction_language_instruction(group_id)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue