From 69777ef0a5d80b3a2a10d91d59a9e4f051d019ca Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 25 Nov 2025 17:53:21 +0100 Subject: [PATCH] feat: Add ability to handle custom connection resolution to avoid storing security critical data in rel dbx --- cognee/context_global_variables.py | 11 ++++- .../dataset_database_handler_interface.py | 40 +++++++++++++++++- .../graph/kuzu/KuzuDatasetDatabaseHandler.py | 6 ++- .../Neo4jAuraDatasetDatabaseHandler.py | 8 ++-- .../databases/utils/__init__.py | 1 + ...esolve_dataset_database_connection_info.py | 42 +++++++++++++++++++ .../modules/users/models/DatasetDatabase.py | 6 +-- 7 files changed, 103 insertions(+), 11 deletions(-) create mode 100644 cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index 0e7e16178..58fff2dff 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -7,6 +7,7 @@ from cognee.base_config import get_base_config from cognee.infrastructure.databases.vector.config import get_vectordb_config from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.utils import get_or_create_dataset_database +from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info from cognee.infrastructure.files.storage.config import file_storage_config from cognee.modules.users.methods import get_user @@ -108,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # To ensure permissions are enforced properly all datasets will have their own databases dataset_database = await get_or_create_dataset_database(dataset, user) + # Ensure that all connection info is resolved properly + dataset_database = await resolve_dataset_database_connection_info(dataset_database) base_config = get_base_config() data_root_directory = os.path.join( @@ -133,8 +136,12 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ "graph_file_path": os.path.join( databases_directory_path, dataset_database.graph_database_name ), - "graph_database_username": dataset_database.graph_database_username, - "graph_database_password": dataset_database.graph_database_password, + "graph_database_username": dataset_database.graph_database_connection_info.get( + "graph_database_username", "" + ), + "graph_database_password": dataset_database.graph_database_connection_info.get( + "graph_database_password", "" + ), } storage_config = { diff --git a/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py b/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py index 6dadee6cf..01ee46c48 100644 --- a/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +++ b/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py @@ -3,6 +3,7 @@ from uuid import UUID from abc import ABC, abstractmethod from cognee.modules.users.models.User import User +from cognee.modules.users.models.DatasetDatabase import DatasetDatabase class DatasetDatabaseHandlerInterface(ABC): @@ -10,7 +11,7 @@ class DatasetDatabaseHandlerInterface(ABC): @abstractmethod async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: """ - Return a dictionary with connection info for a graph or vector database for the given dataset. + Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset. Function can auto handle deploying of the actual database if needed, but is not necessary. Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future. Needed for Cognee multi-tenant/multi-user and backend access control support. @@ -18,6 +19,10 @@ class DatasetDatabaseHandlerInterface(ABC): Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database. From which internal mapping of dataset -> database connection info will be done. + The returned dictionary is stored verbatim in the relational database and is later passed to + resolve_dataset_connection_info() at connection time. For safe credential handling, prefer + returning only references to secrets or role identifiers, not plaintext credentials. + Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data. Args: @@ -28,6 +33,39 @@ class DatasetDatabaseHandlerInterface(ABC): """ pass + @classmethod + async def resolve_dataset_connection_info( + cls, dataset_database: DatasetDatabase + ) -> DatasetDatabase: + """ + Resolve runtime connection details for a dataset’s backing graph/vector database. + Function is intended to be overwritten to implement custom logic for resolving connection info. + + This method is invoked right before the application opens a connection for a given dataset. + It receives the DatasetDatabase row that was persisted when create_dataset() ran and must + return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use. + Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials. + + In case of separate graph and vector database handlers, each handler should implement its own logic for resolving + connection info and only change parameters related to its appropriate database, the resolution function will then + be called one after another with the updated DatasetDatabase value from the previous function as the input. + + Typical behavior: + - If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password + or api_url/api_key), return them as-is. + - If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM + roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider + API to obtain short-lived credentials and assemble the final connection DatasetDatabase object. + - Do not persist any resolved or decrypted secrets back to the relational database. Return them only + to the caller. + + Args: + dataset_database: DatasetDatabase row from the relational database + Returns: + DatasetDatabase: Updated instance with resolved connection info + """ + return dataset_database + @classmethod @abstractmethod async def delete_dataset(cls, dataset_id: UUID, user: User) -> None: diff --git a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py index 8859422f9..a2b2da8f4 100644 --- a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py @@ -48,8 +48,10 @@ class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "graph_database_url": graph_db_url, "graph_database_provider": graph_config.graph_database_provider, "graph_database_key": graph_db_key, - "graph_database_username": graph_db_username, - "graph_database_password": graph_db_password, + "graph_database_connection_info": { + "graph_database_username": graph_db_username, + "graph_database_password": graph_db_password, + }, } @classmethod diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDatasetDatabaseHandler.py index cc38abed0..d1e5eee6f 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDatasetDatabaseHandler.py @@ -108,9 +108,11 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "graph_database_name": graph_db_name, "graph_database_url": graph_db_url, "graph_database_provider": "neo4j", - "graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB - "graph_database_username": graph_db_username, - "graph_database_password": graph_db_password, + "graph_database_key": graph_db_key, + "graph_database_connection_info": { # TODO: Hashing of keys/passwords in relational DB + "graph_database_username": graph_db_username, + "graph_database_password": graph_db_password, + }, } @classmethod diff --git a/cognee/infrastructure/databases/utils/__init__.py b/cognee/infrastructure/databases/utils/__init__.py index 1dfa15640..f31d1e0dc 100644 --- a/cognee/infrastructure/databases/utils/__init__.py +++ b/cognee/infrastructure/databases/utils/__init__.py @@ -1 +1,2 @@ from .get_or_create_dataset_database import get_or_create_dataset_database +from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info diff --git a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py new file mode 100644 index 000000000..4d8c19403 --- /dev/null +++ b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py @@ -0,0 +1,42 @@ +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.graph.config import get_graph_config +from cognee.modules.users.models.DatasetDatabase import DatasetDatabase + + +async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: + vector_config = get_vectordb_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) + + +async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: + graph_config = get_graph_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) + + +async def resolve_dataset_database_connection_info( + dataset_database: DatasetDatabase, +) -> DatasetDatabase: + """ + Resolve the connection info for the given DatasetDatabase instance. + Resolve both vector and graph database connection info and return the updated DatasetDatabase instance. + + Args: + dataset_database: DatasetDatabase instance + Returns: + DatasetDatabase instance with resolved connection info + """ + dataset_database = await _get_vector_db_connection_info(dataset_database) + dataset_database = await _get_graph_db_connection_info(dataset_database) + return dataset_database diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 75e650bcd..b864fb951 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone -from sqlalchemy import Column, DateTime, String, UUID, ForeignKey +from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON from cognee.infrastructure.databases.relational import Base @@ -27,8 +27,8 @@ class DatasetDatabase(Base): # TODO: Instead of specifying and forwawrding all these individual fields, consider using a JSON field to store # configuration details for different database types. This would make it more flexible to add new database types # without changing the database schema. - graph_database_username = Column(String, unique=False, nullable=True) - graph_database_password = Column(String, unique=False, nullable=True) + graph_database_connection_info = Column(JSON, unique=False, nullable=True) + vector_database_connection_info = Column(JSON, unique=False, nullable=True) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))