feat: Add ability to handle custom connection resolution to avoid storing security critical data in rel dbx
This commit is contained in:
parent
5f3b776406
commit
69777ef0a5
7 changed files with 103 additions and 11 deletions
|
|
@ -7,6 +7,7 @@ from cognee.base_config import get_base_config
|
|||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
|
|
@ -108,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
# Ensure that all connection info is resolved properly
|
||||
dataset_database = await resolve_dataset_database_connection_info(dataset_database)
|
||||
|
||||
base_config = get_base_config()
|
||||
data_root_directory = os.path.join(
|
||||
|
|
@ -133,8 +136,12 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
"graph_file_path": os.path.join(
|
||||
databases_directory_path, dataset_database.graph_database_name
|
||||
),
|
||||
"graph_database_username": dataset_database.graph_database_username,
|
||||
"graph_database_password": dataset_database.graph_database_password,
|
||||
"graph_database_username": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_username", ""
|
||||
),
|
||||
"graph_database_password": dataset_database.graph_database_connection_info.get(
|
||||
"graph_database_password", ""
|
||||
),
|
||||
}
|
||||
|
||||
storage_config = {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from uuid import UUID
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from cognee.modules.users.models.User import User
|
||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
class DatasetDatabaseHandlerInterface(ABC):
|
||||
|
|
@ -10,7 +11,7 @@ class DatasetDatabaseHandlerInterface(ABC):
|
|||
@abstractmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
"""
|
||||
Return a dictionary with connection info for a graph or vector database for the given dataset.
|
||||
Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset.
|
||||
Function can auto handle deploying of the actual database if needed, but is not necessary.
|
||||
Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
|
||||
Needed for Cognee multi-tenant/multi-user and backend access control support.
|
||||
|
|
@ -18,6 +19,10 @@ class DatasetDatabaseHandlerInterface(ABC):
|
|||
Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
|
||||
From which internal mapping of dataset -> database connection info will be done.
|
||||
|
||||
The returned dictionary is stored verbatim in the relational database and is later passed to
|
||||
resolve_dataset_connection_info() at connection time. For safe credential handling, prefer
|
||||
returning only references to secrets or role identifiers, not plaintext credentials.
|
||||
|
||||
Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data.
|
||||
|
||||
Args:
|
||||
|
|
@ -28,6 +33,39 @@ class DatasetDatabaseHandlerInterface(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def resolve_dataset_connection_info(
|
||||
cls, dataset_database: DatasetDatabase
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve runtime connection details for a dataset’s backing graph/vector database.
|
||||
Function is intended to be overwritten to implement custom logic for resolving connection info.
|
||||
|
||||
This method is invoked right before the application opens a connection for a given dataset.
|
||||
It receives the DatasetDatabase row that was persisted when create_dataset() ran and must
|
||||
return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use.
|
||||
Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials.
|
||||
|
||||
In case of separate graph and vector database handlers, each handler should implement its own logic for resolving
|
||||
connection info and only change parameters related to its appropriate database, the resolution function will then
|
||||
be called one after another with the updated DatasetDatabase value from the previous function as the input.
|
||||
|
||||
Typical behavior:
|
||||
- If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password
|
||||
or api_url/api_key), return them as-is.
|
||||
- If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM
|
||||
roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider
|
||||
API to obtain short-lived credentials and assemble the final connection DatasetDatabase object.
|
||||
- Do not persist any resolved or decrypted secrets back to the relational database. Return them only
|
||||
to the caller.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase row from the relational database
|
||||
Returns:
|
||||
DatasetDatabase: Updated instance with resolved connection info
|
||||
"""
|
||||
return dataset_database
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def delete_dataset(cls, dataset_id: UUID, user: User) -> None:
|
||||
|
|
|
|||
|
|
@ -48,8 +48,10 @@ class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
|||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": graph_config.graph_database_provider,
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
"graph_database_connection_info": {
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -108,9 +108,11 @@ class Neo4jAuraDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
|||
"graph_database_name": graph_db_name,
|
||||
"graph_database_url": graph_db_url,
|
||||
"graph_database_provider": "neo4j",
|
||||
"graph_database_key": graph_db_key, # TODO: Hashing of keys/passwords in relational DB
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
"graph_database_key": graph_db_key,
|
||||
"graph_database_connection_info": { # TODO: Hashing of keys/passwords in relational DB
|
||||
"graph_database_username": graph_db_username,
|
||||
"graph_database_password": graph_db_password,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .get_or_create_dataset_database import get_or_create_dataset_database
|
||||
from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
|
||||
|
||||
|
||||
async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||
vector_config = get_vectordb_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler]
|
||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||
|
||||
|
||||
async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase:
|
||||
graph_config = get_graph_config()
|
||||
|
||||
from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
|
||||
supported_dataset_database_handlers,
|
||||
)
|
||||
|
||||
handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler]
|
||||
return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database)
|
||||
|
||||
|
||||
async def resolve_dataset_database_connection_info(
|
||||
dataset_database: DatasetDatabase,
|
||||
) -> DatasetDatabase:
|
||||
"""
|
||||
Resolve the connection info for the given DatasetDatabase instance.
|
||||
Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
|
||||
|
||||
Args:
|
||||
dataset_database: DatasetDatabase instance
|
||||
Returns:
|
||||
DatasetDatabase instance with resolved connection info
|
||||
"""
|
||||
dataset_database = await _get_vector_db_connection_info(dataset_database)
|
||||
dataset_database = await _get_graph_db_connection_info(dataset_database)
|
||||
return dataset_database
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey
|
||||
from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
|
|
@ -27,8 +27,8 @@ class DatasetDatabase(Base):
|
|||
# TODO: Instead of specifying and forwawrding all these individual fields, consider using a JSON field to store
|
||||
# configuration details for different database types. This would make it more flexible to add new database types
|
||||
# without changing the database schema.
|
||||
graph_database_username = Column(String, unique=False, nullable=True)
|
||||
graph_database_password = Column(String, unique=False, nullable=True)
|
||||
graph_database_connection_info = Column(JSON, unique=False, nullable=True)
|
||||
vector_database_connection_info = Column(JSON, unique=False, nullable=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue