diff --git a/graphiti_core/llm_client/openai_client.py b/graphiti_core/llm_client/openai_client.py index 9ac71a56..a1d9010e 100644 --- a/graphiti_core/llm_client/openai_client.py +++ b/graphiti_core/llm_client/openai_client.py @@ -33,13 +33,50 @@ DEFAULT_MODEL = 'gpt-4o-2024-08-06' class OpenAIClient(LLMClient): - def __init__(self, config: LLMConfig | None = None, cache: bool = False): + """ + OpenAIClient is a client class for interacting with OpenAI's language models. + + This class extends the LLMClient and provides methods to initialize the client, + get an embedder, and generate responses from the language model. + + Attributes: + client (AsyncOpenAI): The OpenAI client used to interact with the API. + model (str): The model name to use for generating responses. + temperature (float): The temperature to use for generating responses. + max_tokens (int): The maximum number of tokens to generate in a response. + + Methods: + __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None): + Initializes the OpenAIClient with the provided configuration, cache setting, and client. + + get_embedder() -> typing.Any: + Returns the embedder from the OpenAI client. + + _generate_response(messages: list[Message]) -> dict[str, typing.Any]: + Generates a response from the language model based on the provided messages. + """ + + def __init__( + self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None + ): + """ + Initialize the OpenAIClient with the provided configuration, cache setting, and client. + + Args: + config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. + cache (bool): Whether to use caching for responses. Defaults to False. + client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. + + """ if config is None: config = LLMConfig() super().__init__(config, cache) - self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + if client is None: + self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + else: + self.client = client def get_embedder(self) -> typing.Any: return self.client.embeddings