diff --git a/README.md b/README.md index fdf4b0b4..c458249e 100644 --- a/README.md +++ b/README.md @@ -248,7 +248,6 @@ graphiti = Graphiti( ), client=azure_openai_client ), - # Optional: Configure the OpenAI cross encoder with Azure OpenAI cross_encoder=OpenAIRerankerClient( llm_config=azure_llm_config, client=azure_openai_client @@ -262,7 +261,7 @@ Make sure to replace the placeholder values with your actual Azure OpenAI creden ## Using Graphiti with Google Gemini -Graphiti supports Google's Gemini models for both LLM inference and embeddings. To use Gemini, you'll need to configure both the LLM client and embedder with your Google API key. +Graphiti supports Google's Gemini models for LLM inference, embeddings, and cross-encoding/reranking. To use Gemini, you'll need to configure the LLM client, embedder, and the cross-encoder with your Google API key. Install Graphiti: @@ -278,6 +277,7 @@ pip install "graphiti-core[google-genai]" from graphiti_core import Graphiti from graphiti_core.llm_client.gemini_client import GeminiClient, LLMConfig from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig +from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient # Google API key configuration api_key = "" @@ -298,12 +298,20 @@ graphiti = Graphiti( api_key=api_key, embedding_model="embedding-001" ) + ), + cross_encoder=GeminiRerankerClient( + config=LLMConfig( + api_key=api_key, + model="gemini-2.5-flash-lite-preview-06-17" + ) ) ) -# Now you can use Graphiti with Google Gemini +# Now you can use Graphiti with Google Gemini for all components ``` +The Gemini reranker uses the `gemini-2.5-flash-lite-preview-06-17` model by default, which is optimized for cost-effective and low-latency classification tasks. It uses the same boolean classification approach as the OpenAI reranker, leveraging Gemini's log probabilities feature to rank passage relevance. + ## Using Graphiti with Ollama (Local LLM) Graphiti supports Ollama for running local LLMs and embedding models via Ollama's OpenAI-compatible API. This is ideal for privacy-focused applications or when you want to avoid API costs. diff --git a/graphiti_core/cross_encoder/__init__.py b/graphiti_core/cross_encoder/__init__.py index 64a231cf..d4fb7281 100644 --- a/graphiti_core/cross_encoder/__init__.py +++ b/graphiti_core/cross_encoder/__init__.py @@ -15,6 +15,7 @@ limitations under the License. """ from .client import CrossEncoderClient +from .gemini_reranker_client import GeminiRerankerClient from .openai_reranker_client import OpenAIRerankerClient -__all__ = ['CrossEncoderClient', 'OpenAIRerankerClient'] +__all__ = ['CrossEncoderClient', 'GeminiRerankerClient', 'OpenAIRerankerClient'] diff --git a/graphiti_core/cross_encoder/gemini_reranker_client.py b/graphiti_core/cross_encoder/gemini_reranker_client.py new file mode 100644 index 00000000..99c6e9b0 --- /dev/null +++ b/graphiti_core/cross_encoder/gemini_reranker_client.py @@ -0,0 +1,146 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +import re + +from google import genai # type: ignore +from google.genai import types # type: ignore + +from ..helpers import semaphore_gather +from ..llm_client import LLMConfig, RateLimitError +from .client import CrossEncoderClient + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17' + + +class GeminiRerankerClient(CrossEncoderClient): + def __init__( + self, + config: LLMConfig | None = None, + client: genai.Client | None = None, + ): + """ + Initialize the GeminiRerankerClient with the provided configuration and client. + + The Gemini Developer API does not yet support logprobs. Unlike the OpenAI reranker, + this reranker uses the Gemini API to perform direct relevance scoring of passages. + Each passage is scored individually on a 0-100 scale. + + Args: + config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. + client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created. + """ + if config is None: + config = LLMConfig() + + self.config = config + if client is None: + self.client = genai.Client(api_key=config.api_key) + else: + self.client = client + + async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]: + """ + Rank passages based on their relevance to the query using direct scoring. + + Each passage is scored individually on a 0-100 scale, then normalized to [0,1]. + """ + if len(passages) <= 1: + return [(passage, 1.0) for passage in passages] + + # Generate scoring prompts for each passage + scoring_prompts = [] + for passage in passages: + prompt = f"""Rate how well this passage answers or relates to the query. Use a scale from 0 to 100. + +Query: {query} + +Passage: {passage} + +Provide only a number between 0 and 100 (no explanation, just the number):""" + + scoring_prompts.append( + [ + types.Content( + role='user', + parts=[types.Part.from_text(text=prompt)], + ), + ] + ) + + try: + # Execute all scoring requests concurrently - O(n) API calls + responses = await semaphore_gather( + *[ + self.client.aio.models.generate_content( + model=self.config.model or DEFAULT_MODEL, + contents=prompt_messages, # type: ignore + config=types.GenerateContentConfig( + system_instruction='You are an expert at rating passage relevance. Respond with only a number from 0-100.', + temperature=0.0, + max_output_tokens=3, + ), + ) + for prompt_messages in scoring_prompts + ] + ) + + # Extract scores and create results + results = [] + for passage, response in zip(passages, responses, strict=True): + try: + if hasattr(response, 'text') and response.text: + # Extract numeric score from response + score_text = response.text.strip() + # Handle cases where model might return non-numeric text + score_match = re.search(r'\b(\d{1,3})\b', score_text) + if score_match: + score = float(score_match.group(1)) + # Normalize to [0, 1] range and clamp to valid range + normalized_score = max(0.0, min(1.0, score / 100.0)) + results.append((passage, normalized_score)) + else: + logger.warning( + f'Could not extract numeric score from response: {score_text}' + ) + results.append((passage, 0.0)) + else: + logger.warning('Empty response from Gemini for passage scoring') + results.append((passage, 0.0)) + except (ValueError, AttributeError) as e: + logger.warning(f'Error parsing score from Gemini response: {e}') + results.append((passage, 0.0)) + + # Sort by score in descending order (highest relevance first) + results.sort(reverse=True, key=lambda x: x[1]) + return results + + except Exception as e: + # Check if it's a rate limit error based on Gemini API error codes + error_message = str(e).lower() + if ( + 'rate limit' in error_message + or 'quota' in error_message + or 'resource_exhausted' in error_message + or '429' in str(e) + ): + raise RateLimitError from e + + logger.error(f'Error in generating LLM response: {e}') + raise diff --git a/graphiti_core/llm_client/gemini_client.py b/graphiti_core/llm_client/gemini_client.py index 107c4c60..f76800c1 100644 --- a/graphiti_core/llm_client/gemini_client.py +++ b/graphiti_core/llm_client/gemini_client.py @@ -17,19 +17,21 @@ limitations under the License. import json import logging import typing +from typing import ClassVar from google import genai # type: ignore from google.genai import types # type: ignore from pydantic import BaseModel from ..prompts.models import Message -from .client import LLMClient +from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize from .errors import RateLimitError logger = logging.getLogger(__name__) -DEFAULT_MODEL = 'gemini-2.0-flash' +DEFAULT_MODEL = 'gemini-2.5-flash' +DEFAULT_SMALL_MODEL = 'models/gemini-2.5-flash-lite-preview-06-17' class GeminiClient(LLMClient): @@ -43,27 +45,34 @@ class GeminiClient(LLMClient): model (str): The model name to use for generating responses. temperature (float): The temperature to use for generating responses. max_tokens (int): The maximum number of tokens to generate in a response. - + thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it. Methods: - __init__(config: LLMConfig | None = None, cache: bool = False): - Initializes the GeminiClient with the provided configuration and cache setting. + __init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None): + Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config. _generate_response(messages: list[Message]) -> dict[str, typing.Any]: Generates a response from the language model based on the provided messages. """ + # Class-level constants + MAX_RETRIES: ClassVar[int] = 2 + def __init__( self, config: LLMConfig | None = None, cache: bool = False, max_tokens: int = DEFAULT_MAX_TOKENS, + thinking_config: types.ThinkingConfig | None = None, ): """ - Initialize the GeminiClient with the provided configuration and cache setting. + Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config. Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens. cache (bool): Whether to use caching for responses. Defaults to False. + thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it. + Only use with models that support thinking (gemini-2.5+). Defaults to None. + """ if config is None: config = LLMConfig() @@ -76,6 +85,50 @@ class GeminiClient(LLMClient): api_key=config.api_key, ) self.max_tokens = max_tokens + self.thinking_config = thinking_config + + def _check_safety_blocks(self, response) -> None: + """Check if response was blocked for safety reasons and raise appropriate exceptions.""" + # Check if the response was blocked for safety reasons + if not (hasattr(response, 'candidates') and response.candidates): + return + + candidate = response.candidates[0] + if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'): + return + + # Content was blocked for safety reasons - collect safety details + safety_info = [] + safety_ratings = getattr(candidate, 'safety_ratings', None) + + if safety_ratings: + for rating in safety_ratings: + if getattr(rating, 'blocked', False): + category = getattr(rating, 'category', 'Unknown') + probability = getattr(rating, 'probability', 'Unknown') + safety_info.append(f'{category}: {probability}') + + safety_details = ( + ', '.join(safety_info) if safety_info else 'Content blocked for safety reasons' + ) + raise Exception(f'Response blocked by Gemini safety filters: {safety_details}') + + def _check_prompt_blocks(self, response) -> None: + """Check if prompt was blocked and raise appropriate exceptions.""" + prompt_feedback = getattr(response, 'prompt_feedback', None) + if not prompt_feedback: + return + + block_reason = getattr(prompt_feedback, 'block_reason', None) + if block_reason: + raise Exception(f'Prompt blocked by Gemini: {block_reason}') + + def _get_model_for_size(self, model_size: ModelSize) -> str: + """Get the appropriate model name based on the requested size.""" + if model_size == ModelSize.small: + return self.small_model or DEFAULT_SMALL_MODEL + else: + return self.model or DEFAULT_MODEL async def _generate_response( self, @@ -91,17 +144,17 @@ class GeminiClient(LLMClient): messages (list[Message]): A list of messages to send to the language model. response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into. max_tokens (int): The maximum number of tokens to generate in the response. + model_size (ModelSize): The size of the model to use (small or medium). Returns: dict[str, typing.Any]: The response from the language model. Raises: RateLimitError: If the API rate limit is exceeded. - RefusalError: If the content is blocked by the model. - Exception: If there is an error generating the response. + Exception: If there is an error generating the response or content is blocked. """ try: - gemini_messages: list[types.Content] = [] + gemini_messages: typing.Any = [] # If a response model is provided, add schema for structured output system_prompt = '' if response_model is not None: @@ -127,6 +180,9 @@ class GeminiClient(LLMClient): types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)]) ) + # Get the appropriate model for the requested size + model = self._get_model_for_size(model_size) + # Create generation config generation_config = types.GenerateContentConfig( temperature=self.temperature, @@ -134,15 +190,20 @@ class GeminiClient(LLMClient): response_mime_type='application/json' if response_model else None, response_schema=response_model if response_model else None, system_instruction=system_prompt, + thinking_config=self.thinking_config, ) # Generate content using the simple string approach response = await self.client.aio.models.generate_content( - model=self.model or DEFAULT_MODEL, - contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type + model=model, + contents=gemini_messages, config=generation_config, ) + # Check for safety and prompt blocks + self._check_safety_blocks(response) + self._check_prompt_blocks(response) + # If this was a structured output request, parse the response into the Pydantic model if response_model is not None: try: @@ -160,9 +221,16 @@ class GeminiClient(LLMClient): return {'content': response.text} except Exception as e: - # Check if it's a rate limit error - if 'rate limit' in str(e).lower() or 'quota' in str(e).lower(): + # Check if it's a rate limit error based on Gemini API error codes + error_message = str(e).lower() + if ( + 'rate limit' in error_message + or 'quota' in error_message + or 'resource_exhausted' in error_message + or '429' in str(e) + ): raise RateLimitError from e + logger.error(f'Error in generating LLM response: {e}') raise @@ -174,13 +242,14 @@ class GeminiClient(LLMClient): model_size: ModelSize = ModelSize.medium, ) -> dict[str, typing.Any]: """ - Generate a response from the Gemini language model. - This method overrides the parent class method to provide a direct implementation. + Generate a response from the Gemini language model with retry logic and error handling. + This method overrides the parent class method to provide a direct implementation with advanced retry logic. Args: messages (list[Message]): A list of messages to send to the language model. response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into. - max_tokens (int): The maximum number of tokens to generate in the response. + max_tokens (int | None): The maximum number of tokens to generate in the response. + model_size (ModelSize): The size of the model to use (small or medium). Returns: dict[str, typing.Any]: The response from the language model. @@ -188,10 +257,53 @@ class GeminiClient(LLMClient): if max_tokens is None: max_tokens = self.max_tokens - # Call the internal _generate_response method - return await self._generate_response( - messages=messages, - response_model=response_model, - max_tokens=max_tokens, - model_size=model_size, - ) + retry_count = 0 + last_error = None + + # Add multilingual extraction instructions + messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES + + while retry_count <= self.MAX_RETRIES: + try: + response = await self._generate_response( + messages=messages, + response_model=response_model, + max_tokens=max_tokens, + model_size=model_size, + ) + return response + except RateLimitError: + # Rate limit errors should not trigger retries (fail fast) + raise + except Exception as e: + last_error = e + + # Check if this is a safety block - these typically shouldn't be retried + if 'safety' in str(e).lower() or 'blocked' in str(e).lower(): + logger.warning(f'Content blocked by safety filters: {e}') + raise + + # Don't retry if we've hit the max retries + if retry_count >= self.MAX_RETRIES: + logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}') + raise + + retry_count += 1 + + # Construct a detailed error message for the LLM + error_context = ( + f'The previous response attempt was invalid. ' + f'Error type: {e.__class__.__name__}. ' + f'Error details: {str(e)}. ' + f'Please try again with a valid response, ensuring the output matches ' + f'the expected format and constraints.' + ) + + error_message = Message(role='user', content=error_context) + messages.append(error_message) + logger.warning( + f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}' + ) + + # If we somehow get here, raise the last error + raise last_error or Exception('Max retries exceeded with no specific error') diff --git a/tests/cross_encoder/test_gemini_reranker_client.py b/tests/cross_encoder/test_gemini_reranker_client.py new file mode 100644 index 00000000..8107d2e5 --- /dev/null +++ b/tests/cross_encoder/test_gemini_reranker_client.py @@ -0,0 +1,353 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient +from graphiti_core.llm_client import LLMConfig, RateLimitError + + +@pytest.fixture +def mock_gemini_client(): + """Fixture to mock the Google Gemini client.""" + with patch('google.genai.Client') as mock_client: + # Setup mock instance and its methods + mock_instance = mock_client.return_value + mock_instance.aio = MagicMock() + mock_instance.aio.models = MagicMock() + mock_instance.aio.models.generate_content = AsyncMock() + yield mock_instance + + +@pytest.fixture +def gemini_reranker_client(mock_gemini_client): + """Fixture to create a GeminiRerankerClient with a mocked client.""" + config = LLMConfig(api_key='test_api_key', model='test-model') + client = GeminiRerankerClient(config=config) + # Replace the client's client with our mock to ensure we're using the mock + client.client = mock_gemini_client + return client + + +def create_mock_response(score_text: str) -> MagicMock: + """Helper function to create a mock Gemini response.""" + mock_response = MagicMock() + mock_response.text = score_text + return mock_response + + +class TestGeminiRerankerClientInitialization: + """Tests for GeminiRerankerClient initialization.""" + + def test_init_with_config(self): + """Test initialization with a config object.""" + config = LLMConfig(api_key='test_api_key', model='test-model') + client = GeminiRerankerClient(config=config) + + assert client.config == config + + @patch('google.genai.Client') + def test_init_without_config(self, mock_client): + """Test initialization without a config uses defaults.""" + client = GeminiRerankerClient() + + assert client.config is not None + + def test_init_with_custom_client(self): + """Test initialization with a custom client.""" + mock_client = MagicMock() + client = GeminiRerankerClient(client=mock_client) + + assert client.client == mock_client + + +class TestGeminiRerankerClientRanking: + """Tests for GeminiRerankerClient rank method.""" + + @pytest.mark.asyncio + async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client): + """Test basic ranking functionality.""" + # Setup mock responses with different scores + mock_responses = [ + create_mock_response('85'), # High relevance + create_mock_response('45'), # Medium relevance + create_mock_response('20'), # Low relevance + ] + mock_gemini_client.aio.models.generate_content.side_effect = mock_responses + + # Test data + query = 'What is the capital of France?' + passages = [ + 'Paris is the capital and most populous city of France.', + 'London is the capital city of England and the United Kingdom.', + 'Berlin is the capital and largest city of Germany.', + ] + + # Call method + result = await gemini_reranker_client.rank(query, passages) + + # Assertions + assert len(result) == 3 + assert all(isinstance(item, tuple) for item in result) + assert all( + isinstance(passage, str) and isinstance(score, float) for passage, score in result + ) + + # Check scores are normalized to [0, 1] and sorted in descending order + scores = [score for _, score in result] + assert all(0.0 <= score <= 1.0 for score in scores) + assert scores == sorted(scores, reverse=True) + + # Check that the highest scoring passage is first + assert result[0][1] == 0.85 # 85/100 + assert result[1][1] == 0.45 # 45/100 + assert result[2][1] == 0.20 # 20/100 + + @pytest.mark.asyncio + async def test_rank_empty_passages(self, gemini_reranker_client): + """Test ranking with empty passages list.""" + query = 'Test query' + passages = [] + + result = await gemini_reranker_client.rank(query, passages) + + assert result == [] + + @pytest.mark.asyncio + async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client): + """Test ranking with a single passage.""" + # Setup mock response + mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75') + + query = 'Test query' + passages = ['Single test passage'] + + result = await gemini_reranker_client.rank(query, passages) + + assert len(result) == 1 + assert result[0][0] == 'Single test passage' + assert result[0][1] == 1.0 # Single passage gets full score + + @pytest.mark.asyncio + async def test_rank_score_extraction_with_regex( + self, gemini_reranker_client, mock_gemini_client + ): + """Test score extraction from various response formats.""" + # Setup mock responses with different formats + mock_responses = [ + create_mock_response('Score: 90'), # Contains text before number + create_mock_response('The relevance is 65 out of 100'), # Contains text around number + create_mock_response('8'), # Just the number + ] + mock_gemini_client.aio.models.generate_content.side_effect = mock_responses + + query = 'Test query' + passages = ['Passage 1', 'Passage 2', 'Passage 3'] + + result = await gemini_reranker_client.rank(query, passages) + + # Check that scores were extracted correctly and normalized + scores = [score for _, score in result] + assert 0.90 in scores # 90/100 + assert 0.65 in scores # 65/100 + assert 0.08 in scores # 8/100 + + @pytest.mark.asyncio + async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client): + """Test handling of invalid or non-numeric scores.""" + # Setup mock responses with invalid scores + mock_responses = [ + create_mock_response('Not a number'), # Invalid response + create_mock_response(''), # Empty response + create_mock_response('95'), # Valid response + ] + mock_gemini_client.aio.models.generate_content.side_effect = mock_responses + + query = 'Test query' + passages = ['Passage 1', 'Passage 2', 'Passage 3'] + + result = await gemini_reranker_client.rank(query, passages) + + # Check that invalid scores are handled gracefully (assigned 0.0) + scores = [score for _, score in result] + assert 0.95 in scores # Valid score + assert scores.count(0.0) == 2 # Two invalid scores assigned 0.0 + + @pytest.mark.asyncio + async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client): + """Test that scores are properly clamped to [0, 1] range.""" + # Setup mock responses with extreme scores + # Note: regex only matches 1-3 digits, so negative numbers won't match + mock_responses = [ + create_mock_response('999'), # Above 100 but within regex range + create_mock_response('invalid'), # Invalid response becomes 0.0 + create_mock_response('50'), # Normal score + ] + mock_gemini_client.aio.models.generate_content.side_effect = mock_responses + + query = 'Test query' + passages = ['Passage 1', 'Passage 2', 'Passage 3'] + + result = await gemini_reranker_client.rank(query, passages) + + # Check that scores are normalized and clamped + scores = [score for _, score in result] + assert all(0.0 <= score <= 1.0 for score in scores) + # 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0) + assert 1.0 in scores + # Invalid response should be 0.0 + assert 0.0 in scores + # Normal score should be normalized (50/100 = 0.5) + assert 0.5 in scores + + @pytest.mark.asyncio + async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client): + """Test handling of rate limit errors.""" + # Setup mock to raise rate limit error + mock_gemini_client.aio.models.generate_content.side_effect = Exception( + 'Rate limit exceeded' + ) + + query = 'Test query' + passages = ['Passage 1', 'Passage 2'] + + with pytest.raises(RateLimitError): + await gemini_reranker_client.rank(query, passages) + + @pytest.mark.asyncio + async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client): + """Test handling of quota errors.""" + # Setup mock to raise quota error + mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded') + + query = 'Test query' + passages = ['Passage 1', 'Passage 2'] + + with pytest.raises(RateLimitError): + await gemini_reranker_client.rank(query, passages) + + @pytest.mark.asyncio + async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client): + """Test handling of resource exhausted errors.""" + # Setup mock to raise resource exhausted error + mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted') + + query = 'Test query' + passages = ['Passage 1', 'Passage 2'] + + with pytest.raises(RateLimitError): + await gemini_reranker_client.rank(query, passages) + + @pytest.mark.asyncio + async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client): + """Test handling of HTTP 429 errors.""" + # Setup mock to raise 429 error + mock_gemini_client.aio.models.generate_content.side_effect = Exception( + 'HTTP 429 Too Many Requests' + ) + + query = 'Test query' + passages = ['Passage 1', 'Passage 2'] + + with pytest.raises(RateLimitError): + await gemini_reranker_client.rank(query, passages) + + @pytest.mark.asyncio + async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client): + """Test handling of generic errors.""" + # Setup mock to raise generic error + mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error') + + query = 'Test query' + passages = ['Passage 1', 'Passage 2'] + + with pytest.raises(Exception) as exc_info: + await gemini_reranker_client.rank(query, passages) + + assert 'Generic error' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client): + """Test that multiple passages are scored concurrently.""" + # Setup mock responses + mock_responses = [ + create_mock_response('80'), + create_mock_response('60'), + create_mock_response('40'), + ] + mock_gemini_client.aio.models.generate_content.side_effect = mock_responses + + query = 'Test query' + passages = ['Passage 1', 'Passage 2', 'Passage 3'] + + await gemini_reranker_client.rank(query, passages) + + # Verify that generate_content was called for each passage + assert mock_gemini_client.aio.models.generate_content.call_count == 3 + + # Verify that all calls were made with correct parameters + calls = mock_gemini_client.aio.models.generate_content.call_args_list + for call in calls: + args, kwargs = call + assert kwargs['model'] == gemini_reranker_client.config.model + assert kwargs['config'].temperature == 0.0 + assert kwargs['config'].max_output_tokens == 3 + + @pytest.mark.asyncio + async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client): + """Test handling of response parsing errors.""" + # Setup mock responses that will trigger ValueError during parsing + mock_responses = [ + create_mock_response('not a number at all'), # Will fail regex match + create_mock_response('also invalid text'), # Will fail regex match + ] + mock_gemini_client.aio.models.generate_content.side_effect = mock_responses + + query = 'Test query' + # Use multiple passages to avoid the single passage special case + passages = ['Passage 1', 'Passage 2'] + + result = await gemini_reranker_client.rank(query, passages) + + # Should handle the error gracefully and assign 0.0 score to both + assert len(result) == 2 + assert all(score == 0.0 for _, score in result) + + @pytest.mark.asyncio + async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client): + """Test handling of empty response text.""" + # Setup mock response with empty text + mock_response = MagicMock() + mock_response.text = '' # Empty string instead of None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + query = 'Test query' + # Use multiple passages to avoid the single passage special case + passages = ['Passage 1', 'Passage 2'] + + result = await gemini_reranker_client.rank(query, passages) + + # Should handle empty text gracefully and assign 0.0 score to both + assert len(result) == 2 + assert all(score == 0.0 for _, score in result) + + +if __name__ == '__main__': + pytest.main(['-v', 'test_gemini_reranker_client.py']) diff --git a/tests/embedder/test_gemini.py b/tests/embedder/test_gemini.py index 649c1a57..a4d3730b 100644 --- a/tests/embedder/test_gemini.py +++ b/tests/embedder/test_gemini.py @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +# Running tests: pytest -xvs tests/embedder/test_gemini.py + from collections.abc import Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -28,10 +30,10 @@ from graphiti_core.embedder.gemini import ( from tests.embedder.embedder_fixtures import create_embedding_values -def create_gemini_embedding(multiplier: float = 0.1) -> MagicMock: - """Create a mock Gemini embedding with specified value multiplier.""" +def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock: + """Create a mock Gemini embedding with specified value multiplier and dimension.""" mock_embedding = MagicMock() - mock_embedding.values = create_embedding_values(multiplier) + mock_embedding.values = create_embedding_values(multiplier, dimension) return mock_embedding @@ -75,52 +77,304 @@ def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder: return client -@pytest.mark.asyncio -async def test_create_calls_api_correctly( - gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_response: MagicMock -) -> None: - """Test that create method correctly calls the API and processes the response.""" - # Setup - mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response +class TestGeminiEmbedderInitialization: + """Tests for GeminiEmbedder initialization.""" - # Call method - result = await gemini_embedder.create('Test input') + @patch('google.genai.Client') + def test_init_with_config(self, mock_client): + """Test initialization with a config object.""" + config = GeminiEmbedderConfig( + api_key='test_api_key', embedding_model='custom-model', embedding_dim=768 + ) + embedder = GeminiEmbedder(config=config) - # Verify API is called with correct parameters - mock_gemini_client.aio.models.embed_content.assert_called_once() - _, kwargs = mock_gemini_client.aio.models.embed_content.call_args - assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL - assert kwargs['contents'] == ['Test input'] + assert embedder.config == config + assert embedder.config.embedding_model == 'custom-model' + assert embedder.config.api_key == 'test_api_key' + assert embedder.config.embedding_dim == 768 - # Verify result is processed correctly - assert result == mock_gemini_response.embeddings[0].values + @patch('google.genai.Client') + def test_init_without_config(self, mock_client): + """Test initialization without a config uses defaults.""" + embedder = GeminiEmbedder() + + assert embedder.config is not None + assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL + + @patch('google.genai.Client') + def test_init_with_partial_config(self, mock_client): + """Test initialization with partial config.""" + config = GeminiEmbedderConfig(api_key='test_api_key') + embedder = GeminiEmbedder(config=config) + + assert embedder.config.api_key == 'test_api_key' + assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL -@pytest.mark.asyncio -async def test_create_batch_processes_multiple_inputs( - gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_batch_response: MagicMock -) -> None: - """Test that create_batch method correctly processes multiple inputs.""" - # Setup - mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response - input_batch = ['Input 1', 'Input 2', 'Input 3'] +class TestGeminiEmbedderCreate: + """Tests for GeminiEmbedder create method.""" - # Call method - result = await gemini_embedder.create_batch(input_batch) + @pytest.mark.asyncio + async def test_create_calls_api_correctly( + self, + gemini_embedder: GeminiEmbedder, + mock_gemini_client: Any, + mock_gemini_response: MagicMock, + ) -> None: + """Test that create method correctly calls the API and processes the response.""" + # Setup + mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response - # Verify API is called with correct parameters - mock_gemini_client.aio.models.embed_content.assert_called_once() - _, kwargs = mock_gemini_client.aio.models.embed_content.call_args - assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL - assert kwargs['contents'] == input_batch + # Call method + result = await gemini_embedder.create('Test input') - # Verify all results are processed correctly - assert len(result) == 3 - assert result == [ - mock_gemini_batch_response.embeddings[0].values, - mock_gemini_batch_response.embeddings[1].values, - mock_gemini_batch_response.embeddings[2].values, - ] + # Verify API is called with correct parameters + mock_gemini_client.aio.models.embed_content.assert_called_once() + _, kwargs = mock_gemini_client.aio.models.embed_content.call_args + assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL + assert kwargs['contents'] == ['Test input'] + + # Verify result is processed correctly + assert result == mock_gemini_response.embeddings[0].values + + @pytest.mark.asyncio + @patch('google.genai.Client') + async def test_create_with_custom_model( + self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock + ) -> None: + """Test create method with custom embedding model.""" + # Setup embedder with custom model + config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model') + embedder = GeminiEmbedder(config=config) + embedder.client = mock_gemini_client + mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response + + # Call method + await embedder.create('Test input') + + # Verify custom model is used + _, kwargs = mock_gemini_client.aio.models.embed_content.call_args + assert kwargs['model'] == 'custom-model' + + @pytest.mark.asyncio + @patch('google.genai.Client') + async def test_create_with_custom_dimension( + self, mock_client_class, mock_gemini_client: Any + ) -> None: + """Test create method with custom embedding dimension.""" + # Setup embedder with custom dimension + config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768) + embedder = GeminiEmbedder(config=config) + embedder.client = mock_gemini_client + + # Setup mock response with custom dimension + mock_response = MagicMock() + mock_response.embeddings = [create_gemini_embedding(0.1, 768)] + mock_gemini_client.aio.models.embed_content.return_value = mock_response + + # Call method + result = await embedder.create('Test input') + + # Verify custom dimension is used in config + _, kwargs = mock_gemini_client.aio.models.embed_content.call_args + assert kwargs['config'].output_dimensionality == 768 + + # Verify result has correct dimension + assert len(result) == 768 + + @pytest.mark.asyncio + async def test_create_with_different_input_types( + self, + gemini_embedder: GeminiEmbedder, + mock_gemini_client: Any, + mock_gemini_response: MagicMock, + ) -> None: + """Test create method with different input types.""" + mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response + + # Test with string + await gemini_embedder.create('Test string') + + # Test with list of strings + await gemini_embedder.create(['Test', 'List']) + + # Test with iterable of integers + await gemini_embedder.create([1, 2, 3]) + + # Verify all calls were made + assert mock_gemini_client.aio.models.embed_content.call_count == 3 + + @pytest.mark.asyncio + async def test_create_no_embeddings_error( + self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any + ) -> None: + """Test create method handling of no embeddings response.""" + # Setup mock response with no embeddings + mock_response = MagicMock() + mock_response.embeddings = [] + mock_gemini_client.aio.models.embed_content.return_value = mock_response + + # Call method and expect exception + with pytest.raises(ValueError) as exc_info: + await gemini_embedder.create('Test input') + + assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_create_no_values_error( + self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any + ) -> None: + """Test create method handling of embeddings with no values.""" + # Setup mock response with embedding but no values + mock_embedding = MagicMock() + mock_embedding.values = None + mock_response = MagicMock() + mock_response.embeddings = [mock_embedding] + mock_gemini_client.aio.models.embed_content.return_value = mock_response + + # Call method and expect exception + with pytest.raises(ValueError) as exc_info: + await gemini_embedder.create('Test input') + + assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value) + + +class TestGeminiEmbedderCreateBatch: + """Tests for GeminiEmbedder create_batch method.""" + + @pytest.mark.asyncio + async def test_create_batch_processes_multiple_inputs( + self, + gemini_embedder: GeminiEmbedder, + mock_gemini_client: Any, + mock_gemini_batch_response: MagicMock, + ) -> None: + """Test that create_batch method correctly processes multiple inputs.""" + # Setup + mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response + input_batch = ['Input 1', 'Input 2', 'Input 3'] + + # Call method + result = await gemini_embedder.create_batch(input_batch) + + # Verify API is called with correct parameters + mock_gemini_client.aio.models.embed_content.assert_called_once() + _, kwargs = mock_gemini_client.aio.models.embed_content.call_args + assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL + assert kwargs['contents'] == input_batch + + # Verify all results are processed correctly + assert len(result) == 3 + assert result == [ + mock_gemini_batch_response.embeddings[0].values, + mock_gemini_batch_response.embeddings[1].values, + mock_gemini_batch_response.embeddings[2].values, + ] + + @pytest.mark.asyncio + async def test_create_batch_single_input( + self, + gemini_embedder: GeminiEmbedder, + mock_gemini_client: Any, + mock_gemini_response: MagicMock, + ) -> None: + """Test create_batch method with single input.""" + mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response + input_batch = ['Single input'] + + result = await gemini_embedder.create_batch(input_batch) + + assert len(result) == 1 + assert result[0] == mock_gemini_response.embeddings[0].values + + @pytest.mark.asyncio + async def test_create_batch_empty_input( + self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any + ) -> None: + """Test create_batch method with empty input.""" + # Setup mock response with no embeddings + mock_response = MagicMock() + mock_response.embeddings = [] + mock_gemini_client.aio.models.embed_content.return_value = mock_response + + input_batch = [] + + with pytest.raises(Exception) as exc_info: + await gemini_embedder.create_batch(input_batch) + + assert 'No embeddings returned' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_create_batch_no_embeddings_error( + self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any + ) -> None: + """Test create_batch method handling of no embeddings response.""" + # Setup mock response with no embeddings + mock_response = MagicMock() + mock_response.embeddings = [] + mock_gemini_client.aio.models.embed_content.return_value = mock_response + + input_batch = ['Input 1', 'Input 2'] + + with pytest.raises(Exception) as exc_info: + await gemini_embedder.create_batch(input_batch) + + assert 'No embeddings returned' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_create_batch_empty_values_error( + self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any + ) -> None: + """Test create_batch method handling of embeddings with empty values.""" + # Setup mock response with embeddings but empty values + mock_embedding1 = MagicMock() + mock_embedding1.values = [0.1, 0.2, 0.3] # Valid values + mock_embedding2 = MagicMock() + mock_embedding2.values = None # Empty values + + mock_response = MagicMock() + mock_response.embeddings = [mock_embedding1, mock_embedding2] + mock_gemini_client.aio.models.embed_content.return_value = mock_response + + input_batch = ['Input 1', 'Input 2'] + + with pytest.raises(ValueError) as exc_info: + await gemini_embedder.create_batch(input_batch) + + assert 'Empty embedding values returned' in str(exc_info.value) + + @pytest.mark.asyncio + @patch('google.genai.Client') + async def test_create_batch_with_custom_model_and_dimension( + self, mock_client_class, mock_gemini_client: Any + ) -> None: + """Test create_batch method with custom model and dimension.""" + # Setup embedder with custom settings + config = GeminiEmbedderConfig( + api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512 + ) + embedder = GeminiEmbedder(config=config) + embedder.client = mock_gemini_client + + # Setup mock response + mock_response = MagicMock() + mock_response.embeddings = [ + create_gemini_embedding(0.1, 512), + create_gemini_embedding(0.2, 512), + ] + mock_gemini_client.aio.models.embed_content.return_value = mock_response + + input_batch = ['Input 1', 'Input 2'] + result = await embedder.create_batch(input_batch) + + # Verify custom settings are used + _, kwargs = mock_gemini_client.aio.models.embed_content.call_args + assert kwargs['model'] == 'custom-batch-model' + assert kwargs['config'].output_dimensionality == 512 + + # Verify results have correct dimension + assert len(result) == 2 + assert all(len(embedding) == 512 for embedding in result) if __name__ == '__main__': diff --git a/tests/llm_client/test_gemini_client.py b/tests/llm_client/test_gemini_client.py new file mode 100644 index 00000000..2179897e --- /dev/null +++ b/tests/llm_client/test_gemini_client.py @@ -0,0 +1,393 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/llm_client/test_gemini_client.py + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel + +from graphiti_core.llm_client.config import LLMConfig, ModelSize +from graphiti_core.llm_client.errors import RateLimitError +from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient +from graphiti_core.prompts.models import Message + + +# Test model for response testing +class ResponseModel(BaseModel): + """Test model for response testing.""" + + test_field: str + optional_field: int = 0 + + +@pytest.fixture +def mock_gemini_client(): + """Fixture to mock the Google Gemini client.""" + with patch('google.genai.Client') as mock_client: + # Setup mock instance and its methods + mock_instance = mock_client.return_value + mock_instance.aio = MagicMock() + mock_instance.aio.models = MagicMock() + mock_instance.aio.models.generate_content = AsyncMock() + yield mock_instance + + +@pytest.fixture +def gemini_client(mock_gemini_client): + """Fixture to create a GeminiClient with a mocked client.""" + config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000) + client = GeminiClient(config=config, cache=False) + # Replace the client's client with our mock to ensure we're using the mock + client.client = mock_gemini_client + return client + + +class TestGeminiClientInitialization: + """Tests for GeminiClient initialization.""" + + @patch('google.genai.Client') + def test_init_with_config(self, mock_client): + """Test initialization with a config object.""" + config = LLMConfig( + api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000 + ) + client = GeminiClient(config=config, cache=False, max_tokens=1000) + + assert client.config == config + assert client.model == 'test-model' + assert client.temperature == 0.5 + assert client.max_tokens == 1000 + + @patch('google.genai.Client') + def test_init_with_default_model(self, mock_client): + """Test initialization with default model when none is provided.""" + config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL) + client = GeminiClient(config=config, cache=False) + + assert client.model == DEFAULT_MODEL + + @patch('google.genai.Client') + def test_init_without_config(self, mock_client): + """Test initialization without a config uses defaults.""" + client = GeminiClient(cache=False) + + assert client.config is not None + # When no config.model is set, it will be None, not DEFAULT_MODEL + assert client.model is None + + @patch('google.genai.Client') + def test_init_with_thinking_config(self, mock_client): + """Test initialization with thinking config.""" + with patch('google.genai.types.ThinkingConfig') as mock_thinking_config: + thinking_config = mock_thinking_config.return_value + client = GeminiClient(thinking_config=thinking_config) + assert client.thinking_config == thinking_config + + +class TestGeminiClientGenerateResponse: + """Tests for GeminiClient generate_response method.""" + + @pytest.mark.asyncio + async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client): + """Test successful response generation with simple text.""" + # Setup mock response + mock_response = MagicMock() + mock_response.text = 'Test response text' + mock_response.candidates = [] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method + messages = [Message(role='user', content='Test message')] + result = await gemini_client.generate_response(messages) + + # Assertions + assert isinstance(result, dict) + assert result['content'] == 'Test response text' + mock_gemini_client.aio.models.generate_content.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_response_with_structured_output( + self, gemini_client, mock_gemini_client + ): + """Test response generation with structured output.""" + # Setup mock response + mock_response = MagicMock() + mock_response.text = '{"test_field": "test_value", "optional_field": 42}' + mock_response.candidates = [] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method + messages = [ + Message(role='system', content='System message'), + Message(role='user', content='User message'), + ] + result = await gemini_client.generate_response( + messages=messages, response_model=ResponseModel + ) + + # Assertions + assert isinstance(result, dict) + assert result['test_field'] == 'test_value' + assert result['optional_field'] == 42 + mock_gemini_client.aio.models.generate_content.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client): + """Test response generation with system message handling.""" + # Setup mock response + mock_response = MagicMock() + mock_response.text = 'Response with system context' + mock_response.candidates = [] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method + messages = [ + Message(role='system', content='System message'), + Message(role='user', content='User message'), + ] + await gemini_client.generate_response(messages) + + # Verify system message is processed correctly + call_args = mock_gemini_client.aio.models.generate_content.call_args + config = call_args[1]['config'] + assert 'System message' in config.system_instruction + + @pytest.mark.asyncio + async def test_get_model_for_size(self, gemini_client): + """Test model selection based on size.""" + # Test small model + small_model = gemini_client._get_model_for_size(ModelSize.small) + assert small_model == DEFAULT_SMALL_MODEL + + # Test medium/large model + medium_model = gemini_client._get_model_for_size(ModelSize.medium) + assert medium_model == gemini_client.model + + @pytest.mark.asyncio + async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client): + """Test handling of rate limit errors.""" + # Setup mock to raise rate limit error + mock_gemini_client.aio.models.generate_content.side_effect = Exception( + 'Rate limit exceeded' + ) + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(RateLimitError): + await gemini_client.generate_response(messages) + + @pytest.mark.asyncio + async def test_quota_error_handling(self, gemini_client, mock_gemini_client): + """Test handling of quota errors.""" + # Setup mock to raise quota error + mock_gemini_client.aio.models.generate_content.side_effect = Exception( + 'Quota exceeded for requests' + ) + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(RateLimitError): + await gemini_client.generate_response(messages) + + @pytest.mark.asyncio + async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client): + """Test handling of resource exhausted errors.""" + # Setup mock to raise resource exhausted error + mock_gemini_client.aio.models.generate_content.side_effect = Exception( + 'resource_exhausted: Request limit exceeded' + ) + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(RateLimitError): + await gemini_client.generate_response(messages) + + @pytest.mark.asyncio + async def test_safety_block_handling(self, gemini_client, mock_gemini_client): + """Test handling of safety blocks.""" + # Setup mock response with safety block + mock_candidate = MagicMock() + mock_candidate.finish_reason = 'SAFETY' + mock_candidate.safety_ratings = [ + MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH') + ] + + mock_response = MagicMock() + mock_response.candidates = [mock_candidate] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(Exception, match='Response blocked by Gemini safety filters'): + await gemini_client.generate_response(messages) + + @pytest.mark.asyncio + async def test_prompt_block_handling(self, gemini_client, mock_gemini_client): + """Test handling of prompt blocks.""" + # Setup mock response with prompt block + mock_prompt_feedback = MagicMock() + mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER' + + mock_response = MagicMock() + mock_response.candidates = [] + mock_response.prompt_feedback = mock_prompt_feedback + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(Exception, match='Prompt blocked by Gemini'): + await gemini_client.generate_response(messages) + + @pytest.mark.asyncio + async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client): + """Test handling of structured output parsing errors.""" + # Setup mock response with invalid JSON + mock_response = MagicMock() + mock_response.text = 'Invalid JSON response' + mock_response.candidates = [] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(Exception, match='Failed to parse structured response'): + await gemini_client.generate_response(messages, response_model=ResponseModel) + + @pytest.mark.asyncio + async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client): + """Test that safety blocks are not retried.""" + # Setup mock to raise safety error + mock_gemini_client.aio.models.generate_content.side_effect = Exception( + 'Content blocked by safety filters' + ) + + # Call method and check that it doesn't retry + messages = [Message(role='user', content='Test message')] + with pytest.raises(Exception, match='Content blocked by safety filters'): + await gemini_client.generate_response(messages) + + # Should only be called once (no retries for safety blocks) + assert mock_gemini_client.aio.models.generate_content.call_count == 1 + + @pytest.mark.asyncio + async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client): + """Test retry behavior on validation error.""" + # First call returns invalid data, second call returns valid data + mock_response1 = MagicMock() + mock_response1.text = '{"wrong_field": "wrong_value"}' + mock_response1.candidates = [] + mock_response1.prompt_feedback = None + + mock_response2 = MagicMock() + mock_response2.text = '{"test_field": "correct_value"}' + mock_response2.candidates = [] + mock_response2.prompt_feedback = None + + mock_gemini_client.aio.models.generate_content.side_effect = [ + mock_response1, + mock_response2, + ] + + # Call method + messages = [Message(role='user', content='Test message')] + result = await gemini_client.generate_response(messages, response_model=ResponseModel) + + # Should have called generate_content twice due to retry + assert mock_gemini_client.aio.models.generate_content.call_count == 2 + assert result['test_field'] == 'correct_value' + + @pytest.mark.asyncio + async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client): + """Test behavior when max retries are exceeded.""" + # Setup mock to always return invalid data + mock_response = MagicMock() + mock_response.text = '{"wrong_field": "wrong_value"}' + mock_response.candidates = [] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(Exception, match='Failed to parse structured response'): + await gemini_client.generate_response(messages, response_model=ResponseModel) + + # Should have called generate_content MAX_RETRIES + 1 times + assert ( + mock_gemini_client.aio.models.generate_content.call_count + == GeminiClient.MAX_RETRIES + 1 + ) + + @pytest.mark.asyncio + async def test_empty_response_handling(self, gemini_client, mock_gemini_client): + """Test handling of empty responses.""" + # Setup mock response with no text + mock_response = MagicMock() + mock_response.text = '' + mock_response.candidates = [] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method with structured output and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(Exception, match='Failed to parse structured response'): + await gemini_client.generate_response(messages, response_model=ResponseModel) + + @pytest.mark.asyncio + async def test_custom_max_tokens(self, gemini_client, mock_gemini_client): + """Test response generation with custom max tokens.""" + # Setup mock response + mock_response = MagicMock() + mock_response.text = 'Test response' + mock_response.candidates = [] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method with custom max tokens + messages = [Message(role='user', content='Test message')] + await gemini_client.generate_response(messages, max_tokens=500) + + # Verify max tokens is passed in config + call_args = mock_gemini_client.aio.models.generate_content.call_args + config = call_args[1]['config'] + assert config.max_output_tokens == 500 + + @pytest.mark.asyncio + async def test_model_size_selection(self, gemini_client, mock_gemini_client): + """Test that the correct model is selected based on model size.""" + # Setup mock response + mock_response = MagicMock() + mock_response.text = 'Test response' + mock_response.candidates = [] + mock_response.prompt_feedback = None + mock_gemini_client.aio.models.generate_content.return_value = mock_response + + # Call method with small model size + messages = [Message(role='user', content='Test message')] + await gemini_client.generate_response(messages, model_size=ModelSize.small) + + # Verify correct model is used + call_args = mock_gemini_client.aio.models.generate_content.call_args + assert call_args[1]['model'] == DEFAULT_SMALL_MODEL + + +if __name__ == '__main__': + pytest.main(['-v', 'test_gemini_client.py'])