diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py index eb6cbc55a..bccf5020e 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py @@ -1,11 +1,13 @@ import os +import aiohttp import asyncio -import requests import base64 import hashlib from uuid import UUID from typing import Optional +from urllib.parse import urlparse from cryptography.fernet import Fernet +from aiohttp import BasicAuth from cognee.infrastructure.databases.graph import get_graph_config from cognee.modules.users.models import User, DatasetDatabase @@ -23,7 +25,6 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): Quality of life improvements: - Allow configuration of different Neo4j Aura plans and regions. - - Requests should be made async, currently a blocking requests library is used. """ @classmethod @@ -49,6 +50,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): graph_db_name = f"{dataset_id}" # Client credentials and encryption + # Note: Should not be used as class variables so that they are not persisted in memory longer than needed client_id = os.environ.get("NEO4J_CLIENT_ID", None) client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) tenant_id = os.environ.get("NEO4J_TENANT_ID", None) @@ -63,22 +65,13 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling." ) - # Make the request with HTTP Basic Auth - def get_aura_token(client_id: str, client_secret: str) -> dict: - url = "https://api.neo4j.io/oauth/token" - data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded - - resp = requests.post(url, data=data, auth=(client_id, client_secret)) - resp.raise_for_status() # raises if the request failed - return resp.json() - - resp = get_aura_token(client_id, client_secret) + resp_token = await cls._get_aura_token(client_id, client_secret) url = "https://api.neo4j.io/v1/instances" headers = { "accept": "application/json", - "Authorization": f"Bearer {resp['access_token']}", + "Authorization": f"Bearer {resp_token['access_token']}", "Content-Type": "application/json", } @@ -96,31 +89,38 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "cloud_provider": "gcp", } - response = requests.post(url, headers=headers, json=payload) + async def _create_database_instance_request(): + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload) as resp: + resp.raise_for_status() + return await resp.json() + + resp_create = await _create_database_instance_request() graph_db_name = "neo4j" # Has to be 'neo4j' for Aura - graph_db_url = response.json()["data"]["connection_url"] - graph_db_key = resp["access_token"] - graph_db_username = response.json()["data"]["username"] - graph_db_password = response.json()["data"]["password"] + graph_db_url = resp_create["data"]["connection_url"] + graph_db_key = resp_token["access_token"] + graph_db_username = resp_create["data"]["username"] + graph_db_password = resp_create["data"]["password"] async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict): # Poll until the instance is running status_url = f"https://api.neo4j.io/v1/instances/{instance_id}" status = "" for attempt in range(30): # Try for up to ~5 minutes - status_resp = requests.get( - status_url, headers=headers - ) # TODO: Use async requests with httpx - status = status_resp.json()["data"]["status"] - if status.lower() == "running": - return - await asyncio.sleep(10) + async with aiohttp.ClientSession() as session: + async with session.get(status_url, headers=headers) as resp: + resp.raise_for_status() + status_resp = await resp.json() + status = status_resp["data"]["status"] + if status.lower() == "running": + return + await asyncio.sleep(10) raise TimeoutError( f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}" ) - instance_id = response.json()["data"]["id"] + instance_id = resp_create["data"]["id"] await _wait_for_neo4j_instance_provisioning(instance_id, headers) encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode()) @@ -165,4 +165,39 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): @classmethod async def delete_dataset(cls, dataset_database: DatasetDatabase): - pass + # Get dataset database information and credentials + dataset_database = await cls.resolve_dataset_connection_info(dataset_database) + + parsed_url = urlparse(dataset_database.graph_database_url) + instance_id = parsed_url.hostname.split(".")[0] + + url = f"https://api.neo4j.io/v1/instances/{instance_id}" + + # Get access token for Neo4j Aura API + # Client credentials + client_id = os.environ.get("NEO4J_CLIENT_ID", None) + client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) + resp = await cls._get_aura_token(client_id, client_secret) + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {resp['access_token']}", + "Content-Type": "application/json", + } + + async with aiohttp.ClientSession() as session: + async with session.delete(url, headers=headers) as resp: + resp.raise_for_status() + return await resp.json() + + @classmethod + async def _get_aura_token(cls, client_id: str, client_secret: str) -> dict: + url = "https://api.neo4j.io/oauth/token" + data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded + + async with aiohttp.ClientSession() as session: + async with session.post( + url, data=data, auth=BasicAuth(client_id, client_secret) + ) as resp: + resp.raise_for_status() + return await resp.json()