diff --git a/.env.template b/.env.template index 61853b983..7defaee09 100644 --- a/.env.template +++ b/.env.template @@ -97,6 +97,8 @@ DB_NAME=cognee_db # Default (local file-based) GRAPH_DATABASE_PROVIDER="kuzu" +# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset +GRAPH_DATASET_DATABASE_HANDLER="kuzu" # -- To switch to Remote Kuzu uncomment and fill these: ------------------------------------------------------------- #GRAPH_DATABASE_PROVIDER="kuzu" @@ -121,6 +123,8 @@ VECTOR_DB_PROVIDER="lancedb" # Not needed if a cloud vector database is not used VECTOR_DB_URL= VECTOR_DB_KEY= +# Handler for multi-user access control mode, it handles how should the mapping/creation of separate DBs be handled per Cognee dataset +VECTOR_DATASET_DATABASE_HANDLER="lancedb" ################################################################################ # đŸ§© Ontology resolver settings diff --git a/.github/workflows/db_examples_tests.yml b/.github/workflows/db_examples_tests.yml index c58bc48ef..5062982d8 100644 --- a/.github/workflows/db_examples_tests.yml +++ b/.github/workflows/db_examples_tests.yml @@ -61,6 +61,7 @@ jobs: - name: Run Neo4j Example env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -142,6 +143,7 @@ jobs: - name: Run PGVector Example env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/distributed_test.yml b/.github/workflows/distributed_test.yml index 57bbb7459..3c9debfdf 100644 --- a/.github/workflows/distributed_test.yml +++ b/.github/workflows/distributed_test.yml @@ -47,6 +47,7 @@ jobs: - name: Run Distributed Cognee (Modal) env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 676699c2a..cb69e9ef6 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -147,6 +147,7 @@ jobs: - name: Run Deduplication Example env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} # Test needs OpenAI endpoint to handle multimedia OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} @@ -211,6 +212,31 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_parallel_databases.py + test-dataset-database-handler: + name: Test dataset database handlers in Cognee + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run dataset databases handler test + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_dataset_database_handler.py + test-permissions: name: Test permissions with different situations in Cognee runs-on: ubuntu-22.04 @@ -556,3 +582,30 @@ jobs: DB_USERNAME: cognee DB_PASSWORD: cognee run: uv run python ./cognee/tests/test_conversation_history.py + + run-pipeline-cache-test: + name: Test Pipeline Caching + runs-on: ubuntu-22.04 + steps: + - name: Check out + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Pipeline Cache Test + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_pipeline_cache.py diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index 1a3d868c4..f8f3e5aa3 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -72,6 +72,7 @@ jobs: - name: Run Descriptive Graph Metrics Example env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/graph_db_tests.yml b/.github/workflows/graph_db_tests.yml index b07f6232f..e9fd7f4c2 100644 --- a/.github/workflows/graph_db_tests.yml +++ b/.github/workflows/graph_db_tests.yml @@ -78,6 +78,7 @@ jobs: - name: Run default Neo4j env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} diff --git a/.github/workflows/temporal_graph_tests.yml b/.github/workflows/temporal_graph_tests.yml index 8917e432a..60e6fe7ef 100644 --- a/.github/workflows/temporal_graph_tests.yml +++ b/.github/workflows/temporal_graph_tests.yml @@ -72,6 +72,7 @@ jobs: - name: Run Temporal Graph with Neo4j (lancedb + sqlite) env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.OPENAI_MODEL }} LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }} LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -123,6 +124,7 @@ jobs: - name: Run Temporal Graph with Kuzu (postgres + pgvector) env: ENV: dev + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.OPENAI_MODEL }} LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }} LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -189,6 +191,7 @@ jobs: - name: Run Temporal Graph with Neo4j (postgres + pgvector) env: ENV: dev + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.OPENAI_MODEL }} LLM_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }} LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/.github/workflows/vector_db_tests.yml b/.github/workflows/vector_db_tests.yml index 06b58c962..65b70abe5 100644 --- a/.github/workflows/vector_db_tests.yml +++ b/.github/workflows/vector_db_tests.yml @@ -92,6 +92,7 @@ jobs: - name: Run PGVector Tests env: ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'false' LLM_MODEL: ${{ secrets.LLM_MODEL }} LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} LLM_API_KEY: ${{ secrets.LLM_API_KEY }} @@ -127,4 +128,4 @@ jobs: EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/test_lancedb.py \ No newline at end of file + run: uv run python ./cognee/tests/test_lancedb.py diff --git a/.github/workflows/weighted_edges_tests.yml b/.github/workflows/weighted_edges_tests.yml index 2b4a043bf..1c43187ad 100644 --- a/.github/workflows/weighted_edges_tests.yml +++ b/.github/workflows/weighted_edges_tests.yml @@ -94,6 +94,7 @@ jobs: - name: Run Weighted Edges Tests env: + ENABLE_BACKEND_ACCESS_CONTROL: 'false' GRAPH_DATABASE_PROVIDER: ${{ matrix.graph_db_provider }} GRAPH_DATABASE_URL: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-url || '' }} GRAPH_DATABASE_USERNAME: ${{ matrix.graph_db_provider == 'neo4j' && steps.neo4j.outputs.neo4j-username || '' }} @@ -165,5 +166,3 @@ jobs: uses: astral-sh/ruff-action@v2 with: args: "format --check cognee/modules/graph/utils/get_graph_from_model.py cognee/tests/unit/interfaces/graph/test_weighted_edges.py examples/python/weighted_edges_example.py" - - \ No newline at end of file diff --git a/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py b/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py new file mode 100644 index 000000000..25b94a724 --- /dev/null +++ b/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py @@ -0,0 +1,333 @@ +"""Expand dataset database with json connection field + +Revision ID: 46a6ce2bd2b2 +Revises: 76625596c5c3 +Create Date: 2025-11-25 17:56:28.938931 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "46a6ce2bd2b2" +down_revision: Union[str, None] = "76625596c5c3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +graph_constraint_name = "dataset_database_graph_database_name_key" +vector_constraint_name = "dataset_database_vector_database_name_key" +TABLE_NAME = "dataset_database" + + +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + +def _recreate_table_without_unique_constraint_sqlite(op, insp): + """ + SQLite cannot drop unique constraints on individual columns. We must: + 1. Create a new table without the unique constraints. + 2. Copy data from the old table. + 3. Drop the old table. + 4. Rename the new table. + """ + conn = op.get_bind() + + # Create new table definition (without unique constraints) + op.create_table( + f"{TABLE_NAME}_new", + sa.Column("owner_id", sa.UUID()), + sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False), + sa.Column("vector_database_name", sa.String(), nullable=False), + sa.Column("graph_database_name", sa.String(), nullable=False), + sa.Column("vector_database_provider", sa.String(), nullable=False), + sa.Column("graph_database_provider", sa.String(), nullable=False), + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), + sa.Column("vector_database_url", sa.String()), + sa.Column("graph_database_url", sa.String()), + sa.Column("vector_database_key", sa.String()), + sa.Column("graph_database_key", sa.String()), + sa.Column( + "graph_database_connection_info", + sa.JSON(), + nullable=False, + server_default=sa.text("'{}'"), + ), + sa.Column( + "vector_database_connection_info", + sa.JSON(), + nullable=False, + server_default=sa.text("'{}'"), + ), + sa.Column("created_at", sa.DateTime()), + sa.Column("updated_at", sa.DateTime()), + sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"), + ) + + # Copy data into new table + conn.execute( + sa.text(f""" + INSERT INTO {TABLE_NAME}_new + SELECT + owner_id, + dataset_id, + vector_database_name, + graph_database_name, + vector_database_provider, + graph_database_provider, + vector_dataset_database_handler, + graph_dataset_database_handler, + vector_database_url, + graph_database_url, + vector_database_key, + graph_database_key, + COALESCE(graph_database_connection_info, '{{}}'), + COALESCE(vector_database_connection_info, '{{}}'), + created_at, + updated_at + FROM {TABLE_NAME} + """) + ) + + # Drop old table + op.drop_table(TABLE_NAME) + + # Rename new table + op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME) + + +def _recreate_table_with_unique_constraint_sqlite(op, insp): + """ + SQLite cannot drop unique constraints on individual columns. We must: + 1. Create a new table without the unique constraints. + 2. Copy data from the old table. + 3. Drop the old table. + 4. Rename the new table. + """ + conn = op.get_bind() + + # Create new table definition (without unique constraints) + op.create_table( + f"{TABLE_NAME}_new", + sa.Column("owner_id", sa.UUID()), + sa.Column("dataset_id", sa.UUID(), primary_key=True, nullable=False), + sa.Column("vector_database_name", sa.String(), nullable=False, unique=True), + sa.Column("graph_database_name", sa.String(), nullable=False, unique=True), + sa.Column("vector_database_provider", sa.String(), nullable=False), + sa.Column("graph_database_provider", sa.String(), nullable=False), + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), + sa.Column("vector_database_url", sa.String()), + sa.Column("graph_database_url", sa.String()), + sa.Column("vector_database_key", sa.String()), + sa.Column("graph_database_key", sa.String()), + sa.Column( + "graph_database_connection_info", + sa.JSON(), + nullable=False, + server_default=sa.text("'{}'"), + ), + sa.Column( + "vector_database_connection_info", + sa.JSON(), + nullable=False, + server_default=sa.text("'{}'"), + ), + sa.Column("created_at", sa.DateTime()), + sa.Column("updated_at", sa.DateTime()), + sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["owner_id"], ["principals.id"], ondelete="CASCADE"), + ) + + # Copy data into new table + conn.execute( + sa.text(f""" + INSERT INTO {TABLE_NAME}_new + SELECT + owner_id, + dataset_id, + vector_database_name, + graph_database_name, + vector_database_provider, + graph_database_provider, + vector_dataset_database_handler, + graph_dataset_database_handler, + vector_database_url, + graph_database_url, + vector_database_key, + graph_database_key, + COALESCE(graph_database_connection_info, '{{}}'), + COALESCE(vector_database_connection_info, '{{}}'), + created_at, + updated_at + FROM {TABLE_NAME} + """) + ) + + # Drop old table + op.drop_table(TABLE_NAME) + + # Rename new table + op.rename_table(f"{TABLE_NAME}_new", TABLE_NAME) + + +def upgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + unique_constraints = insp.get_unique_constraints(TABLE_NAME) + + vector_database_connection_info_column = _get_column( + insp, "dataset_database", "vector_database_connection_info" + ) + if not vector_database_connection_info_column: + op.add_column( + "dataset_database", + sa.Column( + "vector_database_connection_info", + sa.JSON(), + unique=False, + nullable=False, + server_default=sa.text("'{}'"), + ), + ) + + vector_dataset_database_handler = _get_column( + insp, "dataset_database", "vector_dataset_database_handler" + ) + if not vector_dataset_database_handler: + # Add LanceDB as the default graph dataset database handler + op.add_column( + "dataset_database", + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + ) + + graph_database_connection_info_column = _get_column( + insp, "dataset_database", "graph_database_connection_info" + ) + if not graph_database_connection_info_column: + op.add_column( + "dataset_database", + sa.Column( + "graph_database_connection_info", + sa.JSON(), + unique=False, + nullable=False, + server_default=sa.text("'{}'"), + ), + ) + + graph_dataset_database_handler = _get_column( + insp, "dataset_database", "graph_dataset_database_handler" + ) + if not graph_dataset_database_handler: + # Add Kuzu as the default graph dataset database handler + op.add_column( + "dataset_database", + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), + ) + + with op.batch_alter_table("dataset_database", schema=None) as batch_op: + # Drop the unique constraint to make unique=False + graph_constraint_to_drop = None + for uc in unique_constraints: + # Check if the constraint covers ONLY the target column + if uc["name"] == graph_constraint_name: + graph_constraint_to_drop = uc["name"] + break + + vector_constraint_to_drop = None + for uc in unique_constraints: + # Check if the constraint covers ONLY the target column + if uc["name"] == vector_constraint_name: + vector_constraint_to_drop = uc["name"] + break + + if ( + vector_constraint_to_drop + and graph_constraint_to_drop + and op.get_context().dialect.name == "postgresql" + ): + # PostgreSQL + batch_op.drop_constraint(graph_constraint_name, type_="unique") + batch_op.drop_constraint(vector_constraint_name, type_="unique") + + if op.get_context().dialect.name == "sqlite": + conn = op.get_bind() + # Fun fact: SQLite has hidden auto indexes for unique constraints that can't be dropped or accessed directly + # So we need to check for them and drop them by recreating the table (altering column also won't work) + result = conn.execute(sa.text("PRAGMA index_list('dataset_database')")) + rows = result.fetchall() + unique_auto_indexes = [row for row in rows if row[3] == "u"] + for row in unique_auto_indexes: + result = conn.execute(sa.text(f"PRAGMA index_info('{row[1]}')")) + index_info = result.fetchall() + if index_info[0][2] == "vector_database_name": + # In case a unique index exists on vector_database_name, drop it and the graph_database_name one + _recreate_table_without_unique_constraint_sqlite(op, insp) + + +def downgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + if op.get_context().dialect.name == "sqlite": + _recreate_table_with_unique_constraint_sqlite(op, insp) + elif op.get_context().dialect.name == "postgresql": + with op.batch_alter_table("dataset_database", schema=None) as batch_op: + # Re-add the unique constraint to return to unique=True + batch_op.create_unique_constraint(graph_constraint_name, ["graph_database_name"]) + + with op.batch_alter_table("dataset_database", schema=None) as batch_op: + # Re-add the unique constraint to return to unique=True + batch_op.create_unique_constraint(vector_constraint_name, ["vector_database_name"]) + + op.drop_column("dataset_database", "vector_database_connection_info") + op.drop_column("dataset_database", "graph_database_connection_info") + op.drop_column("dataset_database", "vector_dataset_database_handler") + op.drop_column("dataset_database", "graph_dataset_database_handler") diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index a521b316b..1ea4caca4 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -205,6 +205,7 @@ async def add( pipeline_name="add_pipeline", vector_db_config=vector_db_config, graph_db_config=graph_db_config, + use_pipeline_cache=True, incremental_loading=incremental_loading, data_per_batch=data_per_batch, ): diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 8a7c97050..9862edd49 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -20,7 +20,6 @@ from cognee.modules.ontology.get_default_ontology_resolver import ( from cognee.modules.users.models import User from cognee.tasks.documents import ( - check_permissions_on_dataset, classify_documents, extract_chunks_from_documents, ) @@ -79,12 +78,11 @@ async def cognify( Processing Pipeline: 1. **Document Classification**: Identifies document types and structures - 2. **Permission Validation**: Ensures user has processing rights - 3. **Text Chunking**: Breaks content into semantically meaningful segments - 4. **Entity Extraction**: Identifies key concepts, people, places, organizations - 5. **Relationship Detection**: Discovers connections between entities - 6. **Graph Construction**: Builds semantic knowledge graph with embeddings - 7. **Content Summarization**: Creates hierarchical summaries for navigation + 2. **Text Chunking**: Breaks content into semantically meaningful segments + 3. **Entity Extraction**: Identifies key concepts, people, places, organizations + 4. **Relationship Detection**: Discovers connections between entities + 5. **Graph Construction**: Builds semantic knowledge graph with embeddings + 6. **Content Summarization**: Creates hierarchical summaries for navigation Graph Model Customization: The `graph_model` parameter allows custom knowledge structures: @@ -239,6 +237,7 @@ async def cognify( vector_db_config=vector_db_config, graph_db_config=graph_db_config, incremental_loading=incremental_loading, + use_pipeline_cache=True, pipeline_name="cognify_pipeline", data_per_batch=data_per_batch, ) @@ -278,7 +277,6 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's default_tasks = [ Task(classify_documents), - Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task( extract_chunks_from_documents, max_chunk_size=chunk_size or get_max_chunk_tokens(), @@ -313,14 +311,13 @@ async def get_temporal_tasks( The pipeline includes: 1. Document classification. - 2. Dataset permission checks (requires "write" access). - 3. Document chunking with a specified or default chunk size. - 4. Event and timestamp extraction from chunks. - 5. Knowledge graph extraction from events. - 6. Batched insertion of data points. + 2. Document chunking with a specified or default chunk size. + 3. Event and timestamp extraction from chunks. + 4. Knowledge graph extraction from events. + 5. Batched insertion of data points. Args: - user (User, optional): The user requesting task execution, used for permission checks. + user (User, optional): The user requesting task execution. chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker. chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default. chunks_per_batch (int, optional): Number of chunks to process in a single batch in Cognify @@ -333,7 +330,6 @@ async def get_temporal_tasks( temporal_tasks = [ Task(classify_documents), - Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task( extract_chunks_from_documents, max_chunk_size=chunk_size or get_max_chunk_tokens(), diff --git a/cognee/api/v1/ontologies/ontologies.py b/cognee/api/v1/ontologies/ontologies.py index 130b4a862..2a133bf8a 100644 --- a/cognee/api/v1/ontologies/ontologies.py +++ b/cognee/api/v1/ontologies/ontologies.py @@ -5,6 +5,7 @@ from pathlib import Path from datetime import datetime, timezone from typing import Optional, List from dataclasses import dataclass +from fastapi import UploadFile @dataclass @@ -45,8 +46,10 @@ class OntologyService: json.dump(metadata, f, indent=2) async def upload_ontology( - self, ontology_key: str, file, user, description: Optional[str] = None + self, ontology_key: str, file: UploadFile, user, description: Optional[str] = None ) -> OntologyMetadata: + if not file.filename: + raise ValueError("File must have a filename") if not file.filename.lower().endswith(".owl"): raise ValueError("File must be in .owl format") @@ -57,8 +60,6 @@ class OntologyService: raise ValueError(f"Ontology key '{ontology_key}' already exists") content = await file.read() - if len(content) > 10 * 1024 * 1024: - raise ValueError("File size exceeds 10MB limit") file_path = user_dir / f"{ontology_key}.owl" with open(file_path, "wb") as f: @@ -82,7 +83,11 @@ class OntologyService: ) async def upload_ontologies( - self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None + self, + ontology_key: List[str], + files: List[UploadFile], + user, + descriptions: Optional[List[str]] = None, ) -> List[OntologyMetadata]: """ Upload ontology files with their respective keys. @@ -105,47 +110,17 @@ class OntologyService: if len(set(ontology_key)) != len(ontology_key): raise ValueError("Duplicate ontology keys not allowed") - if descriptions and len(descriptions) != len(files): - raise ValueError("Number of descriptions must match number of files") - results = [] - user_dir = self._get_user_dir(str(user.id)) - metadata = self._load_metadata(user_dir) for i, (key, file) in enumerate(zip(ontology_key, files)): - if key in metadata: - raise ValueError(f"Ontology key '{key}' already exists") - - if not file.filename.lower().endswith(".owl"): - raise ValueError(f"File '{file.filename}' must be in .owl format") - - content = await file.read() - if len(content) > 10 * 1024 * 1024: - raise ValueError(f"File '{file.filename}' exceeds 10MB limit") - - file_path = user_dir / f"{key}.owl" - with open(file_path, "wb") as f: - f.write(content) - - ontology_metadata = { - "filename": file.filename, - "size_bytes": len(content), - "uploaded_at": datetime.now(timezone.utc).isoformat(), - "description": descriptions[i] if descriptions else None, - } - metadata[key] = ontology_metadata - results.append( - OntologyMetadata( + await self.upload_ontology( ontology_key=key, - filename=file.filename, - size_bytes=len(content), - uploaded_at=ontology_metadata["uploaded_at"], + file=file, + user=user, description=descriptions[i] if descriptions else None, ) ) - - self._save_metadata(user_dir, metadata) return results def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]: diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index 62e06fc64..6417f34f7 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -4,9 +4,10 @@ from typing import Union from uuid import UUID from cognee.base_config import get_base_config -from cognee.infrastructure.databases.vector.config import get_vectordb_context_config -from cognee.infrastructure.databases.graph.config import get_graph_context_config +from cognee.infrastructure.databases.vector.config import get_vectordb_config +from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.utils import get_or_create_dataset_database +from cognee.infrastructure.databases.utils import resolve_dataset_database_connection_info from cognee.infrastructure.files.storage.config import file_storage_config from cognee.modules.users.methods import get_user @@ -16,22 +17,59 @@ vector_db_config = ContextVar("vector_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None) session_user = ContextVar("session_user", default=None) -VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"] -GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"] - async def set_session_user_context_variable(user): session_user.set(user) def multi_user_support_possible(): - graph_db_config = get_graph_context_config() - vector_db_config = get_vectordb_context_config() - return ( - graph_db_config["graph_database_provider"] in GRAPH_DBS_WITH_MULTI_USER_SUPPORT - and vector_db_config["vector_db_provider"] in VECTOR_DBS_WITH_MULTI_USER_SUPPORT + graph_db_config = get_graph_config() + vector_db_config = get_vectordb_config() + + graph_handler = graph_db_config.graph_dataset_database_handler + vector_handler = vector_db_config.vector_dataset_database_handler + from cognee.infrastructure.databases.dataset_database_handler import ( + supported_dataset_database_handlers, ) + if graph_handler not in supported_dataset_database_handlers: + raise EnvironmentError( + "Unsupported graph dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n" + f"Selected graph dataset to database handler: {graph_handler}\n" + f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n" + ) + + if vector_handler not in supported_dataset_database_handlers: + raise EnvironmentError( + "Unsupported vector dataset to database handler configured. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n" + f"Selected vector dataset to database handler: {vector_handler}\n" + f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n" + ) + + if ( + supported_dataset_database_handlers[graph_handler]["handler_provider"] + != graph_db_config.graph_database_provider + ): + raise EnvironmentError( + "The selected graph dataset to database handler does not work with the configured graph database provider. Cannot add support for multi-user access control mode. Please use a supported graph dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n" + f"Selected graph database provider: {graph_db_config.graph_database_provider}\n" + f"Selected graph dataset to database handler: {graph_handler}\n" + f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n" + ) + + if ( + supported_dataset_database_handlers[vector_handler]["handler_provider"] + != vector_db_config.vector_db_provider + ): + raise EnvironmentError( + "The selected vector dataset to database handler does not work with the configured vector database provider. Cannot add support for multi-user access control mode. Please use a supported vector dataset to database handler or set the environment variables ENABLE_BACKEND_ACCESS_CONTROL to false to switch off multi-user access control mode.\n" + f"Selected vector database provider: {vector_db_config.vector_db_provider}\n" + f"Selected vector dataset to database handler: {vector_handler}\n" + f"Supported dataset to database handlers: {list(supported_dataset_database_handlers.keys())}\n" + ) + + return True + def backend_access_control_enabled(): backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None) @@ -41,12 +79,7 @@ def backend_access_control_enabled(): return multi_user_support_possible() elif backend_access_control.lower() == "true": # If enabled, ensure that the current graph and vector DBs can support it - multi_user_support = multi_user_support_possible() - if not multi_user_support: - raise EnvironmentError( - "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control." - ) - return True + return multi_user_support_possible() return False @@ -76,6 +109,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ # To ensure permissions are enforced properly all datasets will have their own databases dataset_database = await get_or_create_dataset_database(dataset, user) + # Ensure that all connection info is resolved properly + dataset_database = await resolve_dataset_database_connection_info(dataset_database) base_config = get_base_config() data_root_directory = os.path.join( @@ -86,6 +121,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ ) # Set vector and graph database configuration based on dataset database information + # TODO: Add better handling of vector and graph config accross Cognee. + # LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter vector_config = { "vector_db_provider": dataset_database.vector_database_provider, "vector_db_url": dataset_database.vector_database_url, @@ -101,6 +138,14 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ "graph_file_path": os.path.join( databases_directory_path, dataset_database.graph_database_name ), + "graph_database_username": dataset_database.graph_database_connection_info.get( + "graph_database_username", "" + ), + "graph_database_password": dataset_database.graph_database_connection_info.get( + "graph_database_password", "" + ), + "graph_dataset_database_handler": "", + "graph_database_port": "", } storage_config = { diff --git a/cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py b/cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py index edac15015..1fbc31c02 100644 --- a/cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +++ b/cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py @@ -8,7 +8,6 @@ from cognee.modules.users.models import User from cognee.shared.data_models import KnowledgeGraph from cognee.shared.utils import send_telemetry from cognee.tasks.documents import ( - check_permissions_on_dataset, classify_documents, extract_chunks_from_documents, ) @@ -31,7 +30,6 @@ async def get_cascade_graph_tasks( cognee_config = get_cognify_config() default_tasks = [ Task(classify_documents), - Task(check_permissions_on_dataset, user=user, permissions=["write"]), Task( extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens() ), # Extract text chunks based on the document type. diff --git a/cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py b/cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py index fb10c7eed..6a39a67cf 100644 --- a/cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +++ b/cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py @@ -30,8 +30,8 @@ async def get_no_summary_tasks( ontology_file_path=None, ) -> List[Task]: """Returns default tasks without summarization tasks.""" - # Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks) - base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker) + # Get base tasks (0=classify, 1=extract_chunks) + base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker) ontology_adapter = RDFLibOntologyResolver(ontology_file=ontology_file_path) @@ -51,8 +51,8 @@ async def get_just_chunks_tasks( chunk_size: int = None, chunker=TextChunker, user=None ) -> List[Task]: """Returns default tasks with only chunk extraction and data points addition.""" - # Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks) - base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker) + # Get base tasks (0=classify, 1=extract_chunks) + base_tasks = await get_default_tasks_by_indices([0, 1], chunk_size, chunker) add_data_points_task = Task(add_data_points, task_config={"batch_size": 10}) diff --git a/cognee/infrastructure/databases/dataset_database_handler/__init__.py b/cognee/infrastructure/databases/dataset_database_handler/__init__.py new file mode 100644 index 000000000..a74017113 --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/__init__.py @@ -0,0 +1,3 @@ +from .dataset_database_handler_interface import DatasetDatabaseHandlerInterface +from .supported_dataset_database_handlers import supported_dataset_database_handlers +from .use_dataset_database_handler import use_dataset_database_handler diff --git a/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py b/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py new file mode 100644 index 000000000..a0b68e497 --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py @@ -0,0 +1,80 @@ +from typing import Optional +from uuid import UUID +from abc import ABC, abstractmethod + +from cognee.modules.users.models.User import User +from cognee.modules.users.models.DatasetDatabase import DatasetDatabase + + +class DatasetDatabaseHandlerInterface(ABC): + @classmethod + @abstractmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Return a dictionary with database connection/resolution info for a graph or vector database for the given dataset. + Function can auto handle deploying of the actual database if needed, but is not necessary. + Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future. + Needed for Cognee multi-tenant/multi-user and backend access control support. + + Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database. + From which internal mapping of dataset -> database connection info will be done. + + The returned dictionary is stored verbatim in the relational database and is later passed to + resolve_dataset_connection_info() at connection time. For safe credential handling, prefer + returning only references to secrets or role identifiers, not plaintext credentials. + + Each dataset needs to map to a unique graph or vector database when backend access control is enabled to facilitate a separation of concern for data. + + Args: + dataset_id: UUID of the dataset if needed by the database creation logic + user: User object if needed by the database creation logic + Returns: + dict: Connection info for the created graph or vector database instance. + """ + pass + + @classmethod + async def resolve_dataset_connection_info( + cls, dataset_database: DatasetDatabase + ) -> DatasetDatabase: + """ + Resolve runtime connection details for a dataset’s backing graph/vector database. + Function is intended to be overwritten to implement custom logic for resolving connection info. + + This method is invoked right before the application opens a connection for a given dataset. + It receives the DatasetDatabase row that was persisted when create_dataset() ran and must + return a modified instance of DatasetDatabase with concrete connection parameters that the client/driver can use. + Do not update these new DatasetDatabase values in the relational database to avoid storing secure credentials. + + In case of separate graph and vector database handlers, each handler should implement its own logic for resolving + connection info and only change parameters related to its appropriate database, the resolution function will then + be called one after another with the updated DatasetDatabase value from the previous function as the input. + + Typical behavior: + - If the DatasetDatabase row already contains raw connection fields (e.g., host/port/db/user/password + or api_url/api_key), return them as-is. + - If the row stores only references (e.g., secret IDs, vault paths, cloud resource ARNs/IDs, IAM + roles, SSO tokens), resolve those references by calling the appropriate secret manager or provider + API to obtain short-lived credentials and assemble the final connection DatasetDatabase object. + - Do not persist any resolved or decrypted secrets back to the relational database. Return them only + to the caller. + + Args: + dataset_database: DatasetDatabase row from the relational database + Returns: + DatasetDatabase: Updated instance with resolved connection info + """ + return dataset_database + + @classmethod + @abstractmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase) -> None: + """ + Delete the graph or vector database for the given dataset. + Function should auto handle deleting of the actual database or send a request to the proper service to delete/mark the database as not needed for the given dataset. + Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control. + + Args: + dataset_database: DatasetDatabase row containing connection/resolution info for the graph or vector database to delete. + """ + pass diff --git a/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py new file mode 100644 index 000000000..225e9732e --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py @@ -0,0 +1,18 @@ +from cognee.infrastructure.databases.graph.neo4j_driver.Neo4jAuraDevDatasetDatabaseHandler import ( + Neo4jAuraDevDatasetDatabaseHandler, +) +from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandler import ( + LanceDBDatasetDatabaseHandler, +) +from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import ( + KuzuDatasetDatabaseHandler, +) + +supported_dataset_database_handlers = { + "neo4j_aura_dev": { + "handler_instance": Neo4jAuraDevDatasetDatabaseHandler, + "handler_provider": "neo4j", + }, + "lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"}, + "kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"}, +} diff --git a/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py b/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py new file mode 100644 index 000000000..bca2128ee --- /dev/null +++ b/cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py @@ -0,0 +1,10 @@ +from .supported_dataset_database_handlers import supported_dataset_database_handlers + + +def use_dataset_database_handler( + dataset_database_handler_name, dataset_database_handler, dataset_database_provider +): + supported_dataset_database_handlers[dataset_database_handler_name] = { + "handler_instance": dataset_database_handler, + "handler_provider": dataset_database_provider, + } diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index 23687b359..bcf97ebfa 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -47,6 +47,7 @@ class GraphConfig(BaseSettings): graph_filename: str = "" graph_model: object = KnowledgeGraph graph_topology: object = KnowledgeGraph + graph_dataset_database_handler: str = "kuzu" model_config = SettingsConfigDict(env_file=".env", extra="allow", populate_by_name=True) # Model validator updates graph_filename and path dynamically after class creation based on current database provider @@ -97,6 +98,7 @@ class GraphConfig(BaseSettings): "graph_model": self.graph_model, "graph_topology": self.graph_topology, "model_config": self.model_config, + "graph_dataset_database_handler": self.graph_dataset_database_handler, } def to_hashable_dict(self) -> dict: @@ -121,6 +123,7 @@ class GraphConfig(BaseSettings): "graph_database_port": self.graph_database_port, "graph_database_key": self.graph_database_key, "graph_file_path": self.graph_file_path, + "graph_dataset_database_handler": self.graph_dataset_database_handler, } diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 82e3cad6e..c37af2102 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -34,6 +34,7 @@ def create_graph_engine( graph_database_password="", graph_database_port="", graph_database_key="", + graph_dataset_database_handler="", ): """ Create a graph engine based on the specified provider type. diff --git a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py new file mode 100644 index 000000000..61ff84870 --- /dev/null +++ b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py @@ -0,0 +1,81 @@ +import os +from uuid import UUID +from typing import Optional + +from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine +from cognee.base_config import get_base_config +from cognee.modules.users.models import User +from cognee.modules.users.models import DatasetDatabase +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with Kuzu Dataset databases. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Create a new Kuzu instance for the dataset. Return connection info that will be mapped to the dataset. + + Args: + dataset_id: Dataset UUID + user: User object who owns the dataset and is making the request + + Returns: + dict: Connection details for the created Kuzu instance + + """ + from cognee.infrastructure.databases.graph.config import get_graph_config + + graph_config = get_graph_config() + + if graph_config.graph_database_provider != "kuzu": + raise ValueError( + "KuzuDatasetDatabaseHandler can only be used with Kuzu graph database provider." + ) + + graph_db_name = f"{dataset_id}.pkl" + graph_db_url = graph_config.graph_database_url + graph_db_key = graph_config.graph_database_key + graph_db_username = graph_config.graph_database_username + graph_db_password = graph_config.graph_database_password + + return { + "graph_database_name": graph_db_name, + "graph_database_url": graph_db_url, + "graph_database_provider": graph_config.graph_database_provider, + "graph_database_key": graph_db_key, + "graph_dataset_database_handler": "kuzu", + "graph_database_connection_info": { + "graph_database_username": graph_db_username, + "graph_database_password": graph_db_password, + }, + } + + @classmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase): + base_config = get_base_config() + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(dataset_database.owner_id) + ) + graph_file_path = os.path.join( + databases_directory_path, dataset_database.graph_database_name + ) + graph_engine = create_graph_engine( + graph_database_provider=dataset_database.graph_database_provider, + graph_database_url=dataset_database.graph_database_url, + graph_database_name=dataset_database.graph_database_name, + graph_database_key=dataset_database.graph_database_key, + graph_file_path=graph_file_path, + graph_database_username=dataset_database.graph_database_connection_info.get( + "graph_database_username", "" + ), + graph_database_password=dataset_database.graph_database_connection_info.get( + "graph_database_password", "" + ), + graph_dataset_database_handler="", + graph_database_port="", + ) + await graph_engine.delete_graph() diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py new file mode 100644 index 000000000..eb6cbc55a --- /dev/null +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py @@ -0,0 +1,168 @@ +import os +import asyncio +import requests +import base64 +import hashlib +from uuid import UUID +from typing import Optional +from cryptography.fernet import Fernet + +from cognee.infrastructure.databases.graph import get_graph_config +from cognee.modules.users.models import User, DatasetDatabase +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for a quick development PoC integration of Cognee multi-user and permission mode with Neo4j Aura databases. + This handler creates a new Neo4j Aura instance for each Cognee dataset created. + + Improvements needed to be production ready: + - Secret management for client credentials, currently secrets are encrypted and stored in the Cognee relational database, + a secret manager or a similar system should be used instead. + + Quality of life improvements: + - Allow configuration of different Neo4j Aura plans and regions. + - Requests should be made async, currently a blocking requests library is used. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Create a new Neo4j Aura instance for the dataset. Return connection info that will be mapped to the dataset. + + Args: + dataset_id: Dataset UUID + user: User object who owns the dataset and is making the request + + Returns: + dict: Connection details for the created Neo4j instance + + """ + graph_config = get_graph_config() + + if graph_config.graph_database_provider != "neo4j": + raise ValueError( + "Neo4jAuraDevDatasetDatabaseHandler can only be used with Neo4j graph database provider." + ) + + graph_db_name = f"{dataset_id}" + + # Client credentials and encryption + client_id = os.environ.get("NEO4J_CLIENT_ID", None) + client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None) + tenant_id = os.environ.get("NEO4J_TENANT_ID", None) + encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key") + encryption_key = base64.urlsafe_b64encode( + hashlib.sha256(encryption_env_key.encode()).digest() + ) + cipher = Fernet(encryption_key) + + if client_id is None or client_secret is None or tenant_id is None: + raise ValueError( + "NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling." + ) + + # Make the request with HTTP Basic Auth + def get_aura_token(client_id: str, client_secret: str) -> dict: + url = "https://api.neo4j.io/oauth/token" + data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded + + resp = requests.post(url, data=data, auth=(client_id, client_secret)) + resp.raise_for_status() # raises if the request failed + return resp.json() + + resp = get_aura_token(client_id, client_secret) + + url = "https://api.neo4j.io/v1/instances" + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {resp['access_token']}", + "Content-Type": "application/json", + } + + # TODO: Maybe we can allow **kwargs parameter forwarding for cases like these + # Too allow different configurations between datasets + payload = { + "version": "5", + "region": "europe-west1", + "memory": "1GB", + "name": graph_db_name[ + 0:29 + ], # TODO: Find better name to name Neo4j instance within 30 character limit + "type": "professional-db", + "tenant_id": tenant_id, + "cloud_provider": "gcp", + } + + response = requests.post(url, headers=headers, json=payload) + + graph_db_name = "neo4j" # Has to be 'neo4j' for Aura + graph_db_url = response.json()["data"]["connection_url"] + graph_db_key = resp["access_token"] + graph_db_username = response.json()["data"]["username"] + graph_db_password = response.json()["data"]["password"] + + async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict): + # Poll until the instance is running + status_url = f"https://api.neo4j.io/v1/instances/{instance_id}" + status = "" + for attempt in range(30): # Try for up to ~5 minutes + status_resp = requests.get( + status_url, headers=headers + ) # TODO: Use async requests with httpx + status = status_resp.json()["data"]["status"] + if status.lower() == "running": + return + await asyncio.sleep(10) + raise TimeoutError( + f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}" + ) + + instance_id = response.json()["data"]["id"] + await _wait_for_neo4j_instance_provisioning(instance_id, headers) + + encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode()) + encrypted_db_password_string = encrypted_db_password_bytes.decode() + + return { + "graph_database_name": graph_db_name, + "graph_database_url": graph_db_url, + "graph_database_provider": "neo4j", + "graph_database_key": graph_db_key, + "graph_dataset_database_handler": "neo4j_aura_dev", + "graph_database_connection_info": { + "graph_database_username": graph_db_username, + "graph_database_password": encrypted_db_password_string, + }, + } + + @classmethod + async def resolve_dataset_connection_info( + cls, dataset_database: DatasetDatabase + ) -> DatasetDatabase: + """ + Resolve and decrypt connection info for the Neo4j dataset database. + In this case, decrypt the password stored in the database. + + Args: + dataset_database: DatasetDatabase instance containing encrypted connection info. + """ + encryption_env_key = os.environ.get("NEO4J_ENCRYPTION_KEY", "test_key") + encryption_key = base64.urlsafe_b64encode( + hashlib.sha256(encryption_env_key.encode()).digest() + ) + cipher = Fernet(encryption_key) + graph_db_password = cipher.decrypt( + dataset_database.graph_database_connection_info["graph_database_password"].encode() + ).decode() + + dataset_database.graph_database_connection_info["graph_database_password"] = ( + graph_db_password + ) + return dataset_database + + @classmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase): + pass diff --git a/cognee/infrastructure/databases/utils/__init__.py b/cognee/infrastructure/databases/utils/__init__.py index 1dfa15640..f31d1e0dc 100644 --- a/cognee/infrastructure/databases/utils/__init__.py +++ b/cognee/infrastructure/databases/utils/__init__.py @@ -1 +1,2 @@ from .get_or_create_dataset_database import get_or_create_dataset_database +from .resolve_dataset_database_connection_info import resolve_dataset_database_connection_info diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 3684bb100..3d03a699e 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -1,11 +1,9 @@ -import os from uuid import UUID -from typing import Union +from typing import Union, Optional from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from cognee.base_config import get_base_config from cognee.modules.data.methods import create_dataset from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.vector import get_vectordb_config @@ -15,6 +13,53 @@ from cognee.modules.users.models import DatasetDatabase from cognee.modules.users.models import User +async def _get_vector_db_info(dataset_id: UUID, user: User) -> dict: + vector_config = get_vectordb_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + return await handler["handler_instance"].create_dataset(dataset_id, user) + + +async def _get_graph_db_info(dataset_id: UUID, user: User) -> dict: + graph_config = get_graph_config() + + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + return await handler["handler_instance"].create_dataset(dataset_id, user) + + +async def _existing_dataset_database( + dataset_id: UUID, + user: User, +) -> Optional[DatasetDatabase]: + """ + Check if a DatasetDatabase row already exists for the given owner + dataset. + Return None if it doesn't exist, return the row if it does. + Args: + dataset_id: + user: + + Returns: + DatasetDatabase or None + """ + db_engine = get_relational_engine() + + async with db_engine.get_async_session() as session: + stmt = select(DatasetDatabase).where( + DatasetDatabase.owner_id == user.id, + DatasetDatabase.dataset_id == dataset_id, + ) + existing: DatasetDatabase = await session.scalar(stmt) + return existing + + async def get_or_create_dataset_database( dataset: Union[str, UUID], user: User, @@ -25,6 +70,8 @@ async def get_or_create_dataset_database( • If the row already exists, it is fetched and returned. • Otherwise a new one is created atomically and returned. + DatasetDatabase row contains connection and provider info for vector and graph databases. + Parameters ---------- user : User @@ -36,59 +83,26 @@ async def get_or_create_dataset_database( dataset_id = await get_unique_dataset_id(dataset, user) - vector_config = get_vectordb_config() - graph_config = get_graph_config() + # If dataset is given as name make sure the dataset is created first + if isinstance(dataset, str): + async with db_engine.get_async_session() as session: + await create_dataset(dataset, user, session) - # Note: for hybrid databases both graph and vector DB name have to be the same - if graph_config.graph_database_provider == "kuzu": - graph_db_name = f"{dataset_id}.pkl" - else: - graph_db_name = f"{dataset_id}" + # If dataset database already exists return it + existing_dataset_database = await _existing_dataset_database(dataset_id, user) + if existing_dataset_database: + return existing_dataset_database - if vector_config.vector_db_provider == "lancedb": - vector_db_name = f"{dataset_id}.lance.db" - else: - vector_db_name = f"{dataset_id}" - - base_config = get_base_config() - databases_directory_path = os.path.join( - base_config.system_root_directory, "databases", str(user.id) - ) - - # Determine vector database URL - if vector_config.vector_db_provider == "lancedb": - vector_db_url = os.path.join(databases_directory_path, vector_config.vector_db_name) - else: - vector_db_url = vector_config.vector_database_url - - # Determine graph database URL + graph_config_dict = await _get_graph_db_info(dataset_id, user) + vector_config_dict = await _get_vector_db_info(dataset_id, user) async with db_engine.get_async_session() as session: - # Create dataset if it doesn't exist - if isinstance(dataset, str): - dataset = await create_dataset(dataset, user, session) - - # Try to fetch an existing row first - stmt = select(DatasetDatabase).where( - DatasetDatabase.owner_id == user.id, - DatasetDatabase.dataset_id == dataset_id, - ) - existing: DatasetDatabase = await session.scalar(stmt) - if existing: - return existing - # If there are no existing rows build a new row record = DatasetDatabase( owner_id=user.id, dataset_id=dataset_id, - vector_database_name=vector_db_name, - graph_database_name=graph_db_name, - vector_database_provider=vector_config.vector_db_provider, - graph_database_provider=graph_config.graph_database_provider, - vector_database_url=vector_db_url, - graph_database_url=graph_config.graph_database_url, - vector_database_key=vector_config.vector_db_key, - graph_database_key=graph_config.graph_database_key, + **graph_config_dict, # Unpack graph db config + **vector_config_dict, # Unpack vector db config ) try: diff --git a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py new file mode 100644 index 000000000..d33169642 --- /dev/null +++ b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py @@ -0,0 +1,36 @@ +from cognee.modules.users.models.DatasetDatabase import DatasetDatabase + + +async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler] + return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) + + +async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler] + return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) + + +async def resolve_dataset_database_connection_info( + dataset_database: DatasetDatabase, +) -> DatasetDatabase: + """ + Resolve the connection info for the given DatasetDatabase instance. + Resolve both vector and graph database connection info and return the updated DatasetDatabase instance. + + Args: + dataset_database: DatasetDatabase instance + Returns: + DatasetDatabase instance with resolved connection info + """ + dataset_database = await _get_vector_db_connection_info(dataset_database) + dataset_database = await _get_graph_db_connection_info(dataset_database) + return dataset_database diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index 7d28f1668..86b2a0fce 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -28,6 +28,7 @@ class VectorConfig(BaseSettings): vector_db_name: str = "" vector_db_key: str = "" vector_db_provider: str = "lancedb" + vector_dataset_database_handler: str = "lancedb" model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -63,6 +64,7 @@ class VectorConfig(BaseSettings): "vector_db_name": self.vector_db_name, "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, + "vector_dataset_database_handler": self.vector_dataset_database_handler, } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index b182f084b..02e01e288 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -12,6 +12,7 @@ def create_vector_engine( vector_db_name: str, vector_db_port: str = "", vector_db_key: str = "", + vector_dataset_database_handler: str = "", ): """ Create a vector database engine based on the specified provider. diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py new file mode 100644 index 000000000..e392b7eb8 --- /dev/null +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py @@ -0,0 +1,50 @@ +import os +from uuid import UUID +from typing import Optional + +from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine +from cognee.modules.users.models import User +from cognee.modules.users.models import DatasetDatabase +from cognee.base_config import get_base_config +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with LanceDB Dataset databases. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + vector_config = get_vectordb_config() + base_config = get_base_config() + + if vector_config.vector_db_provider != "lancedb": + raise ValueError( + "LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider." + ) + + databases_directory_path = os.path.join( + base_config.system_root_directory, "databases", str(user.id) + ) + + vector_db_name = f"{dataset_id}.lance.db" + + return { + "vector_database_provider": vector_config.vector_db_provider, + "vector_database_url": os.path.join(databases_directory_path, vector_db_name), + "vector_database_key": vector_config.vector_db_key, + "vector_database_name": vector_db_name, + "vector_dataset_database_handler": "lancedb", + } + + @classmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase): + vector_engine = create_vector_engine( + vector_db_provider=dataset_database.vector_database_provider, + vector_db_url=dataset_database.vector_database_url, + vector_db_key=dataset_database.vector_database_key, + vector_db_name=dataset_database.vector_database_name, + ) + await vector_engine.prune() diff --git a/cognee/infrastructure/databases/vector/vector_db_interface.py b/cognee/infrastructure/databases/vector/vector_db_interface.py index 3a3df62eb..12ace1a6c 100644 --- a/cognee/infrastructure/databases/vector/vector_db_interface.py +++ b/cognee/infrastructure/databases/vector/vector_db_interface.py @@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any from abc import abstractmethod from cognee.infrastructure.engine import DataPoint from .models.PayloadSchema import PayloadSchema +from uuid import UUID +from cognee.modules.users.models import User class VectorDBInterface(Protocol): @@ -217,3 +219,36 @@ class VectorDBInterface(Protocol): - Any: The schema object suitable for this vector database """ return model_type + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + """ + Return a dictionary with connection info for a vector database for the given dataset. + Function can auto handle deploying of the actual database if needed, but is not necessary. + Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future. + Needed for Cognee multi-tenant/multi-user and backend access control support. + + Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database. + From which internal mapping of dataset -> database connection info will be done. + + Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data. + + Args: + dataset_id: UUID of the dataset if needed by the database creation logic + user: User object if needed by the database creation logic + Returns: + dict: Connection info for the created vector database instance. + """ + pass + + async def delete_dataset(self, dataset_id: UUID, user: User) -> None: + """ + Delete the vector database for the given dataset. + Function should auto handle deleting of the actual database or send a request to the proper service to delete the database. + Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control. + + Args: + dataset_id: UUID of the dataset + user: User object + """ + pass diff --git a/cognee/modules/data/deletion/prune_system.py b/cognee/modules/data/deletion/prune_system.py index a1b60988f..645e1a223 100644 --- a/cognee/modules/data/deletion/prune_system.py +++ b/cognee/modules/data/deletion/prune_system.py @@ -1,17 +1,81 @@ +from sqlalchemy.exc import OperationalError + +from cognee.infrastructure.databases.exceptions import EntityNotFoundError +from cognee.context_global_variables import backend_access_control_enabled from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine from cognee.shared.cache import delete_cache +from cognee.modules.users.models import DatasetDatabase +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +async def prune_graph_databases(): + async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict: + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[ + dataset_database.graph_dataset_database_handler + ] + return await handler["handler_instance"].delete_dataset(dataset_database) + + db_engine = get_relational_engine() + try: + data = await db_engine.get_all_data_from_table("dataset_database") + # Go through each dataset database and delete the graph database + for data_item in data: + await _prune_graph_db(data_item) + except (OperationalError, EntityNotFoundError) as e: + logger.debug( + "Skipping pruning of graph DB. Error when accessing dataset_database table: %s", + e, + ) + return + + +async def prune_vector_databases(): + async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict: + from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( + supported_dataset_database_handlers, + ) + + handler = supported_dataset_database_handlers[ + dataset_database.vector_dataset_database_handler + ] + return await handler["handler_instance"].delete_dataset(dataset_database) + + db_engine = get_relational_engine() + try: + data = await db_engine.get_all_data_from_table("dataset_database") + # Go through each dataset database and delete the vector database + for data_item in data: + await _prune_vector_db(data_item) + except (OperationalError, EntityNotFoundError) as e: + logger.debug( + "Skipping pruning of vector DB. Error when accessing dataset_database table: %s", + e, + ) + return async def prune_system(graph=True, vector=True, metadata=True, cache=True): - if graph: + # Note: prune system should not be available through the API, it has no permission checks and will + # delete all graph and vector databases if called. It should only be used in development or testing environments. + if graph and not backend_access_control_enabled(): graph_engine = await get_graph_engine() await graph_engine.delete_graph() + elif graph and backend_access_control_enabled(): + await prune_graph_databases() - if vector: + if vector and not backend_access_control_enabled(): vector_engine = get_vector_engine() await vector_engine.prune() + elif vector and backend_access_control_enabled(): + await prune_vector_databases() if metadata: db_engine = get_relational_engine() diff --git a/cognee/modules/memify/memify.py b/cognee/modules/memify/memify.py index 2d9b32a1b..e60eb5a4e 100644 --- a/cognee/modules/memify/memify.py +++ b/cognee/modules/memify/memify.py @@ -12,9 +12,6 @@ from cognee.modules.users.models import User from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import ( resolve_authorized_user_datasets, ) -from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import ( - reset_dataset_pipeline_run_status, -) from cognee.modules.engine.operations.setup import setup from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks @@ -97,10 +94,6 @@ async def memify( *enrichment_tasks, ] - await reset_dataset_pipeline_run_status( - authorized_dataset.id, user, pipeline_names=["memify_pipeline"] - ) - # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background) @@ -113,6 +106,7 @@ async def memify( datasets=authorized_dataset.id, vector_db_config=vector_db_config, graph_db_config=graph_db_config, + use_pipeline_cache=False, incremental_loading=False, pipeline_name="memify_pipeline", ) diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index eb0ebe8bd..6641d3a4c 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -20,6 +20,9 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import ( from cognee.modules.pipelines.layers.check_pipeline_run_qualification import ( check_pipeline_run_qualification, ) +from cognee.modules.pipelines.models.PipelineRunInfo import ( + PipelineRunStarted, +) from typing import Any logger = get_logger("cognee.pipeline") @@ -35,6 +38,7 @@ async def run_pipeline( pipeline_name: str = "custom_pipeline", vector_db_config: dict = None, graph_db_config: dict = None, + use_pipeline_cache: bool = False, incremental_loading: bool = False, data_per_batch: int = 20, ): @@ -51,6 +55,7 @@ async def run_pipeline( data=data, pipeline_name=pipeline_name, context={"dataset": dataset}, + use_pipeline_cache=use_pipeline_cache, incremental_loading=incremental_loading, data_per_batch=data_per_batch, ): @@ -64,6 +69,7 @@ async def run_pipeline_per_dataset( data=None, pipeline_name: str = "custom_pipeline", context: dict = None, + use_pipeline_cache=False, incremental_loading=False, data_per_batch: int = 20, ): @@ -77,8 +83,18 @@ async def run_pipeline_per_dataset( if process_pipeline_status: # If pipeline was already processed or is currently being processed # return status information to async generator and finish execution - yield process_pipeline_status - return + if use_pipeline_cache: + # If pipeline caching is enabled we do not proceed with re-processing + yield process_pipeline_status + return + else: + # If pipeline caching is disabled we always return pipeline started information and proceed with re-processing + yield PipelineRunStarted( + pipeline_run_id=process_pipeline_status.pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + payload=data, + ) pipeline_run = run_tasks( tasks, diff --git a/cognee/modules/run_custom_pipeline/run_custom_pipeline.py b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py index d3df1c060..269238503 100644 --- a/cognee/modules/run_custom_pipeline/run_custom_pipeline.py +++ b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py @@ -18,6 +18,8 @@ async def run_custom_pipeline( user: User = None, vector_db_config: Optional[dict] = None, graph_db_config: Optional[dict] = None, + use_pipeline_cache: bool = False, + incremental_loading: bool = False, data_per_batch: int = 20, run_in_background: bool = False, pipeline_name: str = "custom_pipeline", @@ -40,6 +42,10 @@ async def run_custom_pipeline( user: User context for authentication and data access. Uses default if None. vector_db_config: Custom vector database configuration for embeddings storage. graph_db_config: Custom graph database configuration for relationship storage. + use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again. + Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function. + incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model). + The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash. data_per_batch: Number of data items to be processed in parallel. run_in_background: If True, starts processing asynchronously and returns immediately. If False, waits for completion before returning. @@ -63,7 +69,8 @@ async def run_custom_pipeline( datasets=dataset, vector_db_config=vector_db_config, graph_db_config=graph_db_config, - incremental_loading=False, + use_pipeline_cache=use_pipeline_cache, + incremental_loading=incremental_loading, data_per_batch=data_per_batch, pipeline_name=pipeline_name, ) diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py index d6d701737..7dc721d7e 100644 --- a/cognee/modules/users/methods/get_authenticated_user.py +++ b/cognee/modules/users/methods/get_authenticated_user.py @@ -12,8 +12,8 @@ logger = get_logger("get_authenticated_user") # Check environment variable to determine authentication requirement REQUIRE_AUTHENTICATION = ( - os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" - or backend_access_control_enabled() + os.getenv("REQUIRE_AUTHENTICATION", "true").lower() == "true" + or os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", "true").lower() == "true" ) fastapi_users = get_fastapi_users() diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 25d610ab9..08c4b5311 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone -from sqlalchemy import Column, DateTime, String, UUID, ForeignKey +from sqlalchemy import Column, DateTime, String, UUID, ForeignKey, JSON, text from cognee.infrastructure.databases.relational import Base @@ -12,17 +12,29 @@ class DatasetDatabase(Base): UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key=True, index=True ) - vector_database_name = Column(String, unique=True, nullable=False) - graph_database_name = Column(String, unique=True, nullable=False) + vector_database_name = Column(String, unique=False, nullable=False) + graph_database_name = Column(String, unique=False, nullable=False) vector_database_provider = Column(String, unique=False, nullable=False) graph_database_provider = Column(String, unique=False, nullable=False) + graph_dataset_database_handler = Column(String, unique=False, nullable=False) + vector_dataset_database_handler = Column(String, unique=False, nullable=False) + vector_database_url = Column(String, unique=False, nullable=True) graph_database_url = Column(String, unique=False, nullable=True) vector_database_key = Column(String, unique=False, nullable=True) graph_database_key = Column(String, unique=False, nullable=True) + # configuration details for different database types. This would make it more flexible to add new database types + # without changing the database schema. + graph_database_connection_info = Column( + JSON, unique=False, nullable=False, server_default=text("'{}'") + ) + vector_database_connection_info = Column( + JSON, unique=False, nullable=False, server_default=text("'{}'") + ) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) diff --git a/cognee/shared/logging_utils.py b/cognee/shared/logging_utils.py index e8efde72c..70a0bd37e 100644 --- a/cognee/shared/logging_utils.py +++ b/cognee/shared/logging_utils.py @@ -534,6 +534,10 @@ def setup_logging(log_level=None, name=None): # Get a configured logger and log system information logger = structlog.get_logger(name if name else __name__) + logger.warning( + "From version 0.5.0 onwards, Cognee will run with multi-user access control mode set to on by default. Data isolation between different users and datasets will be enforced and data created before multi-user access control mode was turned on won't be accessible by default. To disable multi-user access control mode and regain access to old data set the environment variable ENABLE_BACKEND_ACCESS_CONTROL to false before starting Cognee. For more information, please refer to the Cognee documentation." + ) + if logs_dir is not None: logger.info(f"Log file created at: {log_file_path}", log_file=log_file_path) diff --git a/cognee/tasks/documents/__init__.py b/cognee/tasks/documents/__init__.py index f4582fbe0..043625f35 100644 --- a/cognee/tasks/documents/__init__.py +++ b/cognee/tasks/documents/__init__.py @@ -1,3 +1,2 @@ from .classify_documents import classify_documents from .extract_chunks_from_documents import extract_chunks_from_documents -from .check_permissions_on_dataset import check_permissions_on_dataset diff --git a/cognee/tasks/documents/check_permissions_on_dataset.py b/cognee/tasks/documents/check_permissions_on_dataset.py deleted file mode 100644 index 01a03de5f..000000000 --- a/cognee/tasks/documents/check_permissions_on_dataset.py +++ /dev/null @@ -1,26 +0,0 @@ -from cognee.modules.data.processing.document_types import Document -from cognee.modules.users.permissions.methods import check_permission_on_dataset -from typing import List - - -async def check_permissions_on_dataset( - documents: List[Document], context: dict, user, permissions -) -> List[Document]: - """ - Validates a user's permissions on a list of documents. - - Notes: - - This function assumes that `check_permission_on_documents` raises an exception if the permission check fails. - - It is designed to validate multiple permissions in a sequential manner for the same set of documents. - - Ensure that the `Document` and `user` objects conform to the expected structure and interfaces. - """ - - for permission in permissions: - await check_permission_on_dataset( - user, - permission, - # TODO: pass dataset through argument instead of context - context["dataset"].id, - ) - - return documents diff --git a/cognee/tests/test_dataset_database_handler.py b/cognee/tests/test_dataset_database_handler.py new file mode 100644 index 000000000..e4c9b0177 --- /dev/null +++ b/cognee/tests/test_dataset_database_handler.py @@ -0,0 +1,137 @@ +import asyncio +import os + +# Set custom dataset database handler environment variable +os.environ["VECTOR_DATASET_DATABASE_HANDLER"] = "custom_lancedb_handler" +os.environ["GRAPH_DATASET_DATABASE_HANDLER"] = "custom_kuzu_handler" + +import cognee +from cognee.modules.users.methods import get_default_user +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface +from cognee.shared.logging_utils import setup_logging, ERROR +from cognee.api.v1.search import SearchType + + +class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + @classmethod + async def create_dataset(cls, dataset_id, user): + import pathlib + + cognee_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler" + ) + ).resolve() + ) + databases_directory_path = os.path.join(cognee_directory_path, "databases", str(user.id)) + os.makedirs(databases_directory_path, exist_ok=True) + + vector_db_name = "test.lance.db" + + return { + "vector_dataset_database_handler": "custom_lancedb_handler", + "vector_database_name": vector_db_name, + "vector_database_url": os.path.join(databases_directory_path, vector_db_name), + "vector_database_provider": "lancedb", + } + + +class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + @classmethod + async def create_dataset(cls, dataset_id, user): + databases_directory_path = os.path.join("databases", str(user.id)) + os.makedirs(databases_directory_path, exist_ok=True) + + graph_db_name = "test.kuzu" + return { + "graph_dataset_database_handler": "custom_kuzu_handler", + "graph_database_name": graph_db_name, + "graph_database_url": os.path.join(databases_directory_path, graph_db_name), + "graph_database_provider": "kuzu", + } + + +async def main(): + import pathlib + + data_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_dataset_database_handler" + ) + ).resolve() + ) + cognee.config.data_root_directory(data_directory_path) + cognee_directory_path = str( + pathlib.Path( + os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_dataset_database_handler" + ) + ).resolve() + ) + cognee.config.system_root_directory(cognee_directory_path) + + # Add custom dataset database handler + from cognee.infrastructure.databases.dataset_database_handler.use_dataset_database_handler import ( + use_dataset_database_handler, + ) + + use_dataset_database_handler( + "custom_lancedb_handler", LanceDBTestDatasetDatabaseHandler, "lancedb" + ) + use_dataset_database_handler("custom_kuzu_handler", KuzuTestDatasetDatabaseHandler, "kuzu") + + # Create a clean slate for cognee -- reset data and system state + print("Resetting cognee data...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("Data reset complete.\n") + + # cognee knowledge graph will be created based on this text + text = """ + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. + """ + + print("Adding text to cognee:") + print(text.strip()) + + # Add the text, and make it available for cognify + await cognee.add(text) + print("Text added successfully.\n") + + # Use LLMs and cognee to create knowledge graph + await cognee.cognify() + print("Cognify process complete.\n") + + query_text = "Tell me about NLP" + print(f"Searching cognee for insights with query: '{query_text}'") + # Query cognee for insights on the added text + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=query_text + ) + + print("Search results:") + # Display results + for result_text in search_results: + print(result_text) + + default_user = await get_default_user() + # Assert that the custom database files were created based on the custom dataset database handlers + assert os.path.exists( + os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.kuzu") + ), "Graph database file not found." + assert os.path.exists( + os.path.join(cognee_directory_path, "databases", str(default_user.id), "test.lance.db") + ), "Vector database file not found." + + +if __name__ == "__main__": + logger = setup_logging(log_level=ERROR) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) diff --git a/cognee/tests/test_pipeline_cache.py b/cognee/tests/test_pipeline_cache.py new file mode 100644 index 000000000..8cdd6aa3c --- /dev/null +++ b/cognee/tests/test_pipeline_cache.py @@ -0,0 +1,164 @@ +""" +Test suite for the pipeline_cache feature in Cognee pipelines. + +This module tests the behavior of the `pipeline_cache` parameter which controls +whether a pipeline should skip re-execution when it has already been completed +for the same dataset. + +Architecture Overview: +--------------------- +The pipeline_cache mechanism works at the dataset level: +1. When a pipeline runs, it logs its status (INITIATED -> STARTED -> COMPLETED) +2. Before each run, `check_pipeline_run_qualification()` checks the pipeline status +3. If `use_pipeline_cache=True` and status is COMPLETED/STARTED, the pipeline skips +4. If `use_pipeline_cache=False`, the pipeline always re-executes regardless of status +""" + +import pytest + +import cognee +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.pipelines import run_pipeline +from cognee.modules.users.methods import get_default_user + +from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import ( + reset_dataset_pipeline_run_status, +) +from cognee.infrastructure.databases.relational import create_db_and_tables + + +class ExecutionCounter: + """Helper class to track task execution counts.""" + + def __init__(self): + self.count = 0 + + +async def create_counting_task(data, counter: ExecutionCounter): + """Create a task that increments a counter from the ExecutionCounter instance when executed.""" + counter.count += 1 + return counter + + +class TestPipelineCache: + """Tests for basic pipeline_cache on/off behavior.""" + + @pytest.mark.asyncio + async def test_pipeline_cache_off_allows_reexecution(self): + """ + Test that with use_pipeline_cache=False, the pipeline re-executes + even when it has already completed for the dataset. + + Expected behavior: + - First run: Pipeline executes fully, task runs once + - Second run: Pipeline executes again, task runs again (total: 2 times) + """ + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await create_db_and_tables() + + counter = ExecutionCounter() + user = await get_default_user() + + tasks = [Task(create_counting_task, counter=counter)] + + # First run + pipeline_results_1 = [] + async for result in run_pipeline( + tasks=tasks, + datasets="test_dataset_cache_off", + data=["sample data"], # Data is necessary to trigger processing + user=user, + pipeline_name="test_cache_off_pipeline", + use_pipeline_cache=False, + ): + pipeline_results_1.append(result) + + first_run_count = counter.count + assert first_run_count >= 1, "Task should have executed at least once on first run" + + # Second run with pipeline_cache=False + pipeline_results_2 = [] + async for result in run_pipeline( + tasks=tasks, + datasets="test_dataset_cache_off", + data=["sample data"], # Data is necessary to trigger processing + user=user, + pipeline_name="test_cache_off_pipeline", + use_pipeline_cache=False, + ): + pipeline_results_2.append(result) + + second_run_count = counter.count + assert second_run_count > first_run_count, ( + f"With pipeline_cache=False, task should re-execute. " + f"First run: {first_run_count}, After second run: {second_run_count}" + ) + + @pytest.mark.asyncio + async def test_reset_pipeline_status_allows_reexecution_with_cache(self): + """ + Test that resetting pipeline status allows re-execution even with + pipeline_cache=True. + """ + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await create_db_and_tables() + + counter = ExecutionCounter() + user = await get_default_user() + dataset_name = "reset_status_test" + pipeline_name = "test_reset_pipeline" + + tasks = [Task(create_counting_task, counter=counter)] + + # First run + pipeline_result = [] + async for result in run_pipeline( + tasks=tasks, + datasets=dataset_name, + user=user, + data=["sample data"], # Data is necessary to trigger processing + pipeline_name=pipeline_name, + use_pipeline_cache=True, + ): + pipeline_result.append(result) + + first_run_count = counter.count + assert first_run_count >= 1 + + # Second run without reset - should skip + async for _ in run_pipeline( + tasks=tasks, + datasets=dataset_name, + user=user, + data=["sample data"], # Data is necessary to trigger processing + pipeline_name=pipeline_name, + use_pipeline_cache=True, + ): + pass + + after_second_run = counter.count + assert after_second_run == first_run_count, "Should have skipped due to cache" + + # Reset the pipeline status + await reset_dataset_pipeline_run_status( + pipeline_result[0].dataset_id, user, pipeline_names=[pipeline_name] + ) + + # Third run after reset - should execute + async for _ in run_pipeline( + tasks=tasks, + datasets=dataset_name, + user=user, + data=["sample data"], # Data is necessary to trigger processing + pipeline_name=pipeline_name, + use_pipeline_cache=True, + ): + pass + + after_reset_run = counter.count + assert after_reset_run > after_second_run, ( + f"After reset, pipeline should re-execute. " + f"Before reset: {after_second_run}, After reset run: {after_reset_run}" + ) diff --git a/examples/python/simple_example.py b/examples/python/simple_example.py index 237a8295e..9d817561a 100644 --- a/examples/python/simple_example.py +++ b/examples/python/simple_example.py @@ -32,16 +32,13 @@ async def main(): print("Cognify process steps:") print("1. Classifying the document: Determining the type and category of the input text.") print( - "2. Checking permissions: Ensuring the user has the necessary rights to process the text." + "2. Extracting text chunks: Breaking down the text into sentences or phrases for analysis." ) print( - "3. Extracting text chunks: Breaking down the text into sentences or phrases for analysis." + "3. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph." ) - print("4. Adding data points: Storing the extracted chunks for processing.") - print( - "5. Generating knowledge graph: Extracting entities and relationships to form a knowledge graph." - ) - print("6. Summarizing text: Creating concise summaries of the content for quick insights.\n") + print("4. Summarizing text: Creating concise summaries of the content for quick insights.") + print("5. Adding data points: Storing the extracted chunks for processing.\n") # Use LLMs and cognee to create knowledge graph await cognee.cognify() diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index 09c4c89be..fe6ae50ae 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -591,7 +591,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "7c431fdef4921ae0", "metadata": { "ExecuteTime": { @@ -609,7 +609,6 @@ "from cognee.modules.pipelines import run_tasks\n", "from cognee.modules.users.models import User\n", "from cognee.tasks.documents import (\n", - " check_permissions_on_dataset,\n", " classify_documents,\n", " extract_chunks_from_documents,\n", ")\n", @@ -627,7 +626,6 @@ "\n", " tasks = [\n", " Task(classify_documents),\n", - " Task(check_permissions_on_dataset, user=user, permissions=[\"write\"]),\n", " Task(\n", " extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n", " ), # Extract text chunks based on the document type.\n",