feat: migrate pipeline input validation to a layer (#1284)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
5771b36c4c
commit
ea4f58e8fa
4 changed files with 42 additions and 8 deletions
18
cognee/modules/pipelines/exceptions/tasks.py
Normal file
18
cognee/modules/pipelines/exceptions/tasks.py
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
from fastapi import status
|
||||||
|
from cognee.exceptions import CogneeValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class WrongTaskTypeError(CogneeValidationError):
|
||||||
|
"""
|
||||||
|
Raised when the tasks argument is not a list of Task class instances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "tasks argument must be a list, containing Task class instances.",
|
||||||
|
name: str = "WrongTaskTypeError",
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.name = name
|
||||||
|
self.status_code = status_code
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .validate_pipeline_tasks import validate_pipeline_tasks
|
||||||
20
cognee/modules/pipelines/layers/validate_pipeline_tasks.py
Normal file
20
cognee/modules/pipelines/layers/validate_pipeline_tasks.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
from ..tasks.task import Task
|
||||||
|
from ..exceptions.tasks import WrongTaskTypeError
|
||||||
|
|
||||||
|
|
||||||
|
def validate_pipeline_tasks(tasks: list[Task]):
|
||||||
|
"""
|
||||||
|
Validates the tasks argument to ensure it is a list of Task class instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list[Task]): The list of tasks to be validated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(tasks, list):
|
||||||
|
raise WrongTaskTypeError(f"tasks argument must be a list, got {type(tasks).__name__}.")
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
if not isinstance(task, Task):
|
||||||
|
raise WrongTaskTypeError(
|
||||||
|
f"tasks argument must be a list of Task class instances, got {type(task).__name__} in the list."
|
||||||
|
)
|
||||||
|
|
@ -10,7 +10,7 @@ from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||||
from cognee.modules.data.models import Data, Dataset
|
from cognee.modules.data.models import Data, Dataset
|
||||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
||||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||||
|
from cognee.modules.pipelines.layers import validate_pipeline_tasks
|
||||||
from cognee.modules.pipelines.tasks.task import Task
|
from cognee.modules.pipelines.tasks.task import Task
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.pipelines.operations import log_pipeline_run_initiated
|
from cognee.modules.pipelines.operations import log_pipeline_run_initiated
|
||||||
|
|
@ -61,6 +61,8 @@ async def run_pipeline(
|
||||||
context: dict = None,
|
context: dict = None,
|
||||||
incremental_loading=False,
|
incremental_loading=False,
|
||||||
):
|
):
|
||||||
|
validate_pipeline_tasks(tasks)
|
||||||
|
|
||||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||||
|
|
||||||
|
|
@ -94,13 +96,6 @@ async def run_pipeline(
|
||||||
yield process_pipeline_status
|
yield process_pipeline_status
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(tasks, list):
|
|
||||||
raise ValueError("Tasks must be a list")
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
if not isinstance(task, Task):
|
|
||||||
raise ValueError(f"Task {task} is not an instance of Task")
|
|
||||||
|
|
||||||
pipeline_run = run_tasks(
|
pipeline_run = run_tasks(
|
||||||
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
|
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue