feat: add password encryption for Neo4j
This commit is contained in:
parent
92448767fe
commit
1282905888
2 changed files with 47 additions and 8 deletions
|
|
@ -1,12 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
import requests
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User, DatasetDatabase
|
||||||
|
|
||||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,10 +39,15 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
|
||||||
graph_db_name = f"{dataset_id}"
|
graph_db_name = f"{dataset_id}"
|
||||||
|
|
||||||
# Client credentials
|
# Client credentials and encryption
|
||||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||||
|
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||||
|
encryption_key = base64.urlsafe_b64encode(
|
||||||
|
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||||
|
)
|
||||||
|
cipher = Fernet(encryption_key)
|
||||||
|
|
||||||
if client_id is None or client_secret is None or tenant_id is None:
|
if client_id is None or client_secret is None or tenant_id is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -93,7 +100,9 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||||
status = ""
|
status = ""
|
||||||
for attempt in range(30): # Try for up to ~5 minutes
|
for attempt in range(30): # Try for up to ~5 minutes
|
||||||
status_resp = requests.get(status_url, headers=headers)
|
status_resp = requests.get(
|
||||||
|
status_url, headers=headers
|
||||||
|
) # TODO: Use async requests with httpx
|
||||||
status = status_resp.json()["data"]["status"]
|
status = status_resp.json()["data"]["status"]
|
||||||
if status.lower() == "running":
|
if status.lower() == "running":
|
||||||
return
|
return
|
||||||
|
|
@ -104,17 +113,45 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
|
||||||
instance_id = response.json()["data"]["id"]
|
instance_id = response.json()["data"]["id"]
|
||||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||||
|
|
||||||
|
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||||
|
encrypted_db_password_string = encrypted_db_password_bytes.decode()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"graph_database_name": graph_db_name,
|
"graph_database_name": graph_db_name,
|
||||||
"graph_database_url": graph_db_url,
|
"graph_database_url": graph_db_url,
|
||||||
"graph_database_provider": "neo4j",
|
"graph_database_provider": "neo4j",
|
||||||
"graph_database_key": graph_db_key,
|
"graph_database_key": graph_db_key,
|
||||||
"graph_database_connection_info": { # TODO: Hashing of keys/passwords in relational DB
|
"graph_database_connection_info": {
|
||||||
"graph_database_username": graph_db_username,
|
"graph_database_username": graph_db_username,
|
||||||
"graph_database_password": graph_db_password,
|
"graph_database_password": encrypted_db_password_string,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def resolve_dataset_connection_info(
|
||||||
|
cls, dataset_database: DatasetDatabase
|
||||||
|
) -> DatasetDatabase:
|
||||||
|
"""
|
||||||
|
Resolve and decrypt connection info for the Neo4j dataset database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_database: DatasetDatabase instance containing encrypted connection info.
|
||||||
|
"""
|
||||||
|
encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key")
|
||||||
|
encryption_key = base64.urlsafe_b64encode(
|
||||||
|
hashlib.sha256(encryption_env_key.encode()).digest()
|
||||||
|
)
|
||||||
|
cipher = Fernet(encryption_key)
|
||||||
|
graph_db_password = cipher.decrypt(
|
||||||
|
dataset_database.graph_database_connection_info["graph_database_password"].encode()
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
dataset_database.graph_database_connection_info["graph_database_password"] = (
|
||||||
|
graph_db_password
|
||||||
|
)
|
||||||
|
return dataset_database
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
|
async def delete_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
@ -5,11 +6,12 @@ from cognee.shared.cache import delete_cache
|
||||||
|
|
||||||
|
|
||||||
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
||||||
if graph:
|
# TODO: prune_system should work with multi-user access control mode enabled
|
||||||
|
if graph and not backend_access_control_enabled():
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
await graph_engine.delete_graph()
|
await graph_engine.delete_graph()
|
||||||
|
|
||||||
if vector:
|
if vector and not backend_access_control_enabled():
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
await vector_engine.prune()
|
await vector_engine.prune()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue