update neptune for regression purposes

This commit is contained in:
prestonrasmussen 2025-09-07 22:40:36 -04:00
parent 14e1248b5f
commit b036c38f0d

View file

@ -24,13 +24,20 @@ import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices, DEFAULT_SIZE from graphiti_core.driver.driver import (
DEFAULT_SIZE,
GraphDriver,
GraphDriverSession,
GraphProvider,
aoss_indices,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
neptune_aoss_indices = [ neptune_aoss_indices = [
{ {
'index_name': 'node_name_and_summary', 'index_name': 'node_name_and_summary',
'alias_name': 'entities',
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -48,6 +55,7 @@ neptune_aoss_indices = [
}, },
{ {
'index_name': 'community_name', 'index_name': 'community_name',
'alias_name': 'communities',
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -64,6 +72,7 @@ neptune_aoss_indices = [
}, },
{ {
'index_name': 'episode_content', 'index_name': 'episode_content',
'alias_name': 'episodes',
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -87,6 +96,7 @@ neptune_aoss_indices = [
}, },
{ {
'index_name': 'edge_name_and_fact', 'index_name': 'edge_name_and_fact',
'alias_name': 'facts',
'body': { 'body': {
'mappings': { 'mappings': {
'properties': { 'properties': {
@ -173,14 +183,14 @@ class NeptuneDriver(GraphDriver):
if any(isinstance(item, str) and 'T' in item for item in v): if any(isinstance(item, str) and 'T' in item for item in v):
# Create a new list expression with datetime() wrapped around each element # Create a new list expression with datetime() wrapped around each element
datetime_list = ( datetime_list = (
'[' '['
+ ', '.join( + ', '.join(
f'datetime("{item}")' f'datetime("{item}")'
if isinstance(item, str) and 'T' in item if isinstance(item, str) and 'T' in item
else repr(item) else repr(item)
for item in v for item in v
) )
+ ']' + ']'
) )
query = str(query).replace(f'${k}', datetime_list) query = str(query).replace(f'${k}', datetime_list)
elif isinstance(v, dict): elif isinstance(v, dict):
@ -188,7 +198,7 @@ class NeptuneDriver(GraphDriver):
return query return query
async def execute_query( async def execute_query(
self, cypher_query_, **kwargs: Any self, cypher_query_, **kwargs: Any
) -> tuple[dict[str, Any], None, None]: ) -> tuple[dict[str, Any], None, None]:
params = dict(kwargs) params = dict(kwargs)
if isinstance(cypher_query_, list): if isinstance(cypher_query_, list):
@ -225,6 +235,12 @@ class NeptuneDriver(GraphDriver):
client = self.aoss_client client = self.aoss_client
if not client.indices.exists(index=index_name): if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body']) client.indices.create(index=index_name, body=index['body'])
alias_name = index.get('alias_name', index_name)
if not client.indices.exists_alias(name=alias_name, index=index_name):
client.indices.put_alias(index=index_name, name=alias_name)
# Sleep for 1 minute to let the index creation complete # Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60) await asyncio.sleep(60)