fix MCP config handling and docker compose setup (#332)
fix config handling and docker compose setup
This commit is contained in:
parent
807a402ba4
commit
6b12896723
2 changed files with 191 additions and 95 deletions
|
|
@ -34,7 +34,6 @@ services:
|
|||
- NEO4J_USER=neo4j
|
||||
- NEO4J_PASSWORD=demodemo
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- OPENAI_BASE_URL=${OPENAI_BASE_URL}
|
||||
- MODEL_NAME=${MODEL_NAME}
|
||||
- PATH=/root/.local/bin:${PATH}
|
||||
ports:
|
||||
|
|
|
|||
|
|
@ -154,39 +154,152 @@ class StatusResponse(TypedDict):
|
|||
|
||||
|
||||
# Server configuration classes
|
||||
# The configuration system has a hierarchy:
|
||||
# - GraphitiConfig is the top-level configuration
|
||||
# - LLMConfig handles all OpenAI/LLM related settings
|
||||
# - Neo4jConfig manages database connection details
|
||||
# - Various other settings like group_id and feature flags
|
||||
# Configuration values are loaded from:
|
||||
# 1. Default values in the class definitions
|
||||
# 2. Environment variables (loaded via load_dotenv())
|
||||
# 3. Command line arguments (which override environment variables)
|
||||
class GraphitiLLMConfig(BaseModel):
|
||||
"""Configuration for the LLM client.
|
||||
|
||||
Centralizes all LLM-specific configuration parameters including API keys and model selection.
|
||||
"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
model: str = DEFAULT_LLM_MODEL
|
||||
temperature: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiLLMConfig':
|
||||
"""Create LLM configuration from environment variables."""
|
||||
# Get model from environment, or use default if not set or empty
|
||||
model_env = os.environ.get('MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
||||
|
||||
# Log if empty model was provided
|
||||
if model_env == '':
|
||||
logger.debug(
|
||||
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
|
||||
)
|
||||
elif not model_env.strip():
|
||||
logger.warning(
|
||||
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
|
||||
)
|
||||
|
||||
return cls(
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
model=model,
|
||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
|
||||
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
|
||||
# Start with environment-based config
|
||||
config = cls.from_env()
|
||||
|
||||
# CLI arguments override environment variables when provided
|
||||
if hasattr(args, 'model') and args.model:
|
||||
# Only use CLI model if it's not empty
|
||||
if args.model.strip():
|
||||
config.model = args.model
|
||||
else:
|
||||
# Log that empty model was provided and default is used
|
||||
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
|
||||
|
||||
if hasattr(args, 'temperature') and args.temperature is not None:
|
||||
config.temperature = args.temperature
|
||||
|
||||
return config
|
||||
|
||||
def create_client(self) -> Optional[LLMClient]:
|
||||
"""Create an LLM client based on this configuration.
|
||||
|
||||
Returns:
|
||||
LLMClient instance if API key is available, None otherwise
|
||||
"""
|
||||
if not self.api_key:
|
||||
return None
|
||||
|
||||
llm_client_config = LLMConfig(api_key=self.api_key, model=self.model)
|
||||
|
||||
# Set temperature
|
||||
llm_client_config.temperature = self.temperature
|
||||
|
||||
return OpenAIClient(config=llm_client_config)
|
||||
|
||||
|
||||
class Neo4jConfig(BaseModel):
|
||||
"""Configuration for Neo4j database connection."""
|
||||
|
||||
uri: str = 'bolt://localhost:7687'
|
||||
user: str = 'neo4j'
|
||||
password: str = 'password'
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'Neo4jConfig':
|
||||
"""Create Neo4j configuration from environment variables."""
|
||||
return cls(
|
||||
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||
user=os.environ.get('NEO4J_USER', 'neo4j'),
|
||||
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
||||
)
|
||||
|
||||
|
||||
class GraphitiConfig(BaseModel):
|
||||
"""Configuration for Graphiti client.
|
||||
|
||||
Centralizes all configuration parameters for the Graphiti client,
|
||||
including database connection details and LLM settings.
|
||||
Centralizes all configuration parameters for the Graphiti client.
|
||||
"""
|
||||
|
||||
neo4j_uri: str = 'bolt://localhost:7687'
|
||||
neo4j_user: str = 'neo4j'
|
||||
neo4j_password: str = 'password'
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_base_url: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
|
||||
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
|
||||
group_id: Optional[str] = None
|
||||
use_custom_entities: bool = False
|
||||
destroy_graph: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiConfig':
|
||||
"""Create a configuration instance from environment variables."""
|
||||
return cls(
|
||||
neo4j_uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||
neo4j_user=os.environ.get('NEO4J_USER', 'neo4j'),
|
||||
neo4j_password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
||||
openai_api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
openai_base_url=os.environ.get('OPENAI_BASE_URL'),
|
||||
model_name=os.environ.get('MODEL_NAME'),
|
||||
llm=GraphitiLLMConfig.from_env(),
|
||||
neo4j=Neo4jConfig.from_env(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
|
||||
"""Create configuration from CLI arguments, falling back to environment variables."""
|
||||
# Start with environment configuration
|
||||
config = cls.from_env()
|
||||
|
||||
# Apply CLI overrides
|
||||
if args.group_id:
|
||||
config.group_id = args.group_id
|
||||
else:
|
||||
config.group_id = f'graph_{uuid.uuid4().hex[:8]}'
|
||||
|
||||
config.use_custom_entities = args.use_custom_entities
|
||||
config.destroy_graph = args.destroy_graph
|
||||
|
||||
# Update LLM config using CLI args
|
||||
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Configuration for MCP server."""
|
||||
|
||||
transport: str
|
||||
transport: str = 'sse' # Default to SSE transport
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
|
||||
"""Create MCP configuration from CLI arguments."""
|
||||
return cls(transport=args.transport)
|
||||
|
||||
|
||||
# Configure logging
|
||||
|
|
@ -197,8 +310,8 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create global config instance
|
||||
config = GraphitiConfig.from_env()
|
||||
# Create global config instance - will be properly initialized later
|
||||
config = GraphitiConfig()
|
||||
|
||||
# MCP server instructions
|
||||
GRAPHITI_MCP_INSTRUCTIONS = """
|
||||
|
|
@ -243,44 +356,53 @@ mcp = FastMCP(
|
|||
graphiti_client: Optional[Graphiti] = None
|
||||
|
||||
|
||||
async def initialize_graphiti(llm_client: Optional[LLMClient] = None, destroy_graph: bool = False):
|
||||
"""Initialize the Graphiti client with the provided settings.
|
||||
async def initialize_graphiti():
|
||||
"""Initialize the Graphiti client with the configured settings."""
|
||||
global graphiti_client, config
|
||||
|
||||
Args:
|
||||
llm_client: Optional LLMClient instance to use for LLM operations
|
||||
destroy_graph: Optional boolean to destroy all Graphiti graphs
|
||||
"""
|
||||
global graphiti_client
|
||||
try:
|
||||
# Create LLM client if possible
|
||||
llm_client = config.llm.create_client()
|
||||
if not llm_client and config.use_custom_entities:
|
||||
# If custom entities are enabled, we must have an LLM client
|
||||
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
|
||||
|
||||
# If no client is provided, create a default OpenAI client
|
||||
if not llm_client:
|
||||
if config.openai_api_key:
|
||||
llm_config = LLMConfig(api_key=config.openai_api_key)
|
||||
if config.openai_base_url:
|
||||
llm_config.base_url = config.openai_base_url
|
||||
if config.model_name:
|
||||
llm_config.model = config.model_name
|
||||
llm_client = OpenAIClient(config=llm_config)
|
||||
# Validate Neo4j configuration
|
||||
if not config.neo4j.uri or not config.neo4j.user or not config.neo4j.password:
|
||||
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
||||
|
||||
# Initialize Graphiti client
|
||||
graphiti_client = Graphiti(
|
||||
uri=config.neo4j.uri,
|
||||
user=config.neo4j.user,
|
||||
password=config.neo4j.password,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
# Destroy graph if requested
|
||||
if config.destroy_graph:
|
||||
logger.info('Destroying graph...')
|
||||
await clear_data(graphiti_client.driver)
|
||||
|
||||
# Initialize the graph database with Graphiti's indices
|
||||
await graphiti_client.build_indices_and_constraints()
|
||||
logger.info('Graphiti client initialized successfully')
|
||||
|
||||
# Log configuration details for transparency
|
||||
if llm_client:
|
||||
logger.info(f'Using OpenAI model: {config.llm.model}')
|
||||
logger.info(f'Using temperature: {config.llm.temperature}')
|
||||
else:
|
||||
raise ValueError('OPENAI_API_KEY must be set when not using a custom LLM client')
|
||||
logger.info('No LLM client configured - entity extraction will be limited')
|
||||
|
||||
if not config.neo4j_uri or not config.neo4j_user or not config.neo4j_password:
|
||||
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
||||
logger.info(f'Using group_id: {config.group_id}')
|
||||
logger.info(
|
||||
f'Custom entity extraction: {"enabled" if config.use_custom_entities else "disabled"}'
|
||||
)
|
||||
|
||||
graphiti_client = Graphiti(
|
||||
uri=config.neo4j_uri,
|
||||
user=config.neo4j_user,
|
||||
password=config.neo4j_password,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
if destroy_graph:
|
||||
logger.info('Destroying graph...')
|
||||
await clear_data(graphiti_client.driver)
|
||||
|
||||
# Initialize the graph database with Graphiti's indices
|
||||
await graphiti_client.build_indices_and_constraints()
|
||||
logger.info('Graphiti client initialized successfully')
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Graphiti: {str(e)}')
|
||||
raise
|
||||
|
||||
|
||||
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
|
||||
|
|
@ -808,29 +930,8 @@ async def get_status() -> StatusResponse:
|
|||
}
|
||||
|
||||
|
||||
def create_llm_client(api_key: Optional[str] = None, model: Optional[str] = None) -> LLMClient:
|
||||
"""Create an OpenAI LLM client.
|
||||
|
||||
Args:
|
||||
api_key: API key for the OpenAI service
|
||||
model: Model name to use
|
||||
|
||||
Returns:
|
||||
An instance of the OpenAI LLM client
|
||||
"""
|
||||
# Create config with provided API key and model
|
||||
llm_config = LLMConfig(api_key=api_key)
|
||||
|
||||
# Set model if provided
|
||||
if model:
|
||||
llm_config.model = model
|
||||
|
||||
# Create and return the client
|
||||
return OpenAIClient(config=llm_config)
|
||||
|
||||
|
||||
async def initialize_server() -> MCPConfig:
|
||||
"""Initialize the Graphiti server with the specified LLM client."""
|
||||
"""Parse CLI arguments and initialize the Graphiti server configuration."""
|
||||
global config
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -847,8 +948,14 @@ async def initialize_server() -> MCPConfig:
|
|||
default='sse',
|
||||
help='Transport to use for communication with the client. (default: sse)',
|
||||
)
|
||||
# OpenAI is the only supported LLM client
|
||||
parser.add_argument('--model', help='Model name to use with the LLM client')
|
||||
parser.add_argument(
|
||||
'--model', help=f'Model name to use with the LLM client. (default: {DEFAULT_LLM_MODEL})'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--temperature',
|
||||
type=float,
|
||||
help='Temperature setting for the LLM (0.0-2.0). Lower values make output more deterministic. (default: 0.7)',
|
||||
)
|
||||
parser.add_argument('--destroy-graph', action='store_true', help='Destroy all Graphiti graphs')
|
||||
parser.add_argument(
|
||||
'--use-custom-entities',
|
||||
|
|
@ -858,36 +965,26 @@ async def initialize_server() -> MCPConfig:
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set the group_id from CLI argument or generate a random one
|
||||
# Build configuration from CLI arguments and environment variables
|
||||
config = GraphitiConfig.from_cli_and_env(args)
|
||||
|
||||
# Log the group ID configuration
|
||||
if args.group_id:
|
||||
config.group_id = args.group_id
|
||||
logger.info(f'Using provided group_id: {config.group_id}')
|
||||
else:
|
||||
config.group_id = f'graph_{uuid.uuid4().hex[:8]}'
|
||||
logger.info(f'Generated random group_id: {config.group_id}')
|
||||
|
||||
# Set use_custom_entities flag if specified
|
||||
if args.use_custom_entities:
|
||||
config.use_custom_entities = True
|
||||
# Log entity extraction configuration
|
||||
if config.use_custom_entities:
|
||||
logger.info('Entity extraction enabled using predefined ENTITY_TYPES')
|
||||
else:
|
||||
logger.info('Entity extraction disabled (no custom entities will be used)')
|
||||
|
||||
llm_client = None
|
||||
# Initialize Graphiti
|
||||
await initialize_graphiti()
|
||||
|
||||
# Create OpenAI client if model is specified or if OPENAI_API_KEY is available
|
||||
if args.model or config.openai_api_key:
|
||||
# Override model from command line if specified
|
||||
|
||||
config.model_name = args.model or DEFAULT_LLM_MODEL
|
||||
|
||||
# Create the OpenAI client
|
||||
llm_client = create_llm_client(api_key=config.openai_api_key, model=config.model_name)
|
||||
|
||||
# Initialize Graphiti with the specified LLM client
|
||||
await initialize_graphiti(llm_client, destroy_graph=args.destroy_graph)
|
||||
|
||||
return MCPConfig(transport=args.transport)
|
||||
# Return MCP configuration
|
||||
return MCPConfig.from_cli(args)
|
||||
|
||||
|
||||
async def run_mcp_server():
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue