Gemini client improvements; Gemini reranker (#645)
* add support for Gemini 2.5 model thinking budget * allow adding thinking config to support current and future gemini models * merge * improve client; add reranker * refactor: change type hint for gemini_messages to Any for flexibility * refactor: update GeminiRerankerClient to use direct relevance scoring and improve ranking logic. Add tests * fix fixtures --------- Co-authored-by: realugbun <github.disorder751@passmail.net>
This commit is contained in:
parent
daec70db65
commit
689d669559
7 changed files with 1335 additions and 68 deletions
14
README.md
14
README.md
|
|
@ -248,7 +248,6 @@ graphiti = Graphiti(
|
|||
),
|
||||
client=azure_openai_client
|
||||
),
|
||||
# Optional: Configure the OpenAI cross encoder with Azure OpenAI
|
||||
cross_encoder=OpenAIRerankerClient(
|
||||
llm_config=azure_llm_config,
|
||||
client=azure_openai_client
|
||||
|
|
@ -262,7 +261,7 @@ Make sure to replace the placeholder values with your actual Azure OpenAI creden
|
|||
|
||||
## Using Graphiti with Google Gemini
|
||||
|
||||
Graphiti supports Google's Gemini models for both LLM inference and embeddings. To use Gemini, you'll need to configure both the LLM client and embedder with your Google API key.
|
||||
Graphiti supports Google's Gemini models for LLM inference, embeddings, and cross-encoding/reranking. To use Gemini, you'll need to configure the LLM client, embedder, and the cross-encoder with your Google API key.
|
||||
|
||||
Install Graphiti:
|
||||
|
||||
|
|
@ -278,6 +277,7 @@ pip install "graphiti-core[google-genai]"
|
|||
from graphiti_core import Graphiti
|
||||
from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig
|
||||
from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig
|
||||
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
|
||||
|
||||
# Google API key configuration
|
||||
api_key = "<your-google-api-key>"
|
||||
|
|
@ -298,12 +298,20 @@ graphiti = Graphiti(
|
|||
api_key=api_key,
|
||||
embedding_model="embedding-001"
|
||||
)
|
||||
),
|
||||
cross_encoder=GeminiRerankerClient(
|
||||
config=LLMConfig(
|
||||
api_key=api_key,
|
||||
model="gemini-2.5-flash-lite-preview-06-17"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Now you can use Graphiti with Google Gemini
|
||||
# Now you can use Graphiti with Google Gemini for all components
|
||||
```
|
||||
|
||||
The Gemini reranker uses the `gemini-2.5-flash-lite-preview-06-17` model by default, which is optimized for cost-effective and low-latency classification tasks. It uses the same boolean classification approach as the OpenAI reranker, leveraging Gemini's log probabilities feature to rank passage relevance.
|
||||
|
||||
## Using Graphiti with Ollama (Local LLM)
|
||||
|
||||
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs.
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
from .client import CrossEncoderClient
|
||||
from .gemini_reranker_client import GeminiRerankerClient
|
||||
from .openai_reranker_client import OpenAIRerankerClient
|
||||
|
||||
__all__ = ['CrossEncoderClient', 'OpenAIRerankerClient']
|
||||
__all__ = ['CrossEncoderClient', 'GeminiRerankerClient', 'OpenAIRerankerClient']
|
||||
|
|
|
|||
146
graphiti_core/cross_encoder/gemini_reranker_client.py
Normal file
146
graphiti_core/cross_encoder/gemini_reranker_client.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
|
||||
from ..helpers import semaphore_gather
|
||||
from ..llm_client import LLMConfig, RateLimitError
|
||||
from .client import CrossEncoderClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
||||
|
||||
|
||||
class GeminiRerankerClient(CrossEncoderClient):
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
client: genai.Client | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the GeminiRerankerClient with the provided configuration and client.
|
||||
|
||||
The Gemini Developer API does not yet support logprobs. Unlike the OpenAI reranker,
|
||||
this reranker uses the Gemini API to perform direct relevance scoring of passages.
|
||||
Each passage is scored individually on a 0-100 scale.
|
||||
|
||||
Args:
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
|
||||
"""
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
||||
self.config = config
|
||||
if client is None:
|
||||
self.client = genai.Client(api_key=config.api_key)
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||
"""
|
||||
Rank passages based on their relevance to the query using direct scoring.
|
||||
|
||||
Each passage is scored individually on a 0-100 scale, then normalized to [0,1].
|
||||
"""
|
||||
if len(passages) <= 1:
|
||||
return [(passage, 1.0) for passage in passages]
|
||||
|
||||
# Generate scoring prompts for each passage
|
||||
scoring_prompts = []
|
||||
for passage in passages:
|
||||
prompt = f"""Rate how well this passage answers or relates to the query. Use a scale from 0 to 100.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Passage: {passage}
|
||||
|
||||
Provide only a number between 0 and 100 (no explanation, just the number):"""
|
||||
|
||||
scoring_prompts.append(
|
||||
[
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[types.Part.from_text(text=prompt)],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute all scoring requests concurrently - O(n) API calls
|
||||
responses = await semaphore_gather(
|
||||
*[
|
||||
self.client.aio.models.generate_content(
|
||||
model=self.config.model or DEFAULT_MODEL,
|
||||
contents=prompt_messages, # type: ignore
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction='You are an expert at rating passage relevance. Respond with only a number from 0-100.',
|
||||
temperature=0.0,
|
||||
max_output_tokens=3,
|
||||
),
|
||||
)
|
||||
for prompt_messages in scoring_prompts
|
||||
]
|
||||
)
|
||||
|
||||
# Extract scores and create results
|
||||
results = []
|
||||
for passage, response in zip(passages, responses, strict=True):
|
||||
try:
|
||||
if hasattr(response, 'text') and response.text:
|
||||
# Extract numeric score from response
|
||||
score_text = response.text.strip()
|
||||
# Handle cases where model might return non-numeric text
|
||||
score_match = re.search(r'\b(\d{1,3})\b', score_text)
|
||||
if score_match:
|
||||
score = float(score_match.group(1))
|
||||
# Normalize to [0, 1] range and clamp to valid range
|
||||
normalized_score = max(0.0, min(1.0, score / 100.0))
|
||||
results.append((passage, normalized_score))
|
||||
else:
|
||||
logger.warning(
|
||||
f'Could not extract numeric score from response: {score_text}'
|
||||
)
|
||||
results.append((passage, 0.0))
|
||||
else:
|
||||
logger.warning('Empty response from Gemini for passage scoring')
|
||||
results.append((passage, 0.0))
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f'Error parsing score from Gemini response: {e}')
|
||||
results.append((passage, 0.0))
|
||||
|
||||
# Sort by score in descending order (highest relevance first)
|
||||
results.sort(reverse=True, key=lambda x: x[1])
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's a rate limit error based on Gemini API error codes
|
||||
error_message = str(e).lower()
|
||||
if (
|
||||
'rate limit' in error_message
|
||||
or 'quota' in error_message
|
||||
or 'resource_exhausted' in error_message
|
||||
or '429' in str(e)
|
||||
):
|
||||
raise RateLimitError from e
|
||||
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
|
@ -17,19 +17,21 @@ limitations under the License.
|
|||
import json
|
||||
import logging
|
||||
import typing
|
||||
from typing import ClassVar
|
||||
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'gemini-2.0-flash'
|
||||
DEFAULT_MODEL = 'gemini-2.5-flash'
|
||||
DEFAULT_SMALL_MODEL = 'models/gemini-2.5-flash-lite-preview-06-17'
|
||||
|
||||
|
||||
class GeminiClient(LLMClient):
|
||||
|
|
@ -43,27 +45,34 @@ class GeminiClient(LLMClient):
|
|||
model (str): The model name to use for generating responses.
|
||||
temperature (float): The temperature to use for generating responses.
|
||||
max_tokens (int): The maximum number of tokens to generate in a response.
|
||||
|
||||
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
||||
Methods:
|
||||
__init__(config: LLMConfig | None = None, cache: bool = False):
|
||||
Initializes the GeminiClient with the provided configuration and cache setting.
|
||||
__init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
|
||||
Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
||||
|
||||
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
||||
Generates a response from the language model based on the provided messages.
|
||||
"""
|
||||
|
||||
# Class-level constants
|
||||
MAX_RETRIES: ClassVar[int] = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
cache: bool = False,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
thinking_config: types.ThinkingConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the GeminiClient with the provided configuration and cache setting.
|
||||
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
|
||||
|
||||
Args:
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
|
||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
|
||||
Only use with models that support thinking (gemini-2.5+). Defaults to None.
|
||||
|
||||
"""
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
|
@ -76,6 +85,50 @@ class GeminiClient(LLMClient):
|
|||
api_key=config.api_key,
|
||||
)
|
||||
self.max_tokens = max_tokens
|
||||
self.thinking_config = thinking_config
|
||||
|
||||
def _check_safety_blocks(self, response) -> None:
|
||||
"""Check if response was blocked for safety reasons and raise appropriate exceptions."""
|
||||
# Check if the response was blocked for safety reasons
|
||||
if not (hasattr(response, 'candidates') and response.candidates):
|
||||
return
|
||||
|
||||
candidate = response.candidates[0]
|
||||
if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
|
||||
return
|
||||
|
||||
# Content was blocked for safety reasons - collect safety details
|
||||
safety_info = []
|
||||
safety_ratings = getattr(candidate, 'safety_ratings', None)
|
||||
|
||||
if safety_ratings:
|
||||
for rating in safety_ratings:
|
||||
if getattr(rating, 'blocked', False):
|
||||
category = getattr(rating, 'category', 'Unknown')
|
||||
probability = getattr(rating, 'probability', 'Unknown')
|
||||
safety_info.append(f'{category}: {probability}')
|
||||
|
||||
safety_details = (
|
||||
', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
|
||||
)
|
||||
raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
|
||||
|
||||
def _check_prompt_blocks(self, response) -> None:
|
||||
"""Check if prompt was blocked and raise appropriate exceptions."""
|
||||
prompt_feedback = getattr(response, 'prompt_feedback', None)
|
||||
if not prompt_feedback:
|
||||
return
|
||||
|
||||
block_reason = getattr(prompt_feedback, 'block_reason', None)
|
||||
if block_reason:
|
||||
raise Exception(f'Prompt blocked by Gemini: {block_reason}')
|
||||
|
||||
def _get_model_for_size(self, model_size: ModelSize) -> str:
|
||||
"""Get the appropriate model name based on the requested size."""
|
||||
if model_size == ModelSize.small:
|
||||
return self.small_model or DEFAULT_SMALL_MODEL
|
||||
else:
|
||||
return self.model or DEFAULT_MODEL
|
||||
|
||||
async def _generate_response(
|
||||
self,
|
||||
|
|
@ -91,17 +144,17 @@ class GeminiClient(LLMClient):
|
|||
messages (list[Message]): A list of messages to send to the language model.
|
||||
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
||||
max_tokens (int): The maximum number of tokens to generate in the response.
|
||||
model_size (ModelSize): The size of the model to use (small or medium).
|
||||
|
||||
Returns:
|
||||
dict[str, typing.Any]: The response from the language model.
|
||||
|
||||
Raises:
|
||||
RateLimitError: If the API rate limit is exceeded.
|
||||
RefusalError: If the content is blocked by the model.
|
||||
Exception: If there is an error generating the response.
|
||||
Exception: If there is an error generating the response or content is blocked.
|
||||
"""
|
||||
try:
|
||||
gemini_messages: list[types.Content] = []
|
||||
gemini_messages: typing.Any = []
|
||||
# If a response model is provided, add schema for structured output
|
||||
system_prompt = ''
|
||||
if response_model is not None:
|
||||
|
|
@ -127,6 +180,9 @@ class GeminiClient(LLMClient):
|
|||
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
|
||||
)
|
||||
|
||||
# Get the appropriate model for the requested size
|
||||
model = self._get_model_for_size(model_size)
|
||||
|
||||
# Create generation config
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature=self.temperature,
|
||||
|
|
@ -134,15 +190,20 @@ class GeminiClient(LLMClient):
|
|||
response_mime_type='application/json' if response_model else None,
|
||||
response_schema=response_model if response_model else None,
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=self.thinking_config,
|
||||
)
|
||||
|
||||
# Generate content using the simple string approach
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=self.model or DEFAULT_MODEL,
|
||||
contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
|
||||
model=model,
|
||||
contents=gemini_messages,
|
||||
config=generation_config,
|
||||
)
|
||||
|
||||
# Check for safety and prompt blocks
|
||||
self._check_safety_blocks(response)
|
||||
self._check_prompt_blocks(response)
|
||||
|
||||
# If this was a structured output request, parse the response into the Pydantic model
|
||||
if response_model is not None:
|
||||
try:
|
||||
|
|
@ -160,9 +221,16 @@ class GeminiClient(LLMClient):
|
|||
return {'content': response.text}
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's a rate limit error
|
||||
if 'rate limit' in str(e).lower() or 'quota' in str(e).lower():
|
||||
# Check if it's a rate limit error based on Gemini API error codes
|
||||
error_message = str(e).lower()
|
||||
if (
|
||||
'rate limit' in error_message
|
||||
or 'quota' in error_message
|
||||
or 'resource_exhausted' in error_message
|
||||
or '429' in str(e)
|
||||
):
|
||||
raise RateLimitError from e
|
||||
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
||||
|
|
@ -174,13 +242,14 @@ class GeminiClient(LLMClient):
|
|||
model_size: ModelSize = ModelSize.medium,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""
|
||||
Generate a response from the Gemini language model.
|
||||
This method overrides the parent class method to provide a direct implementation.
|
||||
Generate a response from the Gemini language model with retry logic and error handling.
|
||||
This method overrides the parent class method to provide a direct implementation with advanced retry logic.
|
||||
|
||||
Args:
|
||||
messages (list[Message]): A list of messages to send to the language model.
|
||||
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
||||
max_tokens (int): The maximum number of tokens to generate in the response.
|
||||
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
||||
model_size (ModelSize): The size of the model to use (small or medium).
|
||||
|
||||
Returns:
|
||||
dict[str, typing.Any]: The response from the language model.
|
||||
|
|
@ -188,10 +257,53 @@ class GeminiClient(LLMClient):
|
|||
if max_tokens is None:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
# Call the internal _generate_response method
|
||||
return await self._generate_response(
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
max_tokens=max_tokens,
|
||||
model_size=model_size,
|
||||
)
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
# Add multilingual extraction instructions
|
||||
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
||||
|
||||
while retry_count <= self.MAX_RETRIES:
|
||||
try:
|
||||
response = await self._generate_response(
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
max_tokens=max_tokens,
|
||||
model_size=model_size,
|
||||
)
|
||||
return response
|
||||
except RateLimitError:
|
||||
# Rate limit errors should not trigger retries (fail fast)
|
||||
raise
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# Check if this is a safety block - these typically shouldn't be retried
|
||||
if 'safety' in str(e).lower() or 'blocked' in str(e).lower():
|
||||
logger.warning(f'Content blocked by safety filters: {e}')
|
||||
raise
|
||||
|
||||
# Don't retry if we've hit the max retries
|
||||
if retry_count >= self.MAX_RETRIES:
|
||||
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
||||
raise
|
||||
|
||||
retry_count += 1
|
||||
|
||||
# Construct a detailed error message for the LLM
|
||||
error_context = (
|
||||
f'The previous response attempt was invalid. '
|
||||
f'Error type: {e.__class__.__name__}. '
|
||||
f'Error details: {str(e)}. '
|
||||
f'Please try again with a valid response, ensuring the output matches '
|
||||
f'the expected format and constraints.'
|
||||
)
|
||||
|
||||
error_message = Message(role='user', content=error_context)
|
||||
messages.append(error_message)
|
||||
logger.warning(
|
||||
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
||||
)
|
||||
|
||||
# If we somehow get here, raise the last error
|
||||
raise last_error or Exception('Max retries exceeded with no specific error')
|
||||
|
|
|
|||
353
tests/cross_encoder/test_gemini_reranker_client.py
Normal file
353
tests/cross_encoder/test_gemini_reranker_client.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
|
||||
from graphiti_core.llm_client import LLMConfig, RateLimitError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_client():
|
||||
"""Fixture to mock the Google Gemini client."""
|
||||
with patch('google.genai.Client') as mock_client:
|
||||
# Setup mock instance and its methods
|
||||
mock_instance = mock_client.return_value
|
||||
mock_instance.aio = MagicMock()
|
||||
mock_instance.aio.models = MagicMock()
|
||||
mock_instance.aio.models.generate_content = AsyncMock()
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_reranker_client(mock_gemini_client):
|
||||
"""Fixture to create a GeminiRerankerClient with a mocked client."""
|
||||
config = LLMConfig(api_key='test_api_key', model='test-model')
|
||||
client = GeminiRerankerClient(config=config)
|
||||
# Replace the client's client with our mock to ensure we're using the mock
|
||||
client.client = mock_gemini_client
|
||||
return client
|
||||
|
||||
|
||||
def create_mock_response(score_text: str) -> MagicMock:
|
||||
"""Helper function to create a mock Gemini response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = score_text
|
||||
return mock_response
|
||||
|
||||
|
||||
class TestGeminiRerankerClientInitialization:
|
||||
"""Tests for GeminiRerankerClient initialization."""
|
||||
|
||||
def test_init_with_config(self):
|
||||
"""Test initialization with a config object."""
|
||||
config = LLMConfig(api_key='test_api_key', model='test-model')
|
||||
client = GeminiRerankerClient(config=config)
|
||||
|
||||
assert client.config == config
|
||||
|
||||
@patch('google.genai.Client')
|
||||
def test_init_without_config(self, mock_client):
|
||||
"""Test initialization without a config uses defaults."""
|
||||
client = GeminiRerankerClient()
|
||||
|
||||
assert client.config is not None
|
||||
|
||||
def test_init_with_custom_client(self):
|
||||
"""Test initialization with a custom client."""
|
||||
mock_client = MagicMock()
|
||||
client = GeminiRerankerClient(client=mock_client)
|
||||
|
||||
assert client.client == mock_client
|
||||
|
||||
|
||||
class TestGeminiRerankerClientRanking:
|
||||
"""Tests for GeminiRerankerClient rank method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test basic ranking functionality."""
|
||||
# Setup mock responses with different scores
|
||||
mock_responses = [
|
||||
create_mock_response('85'), # High relevance
|
||||
create_mock_response('45'), # Medium relevance
|
||||
create_mock_response('20'), # Low relevance
|
||||
]
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
|
||||
|
||||
# Test data
|
||||
query = 'What is the capital of France?'
|
||||
passages = [
|
||||
'Paris is the capital and most populous city of France.',
|
||||
'London is the capital city of England and the United Kingdom.',
|
||||
'Berlin is the capital and largest city of Germany.',
|
||||
]
|
||||
|
||||
# Call method
|
||||
result = await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
# Assertions
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(item, tuple) for item in result)
|
||||
assert all(
|
||||
isinstance(passage, str) and isinstance(score, float) for passage, score in result
|
||||
)
|
||||
|
||||
# Check scores are normalized to [0, 1] and sorted in descending order
|
||||
scores = [score for _, score in result]
|
||||
assert all(0.0 <= score <= 1.0 for score in scores)
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
# Check that the highest scoring passage is first
|
||||
assert result[0][1] == 0.85 # 85/100
|
||||
assert result[1][1] == 0.45 # 45/100
|
||||
assert result[2][1] == 0.20 # 20/100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_empty_passages(self, gemini_reranker_client):
|
||||
"""Test ranking with empty passages list."""
|
||||
query = 'Test query'
|
||||
passages = []
|
||||
|
||||
result = await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test ranking with a single passage."""
|
||||
# Setup mock response
|
||||
mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Single test passage']
|
||||
|
||||
result = await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == 'Single test passage'
|
||||
assert result[0][1] == 1.0 # Single passage gets full score
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_score_extraction_with_regex(
|
||||
self, gemini_reranker_client, mock_gemini_client
|
||||
):
|
||||
"""Test score extraction from various response formats."""
|
||||
# Setup mock responses with different formats
|
||||
mock_responses = [
|
||||
create_mock_response('Score: 90'), # Contains text before number
|
||||
create_mock_response('The relevance is 65 out of 100'), # Contains text around number
|
||||
create_mock_response('8'), # Just the number
|
||||
]
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2', 'Passage 3']
|
||||
|
||||
result = await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
# Check that scores were extracted correctly and normalized
|
||||
scores = [score for _, score in result]
|
||||
assert 0.90 in scores # 90/100
|
||||
assert 0.65 in scores # 65/100
|
||||
assert 0.08 in scores # 8/100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test handling of invalid or non-numeric scores."""
|
||||
# Setup mock responses with invalid scores
|
||||
mock_responses = [
|
||||
create_mock_response('Not a number'), # Invalid response
|
||||
create_mock_response(''), # Empty response
|
||||
create_mock_response('95'), # Valid response
|
||||
]
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2', 'Passage 3']
|
||||
|
||||
result = await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
# Check that invalid scores are handled gracefully (assigned 0.0)
|
||||
scores = [score for _, score in result]
|
||||
assert 0.95 in scores # Valid score
|
||||
assert scores.count(0.0) == 2 # Two invalid scores assigned 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test that scores are properly clamped to [0, 1] range."""
|
||||
# Setup mock responses with extreme scores
|
||||
# Note: regex only matches 1-3 digits, so negative numbers won't match
|
||||
mock_responses = [
|
||||
create_mock_response('999'), # Above 100 but within regex range
|
||||
create_mock_response('invalid'), # Invalid response becomes 0.0
|
||||
create_mock_response('50'), # Normal score
|
||||
]
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2', 'Passage 3']
|
||||
|
||||
result = await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
# Check that scores are normalized and clamped
|
||||
scores = [score for _, score in result]
|
||||
assert all(0.0 <= score <= 1.0 for score in scores)
|
||||
# 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
|
||||
assert 1.0 in scores
|
||||
# Invalid response should be 0.0
|
||||
assert 0.0 in scores
|
||||
# Normal score should be normalized (50/100 = 0.5)
|
||||
assert 0.5 in scores
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test handling of rate limit errors."""
|
||||
# Setup mock to raise rate limit error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
|
||||
'Rate limit exceeded'
|
||||
)
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2']
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test handling of quota errors."""
|
||||
# Setup mock to raise quota error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2']
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test handling of resource exhausted errors."""
|
||||
# Setup mock to raise resource exhausted error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2']
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test handling of HTTP 429 errors."""
|
||||
# Setup mock to raise 429 error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
|
||||
'HTTP 429 Too Many Requests'
|
||||
)
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2']
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test handling of generic errors."""
|
||||
# Setup mock to raise generic error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2']
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
assert 'Generic error' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test that multiple passages are scored concurrently."""
|
||||
# Setup mock responses
|
||||
mock_responses = [
|
||||
create_mock_response('80'),
|
||||
create_mock_response('60'),
|
||||
create_mock_response('40'),
|
||||
]
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
|
||||
|
||||
query = 'Test query'
|
||||
passages = ['Passage 1', 'Passage 2', 'Passage 3']
|
||||
|
||||
await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
# Verify that generate_content was called for each passage
|
||||
assert mock_gemini_client.aio.models.generate_content.call_count == 3
|
||||
|
||||
# Verify that all calls were made with correct parameters
|
||||
calls = mock_gemini_client.aio.models.generate_content.call_args_list
|
||||
for call in calls:
|
||||
args, kwargs = call
|
||||
assert kwargs['model'] == gemini_reranker_client.config.model
|
||||
assert kwargs['config'].temperature == 0.0
|
||||
assert kwargs['config'].max_output_tokens == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test handling of response parsing errors."""
|
||||
# Setup mock responses that will trigger ValueError during parsing
|
||||
mock_responses = [
|
||||
create_mock_response('not a number at all'), # Will fail regex match
|
||||
create_mock_response('also invalid text'), # Will fail regex match
|
||||
]
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
|
||||
|
||||
query = 'Test query'
|
||||
# Use multiple passages to avoid the single passage special case
|
||||
passages = ['Passage 1', 'Passage 2']
|
||||
|
||||
result = await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
# Should handle the error gracefully and assign 0.0 score to both
|
||||
assert len(result) == 2
|
||||
assert all(score == 0.0 for _, score in result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
|
||||
"""Test handling of empty response text."""
|
||||
# Setup mock response with empty text
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '' # Empty string instead of None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
query = 'Test query'
|
||||
# Use multiple passages to avoid the single passage special case
|
||||
passages = ['Passage 1', 'Passage 2']
|
||||
|
||||
result = await gemini_reranker_client.rank(query, passages)
|
||||
|
||||
# Should handle empty text gracefully and assign 0.0 score to both
|
||||
assert len(result) == 2
|
||||
assert all(score == 0.0 for _, score in result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-v', 'test_gemini_reranker_client.py'])
|
||||
|
|
@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Running tests: pytest -xvs tests/embedder/test_gemini.py
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
|
@ -28,10 +30,10 @@ from graphiti_core.embedder.gemini import (
|
|||
from tests.embedder.embedder_fixtures import create_embedding_values
|
||||
|
||||
|
||||
def create_gemini_embedding(multiplier: float = 0.1) -> MagicMock:
|
||||
"""Create a mock Gemini embedding with specified value multiplier."""
|
||||
def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
|
||||
"""Create a mock Gemini embedding with specified value multiplier and dimension."""
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = create_embedding_values(multiplier)
|
||||
mock_embedding.values = create_embedding_values(multiplier, dimension)
|
||||
return mock_embedding
|
||||
|
||||
|
||||
|
|
@ -75,52 +77,304 @@ def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
|
|||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_calls_api_correctly(
|
||||
gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_response: MagicMock
|
||||
) -> None:
|
||||
"""Test that create method correctly calls the API and processes the response."""
|
||||
# Setup
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
||||
class TestGeminiEmbedderInitialization:
|
||||
"""Tests for GeminiEmbedder initialization."""
|
||||
|
||||
# Call method
|
||||
result = await gemini_embedder.create('Test input')
|
||||
@patch('google.genai.Client')
|
||||
def test_init_with_config(self, mock_client):
|
||||
"""Test initialization with a config object."""
|
||||
config = GeminiEmbedderConfig(
|
||||
api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
|
||||
)
|
||||
embedder = GeminiEmbedder(config=config)
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_gemini_client.aio.models.embed_content.assert_called_once()
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
assert kwargs['contents'] == ['Test input']
|
||||
assert embedder.config == config
|
||||
assert embedder.config.embedding_model == 'custom-model'
|
||||
assert embedder.config.api_key == 'test_api_key'
|
||||
assert embedder.config.embedding_dim == 768
|
||||
|
||||
# Verify result is processed correctly
|
||||
assert result == mock_gemini_response.embeddings[0].values
|
||||
@patch('google.genai.Client')
|
||||
def test_init_without_config(self, mock_client):
|
||||
"""Test initialization without a config uses defaults."""
|
||||
embedder = GeminiEmbedder()
|
||||
|
||||
assert embedder.config is not None
|
||||
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
@patch('google.genai.Client')
|
||||
def test_init_with_partial_config(self, mock_client):
|
||||
"""Test initialization with partial config."""
|
||||
config = GeminiEmbedderConfig(api_key='test_api_key')
|
||||
embedder = GeminiEmbedder(config=config)
|
||||
|
||||
assert embedder.config.api_key == 'test_api_key'
|
||||
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_processes_multiple_inputs(
|
||||
gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_batch_response: MagicMock
|
||||
) -> None:
|
||||
"""Test that create_batch method correctly processes multiple inputs."""
|
||||
# Setup
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
|
||||
input_batch = ['Input 1', 'Input 2', 'Input 3']
|
||||
class TestGeminiEmbedderCreate:
|
||||
"""Tests for GeminiEmbedder create method."""
|
||||
|
||||
# Call method
|
||||
result = await gemini_embedder.create_batch(input_batch)
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_calls_api_correctly(
|
||||
self,
|
||||
gemini_embedder: GeminiEmbedder,
|
||||
mock_gemini_client: Any,
|
||||
mock_gemini_response: MagicMock,
|
||||
) -> None:
|
||||
"""Test that create method correctly calls the API and processes the response."""
|
||||
# Setup
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_gemini_client.aio.models.embed_content.assert_called_once()
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
assert kwargs['contents'] == input_batch
|
||||
# Call method
|
||||
result = await gemini_embedder.create('Test input')
|
||||
|
||||
# Verify all results are processed correctly
|
||||
assert len(result) == 3
|
||||
assert result == [
|
||||
mock_gemini_batch_response.embeddings[0].values,
|
||||
mock_gemini_batch_response.embeddings[1].values,
|
||||
mock_gemini_batch_response.embeddings[2].values,
|
||||
]
|
||||
# Verify API is called with correct parameters
|
||||
mock_gemini_client.aio.models.embed_content.assert_called_once()
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
assert kwargs['contents'] == ['Test input']
|
||||
|
||||
# Verify result is processed correctly
|
||||
assert result == mock_gemini_response.embeddings[0].values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('google.genai.Client')
|
||||
async def test_create_with_custom_model(
|
||||
self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
|
||||
) -> None:
|
||||
"""Test create method with custom embedding model."""
|
||||
# Setup embedder with custom model
|
||||
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
|
||||
embedder = GeminiEmbedder(config=config)
|
||||
embedder.client = mock_gemini_client
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
||||
|
||||
# Call method
|
||||
await embedder.create('Test input')
|
||||
|
||||
# Verify custom model is used
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['model'] == 'custom-model'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('google.genai.Client')
|
||||
async def test_create_with_custom_dimension(
|
||||
self, mock_client_class, mock_gemini_client: Any
|
||||
) -> None:
|
||||
"""Test create method with custom embedding dimension."""
|
||||
# Setup embedder with custom dimension
|
||||
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
|
||||
embedder = GeminiEmbedder(config=config)
|
||||
embedder.client = mock_gemini_client
|
||||
|
||||
# Setup mock response with custom dimension
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
||||
|
||||
# Call method
|
||||
result = await embedder.create('Test input')
|
||||
|
||||
# Verify custom dimension is used in config
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['config'].output_dimensionality == 768
|
||||
|
||||
# Verify result has correct dimension
|
||||
assert len(result) == 768
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_with_different_input_types(
|
||||
self,
|
||||
gemini_embedder: GeminiEmbedder,
|
||||
mock_gemini_client: Any,
|
||||
mock_gemini_response: MagicMock,
|
||||
) -> None:
|
||||
"""Test create method with different input types."""
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
||||
|
||||
# Test with string
|
||||
await gemini_embedder.create('Test string')
|
||||
|
||||
# Test with list of strings
|
||||
await gemini_embedder.create(['Test', 'List'])
|
||||
|
||||
# Test with iterable of integers
|
||||
await gemini_embedder.create([1, 2, 3])
|
||||
|
||||
# Verify all calls were made
|
||||
assert mock_gemini_client.aio.models.embed_content.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_no_embeddings_error(
|
||||
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
||||
) -> None:
|
||||
"""Test create method handling of no embeddings response."""
|
||||
# Setup mock response with no embeddings
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = []
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
||||
|
||||
# Call method and expect exception
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await gemini_embedder.create('Test input')
|
||||
|
||||
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_no_values_error(
|
||||
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
||||
) -> None:
|
||||
"""Test create method handling of embeddings with no values."""
|
||||
# Setup mock response with embedding but no values
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
||||
|
||||
# Call method and expect exception
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await gemini_embedder.create('Test input')
|
||||
|
||||
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestGeminiEmbedderCreateBatch:
|
||||
"""Tests for GeminiEmbedder create_batch method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_processes_multiple_inputs(
|
||||
self,
|
||||
gemini_embedder: GeminiEmbedder,
|
||||
mock_gemini_client: Any,
|
||||
mock_gemini_batch_response: MagicMock,
|
||||
) -> None:
|
||||
"""Test that create_batch method correctly processes multiple inputs."""
|
||||
# Setup
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
|
||||
input_batch = ['Input 1', 'Input 2', 'Input 3']
|
||||
|
||||
# Call method
|
||||
result = await gemini_embedder.create_batch(input_batch)
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_gemini_client.aio.models.embed_content.assert_called_once()
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
assert kwargs['contents'] == input_batch
|
||||
|
||||
# Verify all results are processed correctly
|
||||
assert len(result) == 3
|
||||
assert result == [
|
||||
mock_gemini_batch_response.embeddings[0].values,
|
||||
mock_gemini_batch_response.embeddings[1].values,
|
||||
mock_gemini_batch_response.embeddings[2].values,
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_single_input(
|
||||
self,
|
||||
gemini_embedder: GeminiEmbedder,
|
||||
mock_gemini_client: Any,
|
||||
mock_gemini_response: MagicMock,
|
||||
) -> None:
|
||||
"""Test create_batch method with single input."""
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
||||
input_batch = ['Single input']
|
||||
|
||||
result = await gemini_embedder.create_batch(input_batch)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == mock_gemini_response.embeddings[0].values
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_empty_input(
|
||||
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
||||
) -> None:
|
||||
"""Test create_batch method with empty input."""
|
||||
# Setup mock response with no embeddings
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = []
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
||||
|
||||
input_batch = []
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await gemini_embedder.create_batch(input_batch)
|
||||
|
||||
assert 'No embeddings returned' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_no_embeddings_error(
|
||||
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
||||
) -> None:
|
||||
"""Test create_batch method handling of no embeddings response."""
|
||||
# Setup mock response with no embeddings
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = []
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
||||
|
||||
input_batch = ['Input 1', 'Input 2']
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await gemini_embedder.create_batch(input_batch)
|
||||
|
||||
assert 'No embeddings returned' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_empty_values_error(
|
||||
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
||||
) -> None:
|
||||
"""Test create_batch method handling of embeddings with empty values."""
|
||||
# Setup mock response with embeddings but empty values
|
||||
mock_embedding1 = MagicMock()
|
||||
mock_embedding1.values = [0.1, 0.2, 0.3] # Valid values
|
||||
mock_embedding2 = MagicMock()
|
||||
mock_embedding2.values = None # Empty values
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = [mock_embedding1, mock_embedding2]
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
||||
|
||||
input_batch = ['Input 1', 'Input 2']
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await gemini_embedder.create_batch(input_batch)
|
||||
|
||||
assert 'Empty embedding values returned' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('google.genai.Client')
|
||||
async def test_create_batch_with_custom_model_and_dimension(
|
||||
self, mock_client_class, mock_gemini_client: Any
|
||||
) -> None:
|
||||
"""Test create_batch method with custom model and dimension."""
|
||||
# Setup embedder with custom settings
|
||||
config = GeminiEmbedderConfig(
|
||||
api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
|
||||
)
|
||||
embedder = GeminiEmbedder(config=config)
|
||||
embedder.client = mock_gemini_client
|
||||
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = [
|
||||
create_gemini_embedding(0.1, 512),
|
||||
create_gemini_embedding(0.2, 512),
|
||||
]
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
||||
|
||||
input_batch = ['Input 1', 'Input 2']
|
||||
result = await embedder.create_batch(input_batch)
|
||||
|
||||
# Verify custom settings are used
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['model'] == 'custom-batch-model'
|
||||
assert kwargs['config'].output_dimensionality == 512
|
||||
|
||||
# Verify results have correct dimension
|
||||
assert len(result) == 2
|
||||
assert all(len(embedding) == 512 for embedding in result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
393
tests/llm_client/test_gemini_client.py
Normal file
393
tests/llm_client/test_gemini_client.py
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# Running tests: pytest -xvs tests/llm_client/test_gemini_client.py
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.llm_client.config import LLMConfig, ModelSize
|
||||
from graphiti_core.llm_client.errors import RateLimitError
|
||||
from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
|
||||
from graphiti_core.prompts.models import Message
|
||||
|
||||
|
||||
# Test model for response testing
|
||||
class ResponseModel(BaseModel):
|
||||
"""Test model for response testing."""
|
||||
|
||||
test_field: str
|
||||
optional_field: int = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_client():
|
||||
"""Fixture to mock the Google Gemini client."""
|
||||
with patch('google.genai.Client') as mock_client:
|
||||
# Setup mock instance and its methods
|
||||
mock_instance = mock_client.return_value
|
||||
mock_instance.aio = MagicMock()
|
||||
mock_instance.aio.models = MagicMock()
|
||||
mock_instance.aio.models.generate_content = AsyncMock()
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_client(mock_gemini_client):
|
||||
"""Fixture to create a GeminiClient with a mocked client."""
|
||||
config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
|
||||
client = GeminiClient(config=config, cache=False)
|
||||
# Replace the client's client with our mock to ensure we're using the mock
|
||||
client.client = mock_gemini_client
|
||||
return client
|
||||
|
||||
|
||||
class TestGeminiClientInitialization:
|
||||
"""Tests for GeminiClient initialization."""
|
||||
|
||||
@patch('google.genai.Client')
|
||||
def test_init_with_config(self, mock_client):
|
||||
"""Test initialization with a config object."""
|
||||
config = LLMConfig(
|
||||
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
|
||||
)
|
||||
client = GeminiClient(config=config, cache=False, max_tokens=1000)
|
||||
|
||||
assert client.config == config
|
||||
assert client.model == 'test-model'
|
||||
assert client.temperature == 0.5
|
||||
assert client.max_tokens == 1000
|
||||
|
||||
@patch('google.genai.Client')
|
||||
def test_init_with_default_model(self, mock_client):
|
||||
"""Test initialization with default model when none is provided."""
|
||||
config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
|
||||
client = GeminiClient(config=config, cache=False)
|
||||
|
||||
assert client.model == DEFAULT_MODEL
|
||||
|
||||
@patch('google.genai.Client')
|
||||
def test_init_without_config(self, mock_client):
|
||||
"""Test initialization without a config uses defaults."""
|
||||
client = GeminiClient(cache=False)
|
||||
|
||||
assert client.config is not None
|
||||
# When no config.model is set, it will be None, not DEFAULT_MODEL
|
||||
assert client.model is None
|
||||
|
||||
@patch('google.genai.Client')
|
||||
def test_init_with_thinking_config(self, mock_client):
|
||||
"""Test initialization with thinking config."""
|
||||
with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
|
||||
thinking_config = mock_thinking_config.return_value
|
||||
client = GeminiClient(thinking_config=thinking_config)
|
||||
assert client.thinking_config == thinking_config
|
||||
|
||||
|
||||
class TestGeminiClientGenerateResponse:
|
||||
"""Tests for GeminiClient generate_response method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
|
||||
"""Test successful response generation with simple text."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Test response text'
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
result = await gemini_client.generate_response(messages)
|
||||
|
||||
# Assertions
|
||||
assert isinstance(result, dict)
|
||||
assert result['content'] == 'Test response text'
|
||||
mock_gemini_client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_with_structured_output(
|
||||
self, gemini_client, mock_gemini_client
|
||||
):
|
||||
"""Test response generation with structured output."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method
|
||||
messages = [
|
||||
Message(role='system', content='System message'),
|
||||
Message(role='user', content='User message'),
|
||||
]
|
||||
result = await gemini_client.generate_response(
|
||||
messages=messages, response_model=ResponseModel
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert isinstance(result, dict)
|
||||
assert result['test_field'] == 'test_value'
|
||||
assert result['optional_field'] == 42
|
||||
mock_gemini_client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
|
||||
"""Test response generation with system message handling."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Response with system context'
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method
|
||||
messages = [
|
||||
Message(role='system', content='System message'),
|
||||
Message(role='user', content='User message'),
|
||||
]
|
||||
await gemini_client.generate_response(messages)
|
||||
|
||||
# Verify system message is processed correctly
|
||||
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||
config = call_args[1]['config']
|
||||
assert 'System message' in config.system_instruction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_for_size(self, gemini_client):
|
||||
"""Test model selection based on size."""
|
||||
# Test small model
|
||||
small_model = gemini_client._get_model_for_size(ModelSize.small)
|
||||
assert small_model == DEFAULT_SMALL_MODEL
|
||||
|
||||
# Test medium/large model
|
||||
medium_model = gemini_client._get_model_for_size(ModelSize.medium)
|
||||
assert medium_model == gemini_client.model
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
|
||||
"""Test handling of rate limit errors."""
|
||||
# Setup mock to raise rate limit error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
|
||||
'Rate limit exceeded'
|
||||
)
|
||||
|
||||
# Call method and check exception
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(RateLimitError):
|
||||
await gemini_client.generate_response(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
|
||||
"""Test handling of quota errors."""
|
||||
# Setup mock to raise quota error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
|
||||
'Quota exceeded for requests'
|
||||
)
|
||||
|
||||
# Call method and check exception
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(RateLimitError):
|
||||
await gemini_client.generate_response(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
|
||||
"""Test handling of resource exhausted errors."""
|
||||
# Setup mock to raise resource exhausted error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
|
||||
'resource_exhausted: Request limit exceeded'
|
||||
)
|
||||
|
||||
# Call method and check exception
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(RateLimitError):
|
||||
await gemini_client.generate_response(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
|
||||
"""Test handling of safety blocks."""
|
||||
# Setup mock response with safety block
|
||||
mock_candidate = MagicMock()
|
||||
mock_candidate.finish_reason = 'SAFETY'
|
||||
mock_candidate.safety_ratings = [
|
||||
MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [mock_candidate]
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method and check exception
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(Exception, match='Response blocked by Gemini safety filters'):
|
||||
await gemini_client.generate_response(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
|
||||
"""Test handling of prompt blocks."""
|
||||
# Setup mock response with prompt block
|
||||
mock_prompt_feedback = MagicMock()
|
||||
mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = mock_prompt_feedback
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method and check exception
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(Exception, match='Prompt blocked by Gemini'):
|
||||
await gemini_client.generate_response(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
|
||||
"""Test handling of structured output parsing errors."""
|
||||
# Setup mock response with invalid JSON
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Invalid JSON response'
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method and check exception
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(Exception, match='Failed to parse structured response'):
|
||||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
|
||||
"""Test that safety blocks are not retried."""
|
||||
# Setup mock to raise safety error
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
|
||||
'Content blocked by safety filters'
|
||||
)
|
||||
|
||||
# Call method and check that it doesn't retry
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(Exception, match='Content blocked by safety filters'):
|
||||
await gemini_client.generate_response(messages)
|
||||
|
||||
# Should only be called once (no retries for safety blocks)
|
||||
assert mock_gemini_client.aio.models.generate_content.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
|
||||
"""Test retry behavior on validation error."""
|
||||
# First call returns invalid data, second call returns valid data
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.text = '{"wrong_field": "wrong_value"}'
|
||||
mock_response1.candidates = []
|
||||
mock_response1.prompt_feedback = None
|
||||
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.text = '{"test_field": "correct_value"}'
|
||||
mock_response2.candidates = []
|
||||
mock_response2.prompt_feedback = None
|
||||
|
||||
mock_gemini_client.aio.models.generate_content.side_effect = [
|
||||
mock_response1,
|
||||
mock_response2,
|
||||
]
|
||||
|
||||
# Call method
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
result = await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||
|
||||
# Should have called generate_content twice due to retry
|
||||
assert mock_gemini_client.aio.models.generate_content.call_count == 2
|
||||
assert result['test_field'] == 'correct_value'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
|
||||
"""Test behavior when max retries are exceeded."""
|
||||
# Setup mock to always return invalid data
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"wrong_field": "wrong_value"}'
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method and check exception
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(Exception, match='Failed to parse structured response'):
|
||||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||
|
||||
# Should have called generate_content MAX_RETRIES + 1 times
|
||||
assert (
|
||||
mock_gemini_client.aio.models.generate_content.call_count
|
||||
== GeminiClient.MAX_RETRIES + 1
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
|
||||
"""Test handling of empty responses."""
|
||||
# Setup mock response with no text
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = ''
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method with structured output and check exception
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
with pytest.raises(Exception, match='Failed to parse structured response'):
|
||||
await gemini_client.generate_response(messages, response_model=ResponseModel)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
|
||||
"""Test response generation with custom max tokens."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Test response'
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method with custom max tokens
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
await gemini_client.generate_response(messages, max_tokens=500)
|
||||
|
||||
# Verify max tokens is passed in config
|
||||
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||
config = call_args[1]['config']
|
||||
assert config.max_output_tokens == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_size_selection(self, gemini_client, mock_gemini_client):
|
||||
"""Test that the correct model is selected based on model size."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = 'Test response'
|
||||
mock_response.candidates = []
|
||||
mock_response.prompt_feedback = None
|
||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||
|
||||
# Call method with small model size
|
||||
messages = [Message(role='user', content='Test message')]
|
||||
await gemini_client.generate_response(messages, model_size=ModelSize.small)
|
||||
|
||||
# Verify correct model is used
|
||||
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||
assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-v', 'test_gemini_client.py'])
|
||||
Loading…
Add table
Reference in a new issue