feat: Refactor OpenAIClient initialization and add client parameter (#140)
The code changes refactor the `OpenAIClient` initialization to accept an optional `client` parameter. This allows the client to be passed in from outside, providing more flexibility and enabling easier testing.
This commit is contained in:
parent
32b51530ec
commit
9b71b46c0f
1 changed files with 39 additions and 2 deletions
|
|
@ -33,13 +33,50 @@ DEFAULT_MODEL = 'gpt-4o-2024-08-06'
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient(LLMClient):
|
class OpenAIClient(LLMClient):
|
||||||
def __init__(self, config: LLMConfig | None = None, cache: bool = False):
|
"""
|
||||||
|
OpenAIClient is a client class for interacting with OpenAI's language models.
|
||||||
|
|
||||||
|
This class extends the LLMClient and provides methods to initialize the client,
|
||||||
|
get an embedder, and generate responses from the language model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
client (AsyncOpenAI): The OpenAI client used to interact with the API.
|
||||||
|
model (str): The model name to use for generating responses.
|
||||||
|
temperature (float): The temperature to use for generating responses.
|
||||||
|
max_tokens (int): The maximum number of tokens to generate in a response.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
|
||||||
|
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
|
||||||
|
|
||||||
|
get_embedder() -> typing.Any:
|
||||||
|
Returns the embedder from the OpenAI client.
|
||||||
|
|
||||||
|
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
||||||
|
Generates a response from the language model based on the provided messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||||
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||||
|
|
||||||
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
|
|
||||||
super().__init__(config, cache)
|
super().__init__(config, cache)
|
||||||
|
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
if client is None:
|
||||||
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
def get_embedder(self) -> typing.Any:
|
def get_embedder(self) -> typing.Any:
|
||||||
return self.client.embeddings
|
return self.client.embeddings
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue