From 6cb2f448494ddfc8af84eb6d3d3b3ec5c48adf35 Mon Sep 17 00:00:00 2001 From: supmo668 Date: Tue, 18 Nov 2025 01:21:54 -0800 Subject: [PATCH] fix: OpenAIRerankerClient now properly supports AzureOpenAILLMClient - Updated type detection to use BaseOpenAIClient instead of OpenAIClient - This allows both OpenAIClient and AzureOpenAILLMClient to be properly unwrapped - Added comprehensive test coverage for all client types - Fixes #1006 --- .../cross_encoder/openai_reranker_client.py | 9 +- .../test_openai_reranker_client.py | 145 ++++++++++++++++++ 2 files changed, 150 insertions(+), 4 deletions(-) create mode 100644 tests/cross_encoder/test_openai_reranker_client.py diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index 2e6c5b2f..16f60065 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -22,7 +22,8 @@ import openai from openai import AsyncAzureOpenAI, AsyncOpenAI from ..helpers import semaphore_gather -from ..llm_client import LLMConfig, OpenAIClient, RateLimitError +from ..llm_client import LLMConfig, RateLimitError +from ..llm_client.openai_base_client import BaseOpenAIClient from ..prompts import Message from .client import CrossEncoderClient @@ -35,7 +36,7 @@ class OpenAIRerankerClient(CrossEncoderClient): def __init__( self, config: LLMConfig | None = None, - client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None, + client: AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None = None, ): """ Initialize the OpenAIRerankerClient with the provided configuration and client. @@ -45,7 +46,7 @@ class OpenAIRerankerClient(CrossEncoderClient): Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. - client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. + client (AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. """ if config is None: config = LLMConfig() @@ -53,7 +54,7 @@ class OpenAIRerankerClient(CrossEncoderClient): self.config = config if client is None: self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) - elif isinstance(client, OpenAIClient): + elif isinstance(client, BaseOpenAIClient): self.client = client.client else: self.client = client diff --git a/tests/cross_encoder/test_openai_reranker_client.py b/tests/cross_encoder/test_openai_reranker_client.py new file mode 100644 index 00000000..8dc3ae54 --- /dev/null +++ b/tests/cross_encoder/test_openai_reranker_client.py @@ -0,0 +1,145 @@ +""" +Test file for OpenAIRerankerClient, specifically testing compatibility with +both OpenAIClient and AzureOpenAILLMClient instances. + +This test validates the fix for issue #1006 where OpenAIRerankerClient +failed to properly support AzureOpenAILLMClient. +""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient +from graphiti_core.llm_client import LLMConfig +from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient +from graphiti_core.llm_client.openai_client import OpenAIClient + + +class MockAsyncOpenAI: + """Mock AsyncOpenAI client for testing""" + + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key + self.base_url = base_url + self.chat = MagicMock() + self.chat.completions = MagicMock() + self.chat.completions.create = AsyncMock() + + +class MockAsyncAzureOpenAI: + """Mock AsyncAzureOpenAI client for testing""" + + def __init__(self): + self.chat = MagicMock() + self.chat.completions = MagicMock() + self.chat.completions.create = AsyncMock() + + +@pytest.fixture +def mock_openai_client(): + """Fixture to create a mocked OpenAIClient""" + client = OpenAIClient(config=LLMConfig(api_key='test-key')) + # Replace the internal client with our mock + client.client = MockAsyncOpenAI() + return client + + +@pytest.fixture +def mock_azure_openai_client(): + """Fixture to create a mocked AzureOpenAILLMClient""" + mock_azure = MockAsyncAzureOpenAI() + client = AzureOpenAILLMClient( + azure_client=mock_azure, + config=LLMConfig(api_key='test-key') + ) + return client + + +def test_openai_reranker_accepts_openai_client(mock_openai_client): + """Test that OpenAIRerankerClient properly unwraps OpenAIClient""" + # Create reranker with OpenAIClient + reranker = OpenAIRerankerClient(client=mock_openai_client) + + # Verify the internal client is the unwrapped AsyncOpenAI instance + assert reranker.client == mock_openai_client.client + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_accepts_azure_client(mock_azure_openai_client): + """Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient + + This test validates the fix for issue #1006. + """ + # Create reranker with AzureOpenAILLMClient - this would fail before the fix + reranker = OpenAIRerankerClient(client=mock_azure_openai_client) + + # Verify the internal client is the unwrapped AsyncAzureOpenAI instance + assert reranker.client == mock_azure_openai_client.client + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_accepts_async_openai_directly(): + """Test that OpenAIRerankerClient accepts AsyncOpenAI directly""" + # Create a mock AsyncOpenAI + mock_async = MockAsyncOpenAI(api_key='test-key') + + # Create reranker with AsyncOpenAI directly + reranker = OpenAIRerankerClient(client=mock_async) + + # Verify the internal client is used as-is + assert reranker.client == mock_async + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_creates_default_client(): + """Test that OpenAIRerankerClient creates a default client when none provided""" + config = LLMConfig(api_key='test-key') + + # Create reranker without client + reranker = OpenAIRerankerClient(config=config) + + # Verify a client was created + assert reranker.client is not None + # The default should be an AsyncOpenAI instance + from openai import AsyncOpenAI + assert isinstance(reranker.client, AsyncOpenAI) + + +@pytest.mark.asyncio +async def test_rank_method_with_azure_client(mock_azure_openai_client): + """Test that rank method works correctly with AzureOpenAILLMClient""" + # Setup mock response for the chat completions + mock_response = SimpleNamespace( + choices=[ + SimpleNamespace( + logprobs=SimpleNamespace( + content=[ + SimpleNamespace( + top_logprobs=[ + SimpleNamespace(token='True', logprob=-0.5) + ] + ) + ] + ) + ) + ] + ) + + mock_azure_openai_client.client.chat.completions.create.return_value = mock_response + + # Create reranker with AzureOpenAILLMClient + reranker = OpenAIRerankerClient(client=mock_azure_openai_client) + + # Test ranking + query = "test query" + passages = ["passage 1"] + + # This would previously fail with AttributeError before the fix + results = await reranker.rank(query, passages) + + # Verify the method was called + assert mock_azure_openai_client.client.chat.completions.create.called + assert len(results) == 1 + assert results[0][0] == "passage 1"