refactor: use async requests library
This commit is contained in:
parent
268db003e3
commit
05961221f4
1 changed files with 49 additions and 33 deletions
|
|
@ -1,13 +1,13 @@
|
||||||
import os
|
import os
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
from aiohttp import BasicAuth
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
from cognee.modules.users.models import User, DatasetDatabase
|
from cognee.modules.users.models import User, DatasetDatabase
|
||||||
|
|
@ -25,16 +25,8 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
|
||||||
Quality of life improvements:
|
Quality of life improvements:
|
||||||
- Allow configuration of different Neo4j Aura plans and regions.
|
- Allow configuration of different Neo4j Aura plans and regions.
|
||||||
- Requests should be made async, currently a blocking requests library is used.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Client credentials and encryption
|
|
||||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
|
||||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
|
||||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
|
||||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
|
||||||
encryption_key = base64.urlsafe_b64encode(hashlib.sha256(encryption_env_key.encode()).digest())
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|
@ -57,20 +49,29 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
|
||||||
graph_db_name = f"{dataset_id}"
|
graph_db_name = f"{dataset_id}"
|
||||||
|
|
||||||
cipher = Fernet(cls.encryption_key)
|
# Client credentials and encryption
|
||||||
|
# Note: Should not be used as class variables so that they are not persisted in memory longer than needed
|
||||||
|
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||||
|
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||||
|
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||||
|
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||||
|
encryption_key = base64.urlsafe_b64encode(
|
||||||
|
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||||
|
)
|
||||||
|
cipher = Fernet(encryption_key)
|
||||||
|
|
||||||
if cls.client_id is None or cls.client_secret is None or cls.tenant_id is None:
|
if client_id is None or client_secret is None or tenant_id is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = await cls._get_aura_token(cls.client_id, cls.client_secret)
|
resp_token = await cls._get_aura_token(client_id, client_secret)
|
||||||
|
|
||||||
url = "https://api.neo4j.io/v1/instances"
|
url = "https://api.neo4j.io/v1/instances"
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"Authorization": f"Bearer {resp['access_token']}",
|
"Authorization": f"Bearer {resp_token['access_token']}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -84,35 +85,42 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
0:29
|
0:29
|
||||||
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
], # TODO: Find better name to name Neo4j instance within 30 character limit
|
||||||
"type": "professional-db",
|
"type": "professional-db",
|
||||||
"tenant_id": cls.tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"cloud_provider": "gcp",
|
"cloud_provider": "gcp",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=payload)
|
async def _create_database_instance_request():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, headers=headers, json=payload) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
resp_create = await _create_database_instance_request()
|
||||||
|
|
||||||
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||||
graph_db_url = response.json()["data"]["connection_url"]
|
graph_db_url = resp_create["data"]["connection_url"]
|
||||||
graph_db_key = resp["access_token"]
|
graph_db_key = resp_token["access_token"]
|
||||||
graph_db_username = response.json()["data"]["username"]
|
graph_db_username = resp_create["data"]["username"]
|
||||||
graph_db_password = response.json()["data"]["password"]
|
graph_db_password = resp_create["data"]["password"]
|
||||||
|
|
||||||
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||||
# Poll until the instance is running
|
# Poll until the instance is running
|
||||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||||
status = ""
|
status = ""
|
||||||
for attempt in range(30): # Try for up to ~5 minutes
|
for attempt in range(30): # Try for up to ~5 minutes
|
||||||
status_resp = requests.get(
|
async with aiohttp.ClientSession() as session:
|
||||||
status_url, headers=headers
|
async with session.get(status_url, headers=headers) as resp:
|
||||||
) # TODO: Use async requests with httpx
|
resp.raise_for_status()
|
||||||
status = status_resp.json()["data"]["status"]
|
status_resp = await resp.json()
|
||||||
if status.lower() == "running":
|
status = status_resp["data"]["status"]
|
||||||
return
|
if status.lower() == "running":
|
||||||
await asyncio.sleep(10)
|
return
|
||||||
|
await asyncio.sleep(10)
|
||||||
raise TimeoutError(
|
raise TimeoutError(
|
||||||
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||||
)
|
)
|
||||||
|
|
||||||
instance_id = response.json()["data"]["id"]
|
instance_id = resp_create["data"]["id"]
|
||||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||||
|
|
||||||
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||||
|
|
@ -166,7 +174,10 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||||
|
|
||||||
# Get access token for Neo4j Aura API
|
# Get access token for Neo4j Aura API
|
||||||
resp = await cls._get_aura_token(cls.client_id, cls.client_secret)
|
# Client credentials
|
||||||
|
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||||
|
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||||
|
resp = await cls._get_aura_token(client_id, client_secret)
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
|
|
@ -174,14 +185,19 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.delete(url, headers=headers)
|
async with aiohttp.ClientSession() as session:
|
||||||
return response
|
async with session.delete(url, headers=headers) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _get_aura_token(cls, client_id: str, client_secret: str) -> dict:
|
async def _get_aura_token(cls, client_id: str, client_secret: str) -> dict:
|
||||||
url = "https://api.neo4j.io/oauth/token"
|
url = "https://api.neo4j.io/oauth/token"
|
||||||
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||||
|
|
||||||
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
async with aiohttp.ClientSession() as session:
|
||||||
resp.raise_for_status() # raises if the request failed
|
async with session.post(
|
||||||
return resp.json()
|
url, data=data, auth=BasicAuth(client_id, client_secret)
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue