From ea4f58e8fa4cbae5dc896c8f8e54d815481221bc Mon Sep 17 00:00:00 2001 From: Boris Date: Tue, 26 Aug 2025 18:25:48 +0200 Subject: [PATCH] feat: migrate pipeline input validation to a layer (#1284) ## Description ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- cognee/modules/pipelines/exceptions/tasks.py | 18 +++++++++++++++++ cognee/modules/pipelines/layers/__init__.py | 1 + .../layers/validate_pipeline_tasks.py | 20 +++++++++++++++++++ .../modules/pipelines/operations/pipeline.py | 11 +++------- 4 files changed, 42 insertions(+), 8 deletions(-) create mode 100644 cognee/modules/pipelines/exceptions/tasks.py create mode 100644 cognee/modules/pipelines/layers/validate_pipeline_tasks.py diff --git a/cognee/modules/pipelines/exceptions/tasks.py b/cognee/modules/pipelines/exceptions/tasks.py new file mode 100644 index 000000000..42ad0a112 --- /dev/null +++ b/cognee/modules/pipelines/exceptions/tasks.py @@ -0,0 +1,18 @@ +from fastapi import status +from cognee.exceptions import CogneeValidationError + + +class WrongTaskTypeError(CogneeValidationError): + """ + Raised when the tasks argument is not a list of Task class instances. + """ + + def __init__( + self, + message: str = "tasks argument must be a list, containing Task class instances.", + name: str = "WrongTaskTypeError", + status_code=status.HTTP_400_BAD_REQUEST, + ): + self.message = message + self.name = name + self.status_code = status_code diff --git a/cognee/modules/pipelines/layers/__init__.py b/cognee/modules/pipelines/layers/__init__.py index e69de29bb..d0c3bd7c8 100644 --- a/cognee/modules/pipelines/layers/__init__.py +++ b/cognee/modules/pipelines/layers/__init__.py @@ -0,0 +1 @@ +from .validate_pipeline_tasks import validate_pipeline_tasks diff --git a/cognee/modules/pipelines/layers/validate_pipeline_tasks.py b/cognee/modules/pipelines/layers/validate_pipeline_tasks.py new file mode 100644 index 000000000..9342e54b7 --- /dev/null +++ b/cognee/modules/pipelines/layers/validate_pipeline_tasks.py @@ -0,0 +1,20 @@ +from ..tasks.task import Task +from ..exceptions.tasks import WrongTaskTypeError + + +def validate_pipeline_tasks(tasks: list[Task]): + """ + Validates the tasks argument to ensure it is a list of Task class instances. + + Args: + tasks (list[Task]): The list of tasks to be validated. + """ + + if not isinstance(tasks, list): + raise WrongTaskTypeError(f"tasks argument must be a list, got {type(tasks).__name__}.") + + for task in tasks: + if not isinstance(task, Task): + raise WrongTaskTypeError( + f"tasks argument must be a list of Task class instances, got {type(task).__name__} in the list." + ) diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index 7a53e0f9a..a36fd5cee 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -10,7 +10,7 @@ from cognee.modules.data.methods.get_dataset_data import get_dataset_data from cognee.modules.data.models import Data, Dataset from cognee.modules.pipelines.operations.run_tasks import run_tasks from cognee.modules.pipelines.utils import generate_pipeline_id - +from cognee.modules.pipelines.layers import validate_pipeline_tasks from cognee.modules.pipelines.tasks.task import Task from cognee.modules.users.models import User from cognee.modules.pipelines.operations import log_pipeline_run_initiated @@ -61,6 +61,8 @@ async def run_pipeline( context: dict = None, incremental_loading=False, ): + validate_pipeline_tasks(tasks) + # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -94,13 +96,6 @@ async def run_pipeline( yield process_pipeline_status return - if not isinstance(tasks, list): - raise ValueError("Tasks must be a list") - - for task in tasks: - if not isinstance(task, Task): - raise ValueError(f"Task {task} is not an instance of Task") - pipeline_run = run_tasks( tasks, dataset.id, data, user, pipeline_name, context, incremental_loading )