feat: enhance GeminiClient with max tokens management (#712)
* feat: enhance GeminiClient with max tokens management - Introduced a mapping for maximum output tokens for various Gemini models. - Added methods to resolve max tokens based on precedence rules, allowing for more flexible token management. - Updated tests to verify max tokens behavior, ensuring explicit parameters take precedence and fallback mechanisms work correctly. This change improves the handling of token limits for different models, enhancing the client’s configurability and usability. * refactor: streamline max tokens retrieval in GeminiClient - Removed the fallback to DEFAULT_MAX_TOKENS in favor of directly using model-specific maximum tokens. - Simplified the logic for determining max tokens, enhancing code clarity and maintainability. This change improves the efficiency of token management within the GeminiClient.
This commit is contained in:
parent
e16740be9d
commit
4481702c9f
2 changed files with 139 additions and 11 deletions
|
|
@ -24,7 +24,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
||||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
from .config import LLMConfig, ModelSize
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -47,6 +47,25 @@ logger = logging.getLogger(__name__)
|
||||||
DEFAULT_MODEL = 'gemini-2.5-flash'
|
DEFAULT_MODEL = 'gemini-2.5-flash'
|
||||||
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
|
||||||
|
|
||||||
|
# Maximum output tokens for different Gemini models
|
||||||
|
GEMINI_MODEL_MAX_TOKENS = {
|
||||||
|
# Gemini 2.5 models
|
||||||
|
'gemini-2.5-pro': 65536,
|
||||||
|
'gemini-2.5-flash': 65536,
|
||||||
|
'gemini-2.5-flash-lite': 64000,
|
||||||
|
'models/gemini-2.5-flash-lite-preview-06-17': 64000,
|
||||||
|
# Gemini 2.0 models
|
||||||
|
'gemini-2.0-flash': 8192,
|
||||||
|
'gemini-2.0-flash-lite': 8192,
|
||||||
|
# Gemini 1.5 models
|
||||||
|
'gemini-1.5-pro': 8192,
|
||||||
|
'gemini-1.5-flash': 8192,
|
||||||
|
'gemini-1.5-flash-8b': 8192,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default max tokens for models not in the mapping
|
||||||
|
DEFAULT_GEMINI_MAX_TOKENS = 8192
|
||||||
|
|
||||||
|
|
||||||
class GeminiClient(LLMClient):
|
class GeminiClient(LLMClient):
|
||||||
"""
|
"""
|
||||||
|
|
@ -75,7 +94,7 @@ class GeminiClient(LLMClient):
|
||||||
self,
|
self,
|
||||||
config: LLMConfig | None = None,
|
config: LLMConfig | None = None,
|
||||||
cache: bool = False,
|
cache: bool = False,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int | None = None,
|
||||||
thinking_config: types.ThinkingConfig | None = None,
|
thinking_config: types.ThinkingConfig | None = None,
|
||||||
client: 'genai.Client | None' = None,
|
client: 'genai.Client | None' = None,
|
||||||
):
|
):
|
||||||
|
|
@ -147,6 +166,38 @@ class GeminiClient(LLMClient):
|
||||||
else:
|
else:
|
||||||
return self.model or DEFAULT_MODEL
|
return self.model or DEFAULT_MODEL
|
||||||
|
|
||||||
|
def _get_max_tokens_for_model(self, model: str) -> int:
|
||||||
|
"""Get the maximum output tokens for a specific Gemini model."""
|
||||||
|
return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
|
||||||
|
|
||||||
|
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
|
||||||
|
"""
|
||||||
|
Resolve the maximum output tokens to use based on precedence rules.
|
||||||
|
|
||||||
|
Precedence order (highest to lowest):
|
||||||
|
1. Explicit max_tokens parameter passed to generate_response()
|
||||||
|
2. Instance max_tokens set during client initialization
|
||||||
|
3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
|
||||||
|
4. DEFAULT_MAX_TOKENS as final fallback
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requested_max_tokens: The max_tokens parameter passed to generate_response()
|
||||||
|
model: The model name to look up model-specific limits
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The resolved maximum tokens to use
|
||||||
|
"""
|
||||||
|
# 1. Use explicit parameter if provided
|
||||||
|
if requested_max_tokens is not None:
|
||||||
|
return requested_max_tokens
|
||||||
|
|
||||||
|
# 2. Use instance max_tokens if set during initialization
|
||||||
|
if self.max_tokens is not None:
|
||||||
|
return self.max_tokens
|
||||||
|
|
||||||
|
# 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
|
||||||
|
return self._get_max_tokens_for_model(model)
|
||||||
|
|
||||||
def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
|
def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
|
||||||
"""
|
"""
|
||||||
Attempt to salvage a JSON object if the raw output is truncated.
|
Attempt to salvage a JSON object if the raw output is truncated.
|
||||||
|
|
@ -184,7 +235,7 @@ class GeminiClient(LLMClient):
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
max_tokens: int | None = None,
|
||||||
model_size: ModelSize = ModelSize.medium,
|
model_size: ModelSize = ModelSize.medium,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -193,7 +244,7 @@ class GeminiClient(LLMClient):
|
||||||
Args:
|
Args:
|
||||||
messages (list[Message]): A list of messages to send to the language model.
|
messages (list[Message]): A list of messages to send to the language model.
|
||||||
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
||||||
max_tokens (int): The maximum number of tokens to generate in the response.
|
max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
|
||||||
model_size (ModelSize): The size of the model to use (small or medium).
|
model_size (ModelSize): The size of the model to use (small or medium).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -233,10 +284,13 @@ class GeminiClient(LLMClient):
|
||||||
# Get the appropriate model for the requested size
|
# Get the appropriate model for the requested size
|
||||||
model = self._get_model_for_size(model_size)
|
model = self._get_model_for_size(model_size)
|
||||||
|
|
||||||
|
# Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
|
||||||
|
resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
|
||||||
|
|
||||||
# Create generation config
|
# Create generation config
|
||||||
generation_config = types.GenerateContentConfig(
|
generation_config = types.GenerateContentConfig(
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_output_tokens=max_tokens or self.max_tokens,
|
max_output_tokens=resolved_max_tokens,
|
||||||
response_mime_type='application/json' if response_model else None,
|
response_mime_type='application/json' if response_model else None,
|
||||||
response_schema=response_model if response_model else None,
|
response_schema=response_model if response_model else None,
|
||||||
system_instruction=system_prompt,
|
system_instruction=system_prompt,
|
||||||
|
|
@ -315,9 +369,6 @@ class GeminiClient(LLMClient):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, typing.Any]: The response from the language model.
|
dict[str, typing.Any]: The response from the language model.
|
||||||
"""
|
"""
|
||||||
if max_tokens is None:
|
|
||||||
max_tokens = self.max_tokens
|
|
||||||
|
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
last_error = None
|
last_error = None
|
||||||
last_output = None
|
last_output = None
|
||||||
|
|
|
||||||
|
|
@ -369,7 +369,7 @@ class TestGeminiClientGenerateResponse:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
|
async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
|
||||||
"""Test response generation with custom max tokens."""
|
"""Test that explicit max_tokens parameter takes precedence over all other values."""
|
||||||
# Setup mock response
|
# Setup mock response
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.text = 'Test response'
|
mock_response.text = 'Test response'
|
||||||
|
|
@ -377,15 +377,54 @@ class TestGeminiClientGenerateResponse:
|
||||||
mock_response.prompt_feedback = None
|
mock_response.prompt_feedback = None
|
||||||
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||||
|
|
||||||
# Call method with custom max tokens
|
# Call method with custom max tokens (should take precedence)
|
||||||
messages = [Message(role='user', content='Test message')]
|
messages = [Message(role='user', content='Test message')]
|
||||||
await gemini_client.generate_response(messages, max_tokens=500)
|
await gemini_client.generate_response(messages, max_tokens=500)
|
||||||
|
|
||||||
# Verify max tokens is passed in config
|
# Verify explicit max_tokens parameter takes precedence
|
||||||
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||||
config = call_args[1]['config']
|
config = call_args[1]['config']
|
||||||
|
# Explicit parameter should override everything else
|
||||||
assert config.max_output_tokens == 500
|
assert config.max_output_tokens == 500
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
|
||||||
|
"""Test max_tokens precedence when no explicit parameter is provided."""
|
||||||
|
# Setup mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = 'Test response'
|
||||||
|
mock_response.candidates = []
|
||||||
|
mock_response.prompt_feedback = None
|
||||||
|
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||||
|
|
||||||
|
# Test case 1: No explicit max_tokens, has instance max_tokens
|
||||||
|
config = LLMConfig(
|
||||||
|
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
|
||||||
|
)
|
||||||
|
client = GeminiClient(
|
||||||
|
config=config, cache=False, max_tokens=2000, client=mock_gemini_client
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [Message(role='user', content='Test message')]
|
||||||
|
await client.generate_response(messages)
|
||||||
|
|
||||||
|
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||||
|
config = call_args[1]['config']
|
||||||
|
# Instance max_tokens should be used
|
||||||
|
assert config.max_output_tokens == 2000
|
||||||
|
|
||||||
|
# Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
|
||||||
|
config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
|
||||||
|
client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
|
||||||
|
|
||||||
|
messages = [Message(role='user', content='Test message')]
|
||||||
|
await client.generate_response(messages)
|
||||||
|
|
||||||
|
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||||
|
config = call_args[1]['config']
|
||||||
|
# Model mapping should be used
|
||||||
|
assert config.max_output_tokens == 65536
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_size_selection(self, gemini_client, mock_gemini_client):
|
async def test_model_size_selection(self, gemini_client, mock_gemini_client):
|
||||||
"""Test that the correct model is selected based on model size."""
|
"""Test that the correct model is selected based on model size."""
|
||||||
|
|
@ -404,6 +443,44 @@ class TestGeminiClientGenerateResponse:
|
||||||
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||||
assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
|
assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
|
||||||
|
"""Test that different Gemini models use their correct max tokens."""
|
||||||
|
# Setup mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = 'Test response'
|
||||||
|
mock_response.candidates = []
|
||||||
|
mock_response.prompt_feedback = None
|
||||||
|
mock_gemini_client.aio.models.generate_content.return_value = mock_response
|
||||||
|
|
||||||
|
# Test data: (model_name, expected_max_tokens)
|
||||||
|
test_cases = [
|
||||||
|
('gemini-2.5-flash', 65536),
|
||||||
|
('gemini-2.5-pro', 65536),
|
||||||
|
('gemini-2.5-flash-lite', 64000),
|
||||||
|
('models/gemini-2.5-flash-lite-preview-06-17', 64000),
|
||||||
|
('gemini-2.0-flash', 8192),
|
||||||
|
('gemini-1.5-pro', 8192),
|
||||||
|
('gemini-1.5-flash', 8192),
|
||||||
|
('unknown-model', 8192), # Fallback case
|
||||||
|
]
|
||||||
|
|
||||||
|
for model_name, expected_max_tokens in test_cases:
|
||||||
|
# Create client with specific model, no explicit max_tokens to test mapping
|
||||||
|
config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
|
||||||
|
client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
|
||||||
|
|
||||||
|
# Call method without explicit max_tokens to test model mapping fallback
|
||||||
|
messages = [Message(role='user', content='Test message')]
|
||||||
|
await client.generate_response(messages)
|
||||||
|
|
||||||
|
# Verify correct max tokens is used from model mapping
|
||||||
|
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||||
|
config = call_args[1]['config']
|
||||||
|
assert config.max_output_tokens == expected_max_tokens, (
|
||||||
|
f'Model {model_name} should use {expected_max_tokens} tokens'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main(['-v', 'test_gemini_client.py'])
|
pytest.main(['-v', 'test_gemini_client.py'])
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue