fix: pass context to distributed cognee tasks

This commit is contained in:
Boris Arzentar 2025-10-14 23:19:19 +02:00
parent 5a0500254b
commit fda0edc075
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
4 changed files with 20 additions and 10 deletions

View file

@ -6,7 +6,7 @@ from typing import Any
async def log_pipeline_run_complete( async def log_pipeline_run_complete(
pipeline_run_id: UUID, pipeline_id: str, pipeline_name: str, dataset_id: UUID, data: Any pipeline_run_id: UUID, pipeline_id: UUID, pipeline_name: str, dataset_id: UUID, data: Any
): ):
if not data: if not data:
data_info = "None" data_info = "None"

View file

@ -7,7 +7,7 @@ from typing import Any
async def log_pipeline_run_error( async def log_pipeline_run_error(
pipeline_run_id: UUID, pipeline_run_id: UUID,
pipeline_id: str, pipeline_id: UUID,
pipeline_name: str, pipeline_name: str,
dataset_id: UUID, dataset_id: UUID,
data: Any, data: Any,

View file

@ -1,4 +1,4 @@
from uuid import UUID, uuid4 from uuid import UUID
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Data from cognee.modules.data.models import Data
from cognee.modules.pipelines.models import PipelineRun, PipelineRunStatus from cognee.modules.pipelines.models import PipelineRun, PipelineRunStatus
@ -7,7 +7,9 @@ from typing import Any
from cognee.modules.pipelines.utils import generate_pipeline_run_id from cognee.modules.pipelines.utils import generate_pipeline_run_id
async def log_pipeline_run_start(pipeline_id: str, pipeline_name: str, dataset_id: UUID, data: Any): async def log_pipeline_run_start(
pipeline_id: UUID, pipeline_name: str, dataset_id: UUID, data: Any
):
if not data: if not data:
data_info = "None" data_info = "None"
elif isinstance(data, list) and all(isinstance(item, Data) for item in data): elif isinstance(data, list) and all(isinstance(item, Data) for item in data):

View file

@ -6,9 +6,9 @@ except ModuleNotFoundError:
from typing import Any, List, Optional from typing import Any, List, Optional
from uuid import UUID from uuid import UUID
from cognee.modules.data.models import Dataset
from cognee.modules.pipelines.tasks.task import Task from cognee.modules.pipelines.tasks.task import Task
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.pipelines.models import ( from cognee.modules.pipelines.models import (
PipelineRunStarted, PipelineRunStarted,
PipelineRunCompleted, PipelineRunCompleted,
@ -72,7 +72,12 @@ if modal:
pipeline_name=pipeline_name, pipeline_name=pipeline_name,
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id, pipeline_run_id=pipeline_run_id,
context=context, context={
**(context or {}),
"user": user,
"data": data_item,
"dataset": dataset,
},
user=user, user=user,
incremental_loading=incremental_loading, incremental_loading=incremental_loading,
) )
@ -92,16 +97,19 @@ async def run_tasks_distributed(
if not user: if not user:
user = await get_default_user() user = await get_default_user()
dataset: Optional[Dataset] = None
# Get dataset object # Get dataset object
db_engine = get_relational_engine() db_engine = get_relational_engine()
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
dataset = await session.get(Dataset, dataset_id) dataset = await session.get(Dataset, dataset_id)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name) if not dataset:
raise ValueError(f"Dataset ({dataset_id}) not found.")
pipeline_id: UUID = generate_pipeline_id(user.id, dataset.id, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data) pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
pipeline_run_id = pipeline_run.pipeline_run_id pipeline_run_id: UUID = pipeline_run.pipeline_run_id
yield PipelineRunStarted( yield PipelineRunStarted(
pipeline_run_id=pipeline_run_id, pipeline_run_id=pipeline_run_id,