fix: OpenAIRerankerClient now properly supports AzureOpenAILLMClient
- Updated type detection to use BaseOpenAIClient instead of OpenAIClient - This allows both OpenAIClient and AzureOpenAILLMClient to be properly unwrapped - Added comprehensive test coverage for all client types - Fixes #1006
This commit is contained in:
parent
8b7ad6f84c
commit
6cb2f44849
2 changed files with 150 additions and 4 deletions
|
|
@ -22,7 +22,8 @@ import openai
|
|||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
from ..helpers import semaphore_gather
|
||||
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
|
||||
from ..llm_client import LLMConfig, RateLimitError
|
||||
from ..llm_client.openai_base_client import BaseOpenAIClient
|
||||
from ..prompts import Message
|
||||
from .client import CrossEncoderClient
|
||||
|
||||
|
|
@ -35,7 +36,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|||
def __init__(
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
|
||||
client: AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
||||
|
|
@ -45,7 +46,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|||
|
||||
Args:
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||
client (AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||
"""
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
|
@ -53,7 +54,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|||
self.config = config
|
||||
if client is None:
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
elif isinstance(client, OpenAIClient):
|
||||
elif isinstance(client, BaseOpenAIClient):
|
||||
self.client = client.client
|
||||
else:
|
||||
self.client = client
|
||||
|
|
|
|||
145
tests/cross_encoder/test_openai_reranker_client.py
Normal file
145
tests/cross_encoder/test_openai_reranker_client.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
Test file for OpenAIRerankerClient, specifically testing compatibility with
|
||||
both OpenAIClient and AzureOpenAILLMClient instances.
|
||||
|
||||
This test validates the fix for issue #1006 where OpenAIRerankerClient
|
||||
failed to properly support AzureOpenAILLMClient.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.llm_client import LLMConfig
|
||||
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||
from graphiti_core.llm_client.openai_client import OpenAIClient
|
||||
|
||||
|
||||
class MockAsyncOpenAI:
|
||||
"""Mock AsyncOpenAI client for testing"""
|
||||
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.chat = MagicMock()
|
||||
self.chat.completions = MagicMock()
|
||||
self.chat.completions.create = AsyncMock()
|
||||
|
||||
|
||||
class MockAsyncAzureOpenAI:
|
||||
"""Mock AsyncAzureOpenAI client for testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.chat = MagicMock()
|
||||
self.chat.completions = MagicMock()
|
||||
self.chat.completions.create = AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
"""Fixture to create a mocked OpenAIClient"""
|
||||
client = OpenAIClient(config=LLMConfig(api_key='test-key'))
|
||||
# Replace the internal client with our mock
|
||||
client.client = MockAsyncOpenAI()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_openai_client():
|
||||
"""Fixture to create a mocked AzureOpenAILLMClient"""
|
||||
mock_azure = MockAsyncAzureOpenAI()
|
||||
client = AzureOpenAILLMClient(
|
||||
azure_client=mock_azure,
|
||||
config=LLMConfig(api_key='test-key')
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def test_openai_reranker_accepts_openai_client(mock_openai_client):
|
||||
"""Test that OpenAIRerankerClient properly unwraps OpenAIClient"""
|
||||
# Create reranker with OpenAIClient
|
||||
reranker = OpenAIRerankerClient(client=mock_openai_client)
|
||||
|
||||
# Verify the internal client is the unwrapped AsyncOpenAI instance
|
||||
assert reranker.client == mock_openai_client.client
|
||||
assert hasattr(reranker.client, 'chat')
|
||||
|
||||
|
||||
def test_openai_reranker_accepts_azure_client(mock_azure_openai_client):
|
||||
"""Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient
|
||||
|
||||
This test validates the fix for issue #1006.
|
||||
"""
|
||||
# Create reranker with AzureOpenAILLMClient - this would fail before the fix
|
||||
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
|
||||
|
||||
# Verify the internal client is the unwrapped AsyncAzureOpenAI instance
|
||||
assert reranker.client == mock_azure_openai_client.client
|
||||
assert hasattr(reranker.client, 'chat')
|
||||
|
||||
|
||||
def test_openai_reranker_accepts_async_openai_directly():
|
||||
"""Test that OpenAIRerankerClient accepts AsyncOpenAI directly"""
|
||||
# Create a mock AsyncOpenAI
|
||||
mock_async = MockAsyncOpenAI(api_key='test-key')
|
||||
|
||||
# Create reranker with AsyncOpenAI directly
|
||||
reranker = OpenAIRerankerClient(client=mock_async)
|
||||
|
||||
# Verify the internal client is used as-is
|
||||
assert reranker.client == mock_async
|
||||
assert hasattr(reranker.client, 'chat')
|
||||
|
||||
|
||||
def test_openai_reranker_creates_default_client():
|
||||
"""Test that OpenAIRerankerClient creates a default client when none provided"""
|
||||
config = LLMConfig(api_key='test-key')
|
||||
|
||||
# Create reranker without client
|
||||
reranker = OpenAIRerankerClient(config=config)
|
||||
|
||||
# Verify a client was created
|
||||
assert reranker.client is not None
|
||||
# The default should be an AsyncOpenAI instance
|
||||
from openai import AsyncOpenAI
|
||||
assert isinstance(reranker.client, AsyncOpenAI)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_method_with_azure_client(mock_azure_openai_client):
|
||||
"""Test that rank method works correctly with AzureOpenAILLMClient"""
|
||||
# Setup mock response for the chat completions
|
||||
mock_response = SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
logprobs=SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
top_logprobs=[
|
||||
SimpleNamespace(token='True', logprob=-0.5)
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
mock_azure_openai_client.client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Create reranker with AzureOpenAILLMClient
|
||||
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
|
||||
|
||||
# Test ranking
|
||||
query = "test query"
|
||||
passages = ["passage 1"]
|
||||
|
||||
# This would previously fail with AttributeError before the fix
|
||||
results = await reranker.rank(query, passages)
|
||||
|
||||
# Verify the method was called
|
||||
assert mock_azure_openai_client.client.chat.completions.create.called
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == "passage 1"
|
||||
Loading…
Add table
Reference in a new issue