neptune schema
This commit is contained in:
parent
b5dc7ab698
commit
97b0bbc7a8
3 changed files with 56 additions and 28 deletions
|
|
@ -29,6 +29,38 @@ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphPr
|
|||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_SIZE = 10
|
||||
|
||||
|
||||
class GraphitiNeptuneGraph(NeptuneGraph):
|
||||
"""
|
||||
Custom NeptuneGraph subclass that uses pre-defined Graphiti schema
|
||||
instead of calling Neptune's expensive statistics API.
|
||||
"""
|
||||
|
||||
# Define Graphiti schema to avoid expensive statistics API calls
|
||||
GRAPHITI_SCHEMA = """
|
||||
Node labels: Episodic, Entity, Community
|
||||
Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER
|
||||
Node properties:
|
||||
Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list}
|
||||
Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list}
|
||||
Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string}
|
||||
Relationship properties:
|
||||
MENTIONS {created_at: datetime}
|
||||
RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime}
|
||||
HAS_MEMBER {uuid: string, created_at: datetime}
|
||||
"""
|
||||
|
||||
def _refresh_schema(self) -> None:
|
||||
"""
|
||||
Override to use pre-defined schema instead of calling statistics API.
|
||||
|
||||
This avoids the expensive Neptune statistics API call that requires
|
||||
statistics to be enabled on the Neptune instance.
|
||||
"""
|
||||
self.schema = self.GRAPHITI_SCHEMA
|
||||
logger.debug('Using pre-defined Graphiti schema, skipping Neptune statistics API')
|
||||
|
||||
|
||||
aoss_indices = [
|
||||
{
|
||||
'index_name': 'node_name_and_summary',
|
||||
|
|
@ -109,7 +141,7 @@ aoss_indices = [
|
|||
class NeptuneDriver(GraphDriver):
|
||||
provider: GraphProvider = GraphProvider.NEPTUNE
|
||||
|
||||
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
|
||||
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443, database: str = 'default'):
|
||||
"""This initializes a NeptuneDriver for use with Neptune as a backend
|
||||
|
||||
Args:
|
||||
|
|
@ -117,29 +149,20 @@ class NeptuneDriver(GraphDriver):
|
|||
aoss_host (str): The OpenSearch host value
|
||||
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
|
||||
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
|
||||
database (str, optional): The database name (for compatibility with base class). Defaults to 'default'.
|
||||
"""
|
||||
if not host:
|
||||
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
|
||||
|
||||
# Define Graphiti schema to avoid expensive statistics API calls
|
||||
graphiti_schema = """
|
||||
Node labels: Episodic, Entity, Community
|
||||
Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER
|
||||
Node properties:
|
||||
Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list}
|
||||
Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list}
|
||||
Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string}
|
||||
Relationship properties:
|
||||
MENTIONS {created_at: datetime}
|
||||
RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime}
|
||||
HAS_MEMBER {uuid: string, created_at: datetime}
|
||||
"""
|
||||
# Set the database attribute required by the base GraphDriver class
|
||||
self._database = database
|
||||
|
||||
if host.startswith('neptune-db://'):
|
||||
# This is a Neptune Database Cluster
|
||||
endpoint = host.replace('neptune-db://', '')
|
||||
self.client = NeptuneGraph(endpoint, port, schema=graphiti_schema)
|
||||
logger.debug('Creating Neptune Database session for %s', host)
|
||||
# Use custom GraphitiNeptuneGraph to avoid expensive statistics API calls
|
||||
self.client = GraphitiNeptuneGraph(endpoint, port)
|
||||
logger.debug('Creating Neptune Database session for %s with pre-defined schema', host)
|
||||
elif host.startswith('neptune-graph://'):
|
||||
# This is a Neptune Analytics Graph
|
||||
graphId = host.replace('neptune-graph://', '')
|
||||
|
|
@ -153,9 +176,12 @@ Relationship properties:
|
|||
if not aoss_host:
|
||||
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
|
||||
|
||||
# Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname)
|
||||
aoss_hostname = aoss_host.replace('https://', '').replace('http://', '')
|
||||
|
||||
session = boto3.Session()
|
||||
self.aoss_client = OpenSearch(
|
||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
||||
hosts=[{'host': aoss_hostname, 'port': aoss_port}],
|
||||
http_auth=Urllib3AWSV4SignerAuth(
|
||||
session.get_credentials(), session.region_name, 'aoss'
|
||||
),
|
||||
|
|
|
|||
|
|
@ -203,7 +203,9 @@ class NeptuneProviderConfig(BaseModel):
|
|||
def model_post_init(self, __context) -> None:
|
||||
"""Validate and normalize Neptune-specific requirements."""
|
||||
# Auto-detect and add protocol if missing
|
||||
if not self.host.startswith(('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')):
|
||||
if not self.host.startswith(
|
||||
('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')
|
||||
):
|
||||
# Check if it looks like a Neptune Analytics graph ID (starts with 'g-')
|
||||
if self.host.startswith('g-'):
|
||||
self.host = f'neptune-graph://{self.host}'
|
||||
|
|
@ -230,9 +232,14 @@ class NeptuneProviderConfig(BaseModel):
|
|||
' database:\n'
|
||||
' providers:\n'
|
||||
' neptune:\n'
|
||||
' aoss_host: "your-aoss-endpoint.us-east-1.aoss.amazonaws.com"'
|
||||
' aoss_host: "your-aoss-endpoint.us-east-1.aoss.amazonaws.com"\n'
|
||||
'Note: Provide hostname only, without https:// prefix'
|
||||
)
|
||||
|
||||
# Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname)
|
||||
if self.aoss_host.startswith(('https://', 'http://')):
|
||||
self.aoss_host = self.aoss_host.replace('https://', '').replace('http://', '')
|
||||
|
||||
|
||||
class DatabaseProvidersConfig(BaseModel):
|
||||
"""Database providers configuration."""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ These tests validate Neptune-specific configuration requirements including:
|
|||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
|
@ -107,17 +106,13 @@ def test_neptune_environment_overrides():
|
|||
config_path = Path(__file__).parent.parent / 'config' / 'config.yaml'
|
||||
config = GraphitiConfig(_env_file=None, config_path=str(config_path))
|
||||
|
||||
print(f'✓ Loaded configuration with environment overrides')
|
||||
print('✓ Loaded configuration with environment overrides')
|
||||
|
||||
# Verify environment variables were applied
|
||||
if config.database.providers and config.database.providers.neptune:
|
||||
neptune_config = config.database.providers.neptune
|
||||
print(
|
||||
f' Neptune host: {neptune_config.host}'
|
||||
)
|
||||
print(
|
||||
f' AOSS host: {neptune_config.aoss_host}'
|
||||
)
|
||||
print(f' Neptune host: {neptune_config.host}')
|
||||
print(f' AOSS host: {neptune_config.aoss_host}')
|
||||
print(f' Neptune port: {neptune_config.port}')
|
||||
print(f' AOSS port: {neptune_config.aoss_port}')
|
||||
print(f' Region: {neptune_config.region}')
|
||||
|
|
@ -224,7 +219,7 @@ def test_neptune_factory_missing_credentials():
|
|||
|
||||
with patch('boto3.Session', return_value=mock_session):
|
||||
try:
|
||||
db_config = DatabaseDriverFactory.create_config(test_config)
|
||||
DatabaseDriverFactory.create_config(test_config)
|
||||
print('✗ Factory should have failed with missing credentials')
|
||||
raise AssertionError('Expected ValueError for missing AWS credentials')
|
||||
except ValueError as e:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue