Fix Azure structured completions (#1039)

update azure azure structured parsing
This commit is contained in:
Daniel Chalef 2025-11-01 18:40:43 -07:00 committed by GitHub
parent 9cc04e61c9
commit 8d99984204
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 160 additions and 16 deletions

View file

@ -38,8 +38,16 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
azure_client: AsyncAzureOpenAI,
config: LLMConfig | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
reasoning: str | None = None,
verbosity: str | None = None,
):
super().__init__(config, cache=False, max_tokens=max_tokens)
super().__init__(
config,
cache=False,
max_tokens=max_tokens,
reasoning=reasoning,
verbosity=verbosity,
)
self.client = azure_client
async def _create_structured_completion(
@ -49,15 +57,29 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
temperature: float | None,
max_tokens: int,
response_model: type[BaseModel],
reasoning: str | None,
verbosity: str | None,
):
"""Create a structured completion using Azure OpenAI's beta parse API."""
return await self.client.beta.chat.completions.parse(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_model, # type: ignore
)
"""Create a structured completion using Azure OpenAI's responses.parse API."""
supports_reasoning = self._supports_reasoning_features(model)
request_kwargs = {
'model': model,
'input': messages,
'max_output_tokens': max_tokens,
'text_format': response_model, # type: ignore
}
temperature_value = temperature if not supports_reasoning else None
if temperature_value is not None:
request_kwargs['temperature'] = temperature_value
if supports_reasoning and reasoning:
request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
if supports_reasoning and verbosity:
request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
return await self.client.responses.parse(**request_kwargs)
async def _create_completion(
self,
@ -68,10 +90,23 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
response_model: type[BaseModel] | None = None,
):
"""Create a regular completion with JSON format using Azure OpenAI."""
return await self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format={'type': 'json_object'},
)
supports_reasoning = self._supports_reasoning_features(model)
request_kwargs = {
'model': model,
'messages': messages,
'max_tokens': max_tokens,
'response_format': {'type': 'json_object'},
}
temperature_value = temperature if not supports_reasoning else None
if temperature_value is not None:
request_kwargs['temperature'] = temperature_value
return await self.client.chat.completions.create(**request_kwargs)
@staticmethod
def _supports_reasoning_features(model: str) -> bool:
"""Return True when the Azure model supports reasoning/verbosity options."""
reasoning_prefixes = ('o1', 'o3', 'gpt-5')
return model.startswith(reasoning_prefixes)

View file

@ -0,0 +1,109 @@
from types import SimpleNamespace
import pytest
from pydantic import BaseModel
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.config import LLMConfig
class DummyResponses:
def __init__(self):
self.parse_calls: list[dict] = []
async def parse(self, **kwargs):
self.parse_calls.append(kwargs)
return SimpleNamespace(output_text='{}')
class DummyChatCompletions:
def __init__(self):
self.create_calls: list[dict] = []
async def create(self, **kwargs):
self.create_calls.append(kwargs)
message = SimpleNamespace(content='{}')
choice = SimpleNamespace(message=message)
return SimpleNamespace(choices=[choice])
class DummyChat:
def __init__(self):
self.completions = DummyChatCompletions()
class DummyAzureClient:
def __init__(self):
self.responses = DummyResponses()
self.chat = DummyChat()
class DummyResponseModel(BaseModel):
foo: str
@pytest.mark.asyncio
async def test_structured_completion_strips_reasoning_for_unsupported_models():
dummy_client = DummyAzureClient()
client = AzureOpenAILLMClient(
azure_client=dummy_client,
config=LLMConfig(),
reasoning='minimal',
verbosity='low',
)
await client._create_structured_completion(
model='gpt-4.1',
messages=[],
temperature=0.4,
max_tokens=64,
response_model=DummyResponseModel,
reasoning='minimal',
verbosity='low',
)
assert len(dummy_client.responses.parse_calls) == 1
call_args = dummy_client.responses.parse_calls[0]
assert call_args['model'] == 'gpt-4.1'
assert call_args['input'] == []
assert call_args['max_output_tokens'] == 64
assert call_args['text_format'] is DummyResponseModel
assert call_args['temperature'] == 0.4
assert 'reasoning' not in call_args
assert 'text' not in call_args
@pytest.mark.asyncio
async def test_reasoning_fields_forwarded_for_supported_models():
dummy_client = DummyAzureClient()
client = AzureOpenAILLMClient(
azure_client=dummy_client,
config=LLMConfig(),
reasoning='intense',
verbosity='high',
)
await client._create_structured_completion(
model='o1-custom',
messages=[],
temperature=0.7,
max_tokens=128,
response_model=DummyResponseModel,
reasoning='intense',
verbosity='high',
)
call_args = dummy_client.responses.parse_calls[0]
assert 'temperature' not in call_args
assert call_args['reasoning'] == {'effort': 'intense'}
assert call_args['text'] == {'verbosity': 'high'}
await client._create_completion(
model='o1-custom',
messages=[],
temperature=0.7,
max_tokens=128,
)
create_args = dummy_client.chat.completions.create_calls[0]
assert 'temperature' not in create_args