parent
a9091b06ff
commit
9f3dd5552a
2 changed files with 222 additions and 130 deletions
File diff suppressed because one or more lines are too long
162
graphiti_core/llm_client/openai_generic_client.py
Normal file
162
graphiti_core/llm_client/openai_generic_client.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
from typing import ClassVar
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .errors import RateLimitError, RefusalError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'gpt-4o-mini'
|
||||
|
||||
|
||||
class OpenAIGenericClient(LLMClient):
|
||||
"""
|
||||
OpenAIClient is a client class for interacting with OpenAI's language models.
|
||||
|
||||
This class extends the LLMClient and provides methods to initialize the client,
|
||||
get an embedder, and generate responses from the language model.
|
||||
|
||||
Attributes:
|
||||
client (AsyncOpenAI): The OpenAI client used to interact with the API.
|
||||
model (str): The model name to use for generating responses.
|
||||
temperature (float): The temperature to use for generating responses.
|
||||
max_tokens (int): The maximum number of tokens to generate in a response.
|
||||
|
||||
Methods:
|
||||
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
|
||||
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
|
||||
|
||||
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
||||
Generates a response from the language model based on the provided messages.
|
||||
"""
|
||||
|
||||
# Class-level constants
|
||||
MAX_RETRIES: ClassVar[int] = 2
|
||||
|
||||
def __init__(
|
||||
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
||||
|
||||
Args:
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||
|
||||
"""
|
||||
# removed caching to simplify the `generate_response` override
|
||||
if cache:
|
||||
raise NotImplementedError('Caching is not implemented for OpenAI')
|
||||
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
||||
super().__init__(config, cache)
|
||||
|
||||
if client is None:
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
async def _generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
) -> dict[str, typing.Any]:
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
if m.role == 'user':
|
||||
openai_messages.append({'role': 'user', 'content': m.content})
|
||||
elif m.role == 'system':
|
||||
openai_messages.append({'role': 'system', 'content': m.content})
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model or DEFAULT_MODEL,
|
||||
messages=openai_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
result = response.choices[0].message.content or ''
|
||||
return json.loads(result)
|
||||
except openai.RateLimitError as e:
|
||||
raise RateLimitError from e
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
||||
async def generate_response(
|
||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
||||
) -> dict[str, typing.Any]:
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
if response_model is not None:
|
||||
serialized_model = json.dumps(response_model.model_json_schema())
|
||||
messages[
|
||||
-1
|
||||
].content += (
|
||||
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
||||
)
|
||||
|
||||
while retry_count <= self.MAX_RETRIES:
|
||||
try:
|
||||
response = await self._generate_response(messages, response_model)
|
||||
return response
|
||||
except (RateLimitError, RefusalError):
|
||||
# These errors should not trigger retries
|
||||
raise
|
||||
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
||||
# Let OpenAI's client handle these retries
|
||||
raise
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# Don't retry if we've hit the max retries
|
||||
if retry_count >= self.MAX_RETRIES:
|
||||
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
||||
raise
|
||||
|
||||
retry_count += 1
|
||||
|
||||
# Construct a detailed error message for the LLM
|
||||
error_context = (
|
||||
f'The previous response attempt was invalid. '
|
||||
f'Error type: {e.__class__.__name__}. '
|
||||
f'Error details: {str(e)}. '
|
||||
f'Please try again with a valid response, ensuring the output matches '
|
||||
f'the expected format and constraints.'
|
||||
)
|
||||
|
||||
error_message = Message(role='user', content=error_context)
|
||||
messages.append(error_message)
|
||||
logger.warning(
|
||||
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
||||
)
|
||||
|
||||
# If we somehow get here, raise the last error
|
||||
raise last_error or Exception('Max retries exceeded with no specific error')
|
||||
Loading…
Add table
Reference in a new issue