This commit introduces a comprehensive configuration system that makes Graphiti more flexible and easier to configure across different providers and deployment environments. ## New Features - **Unified Configuration**: New GraphitiConfig class with Pydantic validation - **YAML Support**: Load configuration from .graphiti.yaml files - **Multi-Provider Support**: Easy switching between OpenAI, Azure, Anthropic, Gemini, Groq, and LiteLLM - **LiteLLM Integration**: Unified access to 100+ LLM providers - **Factory Functions**: Automatic client creation from configuration - **Full Backward Compatibility**: Existing code continues to work ## Configuration System - graphiti_core/config/settings.py: Pydantic configuration classes - graphiti_core/config/providers.py: Provider enumerations and defaults - graphiti_core/config/factory.py: Factory functions for client creation ## LiteLLM Client - graphiti_core/llm_client/litellm_client.py: New unified LLM client - Support for Azure OpenAI, AWS Bedrock, Vertex AI, Ollama, vLLM, etc. - Automatic structured output detection ## Documentation - docs/CONFIGURATION.md: Comprehensive configuration guide - examples/graphiti_config_example.yaml: Example configurations - DOMAIN_AGNOSTIC_IMPROVEMENT_PLAN.md: Future improvement roadmap ## Tests - tests/config/test_settings.py: 22 tests for configuration - tests/config/test_factory.py: 12 tests for factories - 33/34 tests passing (97%) ## Issues Addressed - #1004: Azure OpenAI support - #1006: Azure OpenAI reranker support - #1007: vLLM/OpenAI-compatible provider stability - #1074: Ollama embeddings support - #995: Docker Azure OpenAI support 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
459 lines
15 KiB
Python
459 lines
15 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
from .providers import (
|
|
DEFAULT_EMBEDDINGS,
|
|
DEFAULT_MODELS,
|
|
DatabaseProvider,
|
|
EmbedderProvider,
|
|
LLMProvider,
|
|
RerankerProvider,
|
|
)
|
|
|
|
|
|
class LLMProviderConfig(BaseModel):
|
|
"""Configuration for LLM provider.
|
|
|
|
This configuration supports multiple LLM providers including OpenAI, Azure OpenAI,
|
|
Anthropic, Gemini, Groq, and generic providers via LiteLLM.
|
|
|
|
Examples:
|
|
>>> # OpenAI configuration
|
|
>>> config = LLMProviderConfig(
|
|
... provider=LLMProvider.OPENAI,
|
|
... api_key='sk-...',
|
|
... )
|
|
|
|
>>> # Azure OpenAI configuration
|
|
>>> config = LLMProviderConfig(
|
|
... provider=LLMProvider.AZURE_OPENAI,
|
|
... api_key='...',
|
|
... base_url='https://your-resource.openai.azure.com',
|
|
... azure_deployment='your-deployment-name',
|
|
... )
|
|
|
|
>>> # Anthropic configuration
|
|
>>> config = LLMProviderConfig(
|
|
... provider=LLMProvider.ANTHROPIC,
|
|
... model='claude-sonnet-4-5-latest',
|
|
... )
|
|
"""
|
|
|
|
provider: LLMProvider = Field(
|
|
default=LLMProvider.OPENAI,
|
|
description='The LLM provider to use',
|
|
)
|
|
model: str | None = Field(
|
|
default=None,
|
|
description='The model name to use. If not provided, uses provider default.',
|
|
)
|
|
small_model: str | None = Field(
|
|
default=None,
|
|
description='Smaller/faster model for simpler tasks. If not provided, uses provider default.',
|
|
)
|
|
api_key: str | None = Field(
|
|
default=None,
|
|
description='API key for the provider. Falls back to environment variables if not provided.',
|
|
)
|
|
base_url: str | None = Field(
|
|
default=None,
|
|
description='Base URL for the API. Required for Azure OpenAI and custom endpoints.',
|
|
)
|
|
temperature: float = Field(
|
|
default=1.0,
|
|
ge=0.0,
|
|
le=2.0,
|
|
description='Temperature for response generation',
|
|
)
|
|
max_tokens: int = Field(
|
|
default=8192,
|
|
gt=0,
|
|
description='Maximum tokens for response generation',
|
|
)
|
|
|
|
# Azure-specific fields
|
|
azure_deployment: str | None = Field(
|
|
default=None,
|
|
description='Azure OpenAI deployment name (required for Azure provider)',
|
|
)
|
|
azure_api_version: str | None = Field(
|
|
default='2024-10-21',
|
|
description='Azure OpenAI API version',
|
|
)
|
|
|
|
# LiteLLM-specific fields
|
|
litellm_model: str | None = Field(
|
|
default=None,
|
|
description='Full LiteLLM model string (e.g., "azure/gpt-4", "bedrock/claude-3")',
|
|
)
|
|
|
|
# Custom provider fields
|
|
custom_client_class: str | None = Field(
|
|
default=None,
|
|
description='Fully qualified class name for custom LLM client',
|
|
)
|
|
|
|
@model_validator(mode='after')
|
|
def set_defaults_and_validate(self) -> 'LLMProviderConfig':
|
|
"""Set provider-specific defaults and validate configuration."""
|
|
# Set default models if not provided
|
|
if self.model is None and self.provider in DEFAULT_MODELS:
|
|
self.model = DEFAULT_MODELS[self.provider]['model']
|
|
|
|
if self.small_model is None and self.provider in DEFAULT_MODELS:
|
|
self.small_model = DEFAULT_MODELS[self.provider]['small_model']
|
|
|
|
# Set API key from environment if not provided
|
|
if self.api_key is None:
|
|
if self.provider == LLMProvider.OPENAI:
|
|
self.api_key = os.getenv('OPENAI_API_KEY')
|
|
elif self.provider == LLMProvider.AZURE_OPENAI:
|
|
self.api_key = os.getenv('AZURE_OPENAI_API_KEY')
|
|
elif self.provider == LLMProvider.ANTHROPIC:
|
|
self.api_key = os.getenv('ANTHROPIC_API_KEY')
|
|
elif self.provider == LLMProvider.GEMINI:
|
|
self.api_key = os.getenv('GOOGLE_API_KEY')
|
|
elif self.provider == LLMProvider.GROQ:
|
|
self.api_key = os.getenv('GROQ_API_KEY')
|
|
|
|
# Validate Azure-specific requirements
|
|
if self.provider == LLMProvider.AZURE_OPENAI:
|
|
if not self.base_url:
|
|
raise ValueError('base_url is required for Azure OpenAI provider')
|
|
if not self.azure_deployment and not self.model:
|
|
raise ValueError(
|
|
'Either azure_deployment or model must be provided for Azure OpenAI'
|
|
)
|
|
|
|
# Validate LiteLLM requirements
|
|
if self.provider == LLMProvider.LITELLM and not self.litellm_model:
|
|
raise ValueError('litellm_model is required for LiteLLM provider')
|
|
|
|
# Validate custom provider requirements
|
|
if self.provider == LLMProvider.CUSTOM and not self.custom_client_class:
|
|
raise ValueError('custom_client_class is required for custom provider')
|
|
|
|
return self
|
|
|
|
|
|
class EmbedderConfig(BaseModel):
|
|
"""Configuration for embedding provider.
|
|
|
|
Examples:
|
|
>>> # OpenAI embeddings
|
|
>>> config = EmbedderConfig(
|
|
... provider=EmbedderProvider.OPENAI,
|
|
... )
|
|
|
|
>>> # Voyage AI embeddings
|
|
>>> config = EmbedderConfig(
|
|
... provider=EmbedderProvider.VOYAGE,
|
|
... model='voyage-3',
|
|
... )
|
|
"""
|
|
|
|
provider: EmbedderProvider = Field(
|
|
default=EmbedderProvider.OPENAI,
|
|
description='The embedder provider to use',
|
|
)
|
|
model: str | None = Field(
|
|
default=None,
|
|
description='The embedding model name. If not provided, uses provider default.',
|
|
)
|
|
api_key: str | None = Field(
|
|
default=None,
|
|
description='API key for the provider. Falls back to environment variables if not provided.',
|
|
)
|
|
base_url: str | None = Field(
|
|
default=None,
|
|
description='Base URL for the API. Required for Azure OpenAI.',
|
|
)
|
|
dimensions: int | None = Field(
|
|
default=None,
|
|
description='Embedding dimensions. If not provided, uses provider default.',
|
|
)
|
|
|
|
# Azure-specific fields
|
|
azure_deployment: str | None = Field(
|
|
default=None,
|
|
description='Azure OpenAI deployment name (required for Azure provider)',
|
|
)
|
|
azure_api_version: str | None = Field(
|
|
default='2024-10-21',
|
|
description='Azure OpenAI API version',
|
|
)
|
|
|
|
# Custom provider fields
|
|
custom_client_class: str | None = Field(
|
|
default=None,
|
|
description='Fully qualified class name for custom embedder client',
|
|
)
|
|
|
|
@model_validator(mode='after')
|
|
def set_defaults_and_validate(self) -> 'EmbedderConfig':
|
|
"""Set provider-specific defaults and validate configuration."""
|
|
# Set default model and dimensions if not provided
|
|
if self.provider in DEFAULT_EMBEDDINGS:
|
|
if self.model is None:
|
|
self.model = DEFAULT_EMBEDDINGS[self.provider]['model']
|
|
if self.dimensions is None:
|
|
self.dimensions = DEFAULT_EMBEDDINGS[self.provider]['dimensions']
|
|
|
|
# Set API key from environment if not provided
|
|
if self.api_key is None:
|
|
if self.provider == EmbedderProvider.OPENAI:
|
|
self.api_key = os.getenv('OPENAI_API_KEY')
|
|
elif self.provider == EmbedderProvider.AZURE_OPENAI:
|
|
self.api_key = os.getenv('AZURE_OPENAI_API_KEY')
|
|
elif self.provider == EmbedderProvider.VOYAGE:
|
|
self.api_key = os.getenv('VOYAGE_API_KEY')
|
|
elif self.provider == EmbedderProvider.GEMINI:
|
|
self.api_key = os.getenv('GOOGLE_API_KEY')
|
|
|
|
# Validate Azure-specific requirements
|
|
if self.provider == EmbedderProvider.AZURE_OPENAI and not self.base_url:
|
|
raise ValueError('base_url is required for Azure OpenAI embedder')
|
|
|
|
# Validate custom provider requirements
|
|
if self.provider == EmbedderProvider.CUSTOM and not self.custom_client_class:
|
|
raise ValueError('custom_client_class is required for custom embedder')
|
|
|
|
return self
|
|
|
|
|
|
class RerankerConfig(BaseModel):
|
|
"""Configuration for reranker/cross-encoder provider.
|
|
|
|
Examples:
|
|
>>> config = RerankerConfig(
|
|
... provider=RerankerProvider.OPENAI,
|
|
... )
|
|
"""
|
|
|
|
provider: RerankerProvider = Field(
|
|
default=RerankerProvider.OPENAI,
|
|
description='The reranker provider to use',
|
|
)
|
|
api_key: str | None = Field(
|
|
default=None,
|
|
description='API key for the provider. Falls back to environment variables if not provided.',
|
|
)
|
|
base_url: str | None = Field(
|
|
default=None,
|
|
description='Base URL for the API.',
|
|
)
|
|
|
|
# Azure-specific fields
|
|
azure_deployment: str | None = Field(
|
|
default=None,
|
|
description='Azure OpenAI deployment name (required for Azure provider)',
|
|
)
|
|
|
|
# Custom provider fields
|
|
custom_client_class: str | None = Field(
|
|
default=None,
|
|
description='Fully qualified class name for custom reranker client',
|
|
)
|
|
|
|
@model_validator(mode='after')
|
|
def set_defaults(self) -> 'RerankerConfig':
|
|
"""Set provider-specific defaults."""
|
|
# Set API key from environment if not provided
|
|
if self.api_key is None:
|
|
if self.provider == RerankerProvider.OPENAI:
|
|
self.api_key = os.getenv('OPENAI_API_KEY')
|
|
elif self.provider == RerankerProvider.AZURE_OPENAI:
|
|
self.api_key = os.getenv('AZURE_OPENAI_API_KEY')
|
|
|
|
return self
|
|
|
|
|
|
class DatabaseConfig(BaseModel):
|
|
"""Configuration for graph database.
|
|
|
|
Examples:
|
|
>>> # Neo4j configuration
|
|
>>> config = DatabaseConfig(
|
|
... provider=DatabaseProvider.NEO4J,
|
|
... uri='bolt://localhost:7687',
|
|
... user='neo4j',
|
|
... password='password',
|
|
... )
|
|
|
|
>>> # FalkorDB configuration
|
|
>>> config = DatabaseConfig(
|
|
... provider=DatabaseProvider.FALKORDB,
|
|
... uri='redis://localhost:6379',
|
|
... )
|
|
"""
|
|
|
|
provider: DatabaseProvider = Field(
|
|
default=DatabaseProvider.NEO4J,
|
|
description='The graph database provider to use',
|
|
)
|
|
uri: str | None = Field(
|
|
default=None,
|
|
description='Database connection URI',
|
|
)
|
|
user: str | None = Field(
|
|
default=None,
|
|
description='Database username',
|
|
)
|
|
password: str | None = Field(
|
|
default=None,
|
|
description='Database password',
|
|
)
|
|
database: str | None = Field(
|
|
default=None,
|
|
description='Database name. Uses provider default if not specified.',
|
|
)
|
|
|
|
# Custom provider fields
|
|
custom_driver_class: str | None = Field(
|
|
default=None,
|
|
description='Fully qualified class name for custom database driver',
|
|
)
|
|
|
|
@model_validator(mode='after')
|
|
def validate_database_config(self) -> 'DatabaseConfig':
|
|
"""Validate database configuration."""
|
|
if self.provider == DatabaseProvider.CUSTOM and not self.custom_driver_class:
|
|
raise ValueError('custom_driver_class is required for custom database provider')
|
|
return self
|
|
|
|
|
|
class GraphitiConfig(BaseModel):
|
|
"""Main Graphiti configuration.
|
|
|
|
This is the primary configuration class that aggregates all provider configurations.
|
|
It supports loading from YAML files, environment variables, and programmatic configuration.
|
|
|
|
Examples:
|
|
>>> # Programmatic configuration
|
|
>>> config = GraphitiConfig(
|
|
... llm=LLMProviderConfig(provider=LLMProvider.ANTHROPIC),
|
|
... embedder=EmbedderConfig(provider=EmbedderProvider.VOYAGE),
|
|
... database=DatabaseConfig(
|
|
... uri='bolt://localhost:7687',
|
|
... user='neo4j',
|
|
... password='password',
|
|
... ),
|
|
... )
|
|
|
|
>>> # Load from YAML file
|
|
>>> config = GraphitiConfig.from_yaml('graphiti.yaml')
|
|
|
|
>>> # Load from environment (looks for GRAPHITI_CONFIG_PATH)
|
|
>>> config = GraphitiConfig.from_env()
|
|
"""
|
|
|
|
llm: LLMProviderConfig = Field(
|
|
default_factory=LLMProviderConfig,
|
|
description='LLM provider configuration',
|
|
)
|
|
embedder: EmbedderConfig = Field(
|
|
default_factory=EmbedderConfig,
|
|
description='Embedder provider configuration',
|
|
)
|
|
reranker: RerankerConfig = Field(
|
|
default_factory=RerankerConfig,
|
|
description='Reranker provider configuration',
|
|
)
|
|
database: DatabaseConfig = Field(
|
|
default_factory=DatabaseConfig,
|
|
description='Database provider configuration',
|
|
)
|
|
|
|
# General settings
|
|
store_raw_episode_content: bool = Field(
|
|
default=True,
|
|
description='Whether to store raw episode content in the database',
|
|
)
|
|
max_coroutines: int | None = Field(
|
|
default=None,
|
|
description='Maximum number of concurrent operations',
|
|
)
|
|
|
|
@classmethod
|
|
def from_yaml(cls, path: str | Path) -> 'GraphitiConfig':
|
|
"""Load configuration from a YAML file.
|
|
|
|
Args:
|
|
path: Path to the YAML configuration file
|
|
|
|
Returns:
|
|
GraphitiConfig instance loaded from the file
|
|
|
|
Raises:
|
|
FileNotFoundError: If the configuration file doesn't exist
|
|
ValueError: If the YAML file is invalid
|
|
"""
|
|
path = Path(path)
|
|
if not path.exists():
|
|
raise FileNotFoundError(f'Configuration file not found: {path}')
|
|
|
|
with open(path) as f:
|
|
config_dict = yaml.safe_load(f)
|
|
|
|
if config_dict is None:
|
|
config_dict = {}
|
|
|
|
return cls(**config_dict)
|
|
|
|
@classmethod
|
|
def from_env(cls, env_var: str = 'GRAPHITI_CONFIG_PATH') -> 'GraphitiConfig':
|
|
"""Load configuration from a YAML file specified in an environment variable.
|
|
|
|
Args:
|
|
env_var: Name of the environment variable containing the config file path
|
|
|
|
Returns:
|
|
GraphitiConfig instance loaded from the file, or default config if env var not set
|
|
|
|
Raises:
|
|
FileNotFoundError: If the specified config file doesn't exist
|
|
"""
|
|
config_path = os.getenv(env_var)
|
|
if config_path:
|
|
return cls.from_yaml(config_path)
|
|
|
|
# Look for default config files in current directory
|
|
for default_file in ['.graphiti.yaml', '.graphiti.yml', 'graphiti.yaml', 'graphiti.yml']:
|
|
if Path(default_file).exists():
|
|
return cls.from_yaml(default_file)
|
|
|
|
# Return default configuration
|
|
return cls()
|
|
|
|
def to_yaml(self, path: str | Path) -> None:
|
|
"""Save configuration to a YAML file.
|
|
|
|
Args:
|
|
path: Path where the configuration file should be saved
|
|
"""
|
|
path = Path(path)
|
|
# Use json mode to convert enums to their values
|
|
config_dict = self.model_dump(exclude_none=True, mode='json')
|
|
|
|
with open(path, 'w') as f:
|
|
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
|