Azure OpenAI improvements and fixes; Improve Graphiti Azure OpenAI config (#620)
* Azure OpenAI improvements and fixes; Improve Graphiti Azure OpenAI config * format
This commit is contained in:
parent
587f1b9876
commit
9cc2e86071
5 changed files with 307 additions and 178 deletions
|
|
@ -14,60 +14,64 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import ClassVar
|
||||||
|
|
||||||
from openai import AsyncAzureOpenAI
|
from openai import AsyncAzureOpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||||
from .client import LLMClient
|
from .openai_base_client import BaseOpenAIClient
|
||||||
from .config import LLMConfig, ModelSize
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAILLMClient(LLMClient):
|
class AzureOpenAILLMClient(BaseOpenAIClient):
|
||||||
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
|
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
|
||||||
|
|
||||||
def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
|
# Class-level constants
|
||||||
super().__init__(config, cache=False)
|
MAX_RETRIES: ClassVar[int] = 2
|
||||||
self.azure_client = azure_client
|
|
||||||
|
|
||||||
async def _generate_response(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
azure_client: AsyncAzureOpenAI,
|
||||||
|
config: LLMConfig | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
):
|
||||||
|
super().__init__(config, cache=False, max_tokens=max_tokens)
|
||||||
|
self.client = azure_client
|
||||||
|
|
||||||
|
async def _create_structured_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
temperature: float | None,
|
||||||
|
max_tokens: int,
|
||||||
|
response_model: type[BaseModel],
|
||||||
|
):
|
||||||
|
"""Create a structured completion using Azure OpenAI's beta parse API."""
|
||||||
|
return await self.client.beta.chat.completions.parse(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
response_format=response_model, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _create_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
temperature: float | None,
|
||||||
|
max_tokens: int,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int = 1024,
|
):
|
||||||
model_size: ModelSize = ModelSize.medium,
|
"""Create a regular completion with JSON format using Azure OpenAI."""
|
||||||
) -> dict[str, Any]:
|
return await self.client.chat.completions.create(
|
||||||
"""Generate response using Azure OpenAI client."""
|
model=model,
|
||||||
# Convert messages to OpenAI format
|
messages=messages,
|
||||||
openai_messages: list[ChatCompletionMessageParam] = []
|
temperature=temperature,
|
||||||
for message in messages:
|
max_tokens=max_tokens,
|
||||||
message.content = self._clean_input(message.content)
|
response_format={'type': 'json_object'},
|
||||||
if message.role == 'user':
|
)
|
||||||
openai_messages.append({'role': 'user', 'content': message.content})
|
|
||||||
elif message.role == 'system':
|
|
||||||
openai_messages.append({'role': 'system', 'content': message.content})
|
|
||||||
|
|
||||||
# Ensure model is a string
|
|
||||||
model_name = self.model if self.model else 'gpt-4o-mini'
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.azure_client.chat.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
messages=openai_messages,
|
|
||||||
temperature=float(self.temperature) if self.temperature is not None else 0.7,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
response_format={'type': 'json_object'},
|
|
||||||
)
|
|
||||||
result = response.choices[0].message.content or '{}'
|
|
||||||
|
|
||||||
# Parse JSON response
|
|
||||||
return json.loads(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Error in Azure OpenAI LLM response: {e}')
|
|
||||||
raise
|
|
||||||
|
|
|
||||||
217
graphiti_core/llm_client/openai_base_client.py
Normal file
217
graphiti_core/llm_client/openai_base_client.py
Normal file
|
|
@ -0,0 +1,217 @@
|
||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..prompts.models import Message
|
||||||
|
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
||||||
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||||
|
from .errors import RateLimitError, RefusalError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MODEL = 'gpt-4.1-mini'
|
||||||
|
DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOpenAIClient(LLMClient):
|
||||||
|
"""
|
||||||
|
Base client class for OpenAI-compatible APIs (OpenAI and Azure OpenAI).
|
||||||
|
|
||||||
|
This class contains shared logic for both OpenAI and Azure OpenAI clients,
|
||||||
|
reducing code duplication while allowing for implementation-specific differences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Class-level constants
|
||||||
|
MAX_RETRIES: ClassVar[int] = 2
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LLMConfig | None = None,
|
||||||
|
cache: bool = False,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
):
|
||||||
|
if cache:
|
||||||
|
raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = LLMConfig()
|
||||||
|
|
||||||
|
super().__init__(config, cache)
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _create_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
temperature: float | None,
|
||||||
|
max_tokens: int,
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Create a completion using the specific client implementation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _create_structured_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
temperature: float | None,
|
||||||
|
max_tokens: int,
|
||||||
|
response_model: type[BaseModel],
|
||||||
|
) -> Any:
|
||||||
|
"""Create a structured completion using the specific client implementation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _convert_messages_to_openai_format(
|
||||||
|
self, messages: list[Message]
|
||||||
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""Convert internal Message format to OpenAI ChatCompletionMessageParam format."""
|
||||||
|
openai_messages: list[ChatCompletionMessageParam] = []
|
||||||
|
for m in messages:
|
||||||
|
m.content = self._clean_input(m.content)
|
||||||
|
if m.role == 'user':
|
||||||
|
openai_messages.append({'role': 'user', 'content': m.content})
|
||||||
|
elif m.role == 'system':
|
||||||
|
openai_messages.append({'role': 'system', 'content': m.content})
|
||||||
|
return openai_messages
|
||||||
|
|
||||||
|
def _get_model_for_size(self, model_size: ModelSize) -> str:
|
||||||
|
"""Get the appropriate model name based on the requested size."""
|
||||||
|
if model_size == ModelSize.small:
|
||||||
|
return self.small_model or DEFAULT_SMALL_MODEL
|
||||||
|
else:
|
||||||
|
return self.model or DEFAULT_MODEL
|
||||||
|
|
||||||
|
def _handle_structured_response(self, response: Any) -> dict[str, Any]:
|
||||||
|
"""Handle structured response parsing and validation."""
|
||||||
|
response_object = response.choices[0].message
|
||||||
|
|
||||||
|
if response_object.parsed:
|
||||||
|
return response_object.parsed.model_dump()
|
||||||
|
elif response_object.refusal:
|
||||||
|
raise RefusalError(response_object.refusal)
|
||||||
|
else:
|
||||||
|
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
|
||||||
|
|
||||||
|
def _handle_json_response(self, response: Any) -> dict[str, Any]:
|
||||||
|
"""Handle JSON response parsing."""
|
||||||
|
result = response.choices[0].message.content or '{}'
|
||||||
|
return json.loads(result)
|
||||||
|
|
||||||
|
async def _generate_response(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Generate a response using the appropriate client implementation."""
|
||||||
|
openai_messages = self._convert_messages_to_openai_format(messages)
|
||||||
|
model = self._get_model_for_size(model_size)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if response_model:
|
||||||
|
response = await self._create_structured_completion(
|
||||||
|
model=model,
|
||||||
|
messages=openai_messages,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=max_tokens or self.max_tokens,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
return self._handle_structured_response(response)
|
||||||
|
else:
|
||||||
|
response = await self._create_completion(
|
||||||
|
model=model,
|
||||||
|
messages=openai_messages,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=max_tokens or self.max_tokens,
|
||||||
|
)
|
||||||
|
return self._handle_json_response(response)
|
||||||
|
|
||||||
|
except openai.LengthFinishReasonError as e:
|
||||||
|
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
|
||||||
|
except openai.RateLimitError as e:
|
||||||
|
raise RateLimitError from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
model_size: ModelSize = ModelSize.medium,
|
||||||
|
) -> dict[str, typing.Any]:
|
||||||
|
"""Generate a response with retry logic and error handling."""
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
|
||||||
|
retry_count = 0
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
# Add multilingual extraction instructions
|
||||||
|
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
||||||
|
|
||||||
|
while retry_count <= self.MAX_RETRIES:
|
||||||
|
try:
|
||||||
|
response = await self._generate_response(
|
||||||
|
messages, response_model, max_tokens, model_size
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except (RateLimitError, RefusalError):
|
||||||
|
# These errors should not trigger retries
|
||||||
|
raise
|
||||||
|
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
||||||
|
# Let OpenAI's client handle these retries
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
|
||||||
|
# Don't retry if we've hit the max retries
|
||||||
|
if retry_count >= self.MAX_RETRIES:
|
||||||
|
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
# Construct a detailed error message for the LLM
|
||||||
|
error_context = (
|
||||||
|
f'The previous response attempt was invalid. '
|
||||||
|
f'Error type: {e.__class__.__name__}. '
|
||||||
|
f'Error details: {str(e)}. '
|
||||||
|
f'Please try again with a valid response, ensuring the output matches '
|
||||||
|
f'the expected format and constraints.'
|
||||||
|
)
|
||||||
|
|
||||||
|
error_message = Message(role='user', content=error_context)
|
||||||
|
messages.append(error_message)
|
||||||
|
logger.warning(
|
||||||
|
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we somehow get here, raise the last error
|
||||||
|
raise last_error or Exception('Max retries exceeded with no specific error')
|
||||||
|
|
@ -14,50 +14,27 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import typing
|
import typing
|
||||||
from typing import ClassVar
|
|
||||||
|
|
||||||
import openai
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||||
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
from .openai_base_client import BaseOpenAIClient
|
||||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
||||||
from .errors import RateLimitError, RefusalError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DEFAULT_MODEL = 'gpt-4.1-mini'
|
|
||||||
DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient(LLMClient):
|
class OpenAIClient(BaseOpenAIClient):
|
||||||
"""
|
"""
|
||||||
OpenAIClient is a client class for interacting with OpenAI's language models.
|
OpenAIClient is a client class for interacting with OpenAI's language models.
|
||||||
|
|
||||||
This class extends the LLMClient and provides methods to initialize the client,
|
This class extends the BaseOpenAIClient and provides OpenAI-specific implementation
|
||||||
get an embedder, and generate responses from the language model.
|
for creating completions.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
client (AsyncOpenAI): The OpenAI client used to interact with the API.
|
client (AsyncOpenAI): The OpenAI client used to interact with the API.
|
||||||
model (str): The model name to use for generating responses.
|
|
||||||
temperature (float): The temperature to use for generating responses.
|
|
||||||
max_tokens (int): The maximum number of tokens to generate in a response.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
|
|
||||||
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
|
|
||||||
|
|
||||||
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
|
||||||
Generates a response from the language model based on the provided messages.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Class-level constants
|
|
||||||
MAX_RETRIES: ClassVar[int] = 2
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
|
|
@ -72,120 +49,47 @@ class OpenAIClient(LLMClient):
|
||||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||||
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# removed caching to simplify the `generate_response` override
|
super().__init__(config, cache, max_tokens)
|
||||||
if cache:
|
|
||||||
raise NotImplementedError('Caching is not implemented for OpenAI')
|
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
|
|
||||||
super().__init__(config, cache)
|
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
else:
|
else:
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
self.max_tokens = max_tokens
|
async def _create_structured_completion(
|
||||||
|
|
||||||
async def _generate_response(
|
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
model: str,
|
||||||
response_model: type[BaseModel] | None = None,
|
messages: list[ChatCompletionMessageParam],
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
temperature: float | None,
|
||||||
model_size: ModelSize = ModelSize.medium,
|
max_tokens: int,
|
||||||
) -> dict[str, typing.Any]:
|
response_model: type[BaseModel],
|
||||||
openai_messages: list[ChatCompletionMessageParam] = []
|
):
|
||||||
for m in messages:
|
"""Create a structured completion using OpenAI's beta parse API."""
|
||||||
m.content = self._clean_input(m.content)
|
return await self.client.beta.chat.completions.parse(
|
||||||
if m.role == 'user':
|
model=model,
|
||||||
openai_messages.append({'role': 'user', 'content': m.content})
|
messages=messages,
|
||||||
elif m.role == 'system':
|
temperature=temperature,
|
||||||
openai_messages.append({'role': 'system', 'content': m.content})
|
max_tokens=max_tokens,
|
||||||
try:
|
response_format=response_model, # type: ignore
|
||||||
if model_size == ModelSize.small:
|
)
|
||||||
model = self.small_model or DEFAULT_SMALL_MODEL
|
|
||||||
else:
|
|
||||||
model = self.model or DEFAULT_MODEL
|
|
||||||
|
|
||||||
response = await self.client.beta.chat.completions.parse(
|
async def _create_completion(
|
||||||
model=model,
|
|
||||||
messages=openai_messages,
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=max_tokens or self.max_tokens,
|
|
||||||
response_format=response_model, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
response_object = response.choices[0].message
|
|
||||||
|
|
||||||
if response_object.parsed:
|
|
||||||
return response_object.parsed.model_dump()
|
|
||||||
elif response_object.refusal:
|
|
||||||
raise RefusalError(response_object.refusal)
|
|
||||||
else:
|
|
||||||
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
|
|
||||||
except openai.LengthFinishReasonError as e:
|
|
||||||
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
raise RateLimitError from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Error in generating LLM response: {e}')
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def generate_response(
|
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
model: str,
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
temperature: float | None,
|
||||||
|
max_tokens: int,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int | None = None,
|
):
|
||||||
model_size: ModelSize = ModelSize.medium,
|
"""Create a regular completion with JSON format."""
|
||||||
) -> dict[str, typing.Any]:
|
return await self.client.chat.completions.create(
|
||||||
if max_tokens is None:
|
model=model,
|
||||||
max_tokens = self.max_tokens
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
retry_count = 0
|
max_tokens=max_tokens,
|
||||||
last_error = None
|
response_format={'type': 'json_object'},
|
||||||
|
)
|
||||||
# Add multilingual extraction instructions
|
|
||||||
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
|
||||||
|
|
||||||
while retry_count <= self.MAX_RETRIES:
|
|
||||||
try:
|
|
||||||
response = await self._generate_response(
|
|
||||||
messages, response_model, max_tokens, model_size
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except (RateLimitError, RefusalError):
|
|
||||||
# These errors should not trigger retries
|
|
||||||
raise
|
|
||||||
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
|
||||||
# Let OpenAI's client handle these retries
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
|
|
||||||
# Don't retry if we've hit the max retries
|
|
||||||
if retry_count >= self.MAX_RETRIES:
|
|
||||||
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
|
||||||
raise
|
|
||||||
|
|
||||||
retry_count += 1
|
|
||||||
|
|
||||||
# Construct a detailed error message for the LLM
|
|
||||||
error_context = (
|
|
||||||
f'The previous response attempt was invalid. '
|
|
||||||
f'Error type: {e.__class__.__name__}. '
|
|
||||||
f'Error details: {str(e)}. '
|
|
||||||
f'Please try again with a valid response, ensuring the output matches '
|
|
||||||
f'the expected format and constraints.'
|
|
||||||
)
|
|
||||||
|
|
||||||
error_message = Message(role='user', content=error_context)
|
|
||||||
messages.append(error_message)
|
|
||||||
logger.warning(
|
|
||||||
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we somehow get here, raise the last error
|
|
||||||
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
||||||
|
|
|
||||||
|
|
@ -78,9 +78,11 @@ The server uses the following environment variables:
|
||||||
- `MODEL_NAME`: OpenAI model name to use for LLM operations.
|
- `MODEL_NAME`: OpenAI model name to use for LLM operations.
|
||||||
- `SMALL_MODEL_NAME`: OpenAI model name to use for smaller LLM operations.
|
- `SMALL_MODEL_NAME`: OpenAI model name to use for smaller LLM operations.
|
||||||
- `LLM_TEMPERATURE`: Temperature for LLM responses (0.0-2.0).
|
- `LLM_TEMPERATURE`: Temperature for LLM responses (0.0-2.0).
|
||||||
- `AZURE_OPENAI_ENDPOINT`: Optional Azure OpenAI endpoint URL
|
- `AZURE_OPENAI_ENDPOINT`: Optional Azure OpenAI LLM endpoint URL
|
||||||
- `AZURE_OPENAI_DEPLOYMENT_NAME`: Optional Azure OpenAI deployment name
|
- `AZURE_OPENAI_DEPLOYMENT_NAME`: Optional Azure OpenAI LLM deployment name
|
||||||
- `AZURE_OPENAI_API_VERSION`: Optional Azure OpenAI API version
|
- `AZURE_OPENAI_API_VERSION`: Optional Azure OpenAI LLM API version
|
||||||
|
- `AZURE_OPENAI_EMBEDDING_API_KEY`: Optional Azure OpenAI Embedding deployment key (if other than `OPENAI_API_KEY`)
|
||||||
|
- `AZURE_OPENAI_EMBEDDING_ENDPOINT`: Optional Azure OpenAI Embedding endpoint URL
|
||||||
- `AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME`: Optional Azure OpenAI embedding deployment name
|
- `AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME`: Optional Azure OpenAI embedding deployment name
|
||||||
- `AZURE_OPENAI_EMBEDDING_API_VERSION`: Optional Azure OpenAI API version
|
- `AZURE_OPENAI_EMBEDDING_API_VERSION`: Optional Azure OpenAI API version
|
||||||
- `AZURE_OPENAI_USE_MANAGED_IDENTITY`: Optional use Azure Managed Identities for authentication
|
- `AZURE_OPENAI_USE_MANAGED_IDENTITY`: Optional use Azure Managed Identities for authentication
|
||||||
|
|
|
||||||
|
|
@ -367,7 +367,7 @@ class GraphitiEmbedderConfig(BaseModel):
|
||||||
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
||||||
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
||||||
|
|
||||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
||||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
||||||
azure_openai_deployment_name = os.environ.get(
|
azure_openai_deployment_name = os.environ.get(
|
||||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||||
|
|
@ -390,7 +390,9 @@ class GraphitiEmbedderConfig(BaseModel):
|
||||||
|
|
||||||
if not azure_openai_use_managed_identity:
|
if not azure_openai_use_managed_identity:
|
||||||
# api key
|
# api key
|
||||||
api_key = os.environ.get('OPENAI_API_KEY', None)
|
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
|
||||||
|
'OPENAI_API_KEY', None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Managed identity
|
# Managed identity
|
||||||
api_key = None
|
api_key = None
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue