neptune schema

This commit is contained in:
Aidan Petti 2025-11-24 19:15:46 -07:00
parent b5dc7ab698
commit 97b0bbc7a8
3 changed files with 56 additions and 28 deletions

View file

@ -29,6 +29,38 @@ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphPr
logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
class GraphitiNeptuneGraph(NeptuneGraph):
"""
Custom NeptuneGraph subclass that uses pre-defined Graphiti schema
instead of calling Neptune's expensive statistics API.
"""
# Define Graphiti schema to avoid expensive statistics API calls
GRAPHITI_SCHEMA = """
Node labels: Episodic, Entity, Community
Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER
Node properties:
Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list}
Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list}
Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string}
Relationship properties:
MENTIONS {created_at: datetime}
RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime}
HAS_MEMBER {uuid: string, created_at: datetime}
"""
def _refresh_schema(self) -> None:
"""
Override to use pre-defined schema instead of calling statistics API.
This avoids the expensive Neptune statistics API call that requires
statistics to be enabled on the Neptune instance.
"""
self.schema = self.GRAPHITI_SCHEMA
logger.debug('Using pre-defined Graphiti schema, skipping Neptune statistics API')
aoss_indices = [
{
'index_name': 'node_name_and_summary',
@ -109,7 +141,7 @@ aoss_indices = [
class NeptuneDriver(GraphDriver):
provider: GraphProvider = GraphProvider.NEPTUNE
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443, database: str = 'default'):
"""This initializes a NeptuneDriver for use with Neptune as a backend
Args:
@ -117,29 +149,20 @@ class NeptuneDriver(GraphDriver):
aoss_host (str): The OpenSearch host value
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
database (str, optional): The database name (for compatibility with base class). Defaults to 'default'.
"""
if not host:
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
# Define Graphiti schema to avoid expensive statistics API calls
graphiti_schema = """
Node labels: Episodic, Entity, Community
Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER
Node properties:
Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list}
Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list}
Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string}
Relationship properties:
MENTIONS {created_at: datetime}
RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime}
HAS_MEMBER {uuid: string, created_at: datetime}
"""
# Set the database attribute required by the base GraphDriver class
self._database = database
if host.startswith('neptune-db://'):
# This is a Neptune Database Cluster
endpoint = host.replace('neptune-db://', '')
self.client = NeptuneGraph(endpoint, port, schema=graphiti_schema)
logger.debug('Creating Neptune Database session for %s', host)
# Use custom GraphitiNeptuneGraph to avoid expensive statistics API calls
self.client = GraphitiNeptuneGraph(endpoint, port)
logger.debug('Creating Neptune Database session for %s with pre-defined schema', host)
elif host.startswith('neptune-graph://'):
# This is a Neptune Analytics Graph
graphId = host.replace('neptune-graph://', '')
@ -153,9 +176,12 @@ Relationship properties:
if not aoss_host:
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
# Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname)
aoss_hostname = aoss_host.replace('https://', '').replace('http://', '')
session = boto3.Session()
self.aoss_client = OpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}],
hosts=[{'host': aoss_hostname, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth(
session.get_credentials(), session.region_name, 'aoss'
),

View file

@ -203,7 +203,9 @@ class NeptuneProviderConfig(BaseModel):
def model_post_init(self, __context) -> None:
"""Validate and normalize Neptune-specific requirements."""
# Auto-detect and add protocol if missing
if not self.host.startswith(('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')):
if not self.host.startswith(
('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')
):
# Check if it looks like a Neptune Analytics graph ID (starts with 'g-')
if self.host.startswith('g-'):
self.host = f'neptune-graph://{self.host}'
@ -230,9 +232,14 @@ class NeptuneProviderConfig(BaseModel):
' database:\n'
' providers:\n'
' neptune:\n'
' aoss_host: "your-aoss-endpoint.us-east-1.aoss.amazonaws.com"'
' aoss_host: "your-aoss-endpoint.us-east-1.aoss.amazonaws.com"\n'
'Note: Provide hostname only, without https:// prefix'
)
# Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname)
if self.aoss_host.startswith(('https://', 'http://')):
self.aoss_host = self.aoss_host.replace('https://', '').replace('http://', '')
class DatabaseProvidersConfig(BaseModel):
"""Database providers configuration."""

View file

@ -11,7 +11,6 @@ These tests validate Neptune-specific configuration requirements including:
"""
import os
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
@ -107,17 +106,13 @@ def test_neptune_environment_overrides():
config_path = Path(__file__).parent.parent / 'config' / 'config.yaml'
config = GraphitiConfig(_env_file=None, config_path=str(config_path))
print(f'✓ Loaded configuration with environment overrides')
print('✓ Loaded configuration with environment overrides')
# Verify environment variables were applied
if config.database.providers and config.database.providers.neptune:
neptune_config = config.database.providers.neptune
print(
f' Neptune host: {neptune_config.host}'
)
print(
f' AOSS host: {neptune_config.aoss_host}'
)
print(f' Neptune host: {neptune_config.host}')
print(f' AOSS host: {neptune_config.aoss_host}')
print(f' Neptune port: {neptune_config.port}')
print(f' AOSS port: {neptune_config.aoss_port}')
print(f' Region: {neptune_config.region}')
@ -224,7 +219,7 @@ def test_neptune_factory_missing_credentials():
with patch('boto3.Session', return_value=mock_session):
try:
db_config = DatabaseDriverFactory.create_config(test_config)
DatabaseDriverFactory.create_config(test_config)
print('✗ Factory should have failed with missing credentials')
raise AssertionError('Expected ValueError for missing AWS credentials')
except ValueError as e: