Merge branch 'dev' into feature/cog-3502-tool-logging-with-redis

This commit is contained in:
hajdul88 2026-01-16 13:00:11 +01:00 committed by GitHub
commit 2e9f646edd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,11 +1,13 @@
import os import os
import aiohttp
import asyncio import asyncio
import requests
import base64 import base64
import hashlib import hashlib
from uuid import UUID from uuid import UUID
from typing import Optional from typing import Optional
from urllib.parse import urlparse
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from aiohttp import BasicAuth
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config
from cognee.modules.users.models import User, DatasetDatabase from cognee.modules.users.models import User, DatasetDatabase
@ -23,7 +25,6 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
Quality of life improvements: Quality of life improvements:
- Allow configuration of different Neo4j Aura plans and regions. - Allow configuration of different Neo4j Aura plans and regions.
- Requests should be made async, currently a blocking requests library is used.
""" """
@classmethod @classmethod
@ -49,6 +50,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
graph_db_name = f"{dataset_id}" graph_db_name = f"{dataset_id}"
# Client credentials and encryption # Client credentials and encryption
# Note: Should not be used as class variables so that they are not persisted in memory longer than needed
client_id = os.environ.get("NEO4J_CLIENT_ID", None) client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None) tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
@ -63,22 +65,13 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling." "NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
) )
# Make the request with HTTP Basic Auth resp_token = await cls._get_aura_token(client_id, client_secret)
def get_aura_token(client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
resp = requests.post(url, data=data, auth=(client_id, client_secret))
resp.raise_for_status() # raises if the request failed
return resp.json()
resp = get_aura_token(client_id, client_secret)
url = "https://api.neo4j.io/v1/instances" url = "https://api.neo4j.io/v1/instances"
headers = { headers = {
"accept": "application/json", "accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}", "Authorization": f"Bearer {resp_token['access_token']}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
@ -96,31 +89,38 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
"cloud_provider": "gcp", "cloud_provider": "gcp",
} }
response = requests.post(url, headers=headers, json=payload) async def _create_database_instance_request():
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as resp:
resp.raise_for_status()
return await resp.json()
resp_create = await _create_database_instance_request()
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
graph_db_url = response.json()["data"]["connection_url"] graph_db_url = resp_create["data"]["connection_url"]
graph_db_key = resp["access_token"] graph_db_key = resp_token["access_token"]
graph_db_username = response.json()["data"]["username"] graph_db_username = resp_create["data"]["username"]
graph_db_password = response.json()["data"]["password"] graph_db_password = resp_create["data"]["password"]
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict): async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
# Poll until the instance is running # Poll until the instance is running
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}" status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = "" status = ""
for attempt in range(30): # Try for up to ~5 minutes for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get( async with aiohttp.ClientSession() as session:
status_url, headers=headers async with session.get(status_url, headers=headers) as resp:
) # TODO: Use async requests with httpx resp.raise_for_status()
status = status_resp.json()["data"]["status"] status_resp = await resp.json()
if status.lower() == "running": status = status_resp["data"]["status"]
return if status.lower() == "running":
await asyncio.sleep(10) return
await asyncio.sleep(10)
raise TimeoutError( raise TimeoutError(
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}" f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
) )
instance_id = response.json()["data"]["id"] instance_id = resp_create["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers) await _wait_for_neo4j_instance_provisioning(instance_id, headers)
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode()) encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
@ -165,4 +165,39 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
@classmethod @classmethod
async def delete_dataset(cls, dataset_database: DatasetDatabase): async def delete_dataset(cls, dataset_database: DatasetDatabase):
pass # Get dataset database information and credentials
dataset_database = await cls.resolve_dataset_connection_info(dataset_database)
parsed_url = urlparse(dataset_database.graph_database_url)
instance_id = parsed_url.hostname.split(".")[0]
url = f"https://api.neo4j.io/v1/instances/{instance_id}"
# Get access token for Neo4j Aura API
# Client credentials
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
resp = await cls._get_aura_token(client_id, client_secret)
headers = {
"accept": "application/json",
"Authorization": f"Bearer {resp['access_token']}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.delete(url, headers=headers) as resp:
resp.raise_for_status()
return await resp.json()
@classmethod
async def _get_aura_token(cls, client_id: str, client_secret: str) -> dict:
url = "https://api.neo4j.io/oauth/token"
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
async with aiohttp.ClientSession() as session:
async with session.post(
url, data=data, auth=BasicAuth(client_id, client_secret)
) as resp:
resp.raise_for_status()
return await resp.json()