graphiti/graphiti_core/llm_client/provider_defaults.py
claude[bot] 93ab7375cd
feat: eliminate ghost variables with configurable provider defaults system
Replace hardcoded DEFAULT_MODEL and DEFAULT_SMALL_MODEL constants across all LLM clients with a centralized, configurable provider defaults system.

Key changes:
- Created provider_defaults.py with centralized configuration for all providers
- Added environment variable support for easy customization (e.g., GEMINI_DEFAULT_MODEL)
- Updated all LLM clients to use configurable defaults instead of hardcoded constants
- Made edge operations max_tokens configurable via EXTRACT_EDGES_MAX_TOKENS
- Updated cross-encoder reranker clients to use provider defaults
- Maintained full backward compatibility with existing configurations

This resolves the issue where Gemini's flash-lite model has location constraints in Vertex AI that differ from the regular flash model, and users couldn't easily override these without editing source code.

Environment variables now supported:
- {PROVIDER}_DEFAULT_MODEL
- {PROVIDER}_DEFAULT_SMALL_MODEL  
- {PROVIDER}_DEFAULT_MAX_TOKENS
- {PROVIDER}_DEFAULT_TEMPERATURE
- {PROVIDER}_EXTRACT_EDGES_MAX_TOKENS
- EXTRACT_EDGES_MAX_TOKENS (global fallback)

Fixes #681

Co-authored-by: Daniel Chalef <danielchalef@users.noreply.github.com>
2025-07-09 14:59:17 +00:00

162 lines
No EOL
5.5 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from dataclasses import dataclass
from typing import Dict, Optional
@dataclass
class ProviderDefaults:
"""
Configuration for provider-specific model defaults.
This class replaces hardcoded DEFAULT_MODEL constants with configurable
provider-specific defaults that can be overridden via environment variables.
"""
model: str
small_model: str
max_tokens: int = 8192
temperature: float = 0.0
extract_edges_max_tokens: int = 16384
# Provider-specific default configurations
# These can be overridden via environment variables (see get_provider_defaults)
PROVIDER_DEFAULTS: Dict[str, ProviderDefaults] = {
'openai': ProviderDefaults(
model='gpt-4.1-mini',
small_model='gpt-4.1-nano',
extract_edges_max_tokens=16384,
),
'gemini': ProviderDefaults(
model='gemini-2.5-flash',
small_model='models/gemini-2.5-flash-lite-preview-06-17',
extract_edges_max_tokens=16384,
),
'anthropic': ProviderDefaults(
model='claude-3-7-sonnet-latest',
small_model='claude-3-7-haiku-latest',
extract_edges_max_tokens=16384,
),
'groq': ProviderDefaults(
model='llama-3.1-70b-versatile',
small_model='llama-3.1-8b-instant',
extract_edges_max_tokens=16384,
),
'azure_openai': ProviderDefaults(
model='gpt-4.1-mini',
small_model='gpt-4.1-nano',
extract_edges_max_tokens=16384,
),
}
def get_provider_defaults(provider: str) -> ProviderDefaults:
"""
Get provider-specific defaults with optional environment variable overrides.
Environment variables can override defaults using the pattern:
- {PROVIDER}_DEFAULT_MODEL
- {PROVIDER}_DEFAULT_SMALL_MODEL
- {PROVIDER}_DEFAULT_MAX_TOKENS
- {PROVIDER}_DEFAULT_TEMPERATURE
- {PROVIDER}_EXTRACT_EDGES_MAX_TOKENS
Args:
provider: The provider name (e.g., 'openai', 'gemini', 'anthropic', etc.)
Returns:
ProviderDefaults object with defaults for the specified provider
Raises:
ValueError: If the provider is not supported
"""
if provider not in PROVIDER_DEFAULTS:
raise ValueError(f"Unsupported provider: {provider}. Supported providers: {list(PROVIDER_DEFAULTS.keys())}")
defaults = PROVIDER_DEFAULTS[provider]
# Check for environment variable overrides
env_prefix = provider.upper()
model = os.getenv(f'{env_prefix}_DEFAULT_MODEL', defaults.model)
small_model = os.getenv(f'{env_prefix}_DEFAULT_SMALL_MODEL', defaults.small_model)
max_tokens = int(os.getenv(f'{env_prefix}_DEFAULT_MAX_TOKENS', str(defaults.max_tokens)))
temperature = float(os.getenv(f'{env_prefix}_DEFAULT_TEMPERATURE', str(defaults.temperature)))
extract_edges_max_tokens = int(os.getenv(f'{env_prefix}_EXTRACT_EDGES_MAX_TOKENS', str(defaults.extract_edges_max_tokens)))
return ProviderDefaults(
model=model,
small_model=small_model,
max_tokens=max_tokens,
temperature=temperature,
extract_edges_max_tokens=extract_edges_max_tokens
)
def get_model_for_size(provider: str, model_size: str, user_model: Optional[str] = None, user_small_model: Optional[str] = None) -> str:
"""
Get the appropriate model name based on the requested size and provider.
This function replaces the _get_model_for_size methods in individual clients
with a centralized implementation that uses configurable provider defaults.
Args:
provider: The provider name (e.g., 'openai', 'gemini', 'anthropic', etc.)
model_size: The size of the model requested ('small' or 'medium')
user_model: User-configured model override
user_small_model: User-configured small model override
Returns:
The appropriate model name for the requested size
"""
defaults = get_provider_defaults(provider)
if model_size == 'small':
return user_small_model or defaults.small_model
else:
return user_model or defaults.model
def get_extract_edges_max_tokens(provider: str) -> int:
"""
Get the maximum tokens for edge extraction operations.
This function replaces hardcoded extract_edges_max_tokens constants
with configurable provider-specific defaults.
Args:
provider: The provider name (e.g., 'openai', 'gemini', 'anthropic', etc.)
Returns:
The maximum tokens for edge extraction operations
"""
defaults = get_provider_defaults(provider)
return defaults.extract_edges_max_tokens
def get_extract_edges_max_tokens_default() -> int:
"""
Get the default maximum tokens for edge extraction operations.
This function provides a configurable default that can be overridden
via the EXTRACT_EDGES_MAX_TOKENS environment variable.
Returns:
The maximum tokens for edge extraction operations
"""
return int(os.getenv('EXTRACT_EDGES_MAX_TOKENS', '16384'))