feat: Refactor OpenAIClient initialization and add client parameter (#140)

The code changes refactor the `OpenAIClient` initialization to accept an optional `client` parameter. This allows the client to be passed in from outside, providing more flexibility and enabling easier testing.
This commit is contained in:
Daniel Chalef 2024-09-21 12:09:04 -07:00 committed by GitHub
parent 32b51530ec
commit 9b71b46c0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -33,13 +33,50 @@ DEFAULT_MODEL = 'gpt-4o-2024-08-06'
class OpenAIClient(LLMClient):
def __init__(self, config: LLMConfig | None = None, cache: bool = False):
"""
OpenAIClient is a client class for interacting with OpenAI's language models.
This class extends the LLMClient and provides methods to initialize the client,
get an embedder, and generate responses from the language model.
Attributes:
client (AsyncOpenAI): The OpenAI client used to interact with the API.
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
get_embedder() -> typing.Any:
Returns the embedder from the OpenAI client.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
def __init__(
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
):
"""
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
if config is None:
config = LLMConfig()
super().__init__(config, cache)
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else:
self.client = client
def get_embedder(self) -> typing.Any:
return self.client.embeddings