feat: support OpenAIClient in OpenAIRerankerClient (#676)

fix typing
This commit is contained in:
Daniel Chalef 2025-07-04 17:37:41 -07:00 committed by GitHub
parent 8977138a43
commit d1e150f7d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -22,7 +22,7 @@ import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI
from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
from ..prompts import Message
from .client import CrossEncoderClient
@ -35,7 +35,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
def __init__(
self,
config: LLMConfig | None = None,
client: AsyncOpenAI | AsyncAzureOpenAI | None = None,
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
):
"""
Initialize the OpenAIRerankerClient with the provided configuration and client.
@ -45,7 +45,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
if config is None:
config = LLMConfig()
@ -53,6 +53,8 @@ class OpenAIRerankerClient(CrossEncoderClient):
self.config = config
if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
elif isinstance(client, OpenAIClient):
self.client = client.client
else:
self.client = client