parent
8977138a43
commit
d1e150f7d7
1 changed files with 5 additions and 3 deletions
|
|
@ -22,7 +22,7 @@ import openai
|
||||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||||
|
|
||||||
from ..helpers import semaphore_gather
|
from ..helpers import semaphore_gather
|
||||||
from ..llm_client import LLMConfig, RateLimitError
|
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
|
||||||
from ..prompts import Message
|
from ..prompts import Message
|
||||||
from .client import CrossEncoderClient
|
from .client import CrossEncoderClient
|
||||||
|
|
||||||
|
|
@ -35,7 +35,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
client: AsyncOpenAI | AsyncAzureOpenAI | None = None,
|
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
||||||
|
|
@ -45,7 +45,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
|
|
@ -53,6 +53,8 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
self.config = config
|
self.config = config
|
||||||
if client is None:
|
if client is None:
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
elif isinstance(client, OpenAIClient):
|
||||||
|
self.client = client.client
|
||||||
else:
|
else:
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue