From b036c38f0db774e38af8dc1cd3948f31f5884ffd Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Sun, 7 Sep 2025 22:40:36 -0400 Subject: [PATCH] update neptune for regression purposes --- graphiti_core/driver/neptune_driver.py | 36 +++++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index e2a34827..99ea8a6d 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -24,13 +24,20 @@ import boto3 from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices, DEFAULT_SIZE +from graphiti_core.driver.driver import ( + DEFAULT_SIZE, + GraphDriver, + GraphDriverSession, + GraphProvider, + aoss_indices, +) logger = logging.getLogger(__name__) neptune_aoss_indices = [ { 'index_name': 'node_name_and_summary', + 'alias_name': 'entities', 'body': { 'mappings': { 'properties': { @@ -48,6 +55,7 @@ neptune_aoss_indices = [ }, { 'index_name': 'community_name', + 'alias_name': 'communities', 'body': { 'mappings': { 'properties': { @@ -64,6 +72,7 @@ neptune_aoss_indices = [ }, { 'index_name': 'episode_content', + 'alias_name': 'episodes', 'body': { 'mappings': { 'properties': { @@ -87,6 +96,7 @@ neptune_aoss_indices = [ }, { 'index_name': 'edge_name_and_fact', + 'alias_name': 'facts', 'body': { 'mappings': { 'properties': { @@ -173,14 +183,14 @@ class NeptuneDriver(GraphDriver): if any(isinstance(item, str) and 'T' in item for item in v): # Create a new list expression with datetime() wrapped around each element datetime_list = ( - '[' - + ', '.join( - f'datetime("{item}")' - if isinstance(item, str) and 'T' in item - else repr(item) - for item in v - ) - + ']' + '[' + + ', '.join( + f'datetime("{item}")' + if isinstance(item, str) and 'T' in item + else repr(item) + for item in v + ) + + ']' ) query = str(query).replace(f'${k}', datetime_list) elif isinstance(v, dict): @@ -188,7 +198,7 @@ class NeptuneDriver(GraphDriver): return query async def execute_query( - self, cypher_query_, **kwargs: Any + self, cypher_query_, **kwargs: Any ) -> tuple[dict[str, Any], None, None]: params = dict(kwargs) if isinstance(cypher_query_, list): @@ -225,6 +235,12 @@ class NeptuneDriver(GraphDriver): client = self.aoss_client if not client.indices.exists(index=index_name): client.indices.create(index=index_name, body=index['body']) + + alias_name = index.get('alias_name', index_name) + + if not client.indices.exists_alias(name=alias_name, index=index_name): + client.indices.put_alias(index=index_name, name=alias_name) + # Sleep for 1 minute to let the index creation complete await asyncio.sleep(60)