feat: add password encryption for Neo4j

This commit is contained in:
Igor Ilic 2025-12-02 16:34:16 +01:00
parent 92448767fe
commit 1282905888
2 changed files with 47 additions and 8 deletions

View file

@ -1,12 +1,14 @@
import os
import asyncio
import requests
import base64
import hashlib
from uuid import UUID
from typing import Optional
from cryptography.fernet import Fernet
from cognee.infrastructure.databases.graph import get_graph_config
from cognee.modules.users.models import User
from cognee.modules.users.models import User, DatasetDatabase
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
@ -37,10 +39,15 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
graph_db_name = f"{dataset_id}"
# Client credentials
# Client credentials and encryption
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
encryption_key = base64.urlsafe_b64encode(
hashlib.sha256(encryption_env_key.encode()).digest()
)
cipher = Fernet(encryption_key)
if client_id is None or client_secret is None or tenant_id is None:
raise ValueError(
@ -93,7 +100,9 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = ""
for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get(status_url, headers=headers)
status_resp = requests.get(
status_url, headers=headers
) # TODO: Use async requests with httpx
status = status_resp.json()["data"]["status"]
if status.lower() == "running":
return
@ -104,17 +113,45 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
instance_id = response.json()["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
encrypted_db_password_string = encrypted_db_password_bytes.decode()
return {
"graph_database_name": graph_db_name,
"graph_database_url": graph_db_url,
"graph_database_provider": "neo4j",
"graph_database_key": graph_db_key,
"graph_database_connection_info": { # TODO: Hashing of keys/passwords in relational DB
"graph_database_connection_info": {
"graph_database_username": graph_db_username,
"graph_database_password": graph_db_password,
"graph_database_password": encrypted_db_password_string,
},
}
@classmethod
async def resolve_dataset_connection_info(
cls, dataset_database: DatasetDatabase
) -> DatasetDatabase:
"""
Resolve and decrypt connection info for the Neo4j dataset database.
Args:
dataset_database: DatasetDatabase instance containing encrypted connection info.
"""
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
encryption_key = base64.urlsafe_b64encode(
hashlib.sha256(encryption_env_key.encode()).digest()
)
cipher = Fernet(encryption_key)
graph_db_password = cipher.decrypt(
dataset_database.graph_database_connection_info["graph_database_password"].encode()
).decode()
dataset_database.graph_database_connection_info["graph_database_password"] = (
graph_db_password
)
return dataset_database
@classmethod
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
pass

View file

@ -1,3 +1,4 @@
from cognee.context_global_variables import backend_access_control_enabled
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
@ -5,11 +6,12 @@ from cognee.shared.cache import delete_cache
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
if graph:
# TODO: prune_system should work with multi-user access control mode enabled
if graph and not backend_access_control_enabled():
graph_engine = await get_graph_engine()
await graph_engine.delete_graph()
if vector:
if vector and not backend_access_control_enabled():
vector_engine = get_vector_engine()
await vector_engine.prune()