add support for Gemini 2.5 model thinking budget
This commit is contained in:
parent
19772aa5a1
commit
756734be01
1 changed files with 27 additions and 3 deletions
|
|
@ -30,6 +30,14 @@ from .errors import RateLimitError
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'gemini-2.0-flash'
|
||||
DEFAULT_THINKING_BUDGET = 0
|
||||
|
||||
# Gemini models that support thinking capabilities
|
||||
GEMINI_THINKING_MODELS = [
|
||||
'gemini-2.5-pro',
|
||||
'gemini-2.5-flash',
|
||||
'gemini-2.5-flash-lite',
|
||||
]
|
||||
|
||||
|
||||
class GeminiClient(LLMClient):
|
||||
|
|
@ -43,7 +51,7 @@ class GeminiClient(LLMClient):
|
|||
model (str): The model name to use for generating responses.
|
||||
temperature (float): The temperature to use for generating responses.
|
||||
max_tokens (int): The maximum number of tokens to generate in a response.
|
||||
|
||||
thinking_budget (int): The maximum number of tokens to spend on thinking for 2.5 version models. 0 disables thinking.
|
||||
Methods:
|
||||
__init__(config: LLMConfig | None = None, cache: bool = False):
|
||||
Initializes the GeminiClient with the provided configuration and cache setting.
|
||||
|
|
@ -57,6 +65,7 @@ class GeminiClient(LLMClient):
|
|||
config: LLMConfig | None = None,
|
||||
cache: bool = False,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
thinking_budget: int = DEFAULT_THINKING_BUDGET,
|
||||
):
|
||||
"""
|
||||
Initialize the GeminiClient with the provided configuration and cache setting.
|
||||
|
|
@ -64,6 +73,8 @@ class GeminiClient(LLMClient):
|
|||
Args:
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
|
||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||
thinking_budget (int): The maximum number of tokens to spend on thinking for 2.5 version models. 0 disables thinking.
|
||||
|
||||
"""
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
|
@ -76,6 +87,7 @@ class GeminiClient(LLMClient):
|
|||
api_key=config.api_key,
|
||||
)
|
||||
self.max_tokens = max_tokens
|
||||
self.thinking_budget = thinking_budget
|
||||
|
||||
async def _generate_response(
|
||||
self,
|
||||
|
|
@ -127,6 +139,17 @@ class GeminiClient(LLMClient):
|
|||
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
|
||||
)
|
||||
|
||||
# Determine the model to be used
|
||||
model_to_use = self.model or DEFAULT_MODEL
|
||||
|
||||
# Conditionally create thinking_config for models that support thinking
|
||||
thinking_config_arg = None
|
||||
if model_to_use in GEMINI_THINKING_MODELS:
|
||||
thinking_config_arg = types.ThinkingConfig(
|
||||
include_thoughts=False,
|
||||
thinking_budget=self.thinking_budget,
|
||||
)
|
||||
|
||||
# Create generation config
|
||||
generation_config = types.GenerateContentConfig(
|
||||
temperature=self.temperature,
|
||||
|
|
@ -134,12 +157,13 @@ class GeminiClient(LLMClient):
|
|||
response_mime_type='application/json' if response_model else None,
|
||||
response_schema=response_model if response_model else None,
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=thinking_config_arg,
|
||||
)
|
||||
|
||||
# Generate content using the simple string approach
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=self.model or DEFAULT_MODEL,
|
||||
contents=gemini_messages, # type: ignore[arg-type] # mypy fails on broad union type
|
||||
model=model_to_use,
|
||||
contents=gemini_messages,
|
||||
config=generation_config,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue