182 lines
7.2 KiB
Python
182 lines
7.2 KiB
Python
"""LLM configuration for Graphiti MCP Server."""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
from typing import TYPE_CHECKING
|
|
|
|
from openai import AsyncAzureOpenAI
|
|
from pydantic import BaseModel
|
|
from utils import create_azure_credential_token_provider
|
|
|
|
from graphiti_core.llm_client import LLMClient
|
|
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
|
from graphiti_core.llm_client.config import LLMConfig
|
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
|
|
|
if TYPE_CHECKING:
|
|
pass
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
|
SMALL_LLM_MODEL = 'gpt-4.1-nano'
|
|
|
|
|
|
class GraphitiLLMConfig(BaseModel):
|
|
"""Configuration for the LLM client.
|
|
|
|
Centralizes all LLM-specific configuration parameters including API keys and model selection.
|
|
"""
|
|
|
|
api_key: str | None = None
|
|
model: str = DEFAULT_LLM_MODEL
|
|
small_model: str = SMALL_LLM_MODEL
|
|
temperature: float = 0.0
|
|
azure_openai_endpoint: str | None = None
|
|
azure_openai_deployment_name: str | None = None
|
|
azure_openai_api_version: str | None = None
|
|
azure_openai_use_managed_identity: bool = False
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'GraphitiLLMConfig':
|
|
"""Create LLM configuration from environment variables."""
|
|
# Get model from environment, or use default if not set or empty
|
|
model_env = os.environ.get('MODEL_NAME', '')
|
|
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
|
|
|
# Get small_model from environment, or use default if not set or empty
|
|
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
|
|
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
|
|
|
|
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
|
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
|
|
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
|
azure_openai_use_managed_identity = (
|
|
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
|
)
|
|
|
|
if azure_openai_endpoint is None:
|
|
# Setup for OpenAI API
|
|
# Log if empty model was provided
|
|
if model_env == '':
|
|
logger.debug(
|
|
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
|
|
)
|
|
elif not model_env.strip():
|
|
logger.warning(
|
|
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
|
|
)
|
|
|
|
return cls(
|
|
api_key=os.environ.get('OPENAI_API_KEY'),
|
|
model=model,
|
|
small_model=small_model,
|
|
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
|
)
|
|
else:
|
|
# Setup for Azure OpenAI API
|
|
# Log if empty deployment name was provided
|
|
if azure_openai_deployment_name is None:
|
|
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
|
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
|
|
|
if not azure_openai_use_managed_identity:
|
|
# api key
|
|
api_key = os.environ.get('OPENAI_API_KEY', None)
|
|
else:
|
|
# Managed identity
|
|
api_key = None
|
|
|
|
return cls(
|
|
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
|
azure_openai_endpoint=azure_openai_endpoint,
|
|
api_key=api_key,
|
|
azure_openai_api_version=azure_openai_api_version,
|
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
|
model=model,
|
|
small_model=small_model,
|
|
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
|
)
|
|
|
|
@classmethod
|
|
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
|
|
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
|
|
# Start with environment-based config
|
|
config = cls.from_env()
|
|
|
|
# CLI arguments override environment variables when provided
|
|
if hasattr(args, 'model') and args.model:
|
|
# Only use CLI model if it's not empty
|
|
if args.model.strip():
|
|
config.model = args.model
|
|
else:
|
|
# Log that empty model was provided and default is used
|
|
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
|
|
|
|
if hasattr(args, 'small_model') and args.small_model:
|
|
if args.small_model.strip():
|
|
config.small_model = args.small_model
|
|
else:
|
|
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
|
|
|
|
if hasattr(args, 'temperature') and args.temperature is not None:
|
|
config.temperature = args.temperature
|
|
|
|
return config
|
|
|
|
def create_client(self) -> LLMClient:
|
|
"""Create an LLM client based on this configuration.
|
|
|
|
Returns:
|
|
LLMClient instance
|
|
"""
|
|
if self.azure_openai_endpoint is not None:
|
|
# Azure OpenAI API setup
|
|
if self.azure_openai_use_managed_identity:
|
|
# Use managed identity for authentication
|
|
token_provider = create_azure_credential_token_provider()
|
|
return AzureOpenAILLMClient(
|
|
azure_client=AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
azure_ad_token_provider=token_provider,
|
|
),
|
|
config=LLMConfig(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
small_model=self.small_model,
|
|
temperature=self.temperature,
|
|
),
|
|
)
|
|
elif self.api_key:
|
|
# Use API key for authentication
|
|
return AzureOpenAILLMClient(
|
|
azure_client=AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
api_key=self.api_key,
|
|
),
|
|
config=LLMConfig(
|
|
api_key=self.api_key,
|
|
model=self.model,
|
|
small_model=self.small_model,
|
|
temperature=self.temperature,
|
|
),
|
|
)
|
|
else:
|
|
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
|
|
|
if not self.api_key:
|
|
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
|
|
|
|
llm_client_config = LLMConfig(
|
|
api_key=self.api_key, model=self.model, small_model=self.small_model
|
|
)
|
|
|
|
# Set temperature
|
|
llm_client_config.temperature = self.temperature
|
|
|
|
return OpenAIClient(config=llm_client_config)
|