Gemini client improvements; Gemini reranker (#645)

* add support for Gemini 2.5 model thinking budget

* allow adding thinking config to support current and future gemini models

* merge

* improve client; add reranker

* refactor: change type hint for gemini_messages to Any for flexibility

* refactor: update GeminiRerankerClient to use direct relevance scoring and improve ranking logic. Add tests

* fix fixtures

---------

Co-authored-by: realugbun <github.disorder751@passmail.net>
This commit is contained in:
Daniel Chalef 2025-06-30 12:55:17 -07:00 committed by GitHub
parent daec70db65
commit 689d669559
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 1335 additions and 68 deletions

View file

@ -248,7 +248,6 @@ graphiti = Graphiti(
),
client=azure_openai_client
),
# Optional: Configure the OpenAI cross encoder with Azure OpenAI
cross_encoder=OpenAIRerankerClient(
llm_config=azure_llm_config,
client=azure_openai_client
@ -262,7 +261,7 @@ Make sure to replace the placeholder values with your actual Azure OpenAI creden
## Using Graphiti with Google Gemini
Graphiti supports Google's Gemini models for both LLM inference and embeddings. To use Gemini, you'll need to configure both the LLM client and embedder with your Google API key.
Graphiti supports Google's Gemini models for LLM inference, embeddings, and cross-encoding/reranking. To use Gemini, you'll need to configure the LLM client, embedder, and the cross-encoder with your Google API key.
Install Graphiti:
@ -278,6 +277,7 @@ pip install "graphiti-core[google-genai]"
from graphiti_core import Graphiti
from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig
from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
# Google API key configuration
api_key = "<your-google-api-key>"
@ -298,12 +298,20 @@ graphiti = Graphiti(
api_key=api_key,
embedding_model="embedding-001"
)
),
cross_encoder=GeminiRerankerClient(
config=LLMConfig(
api_key=api_key,
model="gemini-2.5-flash-lite-preview-06-17"
)
)
)
# Now you can use Graphiti with Google Gemini
# Now you can use Graphiti with Google Gemini for all components
```
The Gemini reranker uses the `gemini-2.5-flash-lite-preview-06-17` model by default, which is optimized for cost-effective and low-latency classification tasks. It uses the same boolean classification approach as the OpenAI reranker, leveraging Gemini's log probabilities feature to rank passage relevance.
## Using Graphiti with Ollama (Local LLM)
Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs.

View file

@ -15,6 +15,7 @@ limitations under the License.
"""
from .client import CrossEncoderClient
from .gemini_reranker_client import GeminiRerankerClient
from .openai_reranker_client import OpenAIRerankerClient
__all__ = ['CrossEncoderClient', 'OpenAIRerankerClient']
__all__ = ['CrossEncoderClient', 'GeminiRerankerClient', 'OpenAIRerankerClient']

View file

@ -0,0 +1,146 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import re
from google import genai # type: ignore
from google.genai import types # type: ignore
from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError
from .client import CrossEncoderClient
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
class GeminiRerankerClient(CrossEncoderClient):
def __init__(
self,
config: LLMConfig | None = None,
client: genai.Client | None = None,
):
"""
Initialize the GeminiRerankerClient with the provided configuration and client.
The Gemini Developer API does not yet support logprobs. Unlike the OpenAI reranker,
this reranker uses the Gemini API to perform direct relevance scoring of passages.
Each passage is scored individually on a 0-100 scale.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
"""
if config is None:
config = LLMConfig()
self.config = config
if client is None:
self.client = genai.Client(api_key=config.api_key)
else:
self.client = client
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
"""
Rank passages based on their relevance to the query using direct scoring.
Each passage is scored individually on a 0-100 scale, then normalized to [0,1].
"""
if len(passages) <= 1:
return [(passage, 1.0) for passage in passages]
# Generate scoring prompts for each passage
scoring_prompts = []
for passage in passages:
prompt = f"""Rate how well this passage answers or relates to the query. Use a scale from 0 to 100.
Query: {query}
Passage: {passage}
Provide only a number between 0 and 100 (no explanation, just the number):"""
scoring_prompts.append(
[
types.Content(
role='user',
parts=[types.Part.from_text(text=prompt)],
),
]
)
try:
# Execute all scoring requests concurrently - O(n) API calls
responses = await semaphore_gather(
*[
self.client.aio.models.generate_content(
model=self.config.model or DEFAULT_MODEL,
contents=prompt_messages, # type: ignore
config=types.GenerateContentConfig(
system_instruction='You are an expert at rating passage relevance. Respond with only a number from 0-100.',
temperature=0.0,
max_output_tokens=3,
),
)
for prompt_messages in scoring_prompts
]
)
# Extract scores and create results
results = []
for passage, response in zip(passages, responses, strict=True):
try:
if hasattr(response, 'text') and response.text:
# Extract numeric score from response
score_text = response.text.strip()
# Handle cases where model might return non-numeric text
score_match = re.search(r'\b(\d{1,3})\b', score_text)
if score_match:
score = float(score_match.group(1))
# Normalize to [0, 1] range and clamp to valid range
normalized_score = max(0.0, min(1.0, score / 100.0))
results.append((passage, normalized_score))
else:
logger.warning(
f'Could not extract numeric score from response: {score_text}'
)
results.append((passage, 0.0))
else:
logger.warning('Empty response from Gemini for passage scoring')
results.append((passage, 0.0))
except (ValueError, AttributeError) as e:
logger.warning(f'Error parsing score from Gemini response: {e}')
results.append((passage, 0.0))
# Sort by score in descending order (highest relevance first)
results.sort(reverse=True, key=lambda x: x[1])
return results
except Exception as e:
# Check if it's a rate limit error based on Gemini API error codes
error_message = str(e).lower()
if (
'rate limit' in error_message
or 'quota' in error_message
or 'resource_exhausted' in error_message
or '429' in str(e)
):
raise RateLimitError from e
logger.error(f'Error in generating LLM response: {e}')
raise

View file

@ -17,19 +17,21 @@ limitations under the License.
import json
import logging
import typing
from typing import ClassVar
from google import genai # type: ignore
from google.genai import types # type: ignore
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.0-flash'
DEFAULT_MODEL = 'gemini-2.5-flash'
DEFAULT_SMALL_MODEL = 'models/gemini-2.5-flash-lite-preview-06-17'
class GeminiClient(LLMClient):
@ -43,27 +45,34 @@ class GeminiClient(LLMClient):
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False):
Initializes the GeminiClient with the provided configuration and cache setting.
__init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
max_tokens: int = DEFAULT_MAX_TOKENS,
thinking_config: types.ThinkingConfig | None = None,
):
"""
Initialize the GeminiClient with the provided configuration and cache setting.
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Only use with models that support thinking (gemini-2.5+). Defaults to None.
"""
if config is None:
config = LLMConfig()
@ -76,6 +85,50 @@ class GeminiClient(LLMClient):
api_key=config.api_key,
)
self.max_tokens = max_tokens
self.thinking_config = thinking_config
def _check_safety_blocks(self, response) -> None:
"""Check if response was blocked for safety reasons and raise appropriate exceptions."""
# Check if the response was blocked for safety reasons
if not (hasattr(response, 'candidates') and response.candidates):
return
candidate = response.candidates[0]
if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
return
# Content was blocked for safety reasons - collect safety details
safety_info = []
safety_ratings = getattr(candidate, 'safety_ratings', None)
if safety_ratings:
for rating in safety_ratings:
if getattr(rating, 'blocked', False):
category = getattr(rating, 'category', 'Unknown')
probability = getattr(rating, 'probability', 'Unknown')
safety_info.append(f'{category}: {probability}')
safety_details = (
', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
)
raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
def _check_prompt_blocks(self, response) -> None:
"""Check if prompt was blocked and raise appropriate exceptions."""
prompt_feedback = getattr(response, 'prompt_feedback', None)
if not prompt_feedback:
return
block_reason = getattr(prompt_feedback, 'block_reason', None)
if block_reason:
raise Exception(f'Prompt blocked by Gemini: {block_reason}')
def _get_model_for_size(self, model_size: ModelSize) -> str:
"""Get the appropriate model name based on the requested size."""
if model_size == ModelSize.small:
return self.small_model or DEFAULT_SMALL_MODEL
else:
return self.model or DEFAULT_MODEL
async def _generate_response(
self,
@ -91,17 +144,17 @@ class GeminiClient(LLMClient):
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int): The maximum number of tokens to generate in the response.
model_size (ModelSize): The size of the model to use (small or medium).
Returns:
dict[str, typing.Any]: The response from the language model.
Raises:
RateLimitError: If the API rate limit is exceeded.
RefusalError: If the content is blocked by the model.
Exception: If there is an error generating the response.
Exception: If there is an error generating the response or content is blocked.
"""
try:
gemini_messages: list[types.Content] = []
gemini_messages: typing.Any = []
# If a response model is provided, add schema for structured output
system_prompt = ''
if response_model is not None:
@ -127,6 +180,9 @@ class GeminiClient(LLMClient):
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
)
# Get the appropriate model for the requested size
model = self._get_model_for_size(model_size)
# Create generation config
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
@ -134,15 +190,20 @@ class GeminiClient(LLMClient):
response_mime_type='application/json' if response_model else None,
response_schema=response_model if response_model else None,
system_instruction=system_prompt,
thinking_config=self.thinking_config,
)
# Generate content using the simple string approach
response = await self.client.aio.models.generate_content(
model=self.model or DEFAULT_MODEL,
contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
model=model,
contents=gemini_messages,
config=generation_config,
)
# Check for safety and prompt blocks
self._check_safety_blocks(response)
self._check_prompt_blocks(response)
# If this was a structured output request, parse the response into the Pydantic model
if response_model is not None:
try:
@ -160,9 +221,16 @@ class GeminiClient(LLMClient):
return {'content': response.text}
except Exception as e:
# Check if it's a rate limit error
if 'rate limit' in str(e).lower() or 'quota' in str(e).lower():
# Check if it's a rate limit error based on Gemini API error codes
error_message = str(e).lower()
if (
'rate limit' in error_message
or 'quota' in error_message
or 'resource_exhausted' in error_message
or '429' in str(e)
):
raise RateLimitError from e
logger.error(f'Error in generating LLM response: {e}')
raise
@ -174,13 +242,14 @@ class GeminiClient(LLMClient):
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model.
This method overrides the parent class method to provide a direct implementation.
Generate a response from the Gemini language model with retry logic and error handling.
This method overrides the parent class method to provide a direct implementation with advanced retry logic.
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int): The maximum number of tokens to generate in the response.
max_tokens (int | None): The maximum number of tokens to generate in the response.
model_size (ModelSize): The size of the model to use (small or medium).
Returns:
dict[str, typing.Any]: The response from the language model.
@ -188,10 +257,53 @@ class GeminiClient(LLMClient):
if max_tokens is None:
max_tokens = self.max_tokens
# Call the internal _generate_response method
return await self._generate_response(
messages=messages,
response_model=response_model,
max_tokens=max_tokens,
model_size=model_size,
)
retry_count = 0
last_error = None
# Add multilingual extraction instructions
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(
messages=messages,
response_model=response_model,
max_tokens=max_tokens,
model_size=model_size,
)
return response
except RateLimitError:
# Rate limit errors should not trigger retries (fail fast)
raise
except Exception as e:
last_error = e
# Check if this is a safety block - these typically shouldn't be retried
if 'safety' in str(e).lower() or 'blocked' in str(e).lower():
logger.warning(f'Content blocked by safety filters: {e}')
raise
# Don't retry if we've hit the max retries
if retry_count >= self.MAX_RETRIES:
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
raise
retry_count += 1
# Construct a detailed error message for the LLM
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response, ensuring the output matches '
f'the expected format and constraints.'
)
error_message = Message(role='user', content=error_context)
messages.append(error_message)
logger.warning(
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
)
# If we somehow get here, raise the last error
raise last_error or Exception('Max retries exceeded with no specific error')

View file

@ -0,0 +1,353 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
from graphiti_core.llm_client import LLMConfig, RateLimitError
@pytest.fixture
def mock_gemini_client():
"""Fixture to mock the Google Gemini client."""
with patch('google.genai.Client') as mock_client:
# Setup mock instance and its methods
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.generate_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_reranker_client(mock_gemini_client):
"""Fixture to create a GeminiRerankerClient with a mocked client."""
config = LLMConfig(api_key='test_api_key', model='test-model')
client = GeminiRerankerClient(config=config)
# Replace the client's client with our mock to ensure we're using the mock
client.client = mock_gemini_client
return client
def create_mock_response(score_text: str) -> MagicMock:
"""Helper function to create a mock Gemini response."""
mock_response = MagicMock()
mock_response.text = score_text
return mock_response
class TestGeminiRerankerClientInitialization:
"""Tests for GeminiRerankerClient initialization."""
def test_init_with_config(self):
"""Test initialization with a config object."""
config = LLMConfig(api_key='test_api_key', model='test-model')
client = GeminiRerankerClient(config=config)
assert client.config == config
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
client = GeminiRerankerClient()
assert client.config is not None
def test_init_with_custom_client(self):
"""Test initialization with a custom client."""
mock_client = MagicMock()
client = GeminiRerankerClient(client=mock_client)
assert client.client == mock_client
class TestGeminiRerankerClientRanking:
"""Tests for GeminiRerankerClient rank method."""
@pytest.mark.asyncio
async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
"""Test basic ranking functionality."""
# Setup mock responses with different scores
mock_responses = [
create_mock_response('85'), # High relevance
create_mock_response('45'), # Medium relevance
create_mock_response('20'), # Low relevance
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
# Test data
query = 'What is the capital of France?'
passages = [
'Paris is the capital and most populous city of France.',
'London is the capital city of England and the United Kingdom.',
'Berlin is the capital and largest city of Germany.',
]
# Call method
result = await gemini_reranker_client.rank(query, passages)
# Assertions
assert len(result) == 3
assert all(isinstance(item, tuple) for item in result)
assert all(
isinstance(passage, str) and isinstance(score, float) for passage, score in result
)
# Check scores are normalized to [0, 1] and sorted in descending order
scores = [score for _, score in result]
assert all(0.0 <= score <= 1.0 for score in scores)
assert scores == sorted(scores, reverse=True)
# Check that the highest scoring passage is first
assert result[0][1] == 0.85 # 85/100
assert result[1][1] == 0.45 # 45/100
assert result[2][1] == 0.20 # 20/100
@pytest.mark.asyncio
async def test_rank_empty_passages(self, gemini_reranker_client):
"""Test ranking with empty passages list."""
query = 'Test query'
passages = []
result = await gemini_reranker_client.rank(query, passages)
assert result == []
@pytest.mark.asyncio
async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
"""Test ranking with a single passage."""
# Setup mock response
mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')
query = 'Test query'
passages = ['Single test passage']
result = await gemini_reranker_client.rank(query, passages)
assert len(result) == 1
assert result[0][0] == 'Single test passage'
assert result[0][1] == 1.0 # Single passage gets full score
@pytest.mark.asyncio
async def test_rank_score_extraction_with_regex(
self, gemini_reranker_client, mock_gemini_client
):
"""Test score extraction from various response formats."""
# Setup mock responses with different formats
mock_responses = [
create_mock_response('Score: 90'), # Contains text before number
create_mock_response('The relevance is 65 out of 100'), # Contains text around number
create_mock_response('8'), # Just the number
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that scores were extracted correctly and normalized
scores = [score for _, score in result]
assert 0.90 in scores # 90/100
assert 0.65 in scores # 65/100
assert 0.08 in scores # 8/100
@pytest.mark.asyncio
async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of invalid or non-numeric scores."""
# Setup mock responses with invalid scores
mock_responses = [
create_mock_response('Not a number'), # Invalid response
create_mock_response(''), # Empty response
create_mock_response('95'), # Valid response
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that invalid scores are handled gracefully (assigned 0.0)
scores = [score for _, score in result]
assert 0.95 in scores # Valid score
assert scores.count(0.0) == 2 # Two invalid scores assigned 0.0
@pytest.mark.asyncio
async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
"""Test that scores are properly clamped to [0, 1] range."""
# Setup mock responses with extreme scores
# Note: regex only matches 1-3 digits, so negative numbers won't match
mock_responses = [
create_mock_response('999'), # Above 100 but within regex range
create_mock_response('invalid'), # Invalid response becomes 0.0
create_mock_response('50'), # Normal score
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that scores are normalized and clamped
scores = [score for _, score in result]
assert all(0.0 <= score <= 1.0 for score in scores)
# 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
assert 1.0 in scores
# Invalid response should be 0.0
assert 0.0 in scores
# Normal score should be normalized (50/100 = 0.5)
assert 0.5 in scores
@pytest.mark.asyncio
async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of rate limit errors."""
# Setup mock to raise rate limit error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Rate limit exceeded'
)
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of quota errors."""
# Setup mock to raise quota error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of resource exhausted errors."""
# Setup mock to raise resource exhausted error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of HTTP 429 errors."""
# Setup mock to raise 429 error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'HTTP 429 Too Many Requests'
)
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of generic errors."""
# Setup mock to raise generic error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(Exception) as exc_info:
await gemini_reranker_client.rank(query, passages)
assert 'Generic error' in str(exc_info.value)
@pytest.mark.asyncio
async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
"""Test that multiple passages are scored concurrently."""
# Setup mock responses
mock_responses = [
create_mock_response('80'),
create_mock_response('60'),
create_mock_response('40'),
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
await gemini_reranker_client.rank(query, passages)
# Verify that generate_content was called for each passage
assert mock_gemini_client.aio.models.generate_content.call_count == 3
# Verify that all calls were made with correct parameters
calls = mock_gemini_client.aio.models.generate_content.call_args_list
for call in calls:
args, kwargs = call
assert kwargs['model'] == gemini_reranker_client.config.model
assert kwargs['config'].temperature == 0.0
assert kwargs['config'].max_output_tokens == 3
@pytest.mark.asyncio
async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of response parsing errors."""
# Setup mock responses that will trigger ValueError during parsing
mock_responses = [
create_mock_response('not a number at all'), # Will fail regex match
create_mock_response('also invalid text'), # Will fail regex match
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
# Use multiple passages to avoid the single passage special case
passages = ['Passage 1', 'Passage 2']
result = await gemini_reranker_client.rank(query, passages)
# Should handle the error gracefully and assign 0.0 score to both
assert len(result) == 2
assert all(score == 0.0 for _, score in result)
@pytest.mark.asyncio
async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of empty response text."""
# Setup mock response with empty text
mock_response = MagicMock()
mock_response.text = '' # Empty string instead of None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
query = 'Test query'
# Use multiple passages to avoid the single passage special case
passages = ['Passage 1', 'Passage 2']
result = await gemini_reranker_client.rank(query, passages)
# Should handle empty text gracefully and assign 0.0 score to both
assert len(result) == 2
assert all(score == 0.0 for _, score in result)
if __name__ == '__main__':
pytest.main(['-v', 'test_gemini_reranker_client.py'])

View file

@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/embedder/test_gemini.py
from collections.abc import Generator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@ -28,10 +30,10 @@ from graphiti_core.embedder.gemini import (
from tests.embedder.embedder_fixtures import create_embedding_values
def create_gemini_embedding(multiplier: float = 0.1) -> MagicMock:
"""Create a mock Gemini embedding with specified value multiplier."""
def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
"""Create a mock Gemini embedding with specified value multiplier and dimension."""
mock_embedding = MagicMock()
mock_embedding.values = create_embedding_values(multiplier)
mock_embedding.values = create_embedding_values(multiplier, dimension)
return mock_embedding
@ -75,52 +77,304 @@ def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
return client
@pytest.mark.asyncio
async def test_create_calls_api_correctly(
gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_response: MagicMock
) -> None:
"""Test that create method correctly calls the API and processes the response."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
class TestGeminiEmbedderInitialization:
"""Tests for GeminiEmbedder initialization."""
# Call method
result = await gemini_embedder.create('Test input')
@patch('google.genai.Client')
def test_init_with_config(self, mock_client):
"""Test initialization with a config object."""
config = GeminiEmbedderConfig(
api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
)
embedder = GeminiEmbedder(config=config)
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == ['Test input']
assert embedder.config == config
assert embedder.config.embedding_model == 'custom-model'
assert embedder.config.api_key == 'test_api_key'
assert embedder.config.embedding_dim == 768
# Verify result is processed correctly
assert result == mock_gemini_response.embeddings[0].values
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
embedder = GeminiEmbedder()
assert embedder.config is not None
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
@patch('google.genai.Client')
def test_init_with_partial_config(self, mock_client):
"""Test initialization with partial config."""
config = GeminiEmbedderConfig(api_key='test_api_key')
embedder = GeminiEmbedder(config=config)
assert embedder.config.api_key == 'test_api_key'
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
@pytest.mark.asyncio
async def test_create_batch_processes_multiple_inputs(
gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_batch_response: MagicMock
) -> None:
"""Test that create_batch method correctly processes multiple inputs."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
input_batch = ['Input 1', 'Input 2', 'Input 3']
class TestGeminiEmbedderCreate:
"""Tests for GeminiEmbedder create method."""
# Call method
result = await gemini_embedder.create_batch(input_batch)
@pytest.mark.asyncio
async def test_create_calls_api_correctly(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test that create method correctly calls the API and processes the response."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == input_batch
# Call method
result = await gemini_embedder.create('Test input')
# Verify all results are processed correctly
assert len(result) == 3
assert result == [
mock_gemini_batch_response.embeddings[0].values,
mock_gemini_batch_response.embeddings[1].values,
mock_gemini_batch_response.embeddings[2].values,
]
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == ['Test input']
# Verify result is processed correctly
assert result == mock_gemini_response.embeddings[0].values
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_with_custom_model(
self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
) -> None:
"""Test create method with custom embedding model."""
# Setup embedder with custom model
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Call method
await embedder.create('Test input')
# Verify custom model is used
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == 'custom-model'
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_with_custom_dimension(
self, mock_client_class, mock_gemini_client: Any
) -> None:
"""Test create method with custom embedding dimension."""
# Setup embedder with custom dimension
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
# Setup mock response with custom dimension
mock_response = MagicMock()
mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method
result = await embedder.create('Test input')
# Verify custom dimension is used in config
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['config'].output_dimensionality == 768
# Verify result has correct dimension
assert len(result) == 768
@pytest.mark.asyncio
async def test_create_with_different_input_types(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test create method with different input types."""
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Test with string
await gemini_embedder.create('Test string')
# Test with list of strings
await gemini_embedder.create(['Test', 'List'])
# Test with iterable of integers
await gemini_embedder.create([1, 2, 3])
# Verify all calls were made
assert mock_gemini_client.aio.models.embed_content.call_count == 3
@pytest.mark.asyncio
async def test_create_no_embeddings_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create method handling of no embeddings response."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method and expect exception
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create('Test input')
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_no_values_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create method handling of embeddings with no values."""
# Setup mock response with embedding but no values
mock_embedding = MagicMock()
mock_embedding.values = None
mock_response = MagicMock()
mock_response.embeddings = [mock_embedding]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method and expect exception
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create('Test input')
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
class TestGeminiEmbedderCreateBatch:
"""Tests for GeminiEmbedder create_batch method."""
@pytest.mark.asyncio
async def test_create_batch_processes_multiple_inputs(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_batch_response: MagicMock,
) -> None:
"""Test that create_batch method correctly processes multiple inputs."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
input_batch = ['Input 1', 'Input 2', 'Input 3']
# Call method
result = await gemini_embedder.create_batch(input_batch)
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == input_batch
# Verify all results are processed correctly
assert len(result) == 3
assert result == [
mock_gemini_batch_response.embeddings[0].values,
mock_gemini_batch_response.embeddings[1].values,
mock_gemini_batch_response.embeddings[2].values,
]
@pytest.mark.asyncio
async def test_create_batch_single_input(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test create_batch method with single input."""
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
input_batch = ['Single input']
result = await gemini_embedder.create_batch(input_batch)
assert len(result) == 1
assert result[0] == mock_gemini_response.embeddings[0].values
@pytest.mark.asyncio
async def test_create_batch_empty_input(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method with empty input."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = []
with pytest.raises(Exception) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'No embeddings returned' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_batch_no_embeddings_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method handling of no embeddings response."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
with pytest.raises(Exception) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'No embeddings returned' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_batch_empty_values_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method handling of embeddings with empty values."""
# Setup mock response with embeddings but empty values
mock_embedding1 = MagicMock()
mock_embedding1.values = [0.1, 0.2, 0.3] # Valid values
mock_embedding2 = MagicMock()
mock_embedding2.values = None # Empty values
mock_response = MagicMock()
mock_response.embeddings = [mock_embedding1, mock_embedding2]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'Empty embedding values returned' in str(exc_info.value)
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_batch_with_custom_model_and_dimension(
self, mock_client_class, mock_gemini_client: Any
) -> None:
"""Test create_batch method with custom model and dimension."""
# Setup embedder with custom settings
config = GeminiEmbedderConfig(
api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
)
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
# Setup mock response
mock_response = MagicMock()
mock_response.embeddings = [
create_gemini_embedding(0.1, 512),
create_gemini_embedding(0.2, 512),
]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
result = await embedder.create_batch(input_batch)
# Verify custom settings are used
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == 'custom-batch-model'
assert kwargs['config'].output_dimensionality == 512
# Verify results have correct dimension
assert len(result) == 2
assert all(len(embedding) == 512 for embedding in result)
if __name__ == '__main__':

View file

@ -0,0 +1,393 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/llm_client/test_gemini_client.py
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel
from graphiti_core.llm_client.config import LLMConfig, ModelSize
from graphiti_core.llm_client.errors import RateLimitError
from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
from graphiti_core.prompts.models import Message
# Test model for response testing
class ResponseModel(BaseModel):
"""Test model for response testing."""
test_field: str
optional_field: int = 0
@pytest.fixture
def mock_gemini_client():
"""Fixture to mock the Google Gemini client."""
with patch('google.genai.Client') as mock_client:
# Setup mock instance and its methods
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.generate_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_client(mock_gemini_client):
"""Fixture to create a GeminiClient with a mocked client."""
config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
client = GeminiClient(config=config, cache=False)
# Replace the client's client with our mock to ensure we're using the mock
client.client = mock_gemini_client
return client
class TestGeminiClientInitialization:
"""Tests for GeminiClient initialization."""
@patch('google.genai.Client')
def test_init_with_config(self, mock_client):
"""Test initialization with a config object."""
config = LLMConfig(
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
)
client = GeminiClient(config=config, cache=False, max_tokens=1000)
assert client.config == config
assert client.model == 'test-model'
assert client.temperature == 0.5
assert client.max_tokens == 1000
@patch('google.genai.Client')
def test_init_with_default_model(self, mock_client):
"""Test initialization with default model when none is provided."""
config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
client = GeminiClient(config=config, cache=False)
assert client.model == DEFAULT_MODEL
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
client = GeminiClient(cache=False)
assert client.config is not None
# When no config.model is set, it will be None, not DEFAULT_MODEL
assert client.model is None
@patch('google.genai.Client')
def test_init_with_thinking_config(self, mock_client):
"""Test initialization with thinking config."""
with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
thinking_config = mock_thinking_config.return_value
client = GeminiClient(thinking_config=thinking_config)
assert client.thinking_config == thinking_config
class TestGeminiClientGenerateResponse:
"""Tests for GeminiClient generate_response method."""
@pytest.mark.asyncio
async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
"""Test successful response generation with simple text."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response text'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [Message(role='user', content='Test message')]
result = await gemini_client.generate_response(messages)
# Assertions
assert isinstance(result, dict)
assert result['content'] == 'Test response text'
mock_gemini_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_response_with_structured_output(
self, gemini_client, mock_gemini_client
):
"""Test response generation with structured output."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [
Message(role='system', content='System message'),
Message(role='user', content='User message'),
]
result = await gemini_client.generate_response(
messages=messages, response_model=ResponseModel
)
# Assertions
assert isinstance(result, dict)
assert result['test_field'] == 'test_value'
assert result['optional_field'] == 42
mock_gemini_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
"""Test response generation with system message handling."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Response with system context'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [
Message(role='system', content='System message'),
Message(role='user', content='User message'),
]
await gemini_client.generate_response(messages)
# Verify system message is processed correctly
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
assert 'System message' in config.system_instruction
@pytest.mark.asyncio
async def test_get_model_for_size(self, gemini_client):
"""Test model selection based on size."""
# Test small model
small_model = gemini_client._get_model_for_size(ModelSize.small)
assert small_model == DEFAULT_SMALL_MODEL
# Test medium/large model
medium_model = gemini_client._get_model_for_size(ModelSize.medium)
assert medium_model == gemini_client.model
@pytest.mark.asyncio
async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of rate limit errors."""
# Setup mock to raise rate limit error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Rate limit exceeded'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of quota errors."""
# Setup mock to raise quota error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Quota exceeded for requests'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of resource exhausted errors."""
# Setup mock to raise resource exhausted error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'resource_exhausted: Request limit exceeded'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
"""Test handling of safety blocks."""
# Setup mock response with safety block
mock_candidate = MagicMock()
mock_candidate.finish_reason = 'SAFETY'
mock_candidate.safety_ratings = [
MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
]
mock_response = MagicMock()
mock_response.candidates = [mock_candidate]
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Response blocked by Gemini safety filters'):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
"""Test handling of prompt blocks."""
# Setup mock response with prompt block
mock_prompt_feedback = MagicMock()
mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'
mock_response = MagicMock()
mock_response.candidates = []
mock_response.prompt_feedback = mock_prompt_feedback
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Prompt blocked by Gemini'):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
"""Test handling of structured output parsing errors."""
# Setup mock response with invalid JSON
mock_response = MagicMock()
mock_response.text = 'Invalid JSON response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Failed to parse structured response'):
await gemini_client.generate_response(messages, response_model=ResponseModel)
@pytest.mark.asyncio
async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
"""Test that safety blocks are not retried."""
# Setup mock to raise safety error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Content blocked by safety filters'
)
# Call method and check that it doesn't retry
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Content blocked by safety filters'):
await gemini_client.generate_response(messages)
# Should only be called once (no retries for safety blocks)
assert mock_gemini_client.aio.models.generate_content.call_count == 1
@pytest.mark.asyncio
async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
"""Test retry behavior on validation error."""
# First call returns invalid data, second call returns valid data
mock_response1 = MagicMock()
mock_response1.text = '{"wrong_field": "wrong_value"}'
mock_response1.candidates = []
mock_response1.prompt_feedback = None
mock_response2 = MagicMock()
mock_response2.text = '{"test_field": "correct_value"}'
mock_response2.candidates = []
mock_response2.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.side_effect = [
mock_response1,
mock_response2,
]
# Call method
messages = [Message(role='user', content='Test message')]
result = await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content twice due to retry
assert mock_gemini_client.aio.models.generate_content.call_count == 2
assert result['test_field'] == 'correct_value'
@pytest.mark.asyncio
async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
"""Test behavior when max retries are exceeded."""
# Setup mock to always return invalid data
mock_response = MagicMock()
mock_response.text = '{"wrong_field": "wrong_value"}'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Failed to parse structured response'):
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES + 1 times
assert (
mock_gemini_client.aio.models.generate_content.call_count
== GeminiClient.MAX_RETRIES + 1
)
@pytest.mark.asyncio
async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
"""Test handling of empty responses."""
# Setup mock response with no text
mock_response = MagicMock()
mock_response.text = ''
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with structured output and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Failed to parse structured response'):
await gemini_client.generate_response(messages, response_model=ResponseModel)
@pytest.mark.asyncio
async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
"""Test response generation with custom max tokens."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with custom max tokens
messages = [Message(role='user', content='Test message')]
await gemini_client.generate_response(messages, max_tokens=500)
# Verify max tokens is passed in config
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
assert config.max_output_tokens == 500
@pytest.mark.asyncio
async def test_model_size_selection(self, gemini_client, mock_gemini_client):
"""Test that the correct model is selected based on model size."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with small model size
messages = [Message(role='user', content='Test message')]
await gemini_client.generate_response(messages, model_size=ModelSize.small)
# Verify correct model is used
call_args = mock_gemini_client.aio.models.generate_content.call_args
assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
if __name__ == '__main__':
pytest.main(['-v', 'test_gemini_client.py'])