cognee/cognee/modules/pipelines/operations/run_tasks_distributed.py
Igor Ilic 59594e01ac
fix: add missing await for getting default user (#1131)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-07-23 06:47:38 +02:00

94 lines
2.8 KiB
Python

try:
import modal
except ModuleNotFoundError:
modal = None
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.pipelines.models import (
PipelineRunStarted,
PipelineRunYield,
PipelineRunCompleted,
)
from cognee.modules.pipelines.operations import log_pipeline_run_start, log_pipeline_run_complete
from cognee.modules.pipelines.utils.generate_pipeline_id import generate_pipeline_id
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from .run_tasks_with_telemetry import run_tasks_with_telemetry
logger = get_logger("run_tasks_distributed()")
if modal:
from distributed.app import app
from distributed.modal_image import image
@app.function(
retries=3,
image=image,
timeout=86400,
max_containers=50,
secrets=[modal.Secret.from_name("distributed_cognee")],
)
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
pipeline_run = run_tasks_with_telemetry(tasks, data_item, user, pipeline_name, context)
run_info = None
async for pipeline_run_info in pipeline_run:
run_info = pipeline_run_info
return run_info
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
if not user:
user = await get_default_user()
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
dataset = await session.get(Dataset, dataset_id)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
pipeline_run_id = pipeline_run.pipeline_run_id
yield PipelineRunStarted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=data,
)
data_count = len(data) if isinstance(data, list) else 1
arguments = [
[tasks] * data_count,
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
[user] * data_count,
[pipeline_name] * data_count,
[context] * data_count,
]
async for result in run_tasks_on_modal.map.aio(*arguments):
logger.info(f"Received result: {result}")
yield PipelineRunYield(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=result,
)
await log_pipeline_run_complete(pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data)
yield PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
)