feat: add password encryption for Neo4j
This commit is contained in:
parent
92448767fe
commit
1282905888
2 changed files with 47 additions and 8 deletions
|
|
@ -1,12 +1,14 @@
|
|||
import os
|
||||
import asyncio
|
||||
import requests
|
||||
import base64
|
||||
import hashlib
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
from cognee.modules.users.models import User, DatasetDatabase
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
|
|
@ -37,10 +39,15 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
|||
|
||||
graph_db_name = f"{dataset_id}"
|
||||
|
||||
# Client credentials
|
||||
# Client credentials and encryption
|
||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
|
||||
if client_id is None or client_secret is None or tenant_id is None:
|
||||
raise ValueError(
|
||||
|
|
@ -93,7 +100,9 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
|||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||
status = ""
|
||||
for attempt in range(30): # Try for up to ~5 minutes
|
||||
status_resp = requests.get(status_url, headers=headers)
|
||||
status_resp = requests.get(
|
||||
status_url, headers=headers
|
||||
) # TODO: Use async requests with httpx
|
||||
status = status_resp.json()["data"]["status"]
|
||||
if status.lower() == "running":
|
||||
return
|
||||
|
|
@ -104,17 +113,45 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
|||
|
||||
instance_id = response.json()["data"]["id"]
|
||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||
|
||||
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||
encrypted_db_password_string = encrypted_db_password_bytes.decode()
|
||||
|
||||
return {
|
||||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": "neo4j",
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_database_connection_info": { # TODO: Hashing of keys/passwords in relational DB
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
"graph_database_password": encrypted_db_password_string,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve and decrypt connection info for the Neo4j dataset database.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance containing encrypted connection info.
|
||||
"""
|
||||
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||
encryption_key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||
)
|
||||
cipher = Fernet(encryption_key)
|
||||
graph_db_password = cipher.decrypt(
|
||||
dataset_database.graph_database_connection_info["graph_database_password"].encode()
|
||||
).decode()
|
||||
|
||||
dataset_database.graph_database_connection_info["graph_database_password"] = (
|
||||
graph_db_password
|
||||
)
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
@ -5,11 +6,12 @@ from cognee.shared.cache import delete_cache
|
|||
|
||||
|
||||
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||
if graph:
|
||||
# TODO: prune_system should work with multi-user access control mode enabled
|
||||
if graph and not backend_access_control_enabled():
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_graph()
|
||||
|
||||
if vector:
|
||||
if vector and not backend_access_control_enabled():
|
||||
vector_engine = get_vector_engine()
|
||||
await vector_engine.prune()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue