feat: enhance GeminiClient with max tokens management (#712)

* feat: enhance GeminiClient with max tokens management

- Introduced a mapping for maximum output tokens for various Gemini models.
- Added methods to resolve max tokens based on precedence rules, allowing for more flexible token management.
- Updated tests to verify max tokens behavior, ensuring explicit parameters take precedence and fallback mechanisms work correctly.

This change improves the handling of token limits for different models, enhancing the client’s configurability and usability.

* refactor: streamline max tokens retrieval in GeminiClient

- Removed the fallback to DEFAULT_MAX_TOKENS in favor of directly using model-specific maximum tokens.
- Simplified the logic for determining max tokens, enhancing code clarity and maintainability.

This change improves the efficiency of token management within the GeminiClient.
This commit is contained in:
Daniel Chalef 2025-07-13 14:37:55 -07:00 committed by GitHub
parent e16740be9d
commit 4481702c9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 139 additions and 11 deletions

View file

@ -24,7 +24,7 @@ from pydantic import BaseModel
from ..prompts.models import Message
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .config import LLMConfig, ModelSize
from .errors import RateLimitError
if TYPE_CHECKING:
@ -47,6 +47,25 @@ logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.5-flash'
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
# Maximum output tokens for different Gemini models
GEMINI_MODEL_MAX_TOKENS = {
# Gemini 2.5 models
'gemini-2.5-pro': 65536,
'gemini-2.5-flash': 65536,
'gemini-2.5-flash-lite': 64000,
'models/gemini-2.5-flash-lite-preview-06-17': 64000,
# Gemini 2.0 models
'gemini-2.0-flash': 8192,
'gemini-2.0-flash-lite': 8192,
# Gemini 1.5 models
'gemini-1.5-pro': 8192,
'gemini-1.5-flash': 8192,
'gemini-1.5-flash-8b': 8192,
}
# Default max tokens for models not in the mapping
DEFAULT_GEMINI_MAX_TOKENS = 8192
class GeminiClient(LLMClient):
"""
@ -75,7 +94,7 @@ class GeminiClient(LLMClient):
self,
config: LLMConfig | None = None,
cache: bool = False,
max_tokens: int = DEFAULT_MAX_TOKENS,
max_tokens: int | None = None,
thinking_config: types.ThinkingConfig | None = None,
client: 'genai.Client | None' = None,
):
@ -147,6 +166,38 @@ class GeminiClient(LLMClient):
else:
return self.model or DEFAULT_MODEL
def _get_max_tokens_for_model(self, model: str) -> int:
"""Get the maximum output tokens for a specific Gemini model."""
return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
"""
Resolve the maximum output tokens to use based on precedence rules.
Precedence order (highest to lowest):
1. Explicit max_tokens parameter passed to generate_response()
2. Instance max_tokens set during client initialization
3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
4. DEFAULT_MAX_TOKENS as final fallback
Args:
requested_max_tokens: The max_tokens parameter passed to generate_response()
model: The model name to look up model-specific limits
Returns:
int: The resolved maximum tokens to use
"""
# 1. Use explicit parameter if provided
if requested_max_tokens is not None:
return requested_max_tokens
# 2. Use instance max_tokens if set during initialization
if self.max_tokens is not None:
return self.max_tokens
# 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
return self._get_max_tokens_for_model(model)
def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
"""
Attempt to salvage a JSON object if the raw output is truncated.
@ -184,7 +235,7 @@ class GeminiClient(LLMClient):
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
"""
@ -193,7 +244,7 @@ class GeminiClient(LLMClient):
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int): The maximum number of tokens to generate in the response.
max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
model_size (ModelSize): The size of the model to use (small or medium).
Returns:
@ -233,10 +284,13 @@ class GeminiClient(LLMClient):
# Get the appropriate model for the requested size
model = self._get_model_for_size(model_size)
# Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
# Create generation config
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
max_output_tokens=max_tokens or self.max_tokens,
max_output_tokens=resolved_max_tokens,
response_mime_type='application/json' if response_model else None,
response_schema=response_model if response_model else None,
system_instruction=system_prompt,
@ -315,9 +369,6 @@ class GeminiClient(LLMClient):
Returns:
dict[str, typing.Any]: The response from the language model.
"""
if max_tokens is None:
max_tokens = self.max_tokens
retry_count = 0
last_error = None
last_output = None

View file

@ -369,7 +369,7 @@ class TestGeminiClientGenerateResponse:
@pytest.mark.asyncio
async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
"""Test response generation with custom max tokens."""
"""Test that explicit max_tokens parameter takes precedence over all other values."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
@ -377,15 +377,54 @@ class TestGeminiClientGenerateResponse:
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with custom max tokens
# Call method with custom max tokens (should take precedence)
messages = [Message(role='user', content='Test message')]
await gemini_client.generate_response(messages, max_tokens=500)
# Verify max tokens is passed in config
# Verify explicit max_tokens parameter takes precedence
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Explicit parameter should override everything else
assert config.max_output_tokens == 500
@pytest.mark.asyncio
async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
"""Test max_tokens precedence when no explicit parameter is provided."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Test case 1: No explicit max_tokens, has instance max_tokens
config = LLMConfig(
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
)
client = GeminiClient(
config=config, cache=False, max_tokens=2000, client=mock_gemini_client
)
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Instance max_tokens should be used
assert config.max_output_tokens == 2000
# Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Model mapping should be used
assert config.max_output_tokens == 65536
@pytest.mark.asyncio
async def test_model_size_selection(self, gemini_client, mock_gemini_client):
"""Test that the correct model is selected based on model size."""
@ -404,6 +443,44 @@ class TestGeminiClientGenerateResponse:
call_args = mock_gemini_client.aio.models.generate_content.call_args
assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
@pytest.mark.asyncio
async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
"""Test that different Gemini models use their correct max tokens."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Test data: (model_name, expected_max_tokens)
test_cases = [
('gemini-2.5-flash', 65536),
('gemini-2.5-pro', 65536),
('gemini-2.5-flash-lite', 64000),
('models/gemini-2.5-flash-lite-preview-06-17', 64000),
('gemini-2.0-flash', 8192),
('gemini-1.5-pro', 8192),
('gemini-1.5-flash', 8192),
('unknown-model', 8192), # Fallback case
]
for model_name, expected_max_tokens in test_cases:
# Create client with specific model, no explicit max_tokens to test mapping
config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
# Call method without explicit max_tokens to test model mapping fallback
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
# Verify correct max tokens is used from model mapping
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
assert config.max_output_tokens == expected_max_tokens, (
f'Model {model_name} should use {expected_max_tokens} tokens'
)
if __name__ == '__main__':
pytest.main(['-v', 'test_gemini_client.py'])