From ede884e0b0215714a86afcd38ae5800332e861df Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Fri, 12 Dec 2025 13:11:31 +0100 Subject: [PATCH 1/2] feat: make pipeline processing cache optional (#1876) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Make the pipeline cache mechanism optional, have it turned off by default but use it for add and cognify like it has been used until now ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [ x I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **New Features** * Introduced pipeline caching across ingestion, processing, and custom pipeline flows with per-run controls to enable or disable caching. * Added an option for incremental loading in custom pipeline runs. * **Behavior Changes** * One pipeline path now explicitly bypasses caching by default to always re-run when invoked. * Disabling cache forces re-processing instead of early exit; cache reset still enables re-execution. * **Tests** * Added tests validating caching, non-caching, and cache-reset re-execution behavior. * **Chores** * Added CI job to run pipeline caching tests. ✏️ Tip: You can customize this high-level summary in your review settings. --- .github/workflows/e2e_tests.yml | 27 +++ cognee/api/v1/add/add.py | 1 + cognee/api/v1/cognify/cognify.py | 1 + cognee/modules/memify/memify.py | 8 +- .../modules/pipelines/operations/pipeline.py | 20 ++- .../run_custom_pipeline.py | 9 +- cognee/tests/test_pipeline_cache.py | 164 ++++++++++++++++++ 7 files changed, 220 insertions(+), 10 deletions(-) create mode 100644 cognee/tests/test_pipeline_cache.py diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 520d93689..cb69e9ef6 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -582,3 +582,30 @@ jobs: DB_USERNAME: cognee DB_PASSWORD: cognee run: uv run python ./cognee/tests/test_conversation_history.py + + run-pipeline-cache-test: + name: Test Pipeline Caching + runs-on: ubuntu-22.04 + steps: + - name: Check out + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Pipeline Cache Test + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_pipeline_cache.py diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index a521b316b..1ea4caca4 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -205,6 +205,7 @@ async def add( pipeline_name="add_pipeline", vector_db_config=vector_db_config, graph_db_config=graph_db_config, + use_pipeline_cache=True, incremental_loading=incremental_loading, data_per_batch=data_per_batch, ): diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 9d9f7d154..9862edd49 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -237,6 +237,7 @@ async def cognify( vector_db_config=vector_db_config, graph_db_config=graph_db_config, incremental_loading=incremental_loading, + use_pipeline_cache=True, pipeline_name="cognify_pipeline", data_per_batch=data_per_batch, ) diff --git a/cognee/modules/memify/memify.py b/cognee/modules/memify/memify.py index 2d9b32a1b..e60eb5a4e 100644 --- a/cognee/modules/memify/memify.py +++ b/cognee/modules/memify/memify.py @@ -12,9 +12,6 @@ from cognee.modules.users.models import User from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import ( resolve_authorized_user_datasets, ) -from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import ( - reset_dataset_pipeline_run_status, -) from cognee.modules.engine.operations.setup import setup from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks @@ -97,10 +94,6 @@ async def memify( *enrichment_tasks, ] - await reset_dataset_pipeline_run_status( - authorized_dataset.id, user, pipeline_names=["memify_pipeline"] - ) - # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background) @@ -113,6 +106,7 @@ async def memify( datasets=authorized_dataset.id, vector_db_config=vector_db_config, graph_db_config=graph_db_config, + use_pipeline_cache=False, incremental_loading=False, pipeline_name="memify_pipeline", ) diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index eb0ebe8bd..6641d3a4c 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -20,6 +20,9 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import ( from cognee.modules.pipelines.layers.check_pipeline_run_qualification import ( check_pipeline_run_qualification, ) +from cognee.modules.pipelines.models.PipelineRunInfo import ( + PipelineRunStarted, +) from typing import Any logger = get_logger("cognee.pipeline") @@ -35,6 +38,7 @@ async def run_pipeline( pipeline_name: str = "custom_pipeline", vector_db_config: dict = None, graph_db_config: dict = None, + use_pipeline_cache: bool = False, incremental_loading: bool = False, data_per_batch: int = 20, ): @@ -51,6 +55,7 @@ async def run_pipeline( data=data, pipeline_name=pipeline_name, context={"dataset": dataset}, + use_pipeline_cache=use_pipeline_cache, incremental_loading=incremental_loading, data_per_batch=data_per_batch, ): @@ -64,6 +69,7 @@ async def run_pipeline_per_dataset( data=None, pipeline_name: str = "custom_pipeline", context: dict = None, + use_pipeline_cache=False, incremental_loading=False, data_per_batch: int = 20, ): @@ -77,8 +83,18 @@ async def run_pipeline_per_dataset( if process_pipeline_status: # If pipeline was already processed or is currently being processed # return status information to async generator and finish execution - yield process_pipeline_status - return + if use_pipeline_cache: + # If pipeline caching is enabled we do not proceed with re-processing + yield process_pipeline_status + return + else: + # If pipeline caching is disabled we always return pipeline started information and proceed with re-processing + yield PipelineRunStarted( + pipeline_run_id=process_pipeline_status.pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + payload=data, + ) pipeline_run = run_tasks( tasks, diff --git a/cognee/modules/run_custom_pipeline/run_custom_pipeline.py b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py index d3df1c060..269238503 100644 --- a/cognee/modules/run_custom_pipeline/run_custom_pipeline.py +++ b/cognee/modules/run_custom_pipeline/run_custom_pipeline.py @@ -18,6 +18,8 @@ async def run_custom_pipeline( user: User = None, vector_db_config: Optional[dict] = None, graph_db_config: Optional[dict] = None, + use_pipeline_cache: bool = False, + incremental_loading: bool = False, data_per_batch: int = 20, run_in_background: bool = False, pipeline_name: str = "custom_pipeline", @@ -40,6 +42,10 @@ async def run_custom_pipeline( user: User context for authentication and data access. Uses default if None. vector_db_config: Custom vector database configuration for embeddings storage. graph_db_config: Custom graph database configuration for relationship storage. + use_pipeline_cache: If True, pipelines with the same ID that are currently executing and pipelines with the same ID that were completed won't process data again. + Pipelines ID is created based on the generate_pipeline_id function. Pipeline status can be manually reset with the reset_dataset_pipeline_run_status function. + incremental_loading: If True, only new or modified data will be processed to avoid duplication. (Only works if data is used with the Cognee python Data model). + The incremental system stores and compares hashes of processed data in the Data model and skips data with the same content hash. data_per_batch: Number of data items to be processed in parallel. run_in_background: If True, starts processing asynchronously and returns immediately. If False, waits for completion before returning. @@ -63,7 +69,8 @@ async def run_custom_pipeline( datasets=dataset, vector_db_config=vector_db_config, graph_db_config=graph_db_config, - incremental_loading=False, + use_pipeline_cache=use_pipeline_cache, + incremental_loading=incremental_loading, data_per_batch=data_per_batch, pipeline_name=pipeline_name, ) diff --git a/cognee/tests/test_pipeline_cache.py b/cognee/tests/test_pipeline_cache.py new file mode 100644 index 000000000..8cdd6aa3c --- /dev/null +++ b/cognee/tests/test_pipeline_cache.py @@ -0,0 +1,164 @@ +""" +Test suite for the pipeline_cache feature in Cognee pipelines. + +This module tests the behavior of the `pipeline_cache` parameter which controls +whether a pipeline should skip re-execution when it has already been completed +for the same dataset. + +Architecture Overview: +--------------------- +The pipeline_cache mechanism works at the dataset level: +1. When a pipeline runs, it logs its status (INITIATED -> STARTED -> COMPLETED) +2. Before each run, `check_pipeline_run_qualification()` checks the pipeline status +3. If `use_pipeline_cache=True` and status is COMPLETED/STARTED, the pipeline skips +4. If `use_pipeline_cache=False`, the pipeline always re-executes regardless of status +""" + +import pytest + +import cognee +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.pipelines import run_pipeline +from cognee.modules.users.methods import get_default_user + +from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import ( + reset_dataset_pipeline_run_status, +) +from cognee.infrastructure.databases.relational import create_db_and_tables + + +class ExecutionCounter: + """Helper class to track task execution counts.""" + + def __init__(self): + self.count = 0 + + +async def create_counting_task(data, counter: ExecutionCounter): + """Create a task that increments a counter from the ExecutionCounter instance when executed.""" + counter.count += 1 + return counter + + +class TestPipelineCache: + """Tests for basic pipeline_cache on/off behavior.""" + + @pytest.mark.asyncio + async def test_pipeline_cache_off_allows_reexecution(self): + """ + Test that with use_pipeline_cache=False, the pipeline re-executes + even when it has already completed for the dataset. + + Expected behavior: + - First run: Pipeline executes fully, task runs once + - Second run: Pipeline executes again, task runs again (total: 2 times) + """ + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await create_db_and_tables() + + counter = ExecutionCounter() + user = await get_default_user() + + tasks = [Task(create_counting_task, counter=counter)] + + # First run + pipeline_results_1 = [] + async for result in run_pipeline( + tasks=tasks, + datasets="test_dataset_cache_off", + data=["sample data"], # Data is necessary to trigger processing + user=user, + pipeline_name="test_cache_off_pipeline", + use_pipeline_cache=False, + ): + pipeline_results_1.append(result) + + first_run_count = counter.count + assert first_run_count >= 1, "Task should have executed at least once on first run" + + # Second run with pipeline_cache=False + pipeline_results_2 = [] + async for result in run_pipeline( + tasks=tasks, + datasets="test_dataset_cache_off", + data=["sample data"], # Data is necessary to trigger processing + user=user, + pipeline_name="test_cache_off_pipeline", + use_pipeline_cache=False, + ): + pipeline_results_2.append(result) + + second_run_count = counter.count + assert second_run_count > first_run_count, ( + f"With pipeline_cache=False, task should re-execute. " + f"First run: {first_run_count}, After second run: {second_run_count}" + ) + + @pytest.mark.asyncio + async def test_reset_pipeline_status_allows_reexecution_with_cache(self): + """ + Test that resetting pipeline status allows re-execution even with + pipeline_cache=True. + """ + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await create_db_and_tables() + + counter = ExecutionCounter() + user = await get_default_user() + dataset_name = "reset_status_test" + pipeline_name = "test_reset_pipeline" + + tasks = [Task(create_counting_task, counter=counter)] + + # First run + pipeline_result = [] + async for result in run_pipeline( + tasks=tasks, + datasets=dataset_name, + user=user, + data=["sample data"], # Data is necessary to trigger processing + pipeline_name=pipeline_name, + use_pipeline_cache=True, + ): + pipeline_result.append(result) + + first_run_count = counter.count + assert first_run_count >= 1 + + # Second run without reset - should skip + async for _ in run_pipeline( + tasks=tasks, + datasets=dataset_name, + user=user, + data=["sample data"], # Data is necessary to trigger processing + pipeline_name=pipeline_name, + use_pipeline_cache=True, + ): + pass + + after_second_run = counter.count + assert after_second_run == first_run_count, "Should have skipped due to cache" + + # Reset the pipeline status + await reset_dataset_pipeline_run_status( + pipeline_result[0].dataset_id, user, pipeline_names=[pipeline_name] + ) + + # Third run after reset - should execute + async for _ in run_pipeline( + tasks=tasks, + datasets=dataset_name, + user=user, + data=["sample data"], # Data is necessary to trigger processing + pipeline_name=pipeline_name, + use_pipeline_cache=True, + ): + pass + + after_reset_run = counter.count + assert after_reset_run > after_second_run, ( + f"After reset, pipeline should re-execute. " + f"Before reset: {after_second_run}, After reset run: {after_reset_run}" + ) From 127d9860df55c857dd5ee1caa5dab3b3a6f43345 Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Fri, 12 Dec 2025 13:22:03 +0100 Subject: [PATCH 2/2] feat: Add dataset database handler info (#1887) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Add info on dataset database handler used for dataset database ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **New Features** * Datasets now record their assigned vector and graph database handlers, allowing per-dataset backend selection. * **Chores** * Database schema expanded to store handler identifiers per dataset. * Deletion/cleanup processes now use dataset-level handler info for accurate removal across backends. * **Tests** * Tests updated to include and validate the new handler fields in dataset creation outputs. ✏️ Tip: You can customize this high-level summary in your review settings. --- ...d2b2_expand_dataset_database_with_json_.py | 66 +++++++++++++++++++ .../graph/kuzu/KuzuDatasetDatabaseHandler.py | 1 + .../Neo4jAuraDevDatasetDatabaseHandler.py | 1 + ...esolve_dataset_database_connection_info.py | 10 +-- .../lancedb/LanceDBDatasetDatabaseHandler.py | 1 + cognee/modules/data/deletion/prune_system.py | 13 ++-- .../modules/users/models/DatasetDatabase.py | 3 + cognee/tests/test_dataset_database_handler.py | 2 + 8 files changed, 82 insertions(+), 15 deletions(-) diff --git a/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py b/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py index e15a98b7c..25b94a724 100644 --- a/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py +++ b/alembic/versions/46a6ce2bd2b2_expand_dataset_database_with_json_.py @@ -49,6 +49,20 @@ def _recreate_table_without_unique_constraint_sqlite(op, insp): sa.Column("graph_database_name", sa.String(), nullable=False), sa.Column("vector_database_provider", sa.String(), nullable=False), sa.Column("graph_database_provider", sa.String(), nullable=False), + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), sa.Column("vector_database_url", sa.String()), sa.Column("graph_database_url", sa.String()), sa.Column("vector_database_key", sa.String()), @@ -82,6 +96,8 @@ def _recreate_table_without_unique_constraint_sqlite(op, insp): graph_database_name, vector_database_provider, graph_database_provider, + vector_dataset_database_handler, + graph_dataset_database_handler, vector_database_url, graph_database_url, vector_database_key, @@ -120,6 +136,20 @@ def _recreate_table_with_unique_constraint_sqlite(op, insp): sa.Column("graph_database_name", sa.String(), nullable=False, unique=True), sa.Column("vector_database_provider", sa.String(), nullable=False), sa.Column("graph_database_provider", sa.String(), nullable=False), + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), sa.Column("vector_database_url", sa.String()), sa.Column("graph_database_url", sa.String()), sa.Column("vector_database_key", sa.String()), @@ -153,6 +183,8 @@ def _recreate_table_with_unique_constraint_sqlite(op, insp): graph_database_name, vector_database_provider, graph_database_provider, + vector_dataset_database_handler, + graph_dataset_database_handler, vector_database_url, graph_database_url, vector_database_key, @@ -193,6 +225,22 @@ def upgrade() -> None: ), ) + vector_dataset_database_handler = _get_column( + insp, "dataset_database", "vector_dataset_database_handler" + ) + if not vector_dataset_database_handler: + # Add LanceDB as the default graph dataset database handler + op.add_column( + "dataset_database", + sa.Column( + "vector_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), + ) + graph_database_connection_info_column = _get_column( insp, "dataset_database", "graph_database_connection_info" ) @@ -208,6 +256,22 @@ def upgrade() -> None: ), ) + graph_dataset_database_handler = _get_column( + insp, "dataset_database", "graph_dataset_database_handler" + ) + if not graph_dataset_database_handler: + # Add Kuzu as the default graph dataset database handler + op.add_column( + "dataset_database", + sa.Column( + "graph_dataset_database_handler", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), + ) + with op.batch_alter_table("dataset_database", schema=None) as batch_op: # Drop the unique constraint to make unique=False graph_constraint_to_drop = None @@ -265,3 +329,5 @@ def downgrade() -> None: op.drop_column("dataset_database", "vector_database_connection_info") op.drop_column("dataset_database", "graph_database_connection_info") + op.drop_column("dataset_database", "vector_dataset_database_handler") + op.drop_column("dataset_database", "graph_dataset_database_handler") diff --git a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py index edc6d5c39..61ff84870 100644 --- a/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py @@ -47,6 +47,7 @@ class KuzuDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "graph_database_url": graph_db_url, "graph_database_provider": graph_config.graph_database_provider, "graph_database_key": graph_db_key, + "graph_dataset_database_handler": "kuzu", "graph_database_connection_info": { "graph_database_username": graph_db_username, "graph_database_password": graph_db_password, diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py index 73f057fa8..eb6cbc55a 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py @@ -131,6 +131,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "graph_database_url": graph_db_url, "graph_database_provider": "neo4j", "graph_database_key": graph_db_key, + "graph_dataset_database_handler": "neo4j_aura_dev", "graph_database_connection_info": { "graph_database_username": graph_db_username, "graph_database_password": encrypted_db_password_string, diff --git a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py index 4d8c19403..d33169642 100644 --- a/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +++ b/cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py @@ -1,27 +1,21 @@ -from cognee.infrastructure.databases.vector import get_vectordb_config -from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.modules.users.models.DatasetDatabase import DatasetDatabase async def _get_vector_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: - vector_config = get_vectordb_config() - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( supported_dataset_database_handlers, ) - handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler] return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) async def _get_graph_db_connection_info(dataset_database: DatasetDatabase) -> DatasetDatabase: - graph_config = get_graph_config() - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( supported_dataset_database_handlers, ) - handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + handler = supported_dataset_database_handlers[dataset_database.graph_dataset_database_handler] return await handler["handler_instance"].resolve_dataset_connection_info(dataset_database) diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py index f165a7ea4..e392b7eb8 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py @@ -36,6 +36,7 @@ class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): "vector_database_url": os.path.join(databases_directory_path, vector_db_name), "vector_database_key": vector_config.vector_db_key, "vector_database_name": vector_db_name, + "vector_dataset_database_handler": "lancedb", } @classmethod diff --git a/cognee/modules/data/deletion/prune_system.py b/cognee/modules/data/deletion/prune_system.py index b43cab1f7..645e1a223 100644 --- a/cognee/modules/data/deletion/prune_system.py +++ b/cognee/modules/data/deletion/prune_system.py @@ -5,8 +5,6 @@ from cognee.context_global_variables import backend_access_control_enabled from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine -from cognee.infrastructure.databases.vector.config import get_vectordb_config -from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.shared.cache import delete_cache from cognee.modules.users.models import DatasetDatabase from cognee.shared.logging_utils import get_logger @@ -16,12 +14,13 @@ logger = get_logger() async def prune_graph_databases(): async def _prune_graph_db(dataset_database: DatasetDatabase) -> dict: - graph_config = get_graph_config() from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( supported_dataset_database_handlers, ) - handler = supported_dataset_database_handlers[graph_config.graph_dataset_database_handler] + handler = supported_dataset_database_handlers[ + dataset_database.graph_dataset_database_handler + ] return await handler["handler_instance"].delete_dataset(dataset_database) db_engine = get_relational_engine() @@ -40,13 +39,13 @@ async def prune_graph_databases(): async def prune_vector_databases(): async def _prune_vector_db(dataset_database: DatasetDatabase) -> dict: - vector_config = get_vectordb_config() - from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import ( supported_dataset_database_handlers, ) - handler = supported_dataset_database_handlers[vector_config.vector_dataset_database_handler] + handler = supported_dataset_database_handlers[ + dataset_database.vector_dataset_database_handler + ] return await handler["handler_instance"].delete_dataset(dataset_database) db_engine = get_relational_engine() diff --git a/cognee/modules/users/models/DatasetDatabase.py b/cognee/modules/users/models/DatasetDatabase.py index 15964f032..08c4b5311 100644 --- a/cognee/modules/users/models/DatasetDatabase.py +++ b/cognee/modules/users/models/DatasetDatabase.py @@ -18,6 +18,9 @@ class DatasetDatabase(Base): vector_database_provider = Column(String, unique=False, nullable=False) graph_database_provider = Column(String, unique=False, nullable=False) + graph_dataset_database_handler = Column(String, unique=False, nullable=False) + vector_dataset_database_handler = Column(String, unique=False, nullable=False) + vector_database_url = Column(String, unique=False, nullable=True) graph_database_url = Column(String, unique=False, nullable=True) diff --git a/cognee/tests/test_dataset_database_handler.py b/cognee/tests/test_dataset_database_handler.py index be1b249d2..e4c9b0177 100644 --- a/cognee/tests/test_dataset_database_handler.py +++ b/cognee/tests/test_dataset_database_handler.py @@ -30,6 +30,7 @@ class LanceDBTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): vector_db_name = "test.lance.db" return { + "vector_dataset_database_handler": "custom_lancedb_handler", "vector_database_name": vector_db_name, "vector_database_url": os.path.join(databases_directory_path, vector_db_name), "vector_database_provider": "lancedb", @@ -44,6 +45,7 @@ class KuzuTestDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): graph_db_name = "test.kuzu" return { + "graph_dataset_database_handler": "custom_kuzu_handler", "graph_database_name": graph_db_name, "graph_database_url": os.path.join(databases_directory_path, graph_db_name), "graph_database_provider": "kuzu",