add generic client (#237)

* add generic client

* format
This commit is contained in:
Preston Rasmussen 2024-12-10 22:02:46 -05:00 committed by GitHub
parent a9091b06ff
commit 9f3dd5552a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 222 additions and 130 deletions

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,162 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from typing import ClassVar
import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4o-mini'
class OpenAIGenericClient(LLMClient):
"""
OpenAIClient is a client class for interacting with OpenAI's language models.
This class extends the LLMClient and provides methods to initialize the client,
get an embedder, and generate responses from the language model.
Attributes:
client (AsyncOpenAI): The OpenAI client used to interact with the API.
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
):
"""
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
# removed caching to simplify the `generate_response` override
if cache:
raise NotImplementedError('Caching is not implemented for OpenAI')
if config is None:
config = LLMConfig()
super().__init__(config, cache)
if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else:
self.client = client
async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
try:
response = await self.client.chat.completions.create(
model=self.model or DEFAULT_MODEL,
messages=openai_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
response_format={'type': 'json_object'},
)
result = response.choices[0].message.content or ''
return json.loads(result)
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None
) -> dict[str, typing.Any]:
retry_count = 0
last_error = None
if response_model is not None:
serialized_model = json.dumps(response_model.model_json_schema())
messages[
-1
].content += (
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
)
while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(messages, response_model)
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries
raise
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
# Let OpenAI's client handle these retries
raise
except Exception as e:
last_error = e
# Don't retry if we've hit the max retries
if retry_count >= self.MAX_RETRIES:
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
raise
retry_count += 1
# Construct a detailed error message for the LLM
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response, ensuring the output matches '
f'the expected format and constraints.'
)
error_message = Message(role='user', content=error_context)
messages.append(error_message)
logger.warning(
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
)
# If we somehow get here, raise the last error
raise last_error or Exception('Max retries exceeded with no specific error')