diff --git a/.env.template b/.env.template index 61853b983..7defaee09 100644 --- a/.env.template +++ b/.env.template @@ -97,6 +97,8 @@ DB_NAME=cognee_db # Default (local file-based) GRAPH_DATABASE_PROVIDER="kuzu" +# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset +GRAPH_DATASET_DATABASE_HANDLER="kuzu" # -- To switch to Remote Kuzu uncomment and fill these: ------------------------------------------------------------- #GRAPH_DATABASE_PROVIDER="kuzu" @@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb" # Not needed if a cloud vector database is not used VECTOR_DB_URL= VECTOR_DB_KEY= +# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset +VECTOR_DATASET_DATABASE_HANDLER="lancedb" ################################################################################ # đŸ§© Ontology resolver settings diff --git a/.github/workflows/db_examples_tests.yml b/.github/workflows/db_examples_tests.yml index c58bc48ef..5062982d8 100644 --- a/.github/workflows/db_examples_tests.yml +++ b/.github/workflows/db_examples_tests.yml @@ -61,6 +61,7 @@ jobs: - name: Run Neo4j Example env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -142,6 +143,7 @@ jobs: - name: Run PGVector Example env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/distributed_test.yml b/.github/workflows/distributed_test.yml index 57bbb7459..3c9debfdf 100644 --- a/.github/workflows/distributed_test.yml +++ b/.github/workflows/distributed_test.yml @@ -47,6 +47,7 @@ jobs: - name: Run Distributed Cognee (Modal) env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 676699c2a..520d93689 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -147,6 +147,7 @@ jobs: - name: Run Deduplication Example env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} @@ -211,6 +212,31 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_parallel_databases.py + test-dataset-database-handler: + name: Test dataset database handlers in Cognee + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run dataset databases handler test + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_dataset_database_handler.py + test-permissions: name: Test permissions with different situations in Cognee runs-on: ubuntu-22.04 diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index f7cc278cb..a9332cf25 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -72,6 +72,7 @@ jobs: - name: Run Descriptive Graph Metrics Example env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/graph_db_tests.yml b/.github/workflows/graph_db_tests.yml index b07f6232f..e9fd7f4c2 100644 --- a/.github/workflows/graph_db_tests.yml +++ b/.github/workflows/graph_db_tests.yml @@ -78,6 +78,7 @@ jobs: - name: Run default Neo4j env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/temporal_graph_tests.yml b/.github/workflows/temporal_graph_tests.yml index 8917e432a..60e6fe7ef 100644 --- a/.github/workflows/temporal_graph_tests.yml +++ b/.github/workflows/temporal_graph_tests.yml @@ -72,6 +72,7 @@ jobs: - name: Run Temporal Graph with Neo4j (lancedb + sqlite) env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.OPENAI_MODEL }} LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }} LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -123,6 +124,7 @@ jobs: - name: Run Temporal Graph with Kuzu (postgres + pgvector) env: ENV: dev + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.OPENAI_MODEL }} LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }} LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -189,6 +191,7 @@ jobs: - name: Run Temporal Graph with Neo4j (postgres + pgvector) env: ENV: dev + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.OPENAI_MODEL }} LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }} LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/.github/workflows/vector_db_tests.yml b/.github/workflows/vector_db_tests.yml index 06b58c962..65b70abe5 100644 --- a/.github/workflows/vector_db_tests.yml +++ b/.github/workflows/vector_db_tests.yml @@ -92,6 +92,7 @@ jobs: - name: Run PGVector Tests env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -127,4 +128,4 @@ jobs: EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/test_lancedb.py \ No newline at end of file + run: uv run python ./cognee/tests/test_lancedb.py diff --git a/.github/workflows/weighted_edges_tests.yml b/.github/workflows/weighted_edges_tests.yml index 2b4a043bf..1c43187ad 100644 --- a/.github/workflows/weighted_edges_tests.yml +++ b/.github/workflows/weighted_edges_tests.yml @@ -94,6 +94,7 @@ jobs: - name: Run Weighted Edges Tests env: + ENABLE_BACKEND_ACCESS_CONTROL: 'false' GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }} GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }} GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }} @@ -165,5 +166,3 @@ jobs: uses: astral-sh/ruff-action@v2 with: args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py" - - \ No newline at end of file diff --git a/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py b/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py new file mode 100644 index 000000000..e15a98b7c --- /dev/null +++ b/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py @@ -0,0 +1,267 @@ +"""Expand dataset database with json connection field + +Revision ID: 46a6ce2bd2b2 +Revises: 76625596c5c3 +Create Date: 2025-11-25 17:56:28.938931 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "46a6ce2bd2b2" +down_revision: Union[str, None] = "76625596c5c3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +graph_constraint_name = "dataset_database_graph_database_name_key" +vector_constraint_name = "dataset_database_vector_database_name_key" +TABLE_NAME = "dataset_database" + + +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + +def _recreate_table_without_unique_constraint_sqlite(op, insp): + """ + SQLite cannot drop unique constraints on individual columns. We must: + 1. Create a new table without the unique constraints. + 2. Copy data from the old table. + 3. Drop the old table. + 4. Rename the new table. + """ + conn = op.get_bind() + + # Create new table definition (without unique constraints) + op.create_table( + f"{TABLE_NAME}_new", + sa.Column("owner_id", sa.UUID()), + sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False), + sa.Column("vector_database_name", sa.String(), nullable=False), + sa.Column("graph_database_name", sa.String(), nullable=False), + sa.Column("vector_database_provider", sa.String(), nullable=False), + sa.Column("graph_database_provider", sa.String(), nullable=False), + sa.Column("vector_database_url", sa.String()), + sa.Column("graph_database_url", sa.String()), + sa.Column("vector_database_key", sa.String()), + sa.Column("graph_database_key", sa.String()), + sa.Column( + "graph_database_connection_info", + sa.JSON(), + nullable=False, + server_default=sa.text("'{}'"), + ), + sa.Column( + "vector_database_connection_info", + sa.JSON(), + nullable=False, + server_default=sa.text("'{}'"), + ), + sa.Column("created_at", sa.DateTime()), + sa.Column("updated_at", sa.DateTime()), + sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"), + ) + + # Copy data into new table + conn.execute( + sa.text(f""" + INSERT INTO {TABLE_NAME}_new + SELECT + owner_id, + dataset_id, + vector_database_name, + graph_database_name, + vector_database_provider, + graph_database_provider, + vector_database_url, + graph_database_url, + vector_database_key, + graph_database_key, + COALESCE(graph_database_connection_info, '{{}}'), + COALESCE(vector_database_connection_info, '{{}}'), + created_at, + updated_at + FROM {TABLE_NAME} + """) + ) + + # Drop old table + op.drop_table(TABLE_NAME) + + # Rename new table + op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME) + + +def _recreate_table_with_unique_constraint_sqlite(op, insp): + """ + SQLite cannot drop unique constraints on individual columns. We must: + 1. Create a new table without the unique constraints. + 2. Copy data from the old table. + 3. Drop the old table. + 4. Rename the new table. + """ + conn = op.get_bind() + + # Create new table definition (without unique constraints) + op.create_table( + f"{TABLE_NAME}_new", + sa.Column("owner_id", sa.UUID()), + sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False), + sa.Column("vector_database_name", sa.String(), nullable=False, unique=True), + sa.Column("graph_database_name", sa.String(), nullable=False, unique=True), + sa.Column("vector_database_provider", sa.String(), nullable=False), + sa.Column("graph_database_provider", sa.String(), nullable=False), + sa.Column("vector_database_url", sa.String()), + sa.Column("graph_database_url", sa.String()), + sa.Column("vector_database_key", sa.String()), + sa.Column("graph_database_key", sa.String()), + sa.Column( + "graph_database_connection_info", + sa.JSON(), + nullable=False, + server_default=sa.text("'{}'"), + ), + sa.Column( + "vector_database_connection_info", + sa.JSON(), + nullable=False, + server_default=sa.text("'{}'"), + ), + sa.Column("created_at", sa.DateTime()), + sa.Column("updated_at", sa.DateTime()), + sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"), + ) + + # Copy data into new table + conn.execute( + sa.text(f""" + INSERT INTO {TABLE_NAME}_new + SELECT + owner_id, + dataset_id, + vector_database_name, + graph_database_name, + vector_database_provider, + graph_database_provider, + vector_database_url, + graph_database_url, + vector_database_key, + graph_database_key, + COALESCE(graph_database_connection_info, '{{}}'), + COALESCE(vector_database_connection_info, '{{}}'), + created_at, + updated_at + FROM {TABLE_NAME} + """) + ) + + # Drop old table + op.drop_table(TABLE_NAME) + + # Rename new table + op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME) + + +def upgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + unique_constraints = insp.get_unique_constraints(TABLE_NAME) + + vector_database_connection_info_column = _get_column( + insp, "dataset_database", "vector_database_connection_info" + ) + if not vector_database_connection_info_column: + op.add_column( + "dataset_database", + sa.Column( + "vector_database_connection_info", + sa.JSON(), + unique=False, + nullable=False, + server_default=sa.text("'{}'"), + ), + ) + + graph_database_connection_info_column = _get_column( + insp, "dataset_database", "graph_database_connection_info" + ) + if not graph_database_connection_info_column: + op.add_column( + "dataset_database", + sa.Column( + "graph_database_connection_info", + sa.JSON(), + unique=False, + nullable=False, + server_default=sa.text("'{}'"), + ), + ) + + with op.batch_alter_table("dataset_database", schema=None) as batch_op: + # Drop the unique constraint to make unique=False + graph_constraint_to_drop = None + for uc in unique_constraints: + # Check if the constraint covers ONLY the target column + if uc["name"] == graph_constraint_name: + graph_constraint_to_drop = uc["name"] + break + + vector_constraint_to_drop = None + for uc in unique_constraints: + # Check if the constraint covers ONLY the target column + if uc["name"] == vector_constraint_name: + vector_constraint_to_drop = uc["name"] + break + + if ( + vector_constraint_to_drop + and graph_constraint_to_drop + and op.get_context().dialect.name == "postgresql" + ): + # PostgreSQL + batch_op.drop_constraint(graph_constraint_name, type_="unique") + batch_op.drop_constraint(vector_constraint_name, type_="unique") + + if op.get_context().dialect.name == "sqlite": + conn = op.get_bind() + # Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly + # So we need to check for them and drop them by recreating the table (altering column also won't work) + result = conn.execute(sa.text("PRAGMA index_list('dataset_database')")) + rows = result.fetchall() + unique_auto_indexes = [row for row in rows if row[3] == "u"] + for row in unique_auto_indexes: + result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')")) + index_info = result.fetchall() + if index_info[0][2] == "vector_database_name": + # In case a unique index exists on vector_database_name, drop it and the graph_database_name one + _recreate_table_without_unique_constraint_sqlite(op, insp) + + +def downgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + if op.get_context().dialect.name == "sqlite": + _recreate_table_with_unique_constraint_sqlite(op, insp) + elif op.get_context().dialect.name == "postgresql": + with op.batch_alter_table("dataset_database", schema=None) as batch_op: + # Re-add the unique constraint to return to unique=True + batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"]) + + with op.batch_alter_table("dataset_database", schema=None) as batch_op: + # Re-add the unique constraint to return to unique=True + batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"]) + + op.drop_column("dataset_database", "vector_database_connection_info") + op.drop_column("dataset_database", "graph_database_connection_info") diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index 62e06fc64..6417f34f7 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -4,9 +4,10 @@ from typing import Union from uuid import UUID from cognee.base_config import get_base_config -from cognee.infrastructure.databases.vector.config import get_vectordb_context_config -from cognee.infrastructure.databases.graph.config import get_graph_context_config +from cognee.infrastructure.databases.vector.config import get_vectordb_config +from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.utils import get_or_create_dataset_database +from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info from cognee.infrastructure.files.storage.config import file_storage_config from cognee.modules.users.methods import get_user @@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None) session_user = ContextVar("session_user", default=None) -VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"] -GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"] - async def set_session_user_context_variable(user): session_user.set(user) def multi_user_support_possible(): - graph_db_config = get_graph_context_config() - vector_db_config = get_vectordb_context_config() - return ( - graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT - and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT + graph_db_config = get_graph_config() + vector_db_config = get_vectordb_config() + + graph_handler = graph_db_config.graph_dataset_database_handler + vector_handler = vector_db_config.vector_dataset_database_handler + from cognee.infrastructure.databases.dataset_database_handler import ( + supported_dataset_database_handlers, ) + if graph_handler not in supported_dataset_database_handlers: + raise EnvironmentError( + "Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n" + f"Selected graph dataset to database handler: {graph_handler}\n" + f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n" + ) + + if vector_handler not in supported_dataset_database_handlers: + raise EnvironmentError( + "Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n" + f"Selected vector dataset to database handler: {vector_handler}\n" + f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n" + ) + + if ( + supported_dataset_database_handlers[graph_handler]["handler_provider"] + != graph_db_config.graph_database_provider + ): + raise EnvironmentError( + "The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n" + f"Selected graph database provider: {graph_db_config.graph_database_provider}\n" + f"Selected graph dataset to database handler: {graph_handler}\n" + f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n" + ) + + if ( + supported_dataset_database_handlers[vector_handler]["handler_provider"] + != vector_db_config.vector_db_provider + ): + raise EnvironmentError( + "The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n" + f"Selected vector database provider: {vector_db_config.vector_db_provider}\n" + f"Selected vector dataset to database handler: {vector_handler}\n" + f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n" + ) + + return True + def backend_access_control_enabled(): backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None) @@ -41,12 +79,7 @@ def backend_access_control_enabled(): return multi_user_support_possible() elif backend_access_control.lower() == "true": # If enabled, ensure that the current graph and vector DBs can support it - multi_user_support = multi_user_support_possible() - if not multi_user_support: - raise EnvironmentError( - "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control." - ) - return True + return multi_user_support_possible() return False @@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # To ensure permissions are enforced properly all datasets will have their own databases dataset_database = await get_or_create_dataset_database(dataset, user) + # Ensure that all connection info is resolved properly + dataset_database = await resolve_dataset_database_connection_info(dataset_database) base_config = get_base_config() data_root_directory = os.path.join( @@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ ) # Set vector and graph database configuration based on dataset database information + # TODO: Add better handling of vector and graph config accross Cognee. + # LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter vector_config = { "vector_db_provider": dataset_database.vector_database_provider, "vector_db_url": dataset_database.vector_database_url, @@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ "graph_file_path": os.path.join( databases_directory_path, dataset_database.graph_database_name ), + "graph_database_username": dataset_database.graph_database_connection_info.get( + "graph_database_username", "" + ), + "graph_database_password": dataset_database.graph_database_connection_info.get( + "graph_database_password", "" + ), + "graph_dataset_database_handler": "", + "graph_database_port": "", } storage_config = { diff --git a/cognee/infrastructure/databases/dataset_database_handler/__init__.py b/cognee/infrastructure/databases/dataset_database_handler/__init__.py new file mode 100644 index 000000000..a74017113 --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/__init__.py @@ -0,0 +1,3 @@ +from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface +from .supported_dataset_database_handlers import supported_dataset_database_handlers +from .use_dataset_database_handler import use_dataset_database_handler diff --git a/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py b/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py new file mode 100644 index 000000000..a0b68e497 --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py @@ -0,0 +1,80 @@ +from typing import Optional +from uuid import UUID +from abc import ABC, abstractmethod + +from cognee.modules.users.models.User import User +from cognee.modules.users.models.DatasetDatabase import DatasetDatabase + + +class DatasetDatabaseHandlerInterface(ABC): + @classmethod + @abstractmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset. + Function can auto handle deploying of the actual database if needed, but is not necessary. + Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future. + Needed for Cognee multi-tenant/multi-user and backend access control support. + + Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database. + From which internal mapping of dataset -> database connection info will be done. + + The returned dictionary is stored verbatim in the relational database and is later passed to + resolve_dataset_connection_info() at connection time. For safe credential handling, prefer + returning only references to secrets or role identifiers, not plaintext credentials. + + Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data. + + Args: + dataset_id: UUID of the dataset if needed by the database creation logic + user: User object if needed by the database creation logic + Returns: + dict: Connection info for the created graph or vector database instance. + """ + pass + + @classmethod + async def resolve_dataset_connection_info( + cls, dataset_database: DatasetDatabase + ) -> DatasetDatabase: + """ + Resolve runtime connection details for a dataset’s backing graph/vector database. + Function is intended to be overwritten to implement custom logic for resolving connection info. + + This method is invoked right before the application opens a connection for a given dataset. + It receives the DatasetDatabase row that was persisted when create_dataset() ran and must + return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use. + Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials. + + In case of separate graph and vector database handlers, each handler should implement its own logic for resolving + connection info and only change parameters related to its appropriate database, the resolution function will then + be called one after another with the updated DatasetDatabase value from the previous function as the input. + + Typical behavior: + - If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password + or api_url/api_key), return them as-is. + - If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM + roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider + API to obtain short-lived credentials and assemble the final connection DatasetDatabase object. + - Do not persist any resolved or decrypted secrets back to the relational database. Return them only + to the caller. + + Args: + dataset_database: DatasetDatabase row from the relational database + Returns: + DatasetDatabase: Updated instance with resolved connection info + """ + return dataset_database + + @classmethod + @abstractmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None: + """ + Delete the graph or vector database for the given dataset. + Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset. + Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control. + + Args: + dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete. + """ + pass diff --git a/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py new file mode 100644 index 000000000..225e9732e --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py @@ -0,0 +1,18 @@ +from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import ( + Neo4jAuraDevDatasetDatabaseHandler, +) +from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import ( + LanceDBDatasetDatabaseHandler, +) +from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import ( + KuzuDatasetDatabaseHandler, +) + +supported_dataset_database_handlers = { + "neo4j_aura_dev": { + "handler_instance": Neo4jAuraDevDatasetDatabaseHandler, + "handler_provider": "neo4j", + }, + "lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"}, + "kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"}, +} diff --git a/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py b/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py new file mode 100644 index 000000000..bca2128ee --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py @@ -0,0 +1,10 @@ +from .supported_dataset_database_handlers import supported_dataset_database_handlers + + +def use_dataset_database_handler( + dataset_database_handler_name, dataset_database_handler, dataset_database_provider +): + supported_dataset_database_handlers[dataset_database_handler_name] = { + "handler_instance": dataset_database_handler, + "handler_provider": dataset_database_provider, + } diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index 23687b359..bcf97ebfa 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -47,6 +47,7 @@ class GraphConfig(BaseSettings): graph_filename: str = "" graph_model: object = KnowledgeGraph graph_topology: object = KnowledgeGraph + graph_dataset_database_handler: str = "kuzu" model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True) # Model validator updates graph_filename and path dynamically after class creation based on current database provider @@ -97,6 +98,7 @@ class GraphConfig(BaseSettings): "graph_model": self.graph_model, "graph_topology": self.graph_topology, "model_config": self.model_config, + "graph_dataset_database_handler": self.graph_dataset_database_handler, } def to_hashable_dict(self) -> dict: @@ -121,6 +123,7 @@ class GraphConfig(BaseSettings): "graph_database_port": self.graph_database_port, "graph_database_key": self.graph_database_key, "graph_file_path": self.graph_file_path, + "graph_dataset_database_handler": self.graph_dataset_database_handler, } diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 82e3cad6e..c37af2102 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -34,6 +34,7 @@ def create_graph_engine( graph_database_password="", graph_database_port="", graph_database_key="", + graph_dataset_database_handler="", ): """ Create a graph engine based on the specified provider type. diff --git a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py new file mode 100644 index 000000000..edc6d5c39 --- /dev/null +++ b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py @@ -0,0 +1,80 @@ +import os +from uuid import UUID +from typing import Optional + +from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine +from cognee.base_config import get_base_config +from cognee.modules.users.models import User +from cognee.modules.users.models import DatasetDatabase +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with Kuzu Dataset databases. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset. + + Args: + dataset_id: Dataset UUID + user: User object who owns the dataset and is making the request + + Returns: + dict: Connection details for the created Kuzu instance + + """ + from cognee.infrastructure.databases.graph.config import get_graph_config + + graph_config = get_graph_config() + + if graph_config.graph_database_provider != "kuzu": + raise ValueError( + "KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider." + ) + + graph_db_name = f"{dataset_id}.pkl" + graph_db_url = graph_config.graph_database_url + graph_db_key = graph_config.graph_database_key + graph_db_username = graph_config.graph_database_username + graph_db_password = graph_config.graph_database_password + + return { + "graph_database_name": graph_db_name, + "graph_database_url": graph_db_url, + "graph_database_provider": graph_config.graph_database_provider, + "graph_database_key": graph_db_key, + "graph_database_connection_info": { + "graph_database_username": graph_db_username, + "graph_database_password": graph_db_password, + }, + } + + @classmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase): + base_config = get_base_config() + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(dataset_database.owner_id) + ) + graph_file_path = os.path.join( + databases_directory_path, dataset_database.graph_database_name + ) + graph_engine = create_graph_engine( + graph_database_provider=dataset_database.graph_database_provider, + graph_database_url=dataset_database.graph_database_url, + graph_database_name=dataset_database.graph_database_name, + graph_database_key=dataset_database.graph_database_key, + graph_file_path=graph_file_path, + graph_database_username=dataset_database.graph_database_connection_info.get( + "graph_database_username", "" + ), + graph_database_password=dataset_database.graph_database_connection_info.get( + "graph_database_password", "" + ), + graph_dataset_database_handler="", + graph_database_port="", + ) + await graph_engine.delete_graph() diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py new file mode 100644 index 000000000..73f057fa8 --- /dev/null +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py @@ -0,0 +1,167 @@ +import os +import asyncio +import requests +import base64 +import hashlib +from uuid import UUID +from typing import Optional +from cryptography.fernet import Fernet + +from cognee.infrastructure.databases.graph import get_graph_config +from cognee.modules.users.models import User, DatasetDatabase +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases. + This handler creates a new Neo4j Aura instance for each Cognee dataset created. + + Improvements needed to be production ready: + - Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database, + a secret manager or a similar system should be used instead. + + Quality of life improvements: + - Allow configuration of different Neo4j Aura plans and regions. + - Requests should be made async, currently a blocking requests library is used. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset. + + Args: + dataset_id: Dataset UUID + user: User object who owns the dataset and is making the request + + Returns: + dict: Connection details for the created Neo4j instance + + """ + graph_config = get_graph_config() + + if graph_config.graph_database_provider != "neo4j": + raise ValueError( + "Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider." + ) + + graph_db_name = f"{dataset_id}" + + # Client credentials and encryption + client_id = os.environ.get("NEO4J_CLIENT_ID", None) + client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) + tenant_id = os.environ.get("NEO4J_TENANT_ID", None) + encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key") + encryption_key = base64.urlsafe_b64encode( + hashlib.sha256(encryption_env_key.encode()).digest() + ) + cipher = Fernet(encryption_key) + + if client_id is None or client_secret is None or tenant_id is None: + raise ValueError( + "NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling." + ) + + # Make the request with HTTP Basic Auth + def get_aura_token(client_id: str, client_secret: str) -> dict: + url = "https://api.neo4j.io/oauth/token" + data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded + + resp = requests.post(url, data=data, auth=(client_id, client_secret)) + resp.raise_for_status() # raises if the request failed + return resp.json() + + resp = get_aura_token(client_id, client_secret) + + url = "https://api.neo4j.io/v1/instances" + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {resp['access_token']}", + "Content-Type": "application/json", + } + + # TODO: Maybe we can allow **kwargs parameter forwarding for cases like these + # Too allow different configurations between datasets + payload = { + "version": "5", + "region": "europe-west1", + "memory": "1GB", + "name": graph_db_name[ + 0:29 + ], # TODO: Find better name to name Neo4j instance within 30 character limit + "type": "professional-db", + "tenant_id": tenant_id, + "cloud_provider": "gcp", + } + + response = requests.post(url, headers=headers, json=payload) + + graph_db_name = "neo4j" # Has to be 'neo4j' for Aura + graph_db_url = response.json()["data"]["connection_url"] + graph_db_key = resp["access_token"] + graph_db_username = response.json()["data"]["username"] + graph_db_password = response.json()["data"]["password"] + + async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict): + # Poll until the instance is running + status_url = f"https://api.neo4j.io/v1/instances/{instance_id}" + status = "" + for attempt in range(30): # Try for up to ~5 minutes + status_resp = requests.get( + status_url, headers=headers + ) # TODO: Use async requests with httpx + status = status_resp.json()["data"]["status"] + if status.lower() == "running": + return + await asyncio.sleep(10) + raise TimeoutError( + f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}" + ) + + instance_id = response.json()["data"]["id"] + await _wait_for_neo4j_instance_provisioning(instance_id, headers) + + encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode()) + encrypted_db_password_string = encrypted_db_password_bytes.decode() + + return { + "graph_database_name": graph_db_name, + "graph_database_url": graph_db_url, + "graph_database_provider": "neo4j", + "graph_database_key": graph_db_key, + "graph_database_connection_info": { + "graph_database_username": graph_db_username, + "graph_database_password": encrypted_db_password_string, + }, + } + + @classmethod + async def resolve_dataset_connection_info( + cls, dataset_database: DatasetDatabase + ) -> DatasetDatabase: + """ + Resolve and decrypt connection info for the Neo4j dataset database. + In this case, decrypt the password stored in the database. + + Args: + dataset_database: DatasetDatabase instance containing encrypted connection info. + """ + encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key") + encryption_key = base64.urlsafe_b64encode( + hashlib.sha256(encryption_env_key.encode()).digest() + ) + cipher = Fernet(encryption_key) + graph_db_password = cipher.decrypt( + dataset_database.graph_database_connection_info["graph_database_password"].encode() + ).decode() + + dataset_database.graph_database_connection_info["graph_database_password"] = ( + graph_db_password + ) + return dataset_database + + @classmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase): + pass diff --git a/cognee/infrastructure/databases/utils/__init__.py b/cognee/infrastructure/databases/utils/__init__.py index 1dfa15640..f31d1e0dc 100644 --- a/cognee/infrastructure/databases/utils/__init__.py +++ b/cognee/infrastructure/databases/utils/__init__.py @@ -1 +1,2 @@ from .get_or_create_dataset_database import get_or_create_dataset_database +from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 3684bb100..3d03a699e 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -1,11 +1,9 @@ -import os from uuid import UUID -from typing import Union +from typing import Union, Optional from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from cognee.base_config import get_base_config from cognee.modules.data.methods import create_dataset from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.vector import get_vectordb_config @@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import User +async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict: + vector_config = get_vectordb_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + return await handler["handler_instance"].create_dataset(dataset_id, user) + + +async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict: + graph_config = get_graph_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + return await handler["handler_instance"].create_dataset(dataset_id, user) + + +async def _existing_dataset_database( + dataset_id: UUID, + user: User, +) -> Optional[DatasetDatabase]: + """ + Check if a DatasetDatabase row already exists for the given owner + dataset. + Return None if it doesn't exist, return the row if it does. + Args: + dataset_id: + user: + + Returns: + DatasetDatabase or None + """ + db_engine = get_relational_engine() + + async with db_engine.get_async_session() as session: + stmt = select(DatasetDatabase).where( + DatasetDatabase.owner_id == user.id, + DatasetDatabase.dataset_id == dataset_id, + ) + existing: DatasetDatabase = await session.scalar(stmt) + return existing + + async def get_or_create_dataset_database( dataset: Union[str, UUID], user: User, @@ -25,6 +70,8 @@ async def get_or_create_dataset_database( • If the row already exists, it is fetched and returned. • Otherwise a new one is created atomically and returned. + DatasetDatabase row contains connection and provider info for vector and graph databases. + Parameters ---------- user : User @@ -36,59 +83,26 @@ async def get_or_create_dataset_database( dataset_id = await get_unique_dataset_id(dataset, user) - vector_config = get_vectordb_config() - graph_config = get_graph_config() + # If dataset is given as name make sure the dataset is created first + if isinstance(dataset, str): + async with db_engine.get_async_session() as session: + await create_dataset(dataset, user, session) - # Note: for hybrid databases both graph and vector DB name have to be the same - if graph_config.graph_database_provider == "kuzu": - graph_db_name = f"{dataset_id}.pkl" - else: - graph_db_name = f"{dataset_id}" + # If dataset database already exists return it + existing_dataset_database = await _existing_dataset_database(dataset_id, user) + if existing_dataset_database: + return existing_dataset_database - if vector_config.vector_db_provider == "lancedb": - vector_db_name = f"{dataset_id}.lance.db" - else: - vector_db_name = f"{dataset_id}" - - base_config = get_base_config() - databases_directory_path = os.path.join( - base_config.system_root_directory, "databases", str(user.id) - ) - - # Determine vector database URL - if vector_config.vector_db_provider == "lancedb": - vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name) - else: - vector_db_url = vector_config.vector_database_url - - # Determine graph database URL + graph_config_dict = await _get_graph_db_info(dataset_id, user) + vector_config_dict = await _get_vector_db_info(dataset_id, user) async with db_engine.get_async_session() as session: - # Create dataset if it doesn't exist - if isinstance(dataset, str): - dataset = await create_dataset(dataset, user, session) - - # Try to fetch an existing row first - stmt = select(DatasetDatabase).where( - DatasetDatabase.owner_id == user.id, - DatasetDatabase.dataset_id == dataset_id, - ) - existing: DatasetDatabase = await session.scalar(stmt) - if existing: - return existing - # If there are no existing rows build a new row record = DatasetDatabase( owner_id=user.id, dataset_id=dataset_id, - vector_database_name=vector_db_name, - graph_database_name=graph_db_name, - vector_database_provider=vector_config.vector_db_provider, - graph_database_provider=graph_config.graph_database_provider, - vector_database_url=vector_db_url, - graph_database_url=graph_config.graph_database_url, - vector_database_key=vector_config.vector_db_key, - graph_database_key=graph_config.graph_database_key, + **graph_config_dict, # Unpack graph db config + **vector_config_dict, # Unpack vector db config ) try: diff --git a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py new file mode 100644 index 000000000..4d8c19403 --- /dev/null +++ b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py @@ -0,0 +1,42 @@ +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.graph.config import get_graph_config +from cognee.modules.users.models.DatasetDatabase import DatasetDatabase + + +async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: + vector_config = get_vectordb_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) + + +async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: + graph_config = get_graph_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) + + +async def resolve_dataset_database_connection_info( + dataset_database: DatasetDatabase, +) -> DatasetDatabase: + """ + Resolve the connection info for the given DatasetDatabase instance. + Resolve both vector and graph database connection info and return the updated DatasetDatabase instance. + + Args: + dataset_database: DatasetDatabase instance + Returns: + DatasetDatabase instance with resolved connection info + """ + dataset_database = await _get_vector_db_connection_info(dataset_database) + dataset_database = await _get_graph_db_connection_info(dataset_database) + return dataset_database diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index 7d28f1668..86b2a0fce 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -28,6 +28,7 @@ class VectorConfig(BaseSettings): vector_db_name: str = "" vector_db_key: str = "" vector_db_provider: str = "lancedb" + vector_dataset_database_handler: str = "lancedb" model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -63,6 +64,7 @@ class VectorConfig(BaseSettings): "vector_db_name": self.vector_db_name, "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, + "vector_dataset_database_handler": self.vector_dataset_database_handler, } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index b182f084b..02e01e288 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -12,6 +12,7 @@ def create_vector_engine( vector_db_name: str, vector_db_port: str = "", vector_db_key: str = "", + vector_dataset_database_handler: str = "", ): """ Create a vector database engine based on the specified provider. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py new file mode 100644 index 000000000..f165a7ea4 --- /dev/null +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py @@ -0,0 +1,49 @@ +import os +from uuid import UUID +from typing import Optional + +from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine +from cognee.modules.users.models import User +from cognee.modules.users.models import DatasetDatabase +from cognee.base_config import get_base_config +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with LanceDB Dataset databases. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + vector_config = get_vectordb_config() + base_config = get_base_config() + + if vector_config.vector_db_provider != "lancedb": + raise ValueError( + "LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider." + ) + + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(user.id) + ) + + vector_db_name = f"{dataset_id}.lance.db" + + return { + "vector_database_provider": vector_config.vector_db_provider, + "vector_database_url": os.path.join(databases_directory_path, vector_db_name), + "vector_database_key": vector_config.vector_db_key, + "vector_database_name": vector_db_name, + } + + @classmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase): + vector_engine = create_vector_engine( + vector_db_provider=dataset_database.vector_database_provider, + vector_db_url=dataset_database.vector_database_url, + vector_db_key=dataset_database.vector_database_key, + vector_db_name=dataset_database.vector_database_name, + ) + await vector_engine.prune() diff --git a/cognee/infrastructure/databases/vector/vector_db_interface.py b/cognee/infrastructure/databases/vector/vector_db_interface.py index 3a3df62eb..12ace1a6c 100644 --- a/cognee/infrastructure/databases/vector/vector_db_interface.py +++ b/cognee/infrastructure/databases/vector/vector_db_interface.py @@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any from abc import abstractmethod from cognee.infrastructure.engine import DataPoint from .models.PayloadSchema import PayloadSchema +from uuid import UUID +from cognee.modules.users.models import User class VectorDBInterface(Protocol): @@ -217,3 +219,36 @@ class VectorDBInterface(Protocol): - Any: The schema object suitable for this vector database """ return model_type + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Return a dictionary with connection info for a vector database for the given dataset. + Function can auto handle deploying of the actual database if needed, but is not necessary. + Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future. + Needed for Cognee multi-tenant/multi-user and backend access control support. + + Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database. + From which internal mapping of dataset -> database connection info will be done. + + Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data. + + Args: + dataset_id: UUID of the dataset if needed by the database creation logic + user: User object if needed by the database creation logic + Returns: + dict: Connection info for the created vector database instance. + """ + pass + + async def delete_dataset(self, dataset_id: UUID, user: User) -> None: + """ + Delete the vector database for the given dataset. + Function should auto handle deleting of the actual database or send a request to the proper service to delete the database. + Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control. + + Args: + dataset_id: UUID of the dataset + user: User object + """ + pass diff --git a/cognee/modules/data/deletion/prune_system.py b/cognee/modules/data/deletion/prune_system.py index a1b60988f..b43cab1f7 100644 --- a/cognee/modules/data/deletion/prune_system.py +++ b/cognee/modules/data/deletion/prune_system.py @@ -1,17 +1,82 @@ +from sqlalchemy.exc import OperationalError + +from cognee.infrastructure.databases.exceptions import EntityNotFoundError +from cognee.context_global_variables import backend_access_control_enabled from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.infrastructure.databases.vector.config import get_vectordb_config +from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.shared.cache import delete_cache +from cognee.modules.users.models import DatasetDatabase +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +async def prune_graph_databases(): + async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict: + graph_config = get_graph_config() + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + return await handler["handler_instance"].delete_dataset(dataset_database) + + db_engine = get_relational_engine() + try: + data = await db_engine.get_all_data_from_table("dataset_database") + # Go through each dataset database and delete the graph database + for data_item in data: + await _prune_graph_db(data_item) + except (OperationalError, EntityNotFoundError) as e: + logger.debug( + "Skipping pruning of graph DB. Error when accessing dataset_database table: %s", + e, + ) + return + + +async def prune_vector_databases(): + async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict: + vector_config = get_vectordb_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + return await handler["handler_instance"].delete_dataset(dataset_database) + + db_engine = get_relational_engine() + try: + data = await db_engine.get_all_data_from_table("dataset_database") + # Go through each dataset database and delete the vector database + for data_item in data: + await _prune_vector_db(data_item) + except (OperationalError, EntityNotFoundError) as e: + logger.debug( + "Skipping pruning of vector DB. Error when accessing dataset_database table: %s", + e, + ) + return async def prune_system(graph=True, vector=True, metadata=True, cache=True): - if graph: + # Note: prune system should not be available through the API, it has no permission checks and will + # delete all graph and vector databases if called. It should only be used in development or testing environments. + if graph and not backend_access_control_enabled(): graph_engine = await get_graph_engine() await graph_engine.delete_graph() + elif graph and backend_access_control_enabled(): + await prune_graph_databases() - if vector: + if vector and not backend_access_control_enabled(): vector_engine = get_vector_engine() await vector_engine.prune() + elif vector and backend_access_control_enabled(): + await prune_vector_databases() if metadata: db_engine = get_relational_engine() diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py index d6d701737..7dc721d7e 100644 --- a/cognee/modules/users/methods/get_authenticated_user.py +++ b/cognee/modules/users/methods/get_authenticated_user.py @@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user") # Check environment variable to determine authentication requirement REQUIRE_AUTHENTICATION = ( - os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" - or backend_access_control_enabled() + os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true" + or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true" ) fastapi_users = get_fastapi_users() diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 25d610ab9..15964f032 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone -from sqlalchemy import Column, DateTime, String, UUID, ForeignKey +from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text from cognee.infrastructure.databases.relational import Base @@ -12,8 +12,8 @@ class DatasetDatabase(Base): UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True ) - vector_database_name = Column(String, unique=True, nullable=False) - graph_database_name = Column(String, unique=True, nullable=False) + vector_database_name = Column(String, unique=False, nullable=False) + graph_database_name = Column(String, unique=False, nullable=False) vector_database_provider = Column(String, unique=False, nullable=False) graph_database_provider = Column(String, unique=False, nullable=False) @@ -24,5 +24,14 @@ class DatasetDatabase(Base): vector_database_key = Column(String, unique=False, nullable=True) graph_database_key = Column(String, unique=False, nullable=True) + # configuration details for different database types. This would make it more flexible to add new database types + # without changing the database schema. + graph_database_connection_info = Column( + JSON, unique=False, nullable=False, server_default=text("'{}'") + ) + vector_database_connection_info = Column( + JSON, unique=False, nullable=False, server_default=text("'{}'") + ) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) diff --git a/cognee/shared/logging_utils.py b/cognee/shared/logging_utils.py index e8efde72c..70a0bd37e 100644 --- a/cognee/shared/logging_utils.py +++ b/cognee/shared/logging_utils.py @@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None): # Get a configured logger and log system information logger = structlog.get_logger(name if name else __name__) + logger.warning( + "From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation." + ) + if logs_dir is not None: logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path) diff --git a/cognee/tests/test_dataset_database_handler.py b/cognee/tests/test_dataset_database_handler.py new file mode 100644 index 000000000..be1b249d2 --- /dev/null +++ b/cognee/tests/test_dataset_database_handler.py @@ -0,0 +1,135 @@ +import asyncio +import os + +# Set custom dataset database handler environment variable +os.environ["VECTOR_DATASET_DATABASE_HANDLER"] = "custom_lancedb_handler" +os.environ["GRAPH_DATASET_DATABASE_HANDLER"] = "custom_kuzu_handler" + +import cognee +from cognee.modules.users.methods import get_default_user +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface +from cognee.shared.logging_utils import setup_logging, ERROR +from cognee.api.v1.search import SearchType + + +class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + @classmethod + async def create_dataset(cls, dataset_id, user): + import pathlib + + cognee_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler" + ) + ).resolve() + ) + databases_directory_path = os.path.join(cognee_directory_path, "databases", str(user.id)) + os.makedirs(databases_directory_path, exist_ok=True) + + vector_db_name = "test.lance.db" + + return { + "vector_database_name": vector_db_name, + "vector_database_url": os.path.join(databases_directory_path, vector_db_name), + "vector_database_provider": "lancedb", + } + + +class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + @classmethod + async def create_dataset(cls, dataset_id, user): + databases_directory_path = os.path.join("databases", str(user.id)) + os.makedirs(databases_directory_path, exist_ok=True) + + graph_db_name = "test.kuzu" + return { + "graph_database_name": graph_db_name, + "graph_database_url": os.path.join(databases_directory_path, graph_db_name), + "graph_database_provider": "kuzu", + } + + +async def main(): + import pathlib + + data_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_dataset_database_handler" + ) + ).resolve() + ) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler" + ) + ).resolve() + ) + cognee.config.system_root_directory(cognee_directory_path) + + # Add custom dataset database handler + from cognee.infrastructure.databases.dataset_database_handler.use_dataset_database_handler import ( + use_dataset_database_handler, + ) + + use_dataset_database_handler( + "custom_lancedb_handler", LanceDBTestDatasetDatabaseHandler, "lancedb" + ) + use_dataset_database_handler("custom_kuzu_handler", KuzuTestDatasetDatabaseHandler, "kuzu") + + # Create a clean slate for cognee -- reset data and system state + print("Resetting cognee data...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("Data reset complete.\n") + + # cognee knowledge graph will be created based on this text + text = """ + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. + """ + + print("Adding text to cognee:") + print(text.strip()) + + # Add the text, and make it available for cognify + await cognee.add(text) + print("Text added successfully.\n") + + # Use LLMs and cognee to create knowledge graph + await cognee.cognify() + print("Cognify process complete.\n") + + query_text = "Tell me about NLP" + print(f"Searching cognee for insights with query: '{query_text}'") + # Query cognee for insights on the added text + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=query_text + ) + + print("Search results:") + # Display results + for result_text in search_results: + print(result_text) + + default_user = await get_default_user() + # Assert that the custom database files were created based on the custom dataset database handlers + assert os.path.exists( + os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.kuzu") + ), "Graph database file not found." + assert os.path.exists( + os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.lance.db") + ), "Vector database file not found." + + +if __name__ == "__main__": + logger = setup_logging(log_level=ERROR) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens())