* feat(gemini): embedding batch size & lite default The new `gemini-embedding-001` model only allows one embedding input per batch (instance), but has other impressive statistics: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api The -DEFAULT_SMALL_MODEL must not have the 'models/' prefix. * Refactor: Improve Gemini Client Error Handling and Reliability This commit introduces several improvements to the Gemini client to enhance its robustness and reliability. - Implemented more specific error handling for various Gemini API responses, including rate limits and safety blocks. - Added a JSON salvaging mechanism to gracefully handle incomplete or malformed JSON responses from the API. - Introduced detailed logging for failed LLM generations to simplify debugging and troubleshooting. - Refined the Gemini embedder to better handle empty or invalid embedding responses. - Updated and corrected tests to align with the improved error handling and reliability features. * fix: cleanup in _log_failed_generation() * fix: cleanup in _log_failed_generation() * Fix ruff B904 error in gemini_client.py * fix(gemini): correct retry logic and enhance error logging Updated the retry mechanism in the GeminiClient to ensure it retries the maximum number of times specified. Improved error logging to provide clearer insights when all retries are exhausted, including detailed information about the last error encountered. * fix(gemini): enhance error handling for safety blocks and update tests Refined error handling in the GeminiClient to improve detection of safety block conditions. Updated test cases to reflect changes in exception messages and ensure proper retry logic is enforced. Enhanced mock responses in tests to better simulate real-world scenarios, including handling of invalid JSON responses. * revert default gemini to text-embedding-001 --------- Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
395 lines
15 KiB
Python
395 lines
15 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
# Running tests: pytest -xvs tests/embedder/test_gemini.py
|
|
|
|
from collections.abc import Generator
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from embedder_fixtures import create_embedding_values
|
|
|
|
from graphiti_core.embedder.gemini import (
|
|
DEFAULT_EMBEDDING_MODEL,
|
|
GeminiEmbedder,
|
|
GeminiEmbedderConfig,
|
|
)
|
|
|
|
|
|
def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
|
|
"""Create a mock Gemini embedding with specified value multiplier and dimension."""
|
|
mock_embedding = MagicMock()
|
|
mock_embedding.values = create_embedding_values(multiplier, dimension)
|
|
return mock_embedding
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_gemini_response() -> MagicMock:
|
|
"""Create a mock Gemini embeddings response."""
|
|
mock_result = MagicMock()
|
|
mock_result.embeddings = [create_gemini_embedding()]
|
|
return mock_result
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_gemini_batch_response() -> MagicMock:
|
|
"""Create a mock Gemini batch embeddings response."""
|
|
mock_result = MagicMock()
|
|
mock_result.embeddings = [
|
|
create_gemini_embedding(0.1),
|
|
create_gemini_embedding(0.2),
|
|
create_gemini_embedding(0.3),
|
|
]
|
|
return mock_result
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_gemini_client() -> Generator[Any, Any, None]:
|
|
"""Create a mocked Gemini client."""
|
|
with patch('google.genai.Client') as mock_client:
|
|
mock_instance = mock_client.return_value
|
|
mock_instance.aio = MagicMock()
|
|
mock_instance.aio.models = MagicMock()
|
|
mock_instance.aio.models.embed_content = AsyncMock()
|
|
yield mock_instance
|
|
|
|
|
|
@pytest.fixture
|
|
def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
|
|
"""Create a GeminiEmbedder with a mocked client."""
|
|
config = GeminiEmbedderConfig(api_key='test_api_key')
|
|
client = GeminiEmbedder(config=config)
|
|
client.client = mock_gemini_client
|
|
return client
|
|
|
|
|
|
class TestGeminiEmbedderInitialization:
|
|
"""Tests for GeminiEmbedder initialization."""
|
|
|
|
@patch('google.genai.Client')
|
|
def test_init_with_config(self, mock_client):
|
|
"""Test initialization with a config object."""
|
|
config = GeminiEmbedderConfig(
|
|
api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
|
|
)
|
|
embedder = GeminiEmbedder(config=config)
|
|
|
|
assert embedder.config == config
|
|
assert embedder.config.embedding_model == 'custom-model'
|
|
assert embedder.config.api_key == 'test_api_key'
|
|
assert embedder.config.embedding_dim == 768
|
|
|
|
@patch('google.genai.Client')
|
|
def test_init_without_config(self, mock_client):
|
|
"""Test initialization without a config uses defaults."""
|
|
embedder = GeminiEmbedder()
|
|
|
|
assert embedder.config is not None
|
|
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
|
|
|
|
@patch('google.genai.Client')
|
|
def test_init_with_partial_config(self, mock_client):
|
|
"""Test initialization with partial config."""
|
|
config = GeminiEmbedderConfig(api_key='test_api_key')
|
|
embedder = GeminiEmbedder(config=config)
|
|
|
|
assert embedder.config.api_key == 'test_api_key'
|
|
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
|
|
|
|
|
|
class TestGeminiEmbedderCreate:
|
|
"""Tests for GeminiEmbedder create method."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_calls_api_correctly(
|
|
self,
|
|
gemini_embedder: GeminiEmbedder,
|
|
mock_gemini_client: Any,
|
|
mock_gemini_response: MagicMock,
|
|
) -> None:
|
|
"""Test that create method correctly calls the API and processes the response."""
|
|
# Setup
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
|
|
|
# Call method
|
|
result = await gemini_embedder.create('Test input')
|
|
|
|
# Verify API is called with correct parameters
|
|
mock_gemini_client.aio.models.embed_content.assert_called_once()
|
|
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
|
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
|
assert kwargs['contents'] == ['Test input']
|
|
|
|
# Verify result is processed correctly
|
|
assert result == mock_gemini_response.embeddings[0].values
|
|
|
|
@pytest.mark.asyncio
|
|
@patch('google.genai.Client')
|
|
async def test_create_with_custom_model(
|
|
self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
|
|
) -> None:
|
|
"""Test create method with custom embedding model."""
|
|
# Setup embedder with custom model
|
|
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
|
|
embedder = GeminiEmbedder(config=config)
|
|
embedder.client = mock_gemini_client
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
|
|
|
# Call method
|
|
await embedder.create('Test input')
|
|
|
|
# Verify custom model is used
|
|
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
|
assert kwargs['model'] == 'custom-model'
|
|
|
|
@pytest.mark.asyncio
|
|
@patch('google.genai.Client')
|
|
async def test_create_with_custom_dimension(
|
|
self, mock_client_class, mock_gemini_client: Any
|
|
) -> None:
|
|
"""Test create method with custom embedding dimension."""
|
|
# Setup embedder with custom dimension
|
|
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
|
|
embedder = GeminiEmbedder(config=config)
|
|
embedder.client = mock_gemini_client
|
|
|
|
# Setup mock response with custom dimension
|
|
mock_response = MagicMock()
|
|
mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
|
|
|
# Call method
|
|
result = await embedder.create('Test input')
|
|
|
|
# Verify custom dimension is used in config
|
|
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
|
assert kwargs['config'].output_dimensionality == 768
|
|
|
|
# Verify result has correct dimension
|
|
assert len(result) == 768
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_with_different_input_types(
|
|
self,
|
|
gemini_embedder: GeminiEmbedder,
|
|
mock_gemini_client: Any,
|
|
mock_gemini_response: MagicMock,
|
|
) -> None:
|
|
"""Test create method with different input types."""
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
|
|
|
# Test with string
|
|
await gemini_embedder.create('Test string')
|
|
|
|
# Test with list of strings
|
|
await gemini_embedder.create(['Test', 'List'])
|
|
|
|
# Test with iterable of integers
|
|
await gemini_embedder.create([1, 2, 3])
|
|
|
|
# Verify all calls were made
|
|
assert mock_gemini_client.aio.models.embed_content.call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_no_embeddings_error(
|
|
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
|
) -> None:
|
|
"""Test create method handling of no embeddings response."""
|
|
# Setup mock response with no embeddings
|
|
mock_response = MagicMock()
|
|
mock_response.embeddings = []
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
|
|
|
# Call method and expect exception
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await gemini_embedder.create('Test input')
|
|
|
|
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_no_values_error(
|
|
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
|
) -> None:
|
|
"""Test create method handling of embeddings with no values."""
|
|
# Setup mock response with embedding but no values
|
|
mock_embedding = MagicMock()
|
|
mock_embedding.values = None
|
|
mock_response = MagicMock()
|
|
mock_response.embeddings = [mock_embedding]
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
|
|
|
# Call method and expect exception
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await gemini_embedder.create('Test input')
|
|
|
|
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
|
|
|
|
|
|
class TestGeminiEmbedderCreateBatch:
|
|
"""Tests for GeminiEmbedder create_batch method."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_batch_processes_multiple_inputs(
|
|
self,
|
|
gemini_embedder: GeminiEmbedder,
|
|
mock_gemini_client: Any,
|
|
mock_gemini_batch_response: MagicMock,
|
|
) -> None:
|
|
"""Test that create_batch method correctly processes multiple inputs."""
|
|
# Setup
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
|
|
input_batch = ['Input 1', 'Input 2', 'Input 3']
|
|
|
|
# Call method
|
|
result = await gemini_embedder.create_batch(input_batch)
|
|
|
|
# Verify API is called with correct parameters
|
|
mock_gemini_client.aio.models.embed_content.assert_called_once()
|
|
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
|
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
|
assert kwargs['contents'] == input_batch
|
|
|
|
# Verify all results are processed correctly
|
|
assert len(result) == 3
|
|
assert result == [
|
|
mock_gemini_batch_response.embeddings[0].values,
|
|
mock_gemini_batch_response.embeddings[1].values,
|
|
mock_gemini_batch_response.embeddings[2].values,
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_batch_single_input(
|
|
self,
|
|
gemini_embedder: GeminiEmbedder,
|
|
mock_gemini_client: Any,
|
|
mock_gemini_response: MagicMock,
|
|
) -> None:
|
|
"""Test create_batch method with single input."""
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
|
input_batch = ['Single input']
|
|
|
|
result = await gemini_embedder.create_batch(input_batch)
|
|
|
|
assert len(result) == 1
|
|
assert result[0] == mock_gemini_response.embeddings[0].values
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_batch_empty_input(
|
|
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
|
) -> None:
|
|
"""Test create_batch method with empty input."""
|
|
# Setup mock response with no embeddings
|
|
mock_response = MagicMock()
|
|
mock_response.embeddings = []
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
|
|
|
input_batch = []
|
|
|
|
result = await gemini_embedder.create_batch(input_batch)
|
|
assert result == []
|
|
mock_gemini_client.aio.models.embed_content.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_batch_no_embeddings_error(
|
|
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
|
) -> None:
|
|
"""Test create_batch method handling of no embeddings response."""
|
|
# Setup mock response with no embeddings
|
|
mock_response = MagicMock()
|
|
mock_response.embeddings = []
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
|
|
|
input_batch = ['Input 1', 'Input 2']
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await gemini_embedder.create_batch(input_batch)
|
|
|
|
assert 'No embeddings returned from Gemini API' in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_batch_empty_values_error(
|
|
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
|
|
) -> None:
|
|
"""Test create_batch method handling of embeddings with empty values."""
|
|
# Setup mock response with embeddings but empty values
|
|
mock_embedding1 = MagicMock()
|
|
mock_embedding1.values = [0.1, 0.2, 0.3] # Valid values
|
|
mock_embedding2 = MagicMock()
|
|
mock_embedding2.values = None # Empty values
|
|
|
|
# Mock response for the initial batch call
|
|
mock_batch_response = MagicMock()
|
|
mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]
|
|
|
|
# Mock response for individual processing of 'Input 1'
|
|
mock_individual_response_1 = MagicMock()
|
|
mock_individual_response_1.embeddings = [mock_embedding1]
|
|
|
|
# Mock response for individual processing of 'Input 2' (which has empty values)
|
|
mock_individual_response_2 = MagicMock()
|
|
mock_individual_response_2.embeddings = [mock_embedding2]
|
|
|
|
# Set side_effect for embed_content to control return values for each call
|
|
mock_gemini_client.aio.models.embed_content.side_effect = [
|
|
mock_batch_response, # First call for the batch
|
|
mock_individual_response_1, # Second call for individual item 1
|
|
mock_individual_response_2 # Third call for individual item 2
|
|
]
|
|
|
|
input_batch = ['Input 1', 'Input 2']
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
await gemini_embedder.create_batch(input_batch)
|
|
|
|
assert 'Empty embedding values returned' in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch('google.genai.Client')
|
|
async def test_create_batch_with_custom_model_and_dimension(
|
|
self, mock_client_class, mock_gemini_client: Any
|
|
) -> None:
|
|
"""Test create_batch method with custom model and dimension."""
|
|
# Setup embedder with custom settings
|
|
config = GeminiEmbedderConfig(
|
|
api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
|
|
)
|
|
embedder = GeminiEmbedder(config=config)
|
|
embedder.client = mock_gemini_client
|
|
|
|
# Setup mock response
|
|
mock_response = MagicMock()
|
|
mock_response.embeddings = [
|
|
create_gemini_embedding(0.1, 512),
|
|
create_gemini_embedding(0.2, 512),
|
|
]
|
|
mock_gemini_client.aio.models.embed_content.return_value = mock_response
|
|
|
|
input_batch = ['Input 1', 'Input 2']
|
|
result = await embedder.create_batch(input_batch)
|
|
|
|
# Verify custom settings are used
|
|
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
|
assert kwargs['model'] == 'custom-batch-model'
|
|
assert kwargs['config'].output_dimensionality == 512
|
|
|
|
# Verify results have correct dimension
|
|
assert len(result) == 2
|
|
assert all(len(embedding) == 512 for embedding in result)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main(['-xvs', __file__])
|