diff --git a/cognee/modules/pipelines/operations/run_tasks_distributed.py b/cognee/modules/pipelines/operations/run_tasks_distributed.py index 26e9375b9..52689c021 100644 --- a/cognee/modules/pipelines/operations/run_tasks_distributed.py +++ b/cognee/modules/pipelines/operations/run_tasks_distributed.py @@ -45,7 +45,7 @@ if modal: ) async def run_tasks_on_modal( data_item, - dataset_id: UUID, + dataset: Dataset, tasks: List[Task], pipeline_name: str, pipeline_id: str, @@ -60,9 +60,6 @@ if modal: """ from cognee.infrastructure.databases.relational import get_relational_engine - async with get_relational_engine().get_async_session() as session: - dataset = await session.get(Dataset, dataset_id) - result = await run_tasks_data_item( data_item=data_item, dataset=dataset, @@ -85,28 +82,18 @@ if modal: async def run_tasks_distributed( tasks: List[Task], - dataset_id: UUID, - data: List[Any] = None, - user: User = None, + dataset: Dataset, + data: Optional[List[Any]] = None, + user: Optional[User] = None, pipeline_name: str = "unknown_pipeline", - context: dict = None, + context: Optional[dict] = None, incremental_loading: bool = False, ): if not user: user = await get_default_user() - dataset: Optional[Dataset] = None - - # Get dataset object - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - dataset = await session.get(Dataset, dataset_id) - - if not dataset: - raise ValueError(f"Dataset ({dataset_id}) not found.") - pipeline_id: UUID = generate_pipeline_id(user.id, dataset.id, pipeline_name) - pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data) + pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset.id, data) pipeline_run_id: UUID = pipeline_run.pipeline_run_id yield PipelineRunStarted( @@ -126,7 +113,7 @@ async def run_tasks_distributed( data_item_tasks = [ data, - [dataset.id] * number_of_data_items, + [dataset] * number_of_data_items, [tasks] * number_of_data_items, [pipeline_name] * number_of_data_items, [pipeline_id] * number_of_data_items, @@ -155,7 +142,7 @@ async def run_tasks_distributed( raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.") await log_pipeline_run_complete( - pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data + pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data ) yield PipelineRunCompleted( @@ -167,7 +154,7 @@ async def run_tasks_distributed( except Exception as error: await log_pipeline_run_error( - pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error + pipeline_run_id, pipeline_id, pipeline_name, dataset.id, data, error ) yield PipelineRunErrored(