This commit is contained in:
Matthew Mo 2025-12-09 20:19:54 -08:00 committed by GitHub
commit c7cabf3ab1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 154 additions and 4 deletions

View file

@ -22,7 +22,8 @@ import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI
from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
from ..llm_client import LLMConfig, RateLimitError
from ..llm_client.openai_base_client import BaseOpenAIClient
from ..prompts import Message
from .client import CrossEncoderClient
@ -35,7 +36,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
def __init__(
self,
config: LLMConfig | None = None,
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
client: AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None = None,
):
"""
Initialize the OpenAIRerankerClient with the provided configuration and client.
@ -45,7 +46,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
client (AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
if config is None:
config = LLMConfig()
@ -53,7 +54,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
self.config = config
if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
elif isinstance(client, OpenAIClient):
elif isinstance(client, BaseOpenAIClient):
self.client = client.client
else:
self.client = client

View file

@ -21,6 +21,7 @@ from abc import abstractmethod
from typing import Any, ClassVar
import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
@ -48,6 +49,9 @@ class BaseOpenAIClient(LLMClient):
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
# Instance attribute (initialized in subclasses)
client: AsyncOpenAI | AsyncAzureOpenAI
def __init__(
self,
config: LLMConfig | None = None,

View file

@ -0,0 +1,145 @@
"""
Test file for OpenAIRerankerClient, specifically testing compatibility with
both OpenAIClient and AzureOpenAILLMClient instances.
This test validates the fix for issue #1006 where OpenAIRerankerClient
failed to properly support AzureOpenAILLMClient.
"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.llm_client import LLMConfig
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.openai_client import OpenAIClient
class MockAsyncOpenAI:
"""Mock AsyncOpenAI client for testing"""
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url
self.chat = MagicMock()
self.chat.completions = MagicMock()
self.chat.completions.create = AsyncMock()
class MockAsyncAzureOpenAI:
"""Mock AsyncAzureOpenAI client for testing"""
def __init__(self):
self.chat = MagicMock()
self.chat.completions = MagicMock()
self.chat.completions.create = AsyncMock()
@pytest.fixture
def mock_openai_client():
"""Fixture to create a mocked OpenAIClient"""
client = OpenAIClient(config=LLMConfig(api_key='test-key'))
# Replace the internal client with our mock
client.client = MockAsyncOpenAI()
return client
@pytest.fixture
def mock_azure_openai_client():
"""Fixture to create a mocked AzureOpenAILLMClient"""
mock_azure = MockAsyncAzureOpenAI()
client = AzureOpenAILLMClient(
azure_client=mock_azure,
config=LLMConfig(api_key='test-key')
)
return client
def test_openai_reranker_accepts_openai_client(mock_openai_client):
"""Test that OpenAIRerankerClient properly unwraps OpenAIClient"""
# Create reranker with OpenAIClient
reranker = OpenAIRerankerClient(client=mock_openai_client)
# Verify the internal client is the unwrapped AsyncOpenAI instance
assert reranker.client == mock_openai_client.client
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_accepts_azure_client(mock_azure_openai_client):
"""Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient
This test validates the fix for issue #1006.
"""
# Create reranker with AzureOpenAILLMClient - this would fail before the fix
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
# Verify the internal client is the unwrapped AsyncAzureOpenAI instance
assert reranker.client == mock_azure_openai_client.client
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_accepts_async_openai_directly():
"""Test that OpenAIRerankerClient accepts AsyncOpenAI directly"""
# Create a mock AsyncOpenAI
mock_async = MockAsyncOpenAI(api_key='test-key')
# Create reranker with AsyncOpenAI directly
reranker = OpenAIRerankerClient(client=mock_async)
# Verify the internal client is used as-is
assert reranker.client == mock_async
assert hasattr(reranker.client, 'chat')
def test_openai_reranker_creates_default_client():
"""Test that OpenAIRerankerClient creates a default client when none provided"""
config = LLMConfig(api_key='test-key')
# Create reranker without client
reranker = OpenAIRerankerClient(config=config)
# Verify a client was created
assert reranker.client is not None
# The default should be an AsyncOpenAI instance
from openai import AsyncOpenAI
assert isinstance(reranker.client, AsyncOpenAI)
@pytest.mark.asyncio
async def test_rank_method_with_azure_client(mock_azure_openai_client):
"""Test that rank method works correctly with AzureOpenAILLMClient"""
# Setup mock response for the chat completions
mock_response = SimpleNamespace(
choices=[
SimpleNamespace(
logprobs=SimpleNamespace(
content=[
SimpleNamespace(
top_logprobs=[
SimpleNamespace(token='True', logprob=-0.5)
]
)
]
)
)
]
)
mock_azure_openai_client.client.chat.completions.create.return_value = mock_response
# Create reranker with AzureOpenAILLMClient
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)
# Test ranking
query = "test query"
passages = ["passage 1"]
# This would previously fail with AttributeError before the fix
results = await reranker.rank(query, passages)
# Verify the method was called
assert mock_azure_openai_client.client.chat.completions.create.called
assert len(results) == 1
assert results[0][0] == "passage 1"