- Updated type detection to use BaseOpenAIClient instead of OpenAIClient - This allows both OpenAIClient and AzureOpenAILLMClient to be properly unwrapped - Added comprehensive test coverage for all client types - Fixes #1006
145 lines
4.8 KiB
Python
145 lines
4.8 KiB
Python
"""
|
|
Test file for OpenAIRerankerClient, specifically testing compatibility with
|
|
both OpenAIClient and AzureOpenAILLMClient instances.
|
|
|
|
This test validates the fix for issue #1006 where OpenAIRerankerClient
|
|
failed to properly support AzureOpenAILLMClient.
|
|
"""
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
from graphiti_core.llm_client import LLMConfig
|
|
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
|
|
|
|
|
class MockAsyncOpenAI:
|
|
"""Mock AsyncOpenAI client for testing"""
|
|
|
|
def __init__(self, api_key=None, base_url=None):
|
|
self.api_key = api_key
|
|
self.base_url = base_url
|
|
self.chat = MagicMock()
|
|
self.chat.completions = MagicMock()
|
|
self.chat.completions.create = AsyncMock()
|
|
|
|
|
|
class MockAsyncAzureOpenAI:
|
|
"""Mock AsyncAzureOpenAI client for testing"""
|
|
|
|
def __init__(self):
|
|
self.chat = MagicMock()
|
|
self.chat.completions = MagicMock()
|
|
self.chat.completions.create = AsyncMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_openai_client():
|
|
"""Fixture to create a mocked OpenAIClient"""
|
|
client = OpenAIClient(config=LLMConfig(api_key='test-key'))
|
|
# Replace the internal client with our mock
|
|
client.client = MockAsyncOpenAI()
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_azure_openai_client():
|
|
"""Fixture to create a mocked AzureOpenAILLMClient"""
|
|
mock_azure = MockAsyncAzureOpenAI()
|
|
client = AzureOpenAILLMClient(
|
|
azure_client=mock_azure,
|
|
config=LLMConfig(api_key='test-key')
|
|
)
|
|
return client
|
|
|
|
|
|
def test_openai_reranker_accepts_openai_client(mock_openai_client):
|
|
"""Test that OpenAIRerankerClient properly unwraps OpenAIClient"""
|
|
# Create reranker with OpenAIClient
|
|
reranker = OpenAIRerankerClient(client=mock_openai_client)
|
|
|
|
# Verify the internal client is the unwrapped AsyncOpenAI instance
|
|
assert reranker.client == mock_openai_client.client
|
|
assert hasattr(reranker.client, 'chat')
|
|
|
|
|
|
def test_openai_reranker_accepts_azure_client(mock_azure_openai_client):
|
|
"""Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient
|
|
|
|
This test validates the fix for issue #1006.
|
|
"""
|
|
# Create reranker with AzureOpenAILLMClient - this would fail before the fix
|
|
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
|
|
|
|
# Verify the internal client is the unwrapped AsyncAzureOpenAI instance
|
|
assert reranker.client == mock_azure_openai_client.client
|
|
assert hasattr(reranker.client, 'chat')
|
|
|
|
|
|
def test_openai_reranker_accepts_async_openai_directly():
|
|
"""Test that OpenAIRerankerClient accepts AsyncOpenAI directly"""
|
|
# Create a mock AsyncOpenAI
|
|
mock_async = MockAsyncOpenAI(api_key='test-key')
|
|
|
|
# Create reranker with AsyncOpenAI directly
|
|
reranker = OpenAIRerankerClient(client=mock_async)
|
|
|
|
# Verify the internal client is used as-is
|
|
assert reranker.client == mock_async
|
|
assert hasattr(reranker.client, 'chat')
|
|
|
|
|
|
def test_openai_reranker_creates_default_client():
|
|
"""Test that OpenAIRerankerClient creates a default client when none provided"""
|
|
config = LLMConfig(api_key='test-key')
|
|
|
|
# Create reranker without client
|
|
reranker = OpenAIRerankerClient(config=config)
|
|
|
|
# Verify a client was created
|
|
assert reranker.client is not None
|
|
# The default should be an AsyncOpenAI instance
|
|
from openai import AsyncOpenAI
|
|
assert isinstance(reranker.client, AsyncOpenAI)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_method_with_azure_client(mock_azure_openai_client):
|
|
"""Test that rank method works correctly with AzureOpenAILLMClient"""
|
|
# Setup mock response for the chat completions
|
|
mock_response = SimpleNamespace(
|
|
choices=[
|
|
SimpleNamespace(
|
|
logprobs=SimpleNamespace(
|
|
content=[
|
|
SimpleNamespace(
|
|
top_logprobs=[
|
|
SimpleNamespace(token='True', logprob=-0.5)
|
|
]
|
|
)
|
|
]
|
|
)
|
|
)
|
|
]
|
|
)
|
|
|
|
mock_azure_openai_client.client.chat.completions.create.return_value = mock_response
|
|
|
|
# Create reranker with AzureOpenAILLMClient
|
|
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
|
|
|
|
# Test ranking
|
|
query = "test query"
|
|
passages = ["passage 1"]
|
|
|
|
# This would previously fail with AttributeError before the fix
|
|
results = await reranker.rank(query, passages)
|
|
|
|
# Verify the method was called
|
|
assert mock_azure_openai_client.client.chat.completions.create.called
|
|
assert len(results) == 1
|
|
assert results[0][0] == "passage 1"
|