diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py index eb6cbc55a..568bea528 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py @@ -1,10 +1,12 @@ import os +import aiohttp import asyncio import requests import base64 import hashlib from uuid import UUID from typing import Optional +from urllib.parse import urlparse from cryptography.fernet import Fernet from cognee.infrastructure.databases.graph import get_graph_config @@ -26,6 +28,13 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): - Requests should be made async, currently a blocking requests library is used. """ + # Client credentials and encryption + client_id = os.environ.get("NEO4J_CLIENT_ID", None) + client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) + tenant_id = os.environ.get("NEO4J_TENANT_ID", None) + encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key") + encryption_key = base64.urlsafe_b64encode(hashlib.sha256(encryption_env_key.encode()).digest()) + @classmethod async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: """ @@ -48,31 +57,14 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): graph_db_name = f"{dataset_id}" - # Client credentials and encryption - client_id = os.environ.get("NEO4J_CLIENT_ID", None) - client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) - tenant_id = os.environ.get("NEO4J_TENANT_ID", None) - encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key") - encryption_key = base64.urlsafe_b64encode( - hashlib.sha256(encryption_env_key.encode()).digest() - ) - cipher = Fernet(encryption_key) + cipher = Fernet(cls.encryption_key) - if client_id is None or client_secret is None or tenant_id is None: + if cls.client_id is None or cls.client_secret is None or cls.tenant_id is None: raise ValueError( "NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling." ) - # Make the request with HTTP Basic Auth - def get_aura_token(client_id: str, client_secret: str) -> dict: - url = "https://api.neo4j.io/oauth/token" - data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded - - resp = requests.post(url, data=data, auth=(client_id, client_secret)) - resp.raise_for_status() # raises if the request failed - return resp.json() - - resp = get_aura_token(client_id, client_secret) + resp = await cls._get_aura_token(cls.client_id, cls.client_secret) url = "https://api.neo4j.io/v1/instances" @@ -92,7 +84,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): 0:29 ], # TODO: Find better name to name Neo4j instance within 30 character limit "type": "professional-db", - "tenant_id": tenant_id, + "tenant_id": cls.tenant_id, "cloud_provider": "gcp", } @@ -165,4 +157,31 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): @classmethod async def delete_dataset(cls, dataset_database: DatasetDatabase): - pass + # Get dataset database information and credentials + dataset_database = await cls.resolve_dataset_connection_info(dataset_database) + + parsed_url = urlparse(dataset_database.graph_database_url) + instance_id = parsed_url.hostname.split(".")[0] + + url = f"https://api.neo4j.io/v1/instances/{instance_id}" + + # Get access token for Neo4j Aura API + resp = await cls._get_aura_token(cls.client_id, cls.client_secret) + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {resp['access_token']}", + "Content-Type": "application/json", + } + + response = requests.delete(url, headers=headers) + return response + + @classmethod + async def _get_aura_token(cls, client_id: str, client_secret: str) -> dict: + url = "https://api.neo4j.io/oauth/token" + data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded + + resp = requests.post(url, data=data, auth=(client_id, client_secret)) + resp.raise_for_status() # raises if the request failed + return resp.json()