neptune schema
This commit is contained in:
parent
b5dc7ab698
commit
97b0bbc7a8
3 changed files with 56 additions and 28 deletions
|
|
@ -29,6 +29,38 @@ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphPr
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
DEFAULT_SIZE = 10
|
DEFAULT_SIZE = 10
|
||||||
|
|
||||||
|
|
||||||
|
class GraphitiNeptuneGraph(NeptuneGraph):
|
||||||
|
"""
|
||||||
|
Custom NeptuneGraph subclass that uses pre-defined Graphiti schema
|
||||||
|
instead of calling Neptune's expensive statistics API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Define Graphiti schema to avoid expensive statistics API calls
|
||||||
|
GRAPHITI_SCHEMA = """
|
||||||
|
Node labels: Episodic, Entity, Community
|
||||||
|
Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER
|
||||||
|
Node properties:
|
||||||
|
Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list}
|
||||||
|
Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list}
|
||||||
|
Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string}
|
||||||
|
Relationship properties:
|
||||||
|
MENTIONS {created_at: datetime}
|
||||||
|
RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime}
|
||||||
|
HAS_MEMBER {uuid: string, created_at: datetime}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _refresh_schema(self) -> None:
|
||||||
|
"""
|
||||||
|
Override to use pre-defined schema instead of calling statistics API.
|
||||||
|
|
||||||
|
This avoids the expensive Neptune statistics API call that requires
|
||||||
|
statistics to be enabled on the Neptune instance.
|
||||||
|
"""
|
||||||
|
self.schema = self.GRAPHITI_SCHEMA
|
||||||
|
logger.debug('Using pre-defined Graphiti schema, skipping Neptune statistics API')
|
||||||
|
|
||||||
|
|
||||||
aoss_indices = [
|
aoss_indices = [
|
||||||
{
|
{
|
||||||
'index_name': 'node_name_and_summary',
|
'index_name': 'node_name_and_summary',
|
||||||
|
|
@ -109,7 +141,7 @@ aoss_indices = [
|
||||||
class NeptuneDriver(GraphDriver):
|
class NeptuneDriver(GraphDriver):
|
||||||
provider: GraphProvider = GraphProvider.NEPTUNE
|
provider: GraphProvider = GraphProvider.NEPTUNE
|
||||||
|
|
||||||
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
|
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443, database: str = 'default'):
|
||||||
"""This initializes a NeptuneDriver for use with Neptune as a backend
|
"""This initializes a NeptuneDriver for use with Neptune as a backend
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -117,29 +149,20 @@ class NeptuneDriver(GraphDriver):
|
||||||
aoss_host (str): The OpenSearch host value
|
aoss_host (str): The OpenSearch host value
|
||||||
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
|
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
|
||||||
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
|
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
|
||||||
|
database (str, optional): The database name (for compatibility with base class). Defaults to 'default'.
|
||||||
"""
|
"""
|
||||||
if not host:
|
if not host:
|
||||||
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
|
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
|
||||||
|
|
||||||
# Define Graphiti schema to avoid expensive statistics API calls
|
# Set the database attribute required by the base GraphDriver class
|
||||||
graphiti_schema = """
|
self._database = database
|
||||||
Node labels: Episodic, Entity, Community
|
|
||||||
Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER
|
|
||||||
Node properties:
|
|
||||||
Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list}
|
|
||||||
Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list}
|
|
||||||
Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string}
|
|
||||||
Relationship properties:
|
|
||||||
MENTIONS {created_at: datetime}
|
|
||||||
RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime}
|
|
||||||
HAS_MEMBER {uuid: string, created_at: datetime}
|
|
||||||
"""
|
|
||||||
|
|
||||||
if host.startswith('neptune-db://'):
|
if host.startswith('neptune-db://'):
|
||||||
# This is a Neptune Database Cluster
|
# This is a Neptune Database Cluster
|
||||||
endpoint = host.replace('neptune-db://', '')
|
endpoint = host.replace('neptune-db://', '')
|
||||||
self.client = NeptuneGraph(endpoint, port, schema=graphiti_schema)
|
# Use custom GraphitiNeptuneGraph to avoid expensive statistics API calls
|
||||||
logger.debug('Creating Neptune Database session for %s', host)
|
self.client = GraphitiNeptuneGraph(endpoint, port)
|
||||||
|
logger.debug('Creating Neptune Database session for %s with pre-defined schema', host)
|
||||||
elif host.startswith('neptune-graph://'):
|
elif host.startswith('neptune-graph://'):
|
||||||
# This is a Neptune Analytics Graph
|
# This is a Neptune Analytics Graph
|
||||||
graphId = host.replace('neptune-graph://', '')
|
graphId = host.replace('neptune-graph://', '')
|
||||||
|
|
@ -153,9 +176,12 @@ Relationship properties:
|
||||||
if not aoss_host:
|
if not aoss_host:
|
||||||
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
|
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
|
||||||
|
|
||||||
|
# Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname)
|
||||||
|
aoss_hostname = aoss_host.replace('https://', '').replace('http://', '')
|
||||||
|
|
||||||
session = boto3.Session()
|
session = boto3.Session()
|
||||||
self.aoss_client = OpenSearch(
|
self.aoss_client = OpenSearch(
|
||||||
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
hosts=[{'host': aoss_hostname, 'port': aoss_port}],
|
||||||
http_auth=Urllib3AWSV4SignerAuth(
|
http_auth=Urllib3AWSV4SignerAuth(
|
||||||
session.get_credentials(), session.region_name, 'aoss'
|
session.get_credentials(), session.region_name, 'aoss'
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -203,7 +203,9 @@ class NeptuneProviderConfig(BaseModel):
|
||||||
def model_post_init(self, __context) -> None:
|
def model_post_init(self, __context) -> None:
|
||||||
"""Validate and normalize Neptune-specific requirements."""
|
"""Validate and normalize Neptune-specific requirements."""
|
||||||
# Auto-detect and add protocol if missing
|
# Auto-detect and add protocol if missing
|
||||||
if not self.host.startswith(('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')):
|
if not self.host.startswith(
|
||||||
|
('neptune-db://', 'neptune-graph://', 'bolt://', 'http://', 'https://')
|
||||||
|
):
|
||||||
# Check if it looks like a Neptune Analytics graph ID (starts with 'g-')
|
# Check if it looks like a Neptune Analytics graph ID (starts with 'g-')
|
||||||
if self.host.startswith('g-'):
|
if self.host.startswith('g-'):
|
||||||
self.host = f'neptune-graph://{self.host}'
|
self.host = f'neptune-graph://{self.host}'
|
||||||
|
|
@ -230,9 +232,14 @@ class NeptuneProviderConfig(BaseModel):
|
||||||
' database:\n'
|
' database:\n'
|
||||||
' providers:\n'
|
' providers:\n'
|
||||||
' neptune:\n'
|
' neptune:\n'
|
||||||
' aoss_host: "your-aoss-endpoint.us-east-1.aoss.amazonaws.com"'
|
' aoss_host: "your-aoss-endpoint.us-east-1.aoss.amazonaws.com"\n'
|
||||||
|
'Note: Provide hostname only, without https:// prefix'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname)
|
||||||
|
if self.aoss_host.startswith(('https://', 'http://')):
|
||||||
|
self.aoss_host = self.aoss_host.replace('https://', '').replace('http://', '')
|
||||||
|
|
||||||
|
|
||||||
class DatabaseProvidersConfig(BaseModel):
|
class DatabaseProvidersConfig(BaseModel):
|
||||||
"""Database providers configuration."""
|
"""Database providers configuration."""
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ These tests validate Neptune-specific configuration requirements including:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
@ -107,17 +106,13 @@ def test_neptune_environment_overrides():
|
||||||
config_path = Path(__file__).parent.parent / 'config' / 'config.yaml'
|
config_path = Path(__file__).parent.parent / 'config' / 'config.yaml'
|
||||||
config = GraphitiConfig(_env_file=None, config_path=str(config_path))
|
config = GraphitiConfig(_env_file=None, config_path=str(config_path))
|
||||||
|
|
||||||
print(f'✓ Loaded configuration with environment overrides')
|
print('✓ Loaded configuration with environment overrides')
|
||||||
|
|
||||||
# Verify environment variables were applied
|
# Verify environment variables were applied
|
||||||
if config.database.providers and config.database.providers.neptune:
|
if config.database.providers and config.database.providers.neptune:
|
||||||
neptune_config = config.database.providers.neptune
|
neptune_config = config.database.providers.neptune
|
||||||
print(
|
print(f' Neptune host: {neptune_config.host}')
|
||||||
f' Neptune host: {neptune_config.host}'
|
print(f' AOSS host: {neptune_config.aoss_host}')
|
||||||
)
|
|
||||||
print(
|
|
||||||
f' AOSS host: {neptune_config.aoss_host}'
|
|
||||||
)
|
|
||||||
print(f' Neptune port: {neptune_config.port}')
|
print(f' Neptune port: {neptune_config.port}')
|
||||||
print(f' AOSS port: {neptune_config.aoss_port}')
|
print(f' AOSS port: {neptune_config.aoss_port}')
|
||||||
print(f' Region: {neptune_config.region}')
|
print(f' Region: {neptune_config.region}')
|
||||||
|
|
@ -224,7 +219,7 @@ def test_neptune_factory_missing_credentials():
|
||||||
|
|
||||||
with patch('boto3.Session', return_value=mock_session):
|
with patch('boto3.Session', return_value=mock_session):
|
||||||
try:
|
try:
|
||||||
db_config = DatabaseDriverFactory.create_config(test_config)
|
DatabaseDriverFactory.create_config(test_config)
|
||||||
print('✗ Factory should have failed with missing credentials')
|
print('✗ Factory should have failed with missing credentials')
|
||||||
raise AssertionError('Expected ValueError for missing AWS credentials')
|
raise AssertionError('Expected ValueError for missing AWS credentials')
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue