graphiti/graphiti_core/config/settings.py
supmo668 74a422369c feat: Add enhanced configuration system with multi-provider LLM support
This commit introduces a comprehensive configuration system that makes
Graphiti more flexible and easier to configure across different
providers and deployment environments.

## New Features

- **Unified Configuration**: New GraphitiConfig class with Pydantic validation
- **YAML Support**: Load configuration from .graphiti.yaml files
- **Multi-Provider Support**: Easy switching between OpenAI, Azure, Anthropic,
  Gemini, Groq, and LiteLLM
- **LiteLLM Integration**: Unified access to 100+ LLM providers
- **Factory Functions**: Automatic client creation from configuration
- **Full Backward Compatibility**: Existing code continues to work

## Configuration System

- graphiti_core/config/settings.py: Pydantic configuration classes
- graphiti_core/config/providers.py: Provider enumerations and defaults
- graphiti_core/config/factory.py: Factory functions for client creation

## LiteLLM Client

- graphiti_core/llm_client/litellm_client.py: New unified LLM client
- Support for Azure OpenAI, AWS Bedrock, Vertex AI, Ollama, vLLM, etc.
- Automatic structured output detection

## Documentation

- docs/CONFIGURATION.md: Comprehensive configuration guide
- examples/graphiti_config_example.yaml: Example configurations
- DOMAIN_AGNOSTIC_IMPROVEMENT_PLAN.md: Future improvement roadmap

## Tests

- tests/config/test_settings.py: 22 tests for configuration
- tests/config/test_factory.py: 12 tests for factories
- 33/34 tests passing (97%)

## Issues Addressed

- #1004: Azure OpenAI support
- #1006: Azure OpenAI reranker support
- #1007: vLLM/OpenAI-compatible provider stability
- #1074: Ollama embeddings support
- #995: Docker Azure OpenAI support

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-30 23:47:38 -08:00

459 lines
15 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from pathlib import Path
from typing import Any
import yaml
from pydantic import BaseModel, Field, model_validator
from .providers import (
DEFAULT_EMBEDDINGS,
DEFAULT_MODELS,
DatabaseProvider,
EmbedderProvider,
LLMProvider,
RerankerProvider,
)
class LLMProviderConfig(BaseModel):
"""Configuration for LLM provider.
This configuration supports multiple LLM providers including OpenAI, Azure OpenAI,
Anthropic, Gemini, Groq, and generic providers via LiteLLM.
Examples:
>>> # OpenAI configuration
>>> config = LLMProviderConfig(
... provider=LLMProvider.OPENAI,
... api_key='sk-...',
... )
>>> # Azure OpenAI configuration
>>> config = LLMProviderConfig(
... provider=LLMProvider.AZURE_OPENAI,
... api_key='...',
... base_url='https://your-resource.openai.azure.com',
... azure_deployment='your-deployment-name',
... )
>>> # Anthropic configuration
>>> config = LLMProviderConfig(
... provider=LLMProvider.ANTHROPIC,
... model='claude-sonnet-4-5-latest',
... )
"""
provider: LLMProvider = Field(
default=LLMProvider.OPENAI,
description='The LLM provider to use',
)
model: str | None = Field(
default=None,
description='The model name to use. If not provided, uses provider default.',
)
small_model: str | None = Field(
default=None,
description='Smaller/faster model for simpler tasks. If not provided, uses provider default.',
)
api_key: str | None = Field(
default=None,
description='API key for the provider. Falls back to environment variables if not provided.',
)
base_url: str | None = Field(
default=None,
description='Base URL for the API. Required for Azure OpenAI and custom endpoints.',
)
temperature: float = Field(
default=1.0,
ge=0.0,
le=2.0,
description='Temperature for response generation',
)
max_tokens: int = Field(
default=8192,
gt=0,
description='Maximum tokens for response generation',
)
# Azure-specific fields
azure_deployment: str | None = Field(
default=None,
description='Azure OpenAI deployment name (required for Azure provider)',
)
azure_api_version: str | None = Field(
default='2024-10-21',
description='Azure OpenAI API version',
)
# LiteLLM-specific fields
litellm_model: str | None = Field(
default=None,
description='Full LiteLLM model string (e.g., "azure/gpt-4", "bedrock/claude-3")',
)
# Custom provider fields
custom_client_class: str | None = Field(
default=None,
description='Fully qualified class name for custom LLM client',
)
@model_validator(mode='after')
def set_defaults_and_validate(self) -> 'LLMProviderConfig':
"""Set provider-specific defaults and validate configuration."""
# Set default models if not provided
if self.model is None and self.provider in DEFAULT_MODELS:
self.model = DEFAULT_MODELS[self.provider]['model']
if self.small_model is None and self.provider in DEFAULT_MODELS:
self.small_model = DEFAULT_MODELS[self.provider]['small_model']
# Set API key from environment if not provided
if self.api_key is None:
if self.provider == LLMProvider.OPENAI:
self.api_key = os.getenv('OPENAI_API_KEY')
elif self.provider == LLMProvider.AZURE_OPENAI:
self.api_key = os.getenv('AZURE_OPENAI_API_KEY')
elif self.provider == LLMProvider.ANTHROPIC:
self.api_key = os.getenv('ANTHROPIC_API_KEY')
elif self.provider == LLMProvider.GEMINI:
self.api_key = os.getenv('GOOGLE_API_KEY')
elif self.provider == LLMProvider.GROQ:
self.api_key = os.getenv('GROQ_API_KEY')
# Validate Azure-specific requirements
if self.provider == LLMProvider.AZURE_OPENAI:
if not self.base_url:
raise ValueError('base_url is required for Azure OpenAI provider')
if not self.azure_deployment and not self.model:
raise ValueError(
'Either azure_deployment or model must be provided for Azure OpenAI'
)
# Validate LiteLLM requirements
if self.provider == LLMProvider.LITELLM and not self.litellm_model:
raise ValueError('litellm_model is required for LiteLLM provider')
# Validate custom provider requirements
if self.provider == LLMProvider.CUSTOM and not self.custom_client_class:
raise ValueError('custom_client_class is required for custom provider')
return self
class EmbedderConfig(BaseModel):
"""Configuration for embedding provider.
Examples:
>>> # OpenAI embeddings
>>> config = EmbedderConfig(
... provider=EmbedderProvider.OPENAI,
... )
>>> # Voyage AI embeddings
>>> config = EmbedderConfig(
... provider=EmbedderProvider.VOYAGE,
... model='voyage-3',
... )
"""
provider: EmbedderProvider = Field(
default=EmbedderProvider.OPENAI,
description='The embedder provider to use',
)
model: str | None = Field(
default=None,
description='The embedding model name. If not provided, uses provider default.',
)
api_key: str | None = Field(
default=None,
description='API key for the provider. Falls back to environment variables if not provided.',
)
base_url: str | None = Field(
default=None,
description='Base URL for the API. Required for Azure OpenAI.',
)
dimensions: int | None = Field(
default=None,
description='Embedding dimensions. If not provided, uses provider default.',
)
# Azure-specific fields
azure_deployment: str | None = Field(
default=None,
description='Azure OpenAI deployment name (required for Azure provider)',
)
azure_api_version: str | None = Field(
default='2024-10-21',
description='Azure OpenAI API version',
)
# Custom provider fields
custom_client_class: str | None = Field(
default=None,
description='Fully qualified class name for custom embedder client',
)
@model_validator(mode='after')
def set_defaults_and_validate(self) -> 'EmbedderConfig':
"""Set provider-specific defaults and validate configuration."""
# Set default model and dimensions if not provided
if self.provider in DEFAULT_EMBEDDINGS:
if self.model is None:
self.model = DEFAULT_EMBEDDINGS[self.provider]['model']
if self.dimensions is None:
self.dimensions = DEFAULT_EMBEDDINGS[self.provider]['dimensions']
# Set API key from environment if not provided
if self.api_key is None:
if self.provider == EmbedderProvider.OPENAI:
self.api_key = os.getenv('OPENAI_API_KEY')
elif self.provider == EmbedderProvider.AZURE_OPENAI:
self.api_key = os.getenv('AZURE_OPENAI_API_KEY')
elif self.provider == EmbedderProvider.VOYAGE:
self.api_key = os.getenv('VOYAGE_API_KEY')
elif self.provider == EmbedderProvider.GEMINI:
self.api_key = os.getenv('GOOGLE_API_KEY')
# Validate Azure-specific requirements
if self.provider == EmbedderProvider.AZURE_OPENAI and not self.base_url:
raise ValueError('base_url is required for Azure OpenAI embedder')
# Validate custom provider requirements
if self.provider == EmbedderProvider.CUSTOM and not self.custom_client_class:
raise ValueError('custom_client_class is required for custom embedder')
return self
class RerankerConfig(BaseModel):
"""Configuration for reranker/cross-encoder provider.
Examples:
>>> config = RerankerConfig(
... provider=RerankerProvider.OPENAI,
... )
"""
provider: RerankerProvider = Field(
default=RerankerProvider.OPENAI,
description='The reranker provider to use',
)
api_key: str | None = Field(
default=None,
description='API key for the provider. Falls back to environment variables if not provided.',
)
base_url: str | None = Field(
default=None,
description='Base URL for the API.',
)
# Azure-specific fields
azure_deployment: str | None = Field(
default=None,
description='Azure OpenAI deployment name (required for Azure provider)',
)
# Custom provider fields
custom_client_class: str | None = Field(
default=None,
description='Fully qualified class name for custom reranker client',
)
@model_validator(mode='after')
def set_defaults(self) -> 'RerankerConfig':
"""Set provider-specific defaults."""
# Set API key from environment if not provided
if self.api_key is None:
if self.provider == RerankerProvider.OPENAI:
self.api_key = os.getenv('OPENAI_API_KEY')
elif self.provider == RerankerProvider.AZURE_OPENAI:
self.api_key = os.getenv('AZURE_OPENAI_API_KEY')
return self
class DatabaseConfig(BaseModel):
"""Configuration for graph database.
Examples:
>>> # Neo4j configuration
>>> config = DatabaseConfig(
... provider=DatabaseProvider.NEO4J,
... uri='bolt://localhost:7687',
... user='neo4j',
... password='password',
... )
>>> # FalkorDB configuration
>>> config = DatabaseConfig(
... provider=DatabaseProvider.FALKORDB,
... uri='redis://localhost:6379',
... )
"""
provider: DatabaseProvider = Field(
default=DatabaseProvider.NEO4J,
description='The graph database provider to use',
)
uri: str | None = Field(
default=None,
description='Database connection URI',
)
user: str | None = Field(
default=None,
description='Database username',
)
password: str | None = Field(
default=None,
description='Database password',
)
database: str | None = Field(
default=None,
description='Database name. Uses provider default if not specified.',
)
# Custom provider fields
custom_driver_class: str | None = Field(
default=None,
description='Fully qualified class name for custom database driver',
)
@model_validator(mode='after')
def validate_database_config(self) -> 'DatabaseConfig':
"""Validate database configuration."""
if self.provider == DatabaseProvider.CUSTOM and not self.custom_driver_class:
raise ValueError('custom_driver_class is required for custom database provider')
return self
class GraphitiConfig(BaseModel):
"""Main Graphiti configuration.
This is the primary configuration class that aggregates all provider configurations.
It supports loading from YAML files, environment variables, and programmatic configuration.
Examples:
>>> # Programmatic configuration
>>> config = GraphitiConfig(
... llm=LLMProviderConfig(provider=LLMProvider.ANTHROPIC),
... embedder=EmbedderConfig(provider=EmbedderProvider.VOYAGE),
... database=DatabaseConfig(
... uri='bolt://localhost:7687',
... user='neo4j',
... password='password',
... ),
... )
>>> # Load from YAML file
>>> config = GraphitiConfig.from_yaml('graphiti.yaml')
>>> # Load from environment (looks for GRAPHITI_CONFIG_PATH)
>>> config = GraphitiConfig.from_env()
"""
llm: LLMProviderConfig = Field(
default_factory=LLMProviderConfig,
description='LLM provider configuration',
)
embedder: EmbedderConfig = Field(
default_factory=EmbedderConfig,
description='Embedder provider configuration',
)
reranker: RerankerConfig = Field(
default_factory=RerankerConfig,
description='Reranker provider configuration',
)
database: DatabaseConfig = Field(
default_factory=DatabaseConfig,
description='Database provider configuration',
)
# General settings
store_raw_episode_content: bool = Field(
default=True,
description='Whether to store raw episode content in the database',
)
max_coroutines: int | None = Field(
default=None,
description='Maximum number of concurrent operations',
)
@classmethod
def from_yaml(cls, path: str | Path) -> 'GraphitiConfig':
"""Load configuration from a YAML file.
Args:
path: Path to the YAML configuration file
Returns:
GraphitiConfig instance loaded from the file
Raises:
FileNotFoundError: If the configuration file doesn't exist
ValueError: If the YAML file is invalid
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f'Configuration file not found: {path}')
with open(path) as f:
config_dict = yaml.safe_load(f)
if config_dict is None:
config_dict = {}
return cls(**config_dict)
@classmethod
def from_env(cls, env_var: str = 'GRAPHITI_CONFIG_PATH') -> 'GraphitiConfig':
"""Load configuration from a YAML file specified in an environment variable.
Args:
env_var: Name of the environment variable containing the config file path
Returns:
GraphitiConfig instance loaded from the file, or default config if env var not set
Raises:
FileNotFoundError: If the specified config file doesn't exist
"""
config_path = os.getenv(env_var)
if config_path:
return cls.from_yaml(config_path)
# Look for default config files in current directory
for default_file in ['.graphiti.yaml', '.graphiti.yml', 'graphiti.yaml', 'graphiti.yml']:
if Path(default_file).exists():
return cls.from_yaml(default_file)
# Return default configuration
return cls()
def to_yaml(self, path: str | Path) -> None:
"""Save configuration to a YAML file.
Args:
path: Path where the configuration file should be saved
"""
path = Path(path)
# Use json mode to convert enums to their values
config_dict = self.model_dump(exclude_none=True, mode='json')
with open(path, 'w') as f:
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)