feat: add password encryption for Neo4j

This commit is contained in:
Igor Ilic 2025-12-02 16:34:16 +01:00
parent 92448767fe
commit 1282905888
2 changed files with 47 additions and 8 deletions

View file

@ -1,12 +1,14 @@
import os import os
import asyncio import asyncio
import requests import requests
import base64
import hashlib
from uuid import UUID from uuid import UUID
from typing import Optional from typing import Optional
from cryptography.fernet import Fernet
from cognee.infrastructure.databases.graph import get_graph_config from cognee.infrastructure.databases.graph import get_graph_config
from cognee.modules.users.models import User from cognee.modules.users.models import User, DatasetDatabase
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
@ -37,10 +39,15 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
graph_db_name = f"{dataset_id}" graph_db_name = f"{dataset_id}"
# Client credentials # Client credentials and encryption
client_id = os.environ.get("NEO4J_CLIENT_ID", None) client_id = os.environ.get("NEO4J_CLIENT_ID", None)
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
tenant_id = os.environ.get("NEO4J_TENANT_ID", None) tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
encryption_key = base64.urlsafe_b64encode(
hashlib.sha256(encryption_env_key.encode()).digest()
)
cipher = Fernet(encryption_key)
if client_id is None or client_secret is None or tenant_id is None: if client_id is None or client_secret is None or tenant_id is None:
raise ValueError( raise ValueError(
@ -93,7 +100,9 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}" status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
status = "" status = ""
for attempt in range(30): # Try for up to ~5 minutes for attempt in range(30): # Try for up to ~5 minutes
status_resp = requests.get(status_url, headers=headers) status_resp = requests.get(
status_url, headers=headers
) # TODO: Use async requests with httpx
status = status_resp.json()["data"]["status"] status = status_resp.json()["data"]["status"]
if status.lower() == "running": if status.lower() == "running":
return return
@ -104,17 +113,45 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
instance_id = response.json()["data"]["id"] instance_id = response.json()["data"]["id"]
await _wait_for_neo4j_instance_provisioning(instance_id, headers) await _wait_for_neo4j_instance_provisioning(instance_id, headers)
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
encrypted_db_password_string = encrypted_db_password_bytes.decode()
return { return {
"graph_database_name": graph_db_name, "graph_database_name": graph_db_name,
"graph_database_url": graph_db_url, "graph_database_url": graph_db_url,
"graph_database_provider": "neo4j", "graph_database_provider": "neo4j",
"graph_database_key": graph_db_key, "graph_database_key": graph_db_key,
"graph_database_connection_info": { # TODO: Hashing of keys/passwords in relational DB "graph_database_connection_info": {
"graph_database_username": graph_db_username, "graph_database_username": graph_db_username,
"graph_database_password": graph_db_password, "graph_database_password": encrypted_db_password_string,
}, },
} }
@classmethod
async def resolve_dataset_connection_info(
cls, dataset_database: DatasetDatabase
) -> DatasetDatabase:
"""
Resolve and decrypt connection info for the Neo4j dataset database.
Args:
dataset_database: DatasetDatabase instance containing encrypted connection info.
"""
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
encryption_key = base64.urlsafe_b64encode(
hashlib.sha256(encryption_env_key.encode()).digest()
)
cipher = Fernet(encryption_key)
graph_db_password = cipher.decrypt(
dataset_database.graph_database_connection_info["graph_database_password"].encode()
).decode()
dataset_database.graph_database_connection_info["graph_database_password"] = (
graph_db_password
)
return dataset_database
@classmethod @classmethod
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]): async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
pass pass

View file

@ -1,3 +1,4 @@
from cognee.context_global_variables import backend_access_control_enabled
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
@ -5,11 +6,12 @@ from cognee.shared.cache import delete_cache
async def prune_system(graph=True, vector=True, metadata=True, cache=True): async def prune_system(graph=True, vector=True, metadata=True, cache=True):
if graph: # TODO: prune_system should work with multi-user access control mode enabled
if graph and not backend_access_control_enabled():
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
await graph_engine.delete_graph() await graph_engine.delete_graph()
if vector: if vector and not backend_access_control_enabled():
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
await vector_engine.prune() await vector_engine.prune()