diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index 3901be9e..55cfb40f 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -22,7 +22,7 @@ import openai from openai import AsyncAzureOpenAI, AsyncOpenAI from ..helpers import semaphore_gather -from ..llm_client import LLMConfig, RateLimitError +from ..llm_client import LLMConfig, OpenAIClient, RateLimitError from ..prompts import Message from .client import CrossEncoderClient @@ -35,7 +35,7 @@ class OpenAIRerankerClient(CrossEncoderClient): def __init__( self, config: LLMConfig | None = None, - client: AsyncOpenAI | AsyncAzureOpenAI | None = None, + client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None, ): """ Initialize the OpenAIRerankerClient with the provided configuration and client. @@ -45,7 +45,7 @@ class OpenAIRerankerClient(CrossEncoderClient): Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. - client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. + client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. """ if config is None: config = LLMConfig() @@ -53,6 +53,8 @@ class OpenAIRerankerClient(CrossEncoderClient): self.config = config if client is None: self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + elif isinstance(client, OpenAIClient): + self.client = client.client else: self.client = client