Fix Azure structured completions (#1039)
update azure azure structured parsing
This commit is contained in:
parent
9cc04e61c9
commit
8d99984204
2 changed files with 160 additions and 16 deletions
|
|
@ -38,8 +38,16 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
|||
azure_client: AsyncAzureOpenAI,
|
||||
config: LLMConfig | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
reasoning: str | None = None,
|
||||
verbosity: str | None = None,
|
||||
):
|
||||
super().__init__(config, cache=False, max_tokens=max_tokens)
|
||||
super().__init__(
|
||||
config,
|
||||
cache=False,
|
||||
max_tokens=max_tokens,
|
||||
reasoning=reasoning,
|
||||
verbosity=verbosity,
|
||||
)
|
||||
self.client = azure_client
|
||||
|
||||
async def _create_structured_completion(
|
||||
|
|
@ -49,15 +57,29 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
|||
temperature: float | None,
|
||||
max_tokens: int,
|
||||
response_model: type[BaseModel],
|
||||
reasoning: str | None,
|
||||
verbosity: str | None,
|
||||
):
|
||||
"""Create a structured completion using Azure OpenAI's beta parse API."""
|
||||
return await self.client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_model, # type: ignore
|
||||
)
|
||||
"""Create a structured completion using Azure OpenAI's responses.parse API."""
|
||||
supports_reasoning = self._supports_reasoning_features(model)
|
||||
request_kwargs = {
|
||||
'model': model,
|
||||
'input': messages,
|
||||
'max_output_tokens': max_tokens,
|
||||
'text_format': response_model, # type: ignore
|
||||
}
|
||||
|
||||
temperature_value = temperature if not supports_reasoning else None
|
||||
if temperature_value is not None:
|
||||
request_kwargs['temperature'] = temperature_value
|
||||
|
||||
if supports_reasoning and reasoning:
|
||||
request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
|
||||
|
||||
if supports_reasoning and verbosity:
|
||||
request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
|
||||
|
||||
return await self.client.responses.parse(**request_kwargs)
|
||||
|
||||
async def _create_completion(
|
||||
self,
|
||||
|
|
@ -68,10 +90,23 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
|||
response_model: type[BaseModel] | None = None,
|
||||
):
|
||||
"""Create a regular completion with JSON format using Azure OpenAI."""
|
||||
return await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
supports_reasoning = self._supports_reasoning_features(model)
|
||||
|
||||
request_kwargs = {
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'max_tokens': max_tokens,
|
||||
'response_format': {'type': 'json_object'},
|
||||
}
|
||||
|
||||
temperature_value = temperature if not supports_reasoning else None
|
||||
if temperature_value is not None:
|
||||
request_kwargs['temperature'] = temperature_value
|
||||
|
||||
return await self.client.chat.completions.create(**request_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _supports_reasoning_features(model: str) -> bool:
|
||||
"""Return True when the Azure model supports reasoning/verbosity options."""
|
||||
reasoning_prefixes = ('o1', 'o3', 'gpt-5')
|
||||
return model.startswith(reasoning_prefixes)
|
||||
|
|
|
|||
109
tests/llm_client/test_azure_openai_client.py
Normal file
109
tests/llm_client/test_azure_openai_client.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
|
||||
|
||||
class DummyResponses:
|
||||
def __init__(self):
|
||||
self.parse_calls: list[dict] = []
|
||||
|
||||
async def parse(self, **kwargs):
|
||||
self.parse_calls.append(kwargs)
|
||||
return SimpleNamespace(output_text='{}')
|
||||
|
||||
|
||||
class DummyChatCompletions:
|
||||
def __init__(self):
|
||||
self.create_calls: list[dict] = []
|
||||
|
||||
async def create(self, **kwargs):
|
||||
self.create_calls.append(kwargs)
|
||||
message = SimpleNamespace(content='{}')
|
||||
choice = SimpleNamespace(message=message)
|
||||
return SimpleNamespace(choices=[choice])
|
||||
|
||||
|
||||
class DummyChat:
|
||||
def __init__(self):
|
||||
self.completions = DummyChatCompletions()
|
||||
|
||||
|
||||
class DummyAzureClient:
|
||||
def __init__(self):
|
||||
self.responses = DummyResponses()
|
||||
self.chat = DummyChat()
|
||||
|
||||
|
||||
class DummyResponseModel(BaseModel):
|
||||
foo: str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_completion_strips_reasoning_for_unsupported_models():
|
||||
dummy_client = DummyAzureClient()
|
||||
client = AzureOpenAILLMClient(
|
||||
azure_client=dummy_client,
|
||||
config=LLMConfig(),
|
||||
reasoning='minimal',
|
||||
verbosity='low',
|
||||
)
|
||||
|
||||
await client._create_structured_completion(
|
||||
model='gpt-4.1',
|
||||
messages=[],
|
||||
temperature=0.4,
|
||||
max_tokens=64,
|
||||
response_model=DummyResponseModel,
|
||||
reasoning='minimal',
|
||||
verbosity='low',
|
||||
)
|
||||
|
||||
assert len(dummy_client.responses.parse_calls) == 1
|
||||
call_args = dummy_client.responses.parse_calls[0]
|
||||
assert call_args['model'] == 'gpt-4.1'
|
||||
assert call_args['input'] == []
|
||||
assert call_args['max_output_tokens'] == 64
|
||||
assert call_args['text_format'] is DummyResponseModel
|
||||
assert call_args['temperature'] == 0.4
|
||||
assert 'reasoning' not in call_args
|
||||
assert 'text' not in call_args
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_fields_forwarded_for_supported_models():
|
||||
dummy_client = DummyAzureClient()
|
||||
client = AzureOpenAILLMClient(
|
||||
azure_client=dummy_client,
|
||||
config=LLMConfig(),
|
||||
reasoning='intense',
|
||||
verbosity='high',
|
||||
)
|
||||
|
||||
await client._create_structured_completion(
|
||||
model='o1-custom',
|
||||
messages=[],
|
||||
temperature=0.7,
|
||||
max_tokens=128,
|
||||
response_model=DummyResponseModel,
|
||||
reasoning='intense',
|
||||
verbosity='high',
|
||||
)
|
||||
|
||||
call_args = dummy_client.responses.parse_calls[0]
|
||||
assert 'temperature' not in call_args
|
||||
assert call_args['reasoning'] == {'effort': 'intense'}
|
||||
assert call_args['text'] == {'verbosity': 'high'}
|
||||
|
||||
await client._create_completion(
|
||||
model='o1-custom',
|
||||
messages=[],
|
||||
temperature=0.7,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
create_args = dummy_client.chat.completions.create_calls[0]
|
||||
assert 'temperature' not in create_args
|
||||
Loading…
Add table
Reference in a new issue