feat: migrate pipeline input validation to a layer (#1284)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
5771b36c4c
commit
ea4f58e8fa
4 changed files with 42 additions and 8 deletions
18
cognee/modules/pipelines/exceptions/tasks.py
Normal file
18
cognee/modules/pipelines/exceptions/tasks.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from fastapi import status
|
||||
from cognee.exceptions import CogneeValidationError
|
||||
|
||||
|
||||
class WrongTaskTypeError(CogneeValidationError):
|
||||
"""
|
||||
Raised when the tasks argument is not a list of Task class instances.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "tasks argument must be a list, containing Task class instances.",
|
||||
name: str = "WrongTaskTypeError",
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
):
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.status_code = status_code
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .validate_pipeline_tasks import validate_pipeline_tasks
|
||||
20
cognee/modules/pipelines/layers/validate_pipeline_tasks.py
Normal file
20
cognee/modules/pipelines/layers/validate_pipeline_tasks.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from ..tasks.task import Task
|
||||
from ..exceptions.tasks import WrongTaskTypeError
|
||||
|
||||
|
||||
def validate_pipeline_tasks(tasks: list[Task]):
|
||||
"""
|
||||
Validates the tasks argument to ensure it is a list of Task class instances.
|
||||
|
||||
Args:
|
||||
tasks (list[Task]): The list of tasks to be validated.
|
||||
"""
|
||||
|
||||
if not isinstance(tasks, list):
|
||||
raise WrongTaskTypeError(f"tasks argument must be a list, got {type(tasks).__name__}.")
|
||||
|
||||
for task in tasks:
|
||||
if not isinstance(task, Task):
|
||||
raise WrongTaskTypeError(
|
||||
f"tasks argument must be a list of Task class instances, got {type(task).__name__} in the list."
|
||||
)
|
||||
|
|
@ -10,7 +10,7 @@ from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
|||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||
|
||||
from cognee.modules.pipelines.layers import validate_pipeline_tasks
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines.operations import log_pipeline_run_initiated
|
||||
|
|
@ -61,6 +61,8 @@ async def run_pipeline(
|
|||
context: dict = None,
|
||||
incremental_loading=False,
|
||||
):
|
||||
validate_pipeline_tasks(tasks)
|
||||
|
||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
||||
|
|
@ -94,13 +96,6 @@ async def run_pipeline(
|
|||
yield process_pipeline_status
|
||||
return
|
||||
|
||||
if not isinstance(tasks, list):
|
||||
raise ValueError("Tasks must be a list")
|
||||
|
||||
for task in tasks:
|
||||
if not isinstance(task, Task):
|
||||
raise ValueError(f"Task {task} is not an instance of Task")
|
||||
|
||||
pipeline_run = run_tasks(
|
||||
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue