From 97b0bbc7a8d0300f225ae288b13ed4f7779af076 Mon Sep 17 00:00:00 2001 From: Aidan Petti Date: Mon, 24 Nov 2025 19:15:46 -0700 Subject: [PATCH] neptune schema --- graphiti_core/driver/neptune_driver.py | 60 +++++++++++++------ mcp_server/src/config/schema.py | 11 +++- .../tests/test_neptune_configuration.py | 13 ++-- 3 files changed, 56 insertions(+), 28 deletions(-) diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index 7fa3f169..bb802729 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -29,6 +29,38 @@ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphPr logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 + +class GraphitiNeptuneGraph(NeptuneGraph): + """ + Custom NeptuneGraph subclass that uses pre-defined Graphiti schema + instead of calling Neptune's expensive statistics API. + """ + + # Define Graphiti schema to avoid expensive statistics API calls + GRAPHITI_SCHEMA = """ +Node labels: Episodic, Entity, Community +Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER +Node properties: + Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list} + Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list} + Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string} +Relationship properties: + MENTIONS {created_at: datetime} + RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime} + HAS_MEMBER {uuid: string, created_at: datetime} +""" + + def _refresh_schema(self) -> None: + """ + Override to use pre-defined schema instead of calling statistics API. + + This avoids the expensive Neptune statistics API call that requires + statistics to be enabled on the Neptune instance. + """ + self.schema = self.GRAPHITI_SCHEMA + logger.debug('Using pre-defined Graphiti schema, skipping Neptune statistics API') + + aoss_indices = [ { 'index_name': 'node_name_and_summary', @@ -109,7 +141,7 @@ aoss_indices = [ class NeptuneDriver(GraphDriver): provider: GraphProvider = GraphProvider.NEPTUNE - def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443): + def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443, database: str = 'default'): """This initializes a NeptuneDriver for use with Neptune as a backend Args: @@ -117,29 +149,20 @@ class NeptuneDriver(GraphDriver): aoss_host (str): The OpenSearch host value port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182. aoss_port (int, optional): The OpenSearch port. Defaults to 443. + database (str, optional): The database name (for compatibility with base class). Defaults to 'default'. """ if not host: raise ValueError('You must provide an endpoint to create a NeptuneDriver') - # Define Graphiti schema to avoid expensive statistics API calls - graphiti_schema = """ -Node labels: Episodic, Entity, Community -Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER -Node properties: - Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list} - Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list} - Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string} -Relationship properties: - MENTIONS {created_at: datetime} - RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime} - HAS_MEMBER {uuid: string, created_at: datetime} -""" + # Set the database attribute required by the base GraphDriver class + self._database = database if host.startswith('neptune-db://'): # This is a Neptune Database Cluster endpoint = host.replace('neptune-db://', '') - self.client = NeptuneGraph(endpoint, port, schema=graphiti_schema) - logger.debug('Creating Neptune Database session for %s', host) + # Use custom GraphitiNeptuneGraph to avoid expensive statistics API calls + self.client = GraphitiNeptuneGraph(endpoint, port) + logger.debug('Creating Neptune Database session for %s with pre-defined schema', host) elif host.startswith('neptune-graph://'): # This is a Neptune Analytics Graph graphId = host.replace('neptune-graph://', '') @@ -153,9 +176,12 @@ Relationship properties: if not aoss_host: raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.') + # Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname) + aoss_hostname = aoss_host.replace('https://', '').replace('http://', '') + session = boto3.Session() self.aoss_client = OpenSearch( - hosts=[{'host': aoss_host, 'port': aoss_port}], + hosts=[{'host': aoss_hostname, 'port': aoss_port}], http_auth=Urllib3AWSV4SignerAuth( session.get_credentials(), session.region_name, 'aoss' ), diff --git a/mcp_server/src/config/schema.py b/mcp_server/src/config/schema.py index 76e802ff..f81f63d7 100644 --- a/mcp_server/src/config/schema.py +++ b/mcp_server/src/config/schema.py @@ -203,7 +203,9 @@ class NeptuneProviderConfig(BaseModel): def model_post_init(self, __context) -> None: """Validate and normalize Neptune-specific requirements.""" # Auto-detect and add protocol if missing - if not self.host.startswith(('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')): + if not self.host.startswith( + ('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://') + ): # Check if it looks like a Neptune Analytics graph ID (starts with 'g-') if self.host.startswith('g-'): self.host = f'neptune-graph://{self.host}' @@ -230,9 +232,14 @@ class NeptuneProviderConfig(BaseModel): ' database:\n' ' providers:\n' ' neptune:\n' - ' aoss_host: "your-aoss-endpoint.us-east-1.aoss.amazonaws.com"' + ' aoss_host: "your-aoss-endpoint.us-east-1.aoss.amazonaws.com"\n' + 'Note: Provide hostname only, without https:// prefix' ) + # Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname) + if self.aoss_host.startswith(('https://', 'http://')): + self.aoss_host = self.aoss_host.replace('https://', '').replace('http://', '') + class DatabaseProvidersConfig(BaseModel): """Database providers configuration.""" diff --git a/mcp_server/tests/test_neptune_configuration.py b/mcp_server/tests/test_neptune_configuration.py index bbc9b1f1..63ff7f79 100644 --- a/mcp_server/tests/test_neptune_configuration.py +++ b/mcp_server/tests/test_neptune_configuration.py @@ -11,7 +11,6 @@ These tests validate Neptune-specific configuration requirements including: """ import os -import sys from pathlib import Path from unittest.mock import MagicMock, patch @@ -107,17 +106,13 @@ def test_neptune_environment_overrides(): config_path = Path(__file__).parent.parent / 'config' / 'config.yaml' config = GraphitiConfig(_env_file=None, config_path=str(config_path)) - print(f'✓ Loaded configuration with environment overrides') + print('✓ Loaded configuration with environment overrides') # Verify environment variables were applied if config.database.providers and config.database.providers.neptune: neptune_config = config.database.providers.neptune - print( - f' Neptune host: {neptune_config.host}' - ) - print( - f' AOSS host: {neptune_config.aoss_host}' - ) + print(f' Neptune host: {neptune_config.host}') + print(f' AOSS host: {neptune_config.aoss_host}') print(f' Neptune port: {neptune_config.port}') print(f' AOSS port: {neptune_config.aoss_port}') print(f' Region: {neptune_config.region}') @@ -224,7 +219,7 @@ def test_neptune_factory_missing_credentials(): with patch('boto3.Session', return_value=mock_session): try: - db_config = DatabaseDriverFactory.create_config(test_config) + DatabaseDriverFactory.create_config(test_config) print('✗ Factory should have failed with missing credentials') raise AssertionError('Expected ValueError for missing AWS credentials') except ValueError as e: