graphiti/tests/cross_encoder/test_openai_reranker_client.py
supmo668 6cb2f44849 fix: OpenAIRerankerClient now properly supports AzureOpenAILLMClient
- Updated type detection to use BaseOpenAIClient instead of OpenAIClient
- This allows both OpenAIClient and AzureOpenAILLMClient to be properly unwrapped
- Added comprehensive test coverage for all client types
- Fixes #1006
2025-11-18 01:21:54 -08:00

145 lines
4.8 KiB
Python

"""
Test file for OpenAIRerankerClient, specifically testing compatibility with
both OpenAIClient and AzureOpenAILLMClient instances.
This test validates the fix for issue #1006 where OpenAIRerankerClient
failed to properly support AzureOpenAILLMClient.
"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.llm_client import LLMConfig
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.openai_client import OpenAIClient
class MockAsyncOpenAI:
"""Mock AsyncOpenAI client for testing"""
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url
self.chat = MagicMock()
self.chat.completions = MagicMock()
self.chat.completions.create = AsyncMock()
class MockAsyncAzureOpenAI:
"""Mock AsyncAzureOpenAI client for testing"""
def __init__(self):
self.chat = MagicMock()
self.chat.completions = MagicMock()
self.chat.completions.create = AsyncMock()
@pytest.fixture
def mock_openai_client():
"""Fixture to create a mocked OpenAIClient"""
client = OpenAIClient(config=LLMConfig(api_key='test-key'))
# Replace the internal client with our mock
client.client = MockAsyncOpenAI()
return client
@pytest.fixture
def mock_azure_openai_client():
"""Fixture to create a mocked AzureOpenAILLMClient"""
mock_azure = MockAsyncAzureOpenAI()
client = AzureOpenAILLMClient(
azure_client=mock_azure,
config=LLMConfig(api_key='test-key')
)
return client
def test_openai_reranker_accepts_openai_client(mock_openai_client):
"""Test that OpenAIRerankerClient properly unwraps OpenAIClient"""
# Create reranker with OpenAIClient
reranker = OpenAIRerankerClient(client=mock_openai_client)
# Verify the internal client is the unwrapped AsyncOpenAI instance
assert reranker.client == mock_openai_client.client
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_accepts_azure_client(mock_azure_openai_client):
"""Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient
This test validates the fix for issue #1006.
"""
# Create reranker with AzureOpenAILLMClient - this would fail before the fix
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
# Verify the internal client is the unwrapped AsyncAzureOpenAI instance
assert reranker.client == mock_azure_openai_client.client
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_accepts_async_openai_directly():
"""Test that OpenAIRerankerClient accepts AsyncOpenAI directly"""
# Create a mock AsyncOpenAI
mock_async = MockAsyncOpenAI(api_key='test-key')
# Create reranker with AsyncOpenAI directly
reranker = OpenAIRerankerClient(client=mock_async)
# Verify the internal client is used as-is
assert reranker.client == mock_async
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_creates_default_client():
"""Test that OpenAIRerankerClient creates a default client when none provided"""
config = LLMConfig(api_key='test-key')
# Create reranker without client
reranker = OpenAIRerankerClient(config=config)
# Verify a client was created
assert reranker.client is not None
# The default should be an AsyncOpenAI instance
from openai import AsyncOpenAI
assert isinstance(reranker.client, AsyncOpenAI)
@pytest.mark.asyncio
async def test_rank_method_with_azure_client(mock_azure_openai_client):
"""Test that rank method works correctly with AzureOpenAILLMClient"""
# Setup mock response for the chat completions
mock_response = SimpleNamespace(
choices=[
SimpleNamespace(
logprobs=SimpleNamespace(
content=[
SimpleNamespace(
top_logprobs=[
SimpleNamespace(token='True', logprob=-0.5)
]
)
]
)
)
]
)
mock_azure_openai_client.client.chat.completions.create.return_value = mock_response
# Create reranker with AzureOpenAILLMClient
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
# Test ranking
query = "test query"
passages = ["passage 1"]
# This would previously fail with AttributeError before the fix
results = await reranker.rank(query, passages)
# Verify the method was called
assert mock_azure_openai_client.client.chat.completions.create.called
assert len(results) == 1
assert results[0][0] == "passage 1"