auto parse schema
This commit is contained in:
parent
07e2e40cab
commit
a9a8c398eb
2 changed files with 55 additions and 15 deletions
|
|
@ -201,7 +201,20 @@ class NeptuneProviderConfig(BaseModel):
|
|||
region: str | None = None
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
"""Validate Neptune-specific requirements."""
|
||||
"""Validate and normalize Neptune-specific requirements."""
|
||||
# Auto-detect and add protocol if missing
|
||||
if not self.host.startswith(('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')):
|
||||
# Check if it looks like a Neptune Analytics graph ID (starts with 'g-')
|
||||
if self.host.startswith('g-'):
|
||||
self.host = f'neptune-graph://{self.host}'
|
||||
# Check if it contains 'neptune.amazonaws.com' (Neptune Database cluster)
|
||||
elif 'neptune.amazonaws.com' in self.host:
|
||||
self.host = f'neptune-db://{self.host}'
|
||||
# Otherwise default to Neptune Database protocol
|
||||
else:
|
||||
self.host = f'neptune-db://{self.host}'
|
||||
|
||||
# Validate protocol is correct
|
||||
if not self.host.startswith(('neptune-db://', 'neptune-graph://')):
|
||||
raise ValueError(
|
||||
'Neptune host must start with neptune-db:// or neptune-graph://\n'
|
||||
|
|
|
|||
|
|
@ -475,22 +475,49 @@ class DatabaseDriverFactory:
|
|||
except Exception as e:
|
||||
raise ValueError(f'AWS credential error: {e}') from e
|
||||
|
||||
# Load Neptune config
|
||||
if config.providers.neptune:
|
||||
neptune_config = config.providers.neptune
|
||||
else:
|
||||
from config.schema import NeptuneProviderConfig
|
||||
|
||||
neptune_config = NeptuneProviderConfig()
|
||||
|
||||
# Environment overrides
|
||||
# Load Neptune config and environment variables
|
||||
import os
|
||||
|
||||
host = os.environ.get('NEPTUNE_HOST', neptune_config.host)
|
||||
aoss_host = os.environ.get('AOSS_HOST', neptune_config.aoss_host)
|
||||
port = int(os.environ.get('NEPTUNE_PORT', str(neptune_config.port)))
|
||||
aoss_port = int(os.environ.get('AOSS_PORT', str(neptune_config.aoss_port)))
|
||||
region_override = os.environ.get('AWS_REGION', region or neptune_config.region)
|
||||
# Read environment variables first
|
||||
env_host = os.environ.get('NEPTUNE_HOST')
|
||||
env_aoss_host = os.environ.get('AOSS_HOST')
|
||||
env_port = os.environ.get('NEPTUNE_PORT')
|
||||
env_aoss_port = os.environ.get('AOSS_PORT')
|
||||
env_region = os.environ.get('AWS_REGION')
|
||||
|
||||
if config.providers.neptune:
|
||||
neptune_config = config.providers.neptune
|
||||
# Apply environment overrides
|
||||
host = env_host or neptune_config.host
|
||||
aoss_host = env_aoss_host or neptune_config.aoss_host
|
||||
port = int(env_port) if env_port else neptune_config.port
|
||||
aoss_port = int(env_aoss_port) if env_aoss_port else neptune_config.aoss_port
|
||||
region_override = env_region or region or neptune_config.region
|
||||
else:
|
||||
# No config provided, use environment variables with defaults
|
||||
from config.schema import NeptuneProviderConfig
|
||||
|
||||
host = env_host or 'neptune-db://localhost'
|
||||
aoss_host = env_aoss_host
|
||||
port = int(env_port) if env_port else 8182
|
||||
aoss_port = int(env_aoss_port) if env_aoss_port else 443
|
||||
region_override = env_region or region
|
||||
|
||||
# Create config with values to trigger validation
|
||||
neptune_config = NeptuneProviderConfig(
|
||||
host=host,
|
||||
aoss_host=aoss_host,
|
||||
port=port,
|
||||
aoss_port=aoss_port,
|
||||
region=region_override,
|
||||
)
|
||||
|
||||
# Use normalized values from config (protocol may have been auto-added)
|
||||
host = neptune_config.host
|
||||
aoss_host = neptune_config.aoss_host
|
||||
port = neptune_config.port
|
||||
aoss_port = neptune_config.aoss_port
|
||||
region_override = neptune_config.region
|
||||
|
||||
if not aoss_host:
|
||||
raise ValueError(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue