auto parse schema
This commit is contained in:
parent
07e2e40cab
commit
a9a8c398eb
2 changed files with 55 additions and 15 deletions
|
|
@ -201,7 +201,20 @@ class NeptuneProviderConfig(BaseModel):
|
||||||
region: str | None = None
|
region: str | None = None
|
||||||
|
|
||||||
def model_post_init(self, __context) -> None:
|
def model_post_init(self, __context) -> None:
|
||||||
"""Validate Neptune-specific requirements."""
|
"""Validate and normalize Neptune-specific requirements."""
|
||||||
|
# Auto-detect and add protocol if missing
|
||||||
|
if not self.host.startswith(('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')):
|
||||||
|
# Check if it looks like a Neptune Analytics graph ID (starts with 'g-')
|
||||||
|
if self.host.startswith('g-'):
|
||||||
|
self.host = f'neptune-graph://{self.host}'
|
||||||
|
# Check if it contains 'neptune.amazonaws.com' (Neptune Database cluster)
|
||||||
|
elif 'neptune.amazonaws.com' in self.host:
|
||||||
|
self.host = f'neptune-db://{self.host}'
|
||||||
|
# Otherwise default to Neptune Database protocol
|
||||||
|
else:
|
||||||
|
self.host = f'neptune-db://{self.host}'
|
||||||
|
|
||||||
|
# Validate protocol is correct
|
||||||
if not self.host.startswith(('neptune-db://', 'neptune-graph://')):
|
if not self.host.startswith(('neptune-db://', 'neptune-graph://')):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Neptune host must start with neptune-db:// or neptune-graph://\n'
|
'Neptune host must start with neptune-db:// or neptune-graph://\n'
|
||||||
|
|
|
||||||
|
|
@ -475,22 +475,49 @@ class DatabaseDriverFactory:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f'AWS credential error: {e}') from e
|
raise ValueError(f'AWS credential error: {e}') from e
|
||||||
|
|
||||||
# Load Neptune config
|
# Load Neptune config and environment variables
|
||||||
if config.providers.neptune:
|
|
||||||
neptune_config = config.providers.neptune
|
|
||||||
else:
|
|
||||||
from config.schema import NeptuneProviderConfig
|
|
||||||
|
|
||||||
neptune_config = NeptuneProviderConfig()
|
|
||||||
|
|
||||||
# Environment overrides
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
host = os.environ.get('NEPTUNE_HOST', neptune_config.host)
|
# Read environment variables first
|
||||||
aoss_host = os.environ.get('AOSS_HOST', neptune_config.aoss_host)
|
env_host = os.environ.get('NEPTUNE_HOST')
|
||||||
port = int(os.environ.get('NEPTUNE_PORT', str(neptune_config.port)))
|
env_aoss_host = os.environ.get('AOSS_HOST')
|
||||||
aoss_port = int(os.environ.get('AOSS_PORT', str(neptune_config.aoss_port)))
|
env_port = os.environ.get('NEPTUNE_PORT')
|
||||||
region_override = os.environ.get('AWS_REGION', region or neptune_config.region)
|
env_aoss_port = os.environ.get('AOSS_PORT')
|
||||||
|
env_region = os.environ.get('AWS_REGION')
|
||||||
|
|
||||||
|
if config.providers.neptune:
|
||||||
|
neptune_config = config.providers.neptune
|
||||||
|
# Apply environment overrides
|
||||||
|
host = env_host or neptune_config.host
|
||||||
|
aoss_host = env_aoss_host or neptune_config.aoss_host
|
||||||
|
port = int(env_port) if env_port else neptune_config.port
|
||||||
|
aoss_port = int(env_aoss_port) if env_aoss_port else neptune_config.aoss_port
|
||||||
|
region_override = env_region or region or neptune_config.region
|
||||||
|
else:
|
||||||
|
# No config provided, use environment variables with defaults
|
||||||
|
from config.schema import NeptuneProviderConfig
|
||||||
|
|
||||||
|
host = env_host or 'neptune-db://localhost'
|
||||||
|
aoss_host = env_aoss_host
|
||||||
|
port = int(env_port) if env_port else 8182
|
||||||
|
aoss_port = int(env_aoss_port) if env_aoss_port else 443
|
||||||
|
region_override = env_region or region
|
||||||
|
|
||||||
|
# Create config with values to trigger validation
|
||||||
|
neptune_config = NeptuneProviderConfig(
|
||||||
|
host=host,
|
||||||
|
aoss_host=aoss_host,
|
||||||
|
port=port,
|
||||||
|
aoss_port=aoss_port,
|
||||||
|
region=region_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use normalized values from config (protocol may have been auto-added)
|
||||||
|
host = neptune_config.host
|
||||||
|
aoss_host = neptune_config.aoss_host
|
||||||
|
port = neptune_config.port
|
||||||
|
aoss_port = neptune_config.aoss_port
|
||||||
|
region_override = neptune_config.region
|
||||||
|
|
||||||
if not aoss_host:
|
if not aoss_host:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue