fix: OpenAIRerankerClient now properly supports AzureOpenAILLMClient
- Updated type detection to use BaseOpenAIClient instead of OpenAIClient - This allows both OpenAIClient and AzureOpenAILLMClient to be properly unwrapped - Added comprehensive test coverage for all client types - Fixes #1006
This commit is contained in:
parent
8b7ad6f84c
commit
6cb2f44849
2 changed files with 150 additions and 4 deletions
|
|
@ -22,7 +22,8 @@ import openai
|
||||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||||
|
|
||||||
from ..helpers import semaphore_gather
|
from ..helpers import semaphore_gather
|
||||||
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
|
from ..llm_client import LLMConfig, RateLimitError
|
||||||
|
from ..llm_client.openai_base_client import BaseOpenAIClient
|
||||||
from ..prompts import Message
|
from ..prompts import Message
|
||||||
from .client import CrossEncoderClient
|
from .client import CrossEncoderClient
|
||||||
|
|
||||||
|
|
@ -35,7 +36,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
|
client: AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
||||||
|
|
@ -45,7 +46,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||||
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
client (AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
|
|
@ -53,7 +54,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||||
self.config = config
|
self.config = config
|
||||||
if client is None:
|
if client is None:
|
||||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
elif isinstance(client, OpenAIClient):
|
elif isinstance(client, BaseOpenAIClient):
|
||||||
self.client = client.client
|
self.client = client.client
|
||||||
else:
|
else:
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
|
||||||
145
tests/cross_encoder/test_openai_reranker_client.py
Normal file
145
tests/cross_encoder/test_openai_reranker_client.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""
|
||||||
|
Test file for OpenAIRerankerClient, specifically testing compatibility with
|
||||||
|
both OpenAIClient and AzureOpenAILLMClient instances.
|
||||||
|
|
||||||
|
This test validates the fix for issue #1006 where OpenAIRerankerClient
|
||||||
|
failed to properly support AzureOpenAILLMClient.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||||
|
from graphiti_core.llm_client import LLMConfig
|
||||||
|
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||||
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncOpenAI:
|
||||||
|
"""Mock AsyncOpenAI client for testing"""
|
||||||
|
|
||||||
|
def __init__(self, api_key=None, base_url=None):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.chat = MagicMock()
|
||||||
|
self.chat.completions = MagicMock()
|
||||||
|
self.chat.completions.create = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncAzureOpenAI:
|
||||||
|
"""Mock AsyncAzureOpenAI client for testing"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.chat = MagicMock()
|
||||||
|
self.chat.completions = MagicMock()
|
||||||
|
self.chat.completions.create = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client():
|
||||||
|
"""Fixture to create a mocked OpenAIClient"""
|
||||||
|
client = OpenAIClient(config=LLMConfig(api_key='test-key'))
|
||||||
|
# Replace the internal client with our mock
|
||||||
|
client.client = MockAsyncOpenAI()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_azure_openai_client():
|
||||||
|
"""Fixture to create a mocked AzureOpenAILLMClient"""
|
||||||
|
mock_azure = MockAsyncAzureOpenAI()
|
||||||
|
client = AzureOpenAILLMClient(
|
||||||
|
azure_client=mock_azure,
|
||||||
|
config=LLMConfig(api_key='test-key')
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_reranker_accepts_openai_client(mock_openai_client):
|
||||||
|
"""Test that OpenAIRerankerClient properly unwraps OpenAIClient"""
|
||||||
|
# Create reranker with OpenAIClient
|
||||||
|
reranker = OpenAIRerankerClient(client=mock_openai_client)
|
||||||
|
|
||||||
|
# Verify the internal client is the unwrapped AsyncOpenAI instance
|
||||||
|
assert reranker.client == mock_openai_client.client
|
||||||
|
assert hasattr(reranker.client, 'chat')
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_reranker_accepts_azure_client(mock_azure_openai_client):
|
||||||
|
"""Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient
|
||||||
|
|
||||||
|
This test validates the fix for issue #1006.
|
||||||
|
"""
|
||||||
|
# Create reranker with AzureOpenAILLMClient - this would fail before the fix
|
||||||
|
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
|
||||||
|
|
||||||
|
# Verify the internal client is the unwrapped AsyncAzureOpenAI instance
|
||||||
|
assert reranker.client == mock_azure_openai_client.client
|
||||||
|
assert hasattr(reranker.client, 'chat')
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_reranker_accepts_async_openai_directly():
|
||||||
|
"""Test that OpenAIRerankerClient accepts AsyncOpenAI directly"""
|
||||||
|
# Create a mock AsyncOpenAI
|
||||||
|
mock_async = MockAsyncOpenAI(api_key='test-key')
|
||||||
|
|
||||||
|
# Create reranker with AsyncOpenAI directly
|
||||||
|
reranker = OpenAIRerankerClient(client=mock_async)
|
||||||
|
|
||||||
|
# Verify the internal client is used as-is
|
||||||
|
assert reranker.client == mock_async
|
||||||
|
assert hasattr(reranker.client, 'chat')
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_reranker_creates_default_client():
|
||||||
|
"""Test that OpenAIRerankerClient creates a default client when none provided"""
|
||||||
|
config = LLMConfig(api_key='test-key')
|
||||||
|
|
||||||
|
# Create reranker without client
|
||||||
|
reranker = OpenAIRerankerClient(config=config)
|
||||||
|
|
||||||
|
# Verify a client was created
|
||||||
|
assert reranker.client is not None
|
||||||
|
# The default should be an AsyncOpenAI instance
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
assert isinstance(reranker.client, AsyncOpenAI)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_method_with_azure_client(mock_azure_openai_client):
|
||||||
|
"""Test that rank method works correctly with AzureOpenAILLMClient"""
|
||||||
|
# Setup mock response for the chat completions
|
||||||
|
mock_response = SimpleNamespace(
|
||||||
|
choices=[
|
||||||
|
SimpleNamespace(
|
||||||
|
logprobs=SimpleNamespace(
|
||||||
|
content=[
|
||||||
|
SimpleNamespace(
|
||||||
|
top_logprobs=[
|
||||||
|
SimpleNamespace(token='True', logprob=-0.5)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_azure_openai_client.client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
# Create reranker with AzureOpenAILLMClient
|
||||||
|
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
|
||||||
|
|
||||||
|
# Test ranking
|
||||||
|
query = "test query"
|
||||||
|
passages = ["passage 1"]
|
||||||
|
|
||||||
|
# This would previously fail with AttributeError before the fix
|
||||||
|
results = await reranker.rank(query, passages)
|
||||||
|
|
||||||
|
# Verify the method was called
|
||||||
|
assert mock_azure_openai_client.client.chat.completions.create.called
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0][0] == "passage 1"
|
||||||
Loading…
Add table
Reference in a new issue