This commit is contained in:
Matthew Mo 2025-12-09 20:19:54 -08:00 committed by GitHub
commit c7cabf3ab1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 154 additions and 4 deletions

View file

@ -22,7 +22,8 @@ import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI
from ..helpers import semaphore_gather from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError from ..llm_client import LLMConfig, RateLimitError
from ..llm_client.openai_base_client import BaseOpenAIClient
from ..prompts import Message from ..prompts import Message
from .client import CrossEncoderClient from .client import CrossEncoderClient
@ -35,7 +36,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
def __init__( def __init__(
self, self,
config: LLMConfig | None = None, config: LLMConfig | None = None,
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None, client: AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None = None,
): ):
""" """
Initialize the OpenAIRerankerClient with the provided configuration and client. Initialize the OpenAIRerankerClient with the provided configuration and client.
@ -45,7 +46,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
Args: Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. client (AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
""" """
if config is None: if config is None:
config = LLMConfig() config = LLMConfig()
@ -53,7 +54,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
self.config = config self.config = config
if client is None: if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
elif isinstance(client, OpenAIClient): elif isinstance(client, BaseOpenAIClient):
self.client = client.client self.client = client.client
else: else:
self.client = client self.client = client

View file

@ -21,6 +21,7 @@ from abc import abstractmethod
from typing import Any, ClassVar from typing import Any, ClassVar
import openai import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel from pydantic import BaseModel
@ -48,6 +49,9 @@ class BaseOpenAIClient(LLMClient):
# Class-level constants # Class-level constants
MAX_RETRIES: ClassVar[int] = 2 MAX_RETRIES: ClassVar[int] = 2
# Instance attribute (initialized in subclasses)
client: AsyncOpenAI | AsyncAzureOpenAI
def __init__( def __init__(
self, self,
config: LLMConfig | None = None, config: LLMConfig | None = None,

View file

@ -0,0 +1,145 @@
"""
Test file for OpenAIRerankerClient, specifically testing compatibility with
both OpenAIClient and AzureOpenAILLMClient instances.
This test validates the fix for issue #1006 where OpenAIRerankerClient
failed to properly support AzureOpenAILLMClient.
"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.llm_client import LLMConfig
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.openai_client import OpenAIClient
class MockAsyncOpenAI:
"""Mock AsyncOpenAI client for testing"""
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url
self.chat = MagicMock()
self.chat.completions = MagicMock()
self.chat.completions.create = AsyncMock()
class MockAsyncAzureOpenAI:
"""Mock AsyncAzureOpenAI client for testing"""
def __init__(self):
self.chat = MagicMock()
self.chat.completions = MagicMock()
self.chat.completions.create = AsyncMock()
@pytest.fixture
def mock_openai_client():
"""Fixture to create a mocked OpenAIClient"""
client = OpenAIClient(config=LLMConfig(api_key='test-key'))
# Replace the internal client with our mock
client.client = MockAsyncOpenAI()
return client
@pytest.fixture
def mock_azure_openai_client():
"""Fixture to create a mocked AzureOpenAILLMClient"""
mock_azure = MockAsyncAzureOpenAI()
client = AzureOpenAILLMClient(
azure_client=mock_azure,
config=LLMConfig(api_key='test-key')
)
return client
def test_openai_reranker_accepts_openai_client(mock_openai_client):
"""Test that OpenAIRerankerClient properly unwraps OpenAIClient"""
# Create reranker with OpenAIClient
reranker = OpenAIRerankerClient(client=mock_openai_client)
# Verify the internal client is the unwrapped AsyncOpenAI instance
assert reranker.client == mock_openai_client.client
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_accepts_azure_client(mock_azure_openai_client):
"""Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient
This test validates the fix for issue #1006.
"""
# Create reranker with AzureOpenAILLMClient - this would fail before the fix
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
# Verify the internal client is the unwrapped AsyncAzureOpenAI instance
assert reranker.client == mock_azure_openai_client.client
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_accepts_async_openai_directly():
"""Test that OpenAIRerankerClient accepts AsyncOpenAI directly"""
# Create a mock AsyncOpenAI
mock_async = MockAsyncOpenAI(api_key='test-key')
# Create reranker with AsyncOpenAI directly
reranker = OpenAIRerankerClient(client=mock_async)
# Verify the internal client is used as-is
assert reranker.client == mock_async
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_creates_default_client():
"""Test that OpenAIRerankerClient creates a default client when none provided"""
config = LLMConfig(api_key='test-key')
# Create reranker without client
reranker = OpenAIRerankerClient(config=config)
# Verify a client was created
assert reranker.client is not None
# The default should be an AsyncOpenAI instance
from openai import AsyncOpenAI
assert isinstance(reranker.client, AsyncOpenAI)
@pytest.mark.asyncio
async def test_rank_method_with_azure_client(mock_azure_openai_client):
"""Test that rank method works correctly with AzureOpenAILLMClient"""
# Setup mock response for the chat completions
mock_response = SimpleNamespace(
choices=[
SimpleNamespace(
logprobs=SimpleNamespace(
content=[
SimpleNamespace(
top_logprobs=[
SimpleNamespace(token='True', logprob=-0.5)
]
)
]
)
)
]
)
mock_azure_openai_client.client.chat.completions.create.return_value = mock_response
# Create reranker with AzureOpenAILLMClient
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
# Test ranking
query = "test query"
passages = ["passage 1"]
# This would previously fail with AttributeError before the fix
results = await reranker.rank(query, passages)
# Verify the method was called
assert mock_azure_openai_client.client.chat.completions.create.called
assert len(results) == 1
assert results[0][0] == "passage 1"