""" Copyright 2024, Zep Software, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import asyncio import datetime import logging from collections.abc import Coroutine from typing import Any import boto3 from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 class GraphitiNeptuneGraph(NeptuneGraph): """ Custom NeptuneGraph subclass that uses pre-defined Graphiti schema instead of calling Neptune's expensive statistics API. """ # Define Graphiti schema to avoid expensive statistics API calls GRAPHITI_SCHEMA = """ Node labels: Episodic, Entity, Community Relationship types: MENTIONS, RELATES_TO, HAS_MEMBER Node properties: Episodic {uuid: string, name: string, group_id: string, source: string, source_description: string, content: string, valid_at: datetime, created_at: datetime, entity_edges: list} Entity {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string, labels: list} Community {uuid: string, name: string, group_id: string, summary: string, created_at: datetime, name_embedding: string} Relationship properties: MENTIONS {created_at: datetime} RELATES_TO {uuid: string, name: string, group_id: string, fact: string, fact_embedding: string, episodes: list, created_at: datetime, expired_at: datetime, valid_at: datetime, invalid_at: datetime} HAS_MEMBER {uuid: string, created_at: datetime} """ def _refresh_schema(self) -> None: """ Override to use pre-defined schema instead of calling statistics API. This avoids the expensive Neptune statistics API call that requires statistics to be enabled on the Neptune instance. """ self.schema = self.GRAPHITI_SCHEMA logger.debug('Using pre-defined Graphiti schema, skipping Neptune statistics API') aoss_indices = [ { 'index_name': 'node_name_and_summary', 'body': { 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'summary': {'type': 'text'}, 'group_id': {'type': 'text'}, } } }, 'query': { 'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}}, 'size': DEFAULT_SIZE, }, }, { 'index_name': 'community_name', 'body': { 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'group_id': {'type': 'text'}, } } }, 'query': { 'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}}, 'size': DEFAULT_SIZE, }, }, { 'index_name': 'episode_content', 'body': { 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'content': {'type': 'text'}, 'source': {'type': 'text'}, 'source_description': {'type': 'text'}, 'group_id': {'type': 'text'}, } } }, 'query': { 'query': { 'multi_match': { 'query': '', 'fields': ['content', 'source', 'source_description', 'group_id'], } }, 'size': DEFAULT_SIZE, }, }, { 'index_name': 'edge_name_and_fact', 'body': { 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'fact': {'type': 'text'}, 'group_id': {'type': 'text'}, } } }, 'query': { 'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}}, 'size': DEFAULT_SIZE, }, }, ] class NeptuneDriver(GraphDriver): provider: GraphProvider = GraphProvider.NEPTUNE def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443, database: str = 'default'): """This initializes a NeptuneDriver for use with Neptune as a backend Args: host (str): The Neptune Database or Neptune Analytics host aoss_host (str): The OpenSearch host value port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182. aoss_port (int, optional): The OpenSearch port. Defaults to 443. database (str, optional): The database name (for compatibility with base class). Defaults to 'default'. """ if not host: raise ValueError('You must provide an endpoint to create a NeptuneDriver') # Set the database attribute required by the base GraphDriver class self._database = database if host.startswith('neptune-db://'): # This is a Neptune Database Cluster endpoint = host.replace('neptune-db://', '') # Use custom GraphitiNeptuneGraph to avoid expensive statistics API calls self.client = GraphitiNeptuneGraph(endpoint, port) logger.debug('Creating Neptune Database session for %s with pre-defined schema', host) elif host.startswith('neptune-graph://'): # This is a Neptune Analytics Graph graphId = host.replace('neptune-graph://', '') self.client = NeptuneAnalyticsGraph(graphId) logger.debug('Creating Neptune Graph session for %s', host) else: raise ValueError( 'You must provide an endpoint to create a NeptuneDriver as either neptune-db:// or neptune-graph://' ) if not aoss_host: raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.') # Strip protocol prefix from aoss_host if present (OpenSearch expects just the hostname) aoss_hostname = aoss_host.replace('https://', '').replace('http://', '') session = boto3.Session() self.aoss_client = OpenSearch( hosts=[{'host': aoss_hostname, 'port': aoss_port}], http_auth=Urllib3AWSV4SignerAuth( session.get_credentials(), session.region_name, 'aoss' ), use_ssl=True, verify_certs=True, connection_class=Urllib3HttpConnection, pool_maxsize=20, ) def _sanitize_parameters(self, query, params: dict): if isinstance(query, list): queries = [] for q in query: queries.append(self._sanitize_parameters(q, params)) return queries else: for k, v in params.items(): if isinstance(v, datetime.datetime): # Convert datetime to ISO format string params[k] = v.isoformat() elif isinstance(v, list): # Check if list contains actual datetime objects (not just strings with 'T') has_datetime = any(isinstance(item, datetime.datetime) for item in v) if has_datetime: # Convert datetime objects to ISO strings for i, item in enumerate(v): if isinstance(item, datetime.datetime): v[i] = item.isoformat() # Handle nested dictionaries for i, item in enumerate(v): if isinstance(item, dict): query = self._sanitize_parameters(query, v[i]) elif isinstance(v, dict): query = self._sanitize_parameters(query, v) return query async def execute_query( self, cypher_query_, **kwargs: Any ) -> tuple[dict[str, Any], None, None]: params = dict(kwargs) # Flatten nested 'params' dict if present (for compatibility with Neo4j driver interface) if 'params' in params and isinstance(params['params'], dict): nested_params = params.pop('params') # Merge nested params into the top level, nested params take precedence params = {**params, **nested_params} if isinstance(cypher_query_, list): for q in cypher_query_: result, _, _ = self._run_query(q[0], q[1]) return result, None, None else: return self._run_query(cypher_query_, params) def _run_query(self, cypher_query_, params): cypher_query_ = str(self._sanitize_parameters(cypher_query_, params)) try: result = self.client.query(cypher_query_, params=params) except Exception as e: logger.error('Query: %s', cypher_query_) logger.error('Parameters: %s', params) logger.error('Error executing query: %s', e) raise e return result, None, None def session(self, database: str | None = None) -> GraphDriverSession: return NeptuneDriverSession(driver=self) async def close(self) -> None: return self.client.client.close() async def _delete_all_data(self) -> Any: return await self.execute_query('MATCH (n) DETACH DELETE n') def delete_all_indexes(self) -> Coroutine[Any, Any, Any]: return self.delete_all_indexes_impl() async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]: # No matter what happens above, always return True return self.delete_aoss_indices() async def create_aoss_indices(self): for index in aoss_indices: index_name = index['index_name'] client = self.aoss_client if not client.indices.exists(index=index_name): client.indices.create(index=index_name, body=index['body']) # Sleep for 1 minute to let the index creation complete await asyncio.sleep(60) async def delete_aoss_indices(self): for index in aoss_indices: index_name = index['index_name'] client = self.aoss_client if client.indices.exists(index=index_name): client.indices.delete(index=index_name) async def build_indices_and_constraints(self, delete_existing: bool = False): # Neptune uses OpenSearch (AOSS) for indexing if delete_existing: await self.delete_aoss_indices() await self.create_aoss_indices() def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]: for index in aoss_indices: if name.lower() == index['index_name']: index['query']['query']['multi_match']['query'] = query_text query = {'size': limit, 'query': index['query']} resp = self.aoss_client.search(body=query['query'], index=index['index_name']) return resp return {} def save_to_aoss(self, name: str, data: list[dict]) -> int: for index in aoss_indices: if name.lower() == index['index_name']: to_index = [] for d in data: item = {'_index': name, '_id': d['uuid']} for p in index['body']['mappings']['properties']: if p in d: item[p] = d[p] to_index.append(item) success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True) return success return 0 class NeptuneDriverSession(GraphDriverSession): provider = GraphProvider.NEPTUNE def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType] self.driver = driver async def __aenter__(self): return self async def __aexit__(self, exc_type, exc, tb): # No cleanup needed for Neptune, but method must exist pass async def close(self): # No explicit close needed for Neptune, but method must exist pass async def execute_write(self, func, *args, **kwargs): # Directly await the provided async function with `self` as the transaction/session return await func(self, *args, **kwargs) async def run(self, query: str | list, **kwargs: Any) -> Any: if isinstance(query, list): res = None for q in query: res = await self.driver.execute_query(q, **kwargs) return res else: return await self.driver.execute_query(str(query), **kwargs)