diff --git a/mcp_server/src/config/schema.py b/mcp_server/src/config/schema.py index a67f733d..76e802ff 100644 --- a/mcp_server/src/config/schema.py +++ b/mcp_server/src/config/schema.py @@ -201,7 +201,20 @@ class NeptuneProviderConfig(BaseModel): region: str | None = None def model_post_init(self, __context) -> None: - """Validate Neptune-specific requirements.""" + """Validate and normalize Neptune-specific requirements.""" + # Auto-detect and add protocol if missing + if not self.host.startswith(('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')): + # Check if it looks like a Neptune Analytics graph ID (starts with 'g-') + if self.host.startswith('g-'): + self.host = f'neptune-graph://{self.host}' + # Check if it contains 'neptune.amazonaws.com' (Neptune Database cluster) + elif 'neptune.amazonaws.com' in self.host: + self.host = f'neptune-db://{self.host}' + # Otherwise default to Neptune Database protocol + else: + self.host = f'neptune-db://{self.host}' + + # Validate protocol is correct if not self.host.startswith(('neptune-db://', 'neptune-graph://')): raise ValueError( 'Neptune host must start with neptune-db:// or neptune-graph://\n' diff --git a/mcp_server/src/services/factories.py b/mcp_server/src/services/factories.py index 5c7bf4f1..2edd53f4 100644 --- a/mcp_server/src/services/factories.py +++ b/mcp_server/src/services/factories.py @@ -475,22 +475,49 @@ class DatabaseDriverFactory: except Exception as e: raise ValueError(f'AWS credential error: {e}') from e - # Load Neptune config - if config.providers.neptune: - neptune_config = config.providers.neptune - else: - from config.schema import NeptuneProviderConfig - - neptune_config = NeptuneProviderConfig() - - # Environment overrides + # Load Neptune config and environment variables import os - host = os.environ.get('NEPTUNE_HOST', neptune_config.host) - aoss_host = os.environ.get('AOSS_HOST', neptune_config.aoss_host) - port = int(os.environ.get('NEPTUNE_PORT', str(neptune_config.port))) - aoss_port = int(os.environ.get('AOSS_PORT', str(neptune_config.aoss_port))) - region_override = os.environ.get('AWS_REGION', region or neptune_config.region) + # Read environment variables first + env_host = os.environ.get('NEPTUNE_HOST') + env_aoss_host = os.environ.get('AOSS_HOST') + env_port = os.environ.get('NEPTUNE_PORT') + env_aoss_port = os.environ.get('AOSS_PORT') + env_region = os.environ.get('AWS_REGION') + + if config.providers.neptune: + neptune_config = config.providers.neptune + # Apply environment overrides + host = env_host or neptune_config.host + aoss_host = env_aoss_host or neptune_config.aoss_host + port = int(env_port) if env_port else neptune_config.port + aoss_port = int(env_aoss_port) if env_aoss_port else neptune_config.aoss_port + region_override = env_region or region or neptune_config.region + else: + # No config provided, use environment variables with defaults + from config.schema import NeptuneProviderConfig + + host = env_host or 'neptune-db://localhost' + aoss_host = env_aoss_host + port = int(env_port) if env_port else 8182 + aoss_port = int(env_aoss_port) if env_aoss_port else 443 + region_override = env_region or region + + # Create config with values to trigger validation + neptune_config = NeptuneProviderConfig( + host=host, + aoss_host=aoss_host, + port=port, + aoss_port=aoss_port, + region=region_override, + ) + + # Use normalized values from config (protocol may have been auto-added) + host = neptune_config.host + aoss_host = neptune_config.aoss_host + port = neptune_config.port + aoss_port = neptune_config.aoss_port + region_override = neptune_config.region if not aoss_host: raise ValueError(