From d1e150f7d78c9fc3566d42d947d3bebac46c2ca1 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Fri, 4 Jul 2025 17:37:41 -0700 Subject: [PATCH] feat: support OpenAIClient in OpenAIRerankerClient (#676) fix typing --- graphiti_core/cross_encoder/openai_reranker_client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index 3901be9e..55cfb40f 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -22,7 +22,7 @@ import openai from openai import AsyncAzureOpenAI, AsyncOpenAI from ..helpers import semaphore_gather -from ..llm_client import LLMConfig, RateLimitError +from ..llm_client import LLMConfig, OpenAIClient, RateLimitError from ..prompts import Message from .client import CrossEncoderClient @@ -35,7 +35,7 @@ class OpenAIRerankerClient(CrossEncoderClient): def __init__( self, config: LLMConfig | None = None, - client: AsyncOpenAI | AsyncAzureOpenAI | None = None, + client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None, ): """ Initialize the OpenAIRerankerClient with the provided configuration and client. @@ -45,7 +45,7 @@ class OpenAIRerankerClient(CrossEncoderClient): Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. - client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. + client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. """ if config is None: config = LLMConfig() @@ -53,6 +53,8 @@ class OpenAIRerankerClient(CrossEncoderClient): self.config = config if client is None: self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + elif isinstance(client, OpenAIClient): + self.client = client.client else: self.client = client