Replace hardcoded DEFAULT_MODEL and DEFAULT_SMALL_MODEL constants across all LLM clients with a centralized, configurable provider defaults system.
Key changes:
- Created provider_defaults.py with centralized configuration for all providers
- Added environment variable support for easy customization (e.g., GEMINI_DEFAULT_MODEL)
- Updated all LLM clients to use configurable defaults instead of hardcoded constants
- Made edge operations max_tokens configurable via EXTRACT_EDGES_MAX_TOKENS
- Updated cross-encoder reranker clients to use provider defaults
- Maintained full backward compatibility with existing configurations
This resolves the issue where Gemini's flash-lite model has location constraints in Vertex AI that differ from the regular flash model, and users couldn't easily override these without editing source code.
Environment variables now supported:
- {PROVIDER}_DEFAULT_MODEL
- {PROVIDER}_DEFAULT_SMALL_MODEL
- {PROVIDER}_DEFAULT_MAX_TOKENS
- {PROVIDER}_DEFAULT_TEMPERATURE
- {PROVIDER}_EXTRACT_EDGES_MAX_TOKENS
- EXTRACT_EDGES_MAX_TOKENS (global fallback)
Fixes #681
Co-authored-by: Daniel Chalef <danielchalef@users.noreply.github.com>
162 lines
No EOL
5.5 KiB
Python
162 lines
No EOL
5.5 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional
|
|
|
|
|
|
@dataclass
|
|
class ProviderDefaults:
|
|
"""
|
|
Configuration for provider-specific model defaults.
|
|
|
|
This class replaces hardcoded DEFAULT_MODEL constants with configurable
|
|
provider-specific defaults that can be overridden via environment variables.
|
|
"""
|
|
model: str
|
|
small_model: str
|
|
max_tokens: int = 8192
|
|
temperature: float = 0.0
|
|
extract_edges_max_tokens: int = 16384
|
|
|
|
|
|
# Provider-specific default configurations
|
|
# These can be overridden via environment variables (see get_provider_defaults)
|
|
PROVIDER_DEFAULTS: Dict[str, ProviderDefaults] = {
|
|
'openai': ProviderDefaults(
|
|
model='gpt-4.1-mini',
|
|
small_model='gpt-4.1-nano',
|
|
extract_edges_max_tokens=16384,
|
|
),
|
|
'gemini': ProviderDefaults(
|
|
model='gemini-2.5-flash',
|
|
small_model='models/gemini-2.5-flash-lite-preview-06-17',
|
|
extract_edges_max_tokens=16384,
|
|
),
|
|
'anthropic': ProviderDefaults(
|
|
model='claude-3-7-sonnet-latest',
|
|
small_model='claude-3-7-haiku-latest',
|
|
extract_edges_max_tokens=16384,
|
|
),
|
|
'groq': ProviderDefaults(
|
|
model='llama-3.1-70b-versatile',
|
|
small_model='llama-3.1-8b-instant',
|
|
extract_edges_max_tokens=16384,
|
|
),
|
|
'azure_openai': ProviderDefaults(
|
|
model='gpt-4.1-mini',
|
|
small_model='gpt-4.1-nano',
|
|
extract_edges_max_tokens=16384,
|
|
),
|
|
}
|
|
|
|
|
|
def get_provider_defaults(provider: str) -> ProviderDefaults:
|
|
"""
|
|
Get provider-specific defaults with optional environment variable overrides.
|
|
|
|
Environment variables can override defaults using the pattern:
|
|
- {PROVIDER}_DEFAULT_MODEL
|
|
- {PROVIDER}_DEFAULT_SMALL_MODEL
|
|
- {PROVIDER}_DEFAULT_MAX_TOKENS
|
|
- {PROVIDER}_DEFAULT_TEMPERATURE
|
|
- {PROVIDER}_EXTRACT_EDGES_MAX_TOKENS
|
|
|
|
Args:
|
|
provider: The provider name (e.g., 'openai', 'gemini', 'anthropic', etc.)
|
|
|
|
Returns:
|
|
ProviderDefaults object with defaults for the specified provider
|
|
|
|
Raises:
|
|
ValueError: If the provider is not supported
|
|
"""
|
|
if provider not in PROVIDER_DEFAULTS:
|
|
raise ValueError(f"Unsupported provider: {provider}. Supported providers: {list(PROVIDER_DEFAULTS.keys())}")
|
|
|
|
defaults = PROVIDER_DEFAULTS[provider]
|
|
|
|
# Check for environment variable overrides
|
|
env_prefix = provider.upper()
|
|
|
|
model = os.getenv(f'{env_prefix}_DEFAULT_MODEL', defaults.model)
|
|
small_model = os.getenv(f'{env_prefix}_DEFAULT_SMALL_MODEL', defaults.small_model)
|
|
max_tokens = int(os.getenv(f'{env_prefix}_DEFAULT_MAX_TOKENS', str(defaults.max_tokens)))
|
|
temperature = float(os.getenv(f'{env_prefix}_DEFAULT_TEMPERATURE', str(defaults.temperature)))
|
|
extract_edges_max_tokens = int(os.getenv(f'{env_prefix}_EXTRACT_EDGES_MAX_TOKENS', str(defaults.extract_edges_max_tokens)))
|
|
|
|
return ProviderDefaults(
|
|
model=model,
|
|
small_model=small_model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
extract_edges_max_tokens=extract_edges_max_tokens
|
|
)
|
|
|
|
|
|
def get_model_for_size(provider: str, model_size: str, user_model: Optional[str] = None, user_small_model: Optional[str] = None) -> str:
|
|
"""
|
|
Get the appropriate model name based on the requested size and provider.
|
|
|
|
This function replaces the _get_model_for_size methods in individual clients
|
|
with a centralized implementation that uses configurable provider defaults.
|
|
|
|
Args:
|
|
provider: The provider name (e.g., 'openai', 'gemini', 'anthropic', etc.)
|
|
model_size: The size of the model requested ('small' or 'medium')
|
|
user_model: User-configured model override
|
|
user_small_model: User-configured small model override
|
|
|
|
Returns:
|
|
The appropriate model name for the requested size
|
|
"""
|
|
defaults = get_provider_defaults(provider)
|
|
|
|
if model_size == 'small':
|
|
return user_small_model or defaults.small_model
|
|
else:
|
|
return user_model or defaults.model
|
|
|
|
|
|
def get_extract_edges_max_tokens(provider: str) -> int:
|
|
"""
|
|
Get the maximum tokens for edge extraction operations.
|
|
|
|
This function replaces hardcoded extract_edges_max_tokens constants
|
|
with configurable provider-specific defaults.
|
|
|
|
Args:
|
|
provider: The provider name (e.g., 'openai', 'gemini', 'anthropic', etc.)
|
|
|
|
Returns:
|
|
The maximum tokens for edge extraction operations
|
|
"""
|
|
defaults = get_provider_defaults(provider)
|
|
return defaults.extract_edges_max_tokens
|
|
|
|
|
|
def get_extract_edges_max_tokens_default() -> int:
|
|
"""
|
|
Get the default maximum tokens for edge extraction operations.
|
|
|
|
This function provides a configurable default that can be overridden
|
|
via the EXTRACT_EDGES_MAX_TOKENS environment variable.
|
|
|
|
Returns:
|
|
The maximum tokens for edge extraction operations
|
|
"""
|
|
return int(os.getenv('EXTRACT_EDGES_MAX_TOKENS', '16384')) |