cognee/cognee/modules/pipelines/operations/run_tasks_distributed.py
Igor Ilic 3e54b67b4d
fix: Resolve missing argument for distributed (#1563)
<!-- .github/pull_request_template.md -->

## Description
Resolve missing argument for distributed

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
2025-10-20 15:03:35 +02:00

177 lines
5.5 KiB
Python

try:
import modal
except ModuleNotFoundError:
modal = None
from typing import Any, List, Optional
from uuid import UUID
from cognee.modules.pipelines.tasks.task import Task
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.pipelines.models import (
PipelineRunStarted,
PipelineRunCompleted,
PipelineRunErrored,
)
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
log_pipeline_run_complete,
log_pipeline_run_error,
)
from cognee.modules.pipelines.utils import generate_pipeline_id
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
from cognee.tasks.ingestion import resolve_data_directories
from .run_tasks_data_item import run_tasks_data_item
logger = get_logger("run_tasks_distributed()")
if modal:
import os
from distributed.app import app
from distributed.modal_image import image
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
@app.function(
retries=3,
image=image,
timeout=86400,
max_containers=50,
secrets=[modal.Secret.from_name(secret_name)],
)
async def run_tasks_on_modal(
data_item,
dataset_id: UUID,
tasks: List[Task],
pipeline_name: str,
pipeline_id: str,
pipeline_run_id: str,
context: Optional[dict],
user: User,
incremental_loading: bool,
):
"""
Wrapper that runs the run_tasks_data_item function.
This is the function/code that runs on modal executor and produces the graph/vector db objects
"""
from cognee.infrastructure.databases.relational import get_relational_engine
async with get_relational_engine().get_async_session() as session:
from cognee.modules.data.models import Dataset
dataset = await session.get(Dataset, dataset_id)
result = await run_tasks_data_item(
data_item=data_item,
dataset=dataset,
tasks=tasks,
pipeline_name=pipeline_name,
pipeline_id=pipeline_id,
pipeline_run_id=pipeline_run_id,
context=context,
user=user,
incremental_loading=incremental_loading,
)
return result
async def run_tasks_distributed(
tasks: List[Task],
dataset_id: UUID,
data: List[Any] = None,
user: User = None,
pipeline_name: str = "unknown_pipeline",
context: dict = None,
incremental_loading: bool = False,
data_per_batch: int = 20,
):
if not user:
user = await get_default_user()
# Get dataset object
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
dataset = await session.get(Dataset, dataset_id)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
pipeline_run_id = pipeline_run.pipeline_run_id
yield PipelineRunStarted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=data,
)
try:
if not isinstance(data, list):
data = [data]
data = await resolve_data_directories(data)
number_of_data_items = len(data) if isinstance(data, list) else 1
data_item_tasks = [
data,
[dataset.id] * number_of_data_items,
[tasks] * number_of_data_items,
[pipeline_name] * number_of_data_items,
[pipeline_id] * number_of_data_items,
[pipeline_run_id] * number_of_data_items,
[context] * number_of_data_items,
[user] * number_of_data_items,
[incremental_loading] * number_of_data_items,
]
results = []
async for result in run_tasks_on_modal.map.aio(*data_item_tasks):
if not result:
continue
results.append(result)
# Remove skipped results
results = [r for r in results if r]
# If any data item failed, raise PipelineRunFailedError
errored = [
r
for r in results
if r and r.get("run_info") and isinstance(r["run_info"], PipelineRunErrored)
]
if errored:
raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.")
await log_pipeline_run_complete(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
)
yield PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
data_ingestion_info=results,
)
except Exception as error:
await log_pipeline_run_error(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
)
yield PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
data_ingestion_info=locals().get("results"),
)
if not isinstance(error, PipelineRunFailedError):
raise