diff --git a/cognee/modules/pipelines/layers/__init__.py b/cognee/modules/pipelines/layers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/modules/pipelines/layers/authorized_user_datasets.py b/cognee/modules/pipelines/layers/authorized_user_datasets.py new file mode 100644 index 000000000..f0ba013ec --- /dev/null +++ b/cognee/modules/pipelines/layers/authorized_user_datasets.py @@ -0,0 +1,40 @@ +from uuid import UUID +from typing import Union + +from cognee.modules.users.methods import get_default_user +from cognee.modules.users.models import User +from cognee.modules.data.exceptions import DatasetNotFoundError +from cognee.modules.data.methods import ( + get_authorized_existing_datasets, + load_or_create_datasets, + check_dataset_name, +) + + +async def authorized_user_datasets(user: User, datasets: Union[str, list[str], list[UUID]]): + # If no user is provided use default user + if user is None: + user = await get_default_user() + + # Convert datasets to list + if isinstance(datasets, str) or isinstance(datasets, UUID): + datasets = [datasets] + + # Get datasets user wants write permissions for (verify user has permissions if datasets are provided as well) + # NOTE: If a user wants to write to a dataset he does not own it must be provided through UUID + existing_datasets = await get_authorized_existing_datasets(datasets, "write", user) + + if not datasets: + # Get datasets from database if none sent. + authorized_datasets = existing_datasets + else: + # If dataset matches an existing Dataset (by name or id), reuse it. Otherwise, create a new Dataset. + authorized_datasets = await load_or_create_datasets(datasets, existing_datasets, user) + + if not authorized_datasets: + raise DatasetNotFoundError("There are no datasets to work with.") + + for dataset in authorized_datasets: + check_dataset_name(dataset.name) + + return user, authorized_datasets diff --git a/cognee/modules/pipelines/layers/pipeline_status_check.py b/cognee/modules/pipelines/layers/pipeline_status_check.py new file mode 100644 index 000000000..ac8abc0df --- /dev/null +++ b/cognee/modules/pipelines/layers/pipeline_status_check.py @@ -0,0 +1,43 @@ +from cognee.modules.data.models import Dataset +from cognee.modules.pipelines.models import PipelineRunStatus +from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status +from cognee.modules.pipelines.methods import get_pipeline_run_by_dataset +from cognee.shared.logging_utils import get_logger + +from cognee.modules.pipelines.models.PipelineRunInfo import ( + PipelineRunCompleted, + PipelineRunStarted, +) + +logger = get_logger(__name__) + + +async def pipeline_status_check(dataset, data, pipeline_name): + # async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests + if isinstance(dataset, Dataset): + task_status = await get_pipeline_status([dataset.id], pipeline_name) + else: + task_status = [ + PipelineRunStatus.DATASET_PROCESSING_COMPLETED + ] # TODO: this is a random assignment, find permanent solution + + if str(dataset.id) in task_status: + if task_status[str(dataset.id)] == PipelineRunStatus.DATASET_PROCESSING_STARTED: + logger.info("Dataset %s is already being processed.", dataset.id) + pipeline_run = await get_pipeline_run_by_dataset(dataset.id, pipeline_name) + yield PipelineRunStarted( + pipeline_run_id=pipeline_run.pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + payload=data, + ) + return + elif task_status[str(dataset.id)] == PipelineRunStatus.DATASET_PROCESSING_COMPLETED: + logger.info("Dataset %s is already processed.", dataset.id) + pipeline_run = await get_pipeline_run_by_dataset(dataset.id, pipeline_name) + yield PipelineRunCompleted( + pipeline_run_id=pipeline_run.pipeline_run_id, + dataset_id=dataset.id, + dataset_name=dataset.name, + ) + return diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index e52441101..e95340619 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -6,27 +6,14 @@ from cognee.shared.logging_utils import get_logger from cognee.modules.data.methods.get_dataset_data import get_dataset_data from cognee.modules.data.models import Data, Dataset from cognee.modules.pipelines.operations.run_tasks import run_tasks -from cognee.modules.pipelines.models import PipelineRunStatus from cognee.modules.pipelines.utils import generate_pipeline_id -from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status -from cognee.modules.pipelines.methods import get_pipeline_run_by_dataset from cognee.modules.pipelines.tasks.task import Task -from cognee.modules.users.methods import get_default_user from cognee.modules.users.models import User from cognee.modules.pipelines.operations import log_pipeline_run_initiated from cognee.context_global_variables import set_database_global_context_variables -from cognee.modules.data.exceptions import DatasetNotFoundError -from cognee.modules.data.methods import ( - get_authorized_existing_datasets, - load_or_create_datasets, - check_dataset_name, -) - -from cognee.modules.pipelines.models.PipelineRunInfo import ( - PipelineRunCompleted, - PipelineRunStarted, -) +from cognee.modules.pipelines.layers.authorized_user_datasets import authorized_user_datasets +from cognee.modules.pipelines.layers.pipeline_status_check import pipeline_status_check from cognee.infrastructure.databases.relational import ( create_db_and_tables as create_relational_db_and_tables, @@ -80,29 +67,9 @@ async def cognee_pipeline( await test_embedding_connection() cognee_pipeline.first_run = False # Update flag after first run - # If no user is provided use default user - if user is None: - user = await get_default_user() + user, authorized_datasets = await authorized_user_datasets(user, datasets) - # Convert datasets to list - if isinstance(datasets, str) or isinstance(datasets, UUID): - datasets = [datasets] - - # Get datasets user wants write permissions for (verify user has permissions if datasets are provided as well) - # NOTE: If a user wants to write to a dataset he does not own it must be provided through UUID - existing_datasets = await get_authorized_existing_datasets(datasets, "write", user) - - if not datasets: - # Get datasets from database if none sent. - datasets = existing_datasets - else: - # If dataset matches an existing Dataset (by name or id), reuse it. Otherwise, create a new Dataset. - datasets = await load_or_create_datasets(datasets, existing_datasets, user) - - if not datasets: - raise DatasetNotFoundError("There are no datasets to work with.") - - for dataset in datasets: + for dataset in authorized_datasets: async for run_info in run_pipeline( dataset=dataset, user=user, @@ -124,8 +91,6 @@ async def run_pipeline( context: dict = None, incremental_loading=False, ): - check_dataset_name(dataset.name) - # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -149,39 +114,11 @@ async def run_pipeline( dataset_id=dataset.id, ) - dataset_id = dataset.id - if not data: - data: list[Data] = await get_dataset_data(dataset_id=dataset_id) + data: list[Data] = await get_dataset_data(dataset_id=dataset.id) - # async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests - if isinstance(dataset, Dataset): - task_status = await get_pipeline_status([dataset_id], pipeline_name) - else: - task_status = [ - PipelineRunStatus.DATASET_PROCESSING_COMPLETED - ] # TODO: this is a random assignment, find permanent solution - - if str(dataset_id) in task_status: - if task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_STARTED: - logger.info("Dataset %s is already being processed.", dataset_id) - pipeline_run = await get_pipeline_run_by_dataset(dataset_id, pipeline_name) - yield PipelineRunStarted( - pipeline_run_id=pipeline_run.pipeline_run_id, - dataset_id=dataset.id, - dataset_name=dataset.name, - payload=data, - ) - return - elif task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_COMPLETED: - logger.info("Dataset %s is already processed.", dataset_id) - pipeline_run = await get_pipeline_run_by_dataset(dataset_id, pipeline_name) - yield PipelineRunCompleted( - pipeline_run_id=pipeline_run.pipeline_run_id, - dataset_id=dataset.id, - dataset_name=dataset.name, - ) - return + async for pipeline_status in pipeline_status_check(dataset, data, pipeline_name): + yield pipeline_status if not isinstance(tasks, list): raise ValueError("Tasks must be a list") @@ -191,7 +128,7 @@ async def run_pipeline( raise ValueError(f"Task {task} is not an instance of Task") pipeline_run = run_tasks( - tasks, dataset_id, data, user, pipeline_name, context, incremental_loading + tasks, dataset.id, data, user, pipeline_name, context, incremental_loading ) async for pipeline_run_info in pipeline_run: