From 6cb2f448494ddfc8af84eb6d3d3b3ec5c48adf35 Mon Sep 17 00:00:00 2001 From: supmo668 Date: Tue, 18 Nov 2025 01:21:54 -0800 Subject: [PATCH 1/2] fix: OpenAIRerankerClient now properly supports AzureOpenAILLMClient - Updated type detection to use BaseOpenAIClient instead of OpenAIClient - This allows both OpenAIClient and AzureOpenAILLMClient to be properly unwrapped - Added comprehensive test coverage for all client types - Fixes #1006 --- .../cross_encoder/openai_reranker_client.py | 9 +- .../test_openai_reranker_client.py | 145 ++++++++++++++++++ 2 files changed, 150 insertions(+), 4 deletions(-) create mode 100644 tests/cross_encoder/test_openai_reranker_client.py diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index 2e6c5b2f..16f60065 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -22,7 +22,8 @@ import openai from openai import AsyncAzureOpenAI, AsyncOpenAI from ..helpers import semaphore_gather -from ..llm_client import LLMConfig, OpenAIClient, RateLimitError +from ..llm_client import LLMConfig, RateLimitError +from ..llm_client.openai_base_client import BaseOpenAIClient from ..prompts import Message from .client import CrossEncoderClient @@ -35,7 +36,7 @@ class OpenAIRerankerClient(CrossEncoderClient): def __init__( self, config: LLMConfig | None = None, - client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None, + client: AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None = None, ): """ Initialize the OpenAIRerankerClient with the provided configuration and client. @@ -45,7 +46,7 @@ class OpenAIRerankerClient(CrossEncoderClient): Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. - client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. + client (AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. """ if config is None: config = LLMConfig() @@ -53,7 +54,7 @@ class OpenAIRerankerClient(CrossEncoderClient): self.config = config if client is None: self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) - elif isinstance(client, OpenAIClient): + elif isinstance(client, BaseOpenAIClient): self.client = client.client else: self.client = client diff --git a/tests/cross_encoder/test_openai_reranker_client.py b/tests/cross_encoder/test_openai_reranker_client.py new file mode 100644 index 00000000..8dc3ae54 --- /dev/null +++ b/tests/cross_encoder/test_openai_reranker_client.py @@ -0,0 +1,145 @@ +""" +Test file for OpenAIRerankerClient, specifically testing compatibility with +both OpenAIClient and AzureOpenAILLMClient instances. + +This test validates the fix for issue #1006 where OpenAIRerankerClient +failed to properly support AzureOpenAILLMClient. +""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient +from graphiti_core.llm_client import LLMConfig +from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient +from graphiti_core.llm_client.openai_client import OpenAIClient + + +class MockAsyncOpenAI: + """Mock AsyncOpenAI client for testing""" + + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key + self.base_url = base_url + self.chat = MagicMock() + self.chat.completions = MagicMock() + self.chat.completions.create = AsyncMock() + + +class MockAsyncAzureOpenAI: + """Mock AsyncAzureOpenAI client for testing""" + + def __init__(self): + self.chat = MagicMock() + self.chat.completions = MagicMock() + self.chat.completions.create = AsyncMock() + + +@pytest.fixture +def mock_openai_client(): + """Fixture to create a mocked OpenAIClient""" + client = OpenAIClient(config=LLMConfig(api_key='test-key')) + # Replace the internal client with our mock + client.client = MockAsyncOpenAI() + return client + + +@pytest.fixture +def mock_azure_openai_client(): + """Fixture to create a mocked AzureOpenAILLMClient""" + mock_azure = MockAsyncAzureOpenAI() + client = AzureOpenAILLMClient( + azure_client=mock_azure, + config=LLMConfig(api_key='test-key') + ) + return client + + +def test_openai_reranker_accepts_openai_client(mock_openai_client): + """Test that OpenAIRerankerClient properly unwraps OpenAIClient""" + # Create reranker with OpenAIClient + reranker = OpenAIRerankerClient(client=mock_openai_client) + + # Verify the internal client is the unwrapped AsyncOpenAI instance + assert reranker.client == mock_openai_client.client + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_accepts_azure_client(mock_azure_openai_client): + """Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient + + This test validates the fix for issue #1006. + """ + # Create reranker with AzureOpenAILLMClient - this would fail before the fix + reranker = OpenAIRerankerClient(client=mock_azure_openai_client) + + # Verify the internal client is the unwrapped AsyncAzureOpenAI instance + assert reranker.client == mock_azure_openai_client.client + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_accepts_async_openai_directly(): + """Test that OpenAIRerankerClient accepts AsyncOpenAI directly""" + # Create a mock AsyncOpenAI + mock_async = MockAsyncOpenAI(api_key='test-key') + + # Create reranker with AsyncOpenAI directly + reranker = OpenAIRerankerClient(client=mock_async) + + # Verify the internal client is used as-is + assert reranker.client == mock_async + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_creates_default_client(): + """Test that OpenAIRerankerClient creates a default client when none provided""" + config = LLMConfig(api_key='test-key') + + # Create reranker without client + reranker = OpenAIRerankerClient(config=config) + + # Verify a client was created + assert reranker.client is not None + # The default should be an AsyncOpenAI instance + from openai import AsyncOpenAI + assert isinstance(reranker.client, AsyncOpenAI) + + +@pytest.mark.asyncio +async def test_rank_method_with_azure_client(mock_azure_openai_client): + """Test that rank method works correctly with AzureOpenAILLMClient""" + # Setup mock response for the chat completions + mock_response = SimpleNamespace( + choices=[ + SimpleNamespace( + logprobs=SimpleNamespace( + content=[ + SimpleNamespace( + top_logprobs=[ + SimpleNamespace(token='True', logprob=-0.5) + ] + ) + ] + ) + ) + ] + ) + + mock_azure_openai_client.client.chat.completions.create.return_value = mock_response + + # Create reranker with AzureOpenAILLMClient + reranker = OpenAIRerankerClient(client=mock_azure_openai_client) + + # Test ranking + query = "test query" + passages = ["passage 1"] + + # This would previously fail with AttributeError before the fix + results = await reranker.rank(query, passages) + + # Verify the method was called + assert mock_azure_openai_client.client.chat.completions.create.called + assert len(results) == 1 + assert results[0][0] == "passage 1" From 9d9ae1fa7646675b9e0b3bd8b3f681aa75727f77 Mon Sep 17 00:00:00 2001 From: supmo668 Date: Sun, 23 Nov 2025 17:13:33 -0800 Subject: [PATCH 2/2] fix: Add client attribute type annotation to BaseOpenAIClient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes pyright type checking error by declaring the client attribute in the base class. All concrete implementations (OpenAIClient, AzureOpenAILLMClient, OpenAIGenericClient) initialize this attribute, and it needs to be accessible for type checking in OpenAIRerankerClient. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- graphiti_core/llm_client/openai_base_client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/graphiti_core/llm_client/openai_base_client.py b/graphiti_core/llm_client/openai_base_client.py index 93e9c598..25927270 100644 --- a/graphiti_core/llm_client/openai_base_client.py +++ b/graphiti_core/llm_client/openai_base_client.py @@ -21,6 +21,7 @@ from abc import abstractmethod from typing import Any, ClassVar import openai +from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel @@ -48,6 +49,9 @@ class BaseOpenAIClient(LLMClient): # Class-level constants MAX_RETRIES: ClassVar[int] = 2 + # Instance attribute (initialized in subclasses) + client: AsyncOpenAI | AsyncAzureOpenAI + def __init__( self, config: LLMConfig | None = None,