update neptune for regression purposes
This commit is contained in:
parent
14e1248b5f
commit
b036c38f0d
1 changed files with 26 additions and 10 deletions
|
|
@ -24,13 +24,20 @@ import boto3
|
|||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
|
||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices, DEFAULT_SIZE
|
||||
from graphiti_core.driver.driver import (
|
||||
DEFAULT_SIZE,
|
||||
GraphDriver,
|
||||
GraphDriverSession,
|
||||
GraphProvider,
|
||||
aoss_indices,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
neptune_aoss_indices = [
|
||||
{
|
||||
'index_name': 'node_name_and_summary',
|
||||
'alias_name': 'entities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -48,6 +55,7 @@ neptune_aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'community_name',
|
||||
'alias_name': 'communities',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -64,6 +72,7 @@ neptune_aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'episode_content',
|
||||
'alias_name': 'episodes',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -87,6 +96,7 @@ neptune_aoss_indices = [
|
|||
},
|
||||
{
|
||||
'index_name': 'edge_name_and_fact',
|
||||
'alias_name': 'facts',
|
||||
'body': {
|
||||
'mappings': {
|
||||
'properties': {
|
||||
|
|
@ -173,14 +183,14 @@ class NeptuneDriver(GraphDriver):
|
|||
if any(isinstance(item, str) and 'T' in item for item in v):
|
||||
# Create a new list expression with datetime() wrapped around each element
|
||||
datetime_list = (
|
||||
'['
|
||||
+ ', '.join(
|
||||
f'datetime("{item}")'
|
||||
if isinstance(item, str) and 'T' in item
|
||||
else repr(item)
|
||||
for item in v
|
||||
)
|
||||
+ ']'
|
||||
'['
|
||||
+ ', '.join(
|
||||
f'datetime("{item}")'
|
||||
if isinstance(item, str) and 'T' in item
|
||||
else repr(item)
|
||||
for item in v
|
||||
)
|
||||
+ ']'
|
||||
)
|
||||
query = str(query).replace(f'${k}', datetime_list)
|
||||
elif isinstance(v, dict):
|
||||
|
|
@ -188,7 +198,7 @@ class NeptuneDriver(GraphDriver):
|
|||
return query
|
||||
|
||||
async def execute_query(
|
||||
self, cypher_query_, **kwargs: Any
|
||||
self, cypher_query_, **kwargs: Any
|
||||
) -> tuple[dict[str, Any], None, None]:
|
||||
params = dict(kwargs)
|
||||
if isinstance(cypher_query_, list):
|
||||
|
|
@ -225,6 +235,12 @@ class NeptuneDriver(GraphDriver):
|
|||
client = self.aoss_client
|
||||
if not client.indices.exists(index=index_name):
|
||||
client.indices.create(index=index_name, body=index['body'])
|
||||
|
||||
alias_name = index.get('alias_name', index_name)
|
||||
|
||||
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
||||
client.indices.put_alias(index=index_name, name=alias_name)
|
||||
|
||||
# Sleep for 1 minute to let the index creation complete
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue