Fix Azure structured completions (#1039)
update azure azure structured parsing
This commit is contained in:
parent
9cc04e61c9
commit
8d99984204
2 changed files with 160 additions and 16 deletions
|
|
@ -38,8 +38,16 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
||||||
azure_client: AsyncAzureOpenAI,
|
azure_client: AsyncAzureOpenAI,
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
reasoning: str | None = None,
|
||||||
|
verbosity: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(config, cache=False, max_tokens=max_tokens)
|
super().__init__(
|
||||||
|
config,
|
||||||
|
cache=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
reasoning=reasoning,
|
||||||
|
verbosity=verbosity,
|
||||||
|
)
|
||||||
self.client = azure_client
|
self.client = azure_client
|
||||||
|
|
||||||
async def _create_structured_completion(
|
async def _create_structured_completion(
|
||||||
|
|
@ -49,15 +57,29 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
||||||
temperature: float | None,
|
temperature: float | None,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
response_model: type[BaseModel],
|
response_model: type[BaseModel],
|
||||||
|
reasoning: str | None,
|
||||||
|
verbosity: str | None,
|
||||||
):
|
):
|
||||||
"""Create a structured completion using Azure OpenAI's beta parse API."""
|
"""Create a structured completion using Azure OpenAI's responses.parse API."""
|
||||||
return await self.client.beta.chat.completions.parse(
|
supports_reasoning = self._supports_reasoning_features(model)
|
||||||
model=model,
|
request_kwargs = {
|
||||||
messages=messages,
|
'model': model,
|
||||||
temperature=temperature,
|
'input': messages,
|
||||||
max_tokens=max_tokens,
|
'max_output_tokens': max_tokens,
|
||||||
response_format=response_model, # type: ignore
|
'text_format': response_model, # type: ignore
|
||||||
)
|
}
|
||||||
|
|
||||||
|
temperature_value = temperature if not supports_reasoning else None
|
||||||
|
if temperature_value is not None:
|
||||||
|
request_kwargs['temperature'] = temperature_value
|
||||||
|
|
||||||
|
if supports_reasoning and reasoning:
|
||||||
|
request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
|
||||||
|
|
||||||
|
if supports_reasoning and verbosity:
|
||||||
|
request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
|
||||||
|
|
||||||
|
return await self.client.responses.parse(**request_kwargs)
|
||||||
|
|
||||||
async def _create_completion(
|
async def _create_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -68,10 +90,23 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
):
|
):
|
||||||
"""Create a regular completion with JSON format using Azure OpenAI."""
|
"""Create a regular completion with JSON format using Azure OpenAI."""
|
||||||
return await self.client.chat.completions.create(
|
supports_reasoning = self._supports_reasoning_features(model)
|
||||||
model=model,
|
|
||||||
messages=messages,
|
request_kwargs = {
|
||||||
temperature=temperature,
|
'model': model,
|
||||||
max_tokens=max_tokens,
|
'messages': messages,
|
||||||
response_format={'type': 'json_object'},
|
'max_tokens': max_tokens,
|
||||||
)
|
'response_format': {'type': 'json_object'},
|
||||||
|
}
|
||||||
|
|
||||||
|
temperature_value = temperature if not supports_reasoning else None
|
||||||
|
if temperature_value is not None:
|
||||||
|
request_kwargs['temperature'] = temperature_value
|
||||||
|
|
||||||
|
return await self.client.chat.completions.create(**request_kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _supports_reasoning_features(model: str) -> bool:
|
||||||
|
"""Return True when the Azure model supports reasoning/verbosity options."""
|
||||||
|
reasoning_prefixes = ('o1', 'o3', 'gpt-5')
|
||||||
|
return model.startswith(reasoning_prefixes)
|
||||||
|
|
|
||||||
109
tests/llm_client/test_azure_openai_client.py
Normal file
109
tests/llm_client/test_azure_openai_client.py
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||||
|
from graphiti_core.llm_client.config import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DummyResponses:
|
||||||
|
def __init__(self):
|
||||||
|
self.parse_calls: list[dict] = []
|
||||||
|
|
||||||
|
async def parse(self, **kwargs):
|
||||||
|
self.parse_calls.append(kwargs)
|
||||||
|
return SimpleNamespace(output_text='{}')
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChatCompletions:
|
||||||
|
def __init__(self):
|
||||||
|
self.create_calls: list[dict] = []
|
||||||
|
|
||||||
|
async def create(self, **kwargs):
|
||||||
|
self.create_calls.append(kwargs)
|
||||||
|
message = SimpleNamespace(content='{}')
|
||||||
|
choice = SimpleNamespace(message=message)
|
||||||
|
return SimpleNamespace(choices=[choice])
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChat:
|
||||||
|
def __init__(self):
|
||||||
|
self.completions = DummyChatCompletions()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyAzureClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.responses = DummyResponses()
|
||||||
|
self.chat = DummyChat()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyResponseModel(BaseModel):
|
||||||
|
foo: str
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_structured_completion_strips_reasoning_for_unsupported_models():
|
||||||
|
dummy_client = DummyAzureClient()
|
||||||
|
client = AzureOpenAILLMClient(
|
||||||
|
azure_client=dummy_client,
|
||||||
|
config=LLMConfig(),
|
||||||
|
reasoning='minimal',
|
||||||
|
verbosity='low',
|
||||||
|
)
|
||||||
|
|
||||||
|
await client._create_structured_completion(
|
||||||
|
model='gpt-4.1',
|
||||||
|
messages=[],
|
||||||
|
temperature=0.4,
|
||||||
|
max_tokens=64,
|
||||||
|
response_model=DummyResponseModel,
|
||||||
|
reasoning='minimal',
|
||||||
|
verbosity='low',
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dummy_client.responses.parse_calls) == 1
|
||||||
|
call_args = dummy_client.responses.parse_calls[0]
|
||||||
|
assert call_args['model'] == 'gpt-4.1'
|
||||||
|
assert call_args['input'] == []
|
||||||
|
assert call_args['max_output_tokens'] == 64
|
||||||
|
assert call_args['text_format'] is DummyResponseModel
|
||||||
|
assert call_args['temperature'] == 0.4
|
||||||
|
assert 'reasoning' not in call_args
|
||||||
|
assert 'text' not in call_args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reasoning_fields_forwarded_for_supported_models():
|
||||||
|
dummy_client = DummyAzureClient()
|
||||||
|
client = AzureOpenAILLMClient(
|
||||||
|
azure_client=dummy_client,
|
||||||
|
config=LLMConfig(),
|
||||||
|
reasoning='intense',
|
||||||
|
verbosity='high',
|
||||||
|
)
|
||||||
|
|
||||||
|
await client._create_structured_completion(
|
||||||
|
model='o1-custom',
|
||||||
|
messages=[],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=128,
|
||||||
|
response_model=DummyResponseModel,
|
||||||
|
reasoning='intense',
|
||||||
|
verbosity='high',
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = dummy_client.responses.parse_calls[0]
|
||||||
|
assert 'temperature' not in call_args
|
||||||
|
assert call_args['reasoning'] == {'effort': 'intense'}
|
||||||
|
assert call_args['text'] == {'verbosity': 'high'}
|
||||||
|
|
||||||
|
await client._create_completion(
|
||||||
|
model='o1-custom',
|
||||||
|
messages=[],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_args = dummy_client.chat.completions.create_calls[0]
|
||||||
|
assert 'temperature' not in create_args
|
||||||
Loading…
Add table
Reference in a new issue