neptune regression update
This commit is contained in:
parent
6441be9934
commit
14e1248b5f
1 changed files with 95 additions and 10 deletions
|
|
@ -24,10 +24,86 @@ import boto3
|
||||||
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
||||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||||
|
|
||||||
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider, aoss_indices, DEFAULT_SIZE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
neptune_aoss_indices = [
|
||||||
|
{
|
||||||
|
'index_name': 'node_name_and_summary',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'summary': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'community_name',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'episode_content',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'content': {'type': 'text'},
|
||||||
|
'source': {'type': 'text'},
|
||||||
|
'source_description': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {
|
||||||
|
'multi_match': {
|
||||||
|
'query': '',
|
||||||
|
'fields': ['content', 'source', 'source_description', 'group_id'],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index_name': 'edge_name_and_fact',
|
||||||
|
'body': {
|
||||||
|
'mappings': {
|
||||||
|
'properties': {
|
||||||
|
'uuid': {'type': 'keyword'},
|
||||||
|
'name': {'type': 'text'},
|
||||||
|
'fact': {'type': 'text'},
|
||||||
|
'group_id': {'type': 'text'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'query': {
|
||||||
|
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
||||||
|
'size': DEFAULT_SIZE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class NeptuneDriver(GraphDriver):
|
class NeptuneDriver(GraphDriver):
|
||||||
provider: GraphProvider = GraphProvider.NEPTUNE
|
provider: GraphProvider = GraphProvider.NEPTUNE
|
||||||
|
|
@ -97,14 +173,14 @@ class NeptuneDriver(GraphDriver):
|
||||||
if any(isinstance(item, str) and 'T' in item for item in v):
|
if any(isinstance(item, str) and 'T' in item for item in v):
|
||||||
# Create a new list expression with datetime() wrapped around each element
|
# Create a new list expression with datetime() wrapped around each element
|
||||||
datetime_list = (
|
datetime_list = (
|
||||||
'['
|
'['
|
||||||
+ ', '.join(
|
+ ', '.join(
|
||||||
f'datetime("{item}")'
|
f'datetime("{item}")'
|
||||||
if isinstance(item, str) and 'T' in item
|
if isinstance(item, str) and 'T' in item
|
||||||
else repr(item)
|
else repr(item)
|
||||||
for item in v
|
for item in v
|
||||||
)
|
)
|
||||||
+ ']'
|
+ ']'
|
||||||
)
|
)
|
||||||
query = str(query).replace(f'${k}', datetime_list)
|
query = str(query).replace(f'${k}', datetime_list)
|
||||||
elif isinstance(v, dict):
|
elif isinstance(v, dict):
|
||||||
|
|
@ -112,7 +188,7 @@ class NeptuneDriver(GraphDriver):
|
||||||
return query
|
return query
|
||||||
|
|
||||||
async def execute_query(
|
async def execute_query(
|
||||||
self, cypher_query_, **kwargs: Any
|
self, cypher_query_, **kwargs: Any
|
||||||
) -> tuple[dict[str, Any], None, None]:
|
) -> tuple[dict[str, Any], None, None]:
|
||||||
params = dict(kwargs)
|
params = dict(kwargs)
|
||||||
if isinstance(cypher_query_, list):
|
if isinstance(cypher_query_, list):
|
||||||
|
|
@ -143,6 +219,15 @@ class NeptuneDriver(GraphDriver):
|
||||||
async def _delete_all_data(self) -> Any:
|
async def _delete_all_data(self) -> Any:
|
||||||
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
||||||
|
|
||||||
|
async def create_aoss_indices(self):
|
||||||
|
for index in neptune_aoss_indices:
|
||||||
|
index_name = index['index_name']
|
||||||
|
client = self.aoss_client
|
||||||
|
if not client.indices.exists(index=index_name):
|
||||||
|
client.indices.create(index=index_name, body=index['body'])
|
||||||
|
# Sleep for 1 minute to let the index creation complete
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
||||||
return self.delete_all_indexes_impl()
|
return self.delete_all_indexes_impl()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue