fix: pass context to distributed cognee tasks
This commit is contained in:
parent
5a0500254b
commit
fda0edc075
4 changed files with 20 additions and 10 deletions
|
|
@ -6,7 +6,7 @@ from typing import Any
|
|||
|
||||
|
||||
async def log_pipeline_run_complete(
|
||||
pipeline_run_id: UUID, pipeline_id: str, pipeline_name: str, dataset_id: UUID, data: Any
|
||||
pipeline_run_id: UUID, pipeline_id: UUID, pipeline_name: str, dataset_id: UUID, data: Any
|
||||
):
|
||||
if not data:
|
||||
data_info = "None"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import Any
|
|||
|
||||
async def log_pipeline_run_error(
|
||||
pipeline_run_id: UUID,
|
||||
pipeline_id: str,
|
||||
pipeline_id: UUID,
|
||||
pipeline_name: str,
|
||||
dataset_id: UUID,
|
||||
data: Any,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.pipelines.models import PipelineRun, PipelineRunStatus
|
||||
|
|
@ -7,7 +7,9 @@ from typing import Any
|
|||
from cognee.modules.pipelines.utils import generate_pipeline_run_id
|
||||
|
||||
|
||||
async def log_pipeline_run_start(pipeline_id: str, pipeline_name: str, dataset_id: UUID, data: Any):
|
||||
async def log_pipeline_run_start(
|
||||
pipeline_id: UUID, pipeline_name: str, dataset_id: UUID, data: Any
|
||||
):
|
||||
if not data:
|
||||
data_info = "None"
|
||||
elif isinstance(data, list) and all(isinstance(item, Data) for item in data):
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ except ModuleNotFoundError:
|
|||
from typing import Any, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.pipelines.models import (
|
||||
PipelineRunStarted,
|
||||
PipelineRunCompleted,
|
||||
|
|
@ -72,7 +72,12 @@ if modal:
|
|||
pipeline_name=pipeline_name,
|
||||
pipeline_id=pipeline_id,
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
context=context,
|
||||
context={
|
||||
**(context or {}),
|
||||
"user": user,
|
||||
"data": data_item,
|
||||
"dataset": dataset,
|
||||
},
|
||||
user=user,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
|
|
@ -92,16 +97,19 @@ async def run_tasks_distributed(
|
|||
if not user:
|
||||
user = await get_default_user()
|
||||
|
||||
dataset: Optional[Dataset] = None
|
||||
|
||||
# Get dataset object
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
||||
dataset = await session.get(Dataset, dataset_id)
|
||||
|
||||
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||
if not dataset:
|
||||
raise ValueError(f"Dataset ({dataset_id}) not found.")
|
||||
|
||||
pipeline_id: UUID = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||
pipeline_run_id: UUID = pipeline_run.pipeline_run_id
|
||||
|
||||
yield PipelineRunStarted(
|
||||
pipeline_run_id=pipeline_run_id,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue