cognee/cognee/modules/pipelines/operations/run_tasks.py
hajdul88 faeca138d9
fix: fixes distributed pipeline (#1454)
<!-- .github/pull_request_template.md -->

## Description
This PR fixes distributed pipeline + updates core changes in distr
logic.

## Type of Change
<!-- Please check the relevant option -->
- [x] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [x] Performance improvement
- [ ] Other (please specify):

## Changes Made
Fixes distributed pipeline:
-Changed spawning logic + adds incremental loading to
run_tasks_diistributed
-Adds batching to consumer nodes
-Fixes consumer stopping criteria by adding stop signal + handling
-Changed edge embedding solution to avoid huge network load in a case of
a multicontainer environment

## Testing
Tested it by running 1GB on modal + manually

## Screenshots/Videos (if applicable)
None

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## Related Issues
None

## Additional Notes
None

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Boris <boris@topoteretes.com>
Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
2025-10-09 14:06:25 +02:00

159 lines
5.2 KiB
Python

import os
import asyncio
from uuid import UUID
from typing import Any, List
from functools import wraps
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.pipelines.operations.run_tasks_distributed import run_tasks_distributed
from cognee.modules.users.models import User
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines.utils import generate_pipeline_id
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
from cognee.tasks.ingestion import resolve_data_directories
from cognee.modules.pipelines.models.PipelineRunInfo import (
PipelineRunCompleted,
PipelineRunErrored,
PipelineRunStarted,
)
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
log_pipeline_run_complete,
log_pipeline_run_error,
)
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from .run_tasks_data_item import run_tasks_data_item
from ..tasks.task import Task
logger = get_logger("run_tasks(tasks: [Task], data)")
def override_run_tasks(new_gen):
def decorator(original_gen):
@wraps(original_gen)
async def wrapper(*args, distributed=None, **kwargs):
default_distributed_value = os.getenv("COGNEE_DISTRIBUTED", "False").lower() == "true"
distributed = default_distributed_value if distributed is None else distributed
if distributed:
async for run_info in new_gen(*args, **kwargs):
yield run_info
else:
async for run_info in original_gen(*args, **kwargs):
yield run_info
return wrapper
return decorator
@override_run_tasks(run_tasks_distributed)
async def run_tasks(
tasks: List[Task],
dataset_id: UUID,
data: List[Any] = None,
user: User = None,
pipeline_name: str = "unknown_pipeline",
context: dict = None,
incremental_loading: bool = False,
):
if not user:
user = await get_default_user()
# Get Dataset object
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
from cognee.modules.data.models import Dataset
dataset = await session.get(Dataset, dataset_id)
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
pipeline_run_id = pipeline_run.pipeline_run_id
yield PipelineRunStarted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
payload=data,
)
try:
if not isinstance(data, list):
data = [data]
if incremental_loading:
data = await resolve_data_directories(data)
# Create async tasks per data item that will run the pipeline for the data item
data_item_tasks = [
asyncio.create_task(
run_tasks_data_item(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
incremental_loading,
)
)
for data_item in data
]
results = await asyncio.gather(*data_item_tasks)
# Remove skipped data items from results
results = [result for result in results if result]
# If any data item could not be processed propagate error
errored_results = [
result for result in results if isinstance(result["run_info"], PipelineRunErrored)
]
if errored_results:
raise PipelineRunFailedError(
message="Pipeline run failed. Data item could not be processed."
)
await log_pipeline_run_complete(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
)
yield PipelineRunCompleted(
pipeline_run_id=pipeline_run_id,
dataset_id=dataset.id,
dataset_name=dataset.name,
data_ingestion_info=results,
)
graph_engine = await get_graph_engine()
if hasattr(graph_engine, "push_to_s3"):
await graph_engine.push_to_s3()
relational_engine = get_relational_engine()
if hasattr(relational_engine, "push_to_s3"):
await relational_engine.push_to_s3()
except Exception as error:
await log_pipeline_run_error(
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
)
yield PipelineRunErrored(
pipeline_run_id=pipeline_run_id,
payload=repr(error),
dataset_id=dataset.id,
dataset_name=dataset.name,
data_ingestion_info=locals().get(
"results"
), # Returns results if they exist or returns None
)
# In case of error during incremental loading of data just let the user know the pipeline Errored, don't raise error
if not isinstance(error, PipelineRunFailedError):
raise error