feat: support OpenAIClient in OpenAIRerankerClient (#676)

fix typing
This commit is contained in:
Daniel Chalef 2025-07-04 17:37:41 -07:00 committed by GitHub
parent 8977138a43
commit d1e150f7d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -22,7 +22,7 @@ import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI
from ..helpers import semaphore_gather from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
from ..prompts import Message from ..prompts import Message
from .client import CrossEncoderClient from .client import CrossEncoderClient
@ -35,7 +35,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
def __init__( def __init__(
self, self,
config: LLMConfig | None = None, config: LLMConfig | None = None,
client: AsyncOpenAI | AsyncAzureOpenAI | None = None, client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
): ):
""" """
Initialize the OpenAIRerankerClient with the provided configuration and client. Initialize the OpenAIRerankerClient with the provided configuration and client.
@ -45,7 +45,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
Args: Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
""" """
if config is None: if config is None:
config = LLMConfig() config = LLMConfig()
@ -53,6 +53,8 @@ class OpenAIRerankerClient(CrossEncoderClient):
self.config = config self.config = config
if client is None: if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
elif isinstance(client, OpenAIClient):
self.client = client.client
else: else:
self.client = client self.client = client