auto parse schema

This commit is contained in:
Aidan Petti 2025-11-24 15:14:23 -07:00
parent 07e2e40cab
commit a9a8c398eb
2 changed files with 55 additions and 15 deletions

View file

@ -201,7 +201,20 @@ class NeptuneProviderConfig(BaseModel):
region: str | None = None region: str | None = None
def model_post_init(self, __context) -> None: def model_post_init(self, __context) -> None:
"""Validate Neptune-specific requirements.""" """Validate and normalize Neptune-specific requirements."""
# Auto-detect and add protocol if missing
if not self.host.startswith(('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')):
# Check if it looks like a Neptune Analytics graph ID (starts with 'g-')
if self.host.startswith('g-'):
self.host = f'neptune-graph://{self.host}'
# Check if it contains 'neptune.amazonaws.com' (Neptune Database cluster)
elif 'neptune.amazonaws.com' in self.host:
self.host = f'neptune-db://{self.host}'
# Otherwise default to Neptune Database protocol
else:
self.host = f'neptune-db://{self.host}'
# Validate protocol is correct
if not self.host.startswith(('neptune-db://', 'neptune-graph://')): if not self.host.startswith(('neptune-db://', 'neptune-graph://')):
raise ValueError( raise ValueError(
'Neptune host must start with neptune-db:// or neptune-graph://\n' 'Neptune host must start with neptune-db:// or neptune-graph://\n'

View file

@ -475,22 +475,49 @@ class DatabaseDriverFactory:
except Exception as e: except Exception as e:
raise ValueError(f'AWS credential error: {e}') from e raise ValueError(f'AWS credential error: {e}') from e
# Load Neptune config # Load Neptune config and environment variables
if config.providers.neptune:
neptune_config = config.providers.neptune
else:
from config.schema import NeptuneProviderConfig
neptune_config = NeptuneProviderConfig()
# Environment overrides
import os import os
host = os.environ.get('NEPTUNE_HOST', neptune_config.host) # Read environment variables first
aoss_host = os.environ.get('AOSS_HOST', neptune_config.aoss_host) env_host = os.environ.get('NEPTUNE_HOST')
port = int(os.environ.get('NEPTUNE_PORT', str(neptune_config.port))) env_aoss_host = os.environ.get('AOSS_HOST')
aoss_port = int(os.environ.get('AOSS_PORT', str(neptune_config.aoss_port))) env_port = os.environ.get('NEPTUNE_PORT')
region_override = os.environ.get('AWS_REGION', region or neptune_config.region) env_aoss_port = os.environ.get('AOSS_PORT')
env_region = os.environ.get('AWS_REGION')
if config.providers.neptune:
neptune_config = config.providers.neptune
# Apply environment overrides
host = env_host or neptune_config.host
aoss_host = env_aoss_host or neptune_config.aoss_host
port = int(env_port) if env_port else neptune_config.port
aoss_port = int(env_aoss_port) if env_aoss_port else neptune_config.aoss_port
region_override = env_region or region or neptune_config.region
else:
# No config provided, use environment variables with defaults
from config.schema import NeptuneProviderConfig
host = env_host or 'neptune-db://localhost'
aoss_host = env_aoss_host
port = int(env_port) if env_port else 8182
aoss_port = int(env_aoss_port) if env_aoss_port else 443
region_override = env_region or region
# Create config with values to trigger validation
neptune_config = NeptuneProviderConfig(
host=host,
aoss_host=aoss_host,
port=port,
aoss_port=aoss_port,
region=region_override,
)
# Use normalized values from config (protocol may have been auto-added)
host = neptune_config.host
aoss_host = neptune_config.aoss_host
port = neptune_config.port
aoss_port = neptune_config.aoss_port
region_override = neptune_config.region
if not aoss_host: if not aoss_host:
raise ValueError( raise ValueError(