<!-- .github/pull_request_template.md --> ## Description Resolve missing argument for distributed ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
177 lines
5.5 KiB
Python
177 lines
5.5 KiB
Python
try:
|
|
import modal
|
|
except ModuleNotFoundError:
|
|
modal = None
|
|
|
|
from typing import Any, List, Optional
|
|
from uuid import UUID
|
|
|
|
from cognee.modules.pipelines.tasks.task import Task
|
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
from cognee.modules.pipelines.models import (
|
|
PipelineRunStarted,
|
|
PipelineRunCompleted,
|
|
PipelineRunErrored,
|
|
)
|
|
from cognee.modules.pipelines.operations import (
|
|
log_pipeline_run_start,
|
|
log_pipeline_run_complete,
|
|
log_pipeline_run_error,
|
|
)
|
|
from cognee.modules.pipelines.utils import generate_pipeline_id
|
|
from cognee.modules.users.methods import get_default_user
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.modules.users.models import User
|
|
from cognee.modules.pipelines.exceptions import PipelineRunFailedError
|
|
from cognee.tasks.ingestion import resolve_data_directories
|
|
from .run_tasks_data_item import run_tasks_data_item
|
|
|
|
logger = get_logger("run_tasks_distributed()")
|
|
|
|
if modal:
|
|
import os
|
|
from distributed.app import app
|
|
from distributed.modal_image import image
|
|
|
|
secret_name = os.environ.get("MODAL_SECRET_NAME", "distributed_cognee")
|
|
|
|
@app.function(
|
|
retries=3,
|
|
image=image,
|
|
timeout=86400,
|
|
max_containers=50,
|
|
secrets=[modal.Secret.from_name(secret_name)],
|
|
)
|
|
async def run_tasks_on_modal(
|
|
data_item,
|
|
dataset_id: UUID,
|
|
tasks: List[Task],
|
|
pipeline_name: str,
|
|
pipeline_id: str,
|
|
pipeline_run_id: str,
|
|
context: Optional[dict],
|
|
user: User,
|
|
incremental_loading: bool,
|
|
):
|
|
"""
|
|
Wrapper that runs the run_tasks_data_item function.
|
|
This is the function/code that runs on modal executor and produces the graph/vector db objects
|
|
"""
|
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
|
|
async with get_relational_engine().get_async_session() as session:
|
|
from cognee.modules.data.models import Dataset
|
|
|
|
dataset = await session.get(Dataset, dataset_id)
|
|
|
|
result = await run_tasks_data_item(
|
|
data_item=data_item,
|
|
dataset=dataset,
|
|
tasks=tasks,
|
|
pipeline_name=pipeline_name,
|
|
pipeline_id=pipeline_id,
|
|
pipeline_run_id=pipeline_run_id,
|
|
context=context,
|
|
user=user,
|
|
incremental_loading=incremental_loading,
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
async def run_tasks_distributed(
|
|
tasks: List[Task],
|
|
dataset_id: UUID,
|
|
data: List[Any] = None,
|
|
user: User = None,
|
|
pipeline_name: str = "unknown_pipeline",
|
|
context: dict = None,
|
|
incremental_loading: bool = False,
|
|
data_per_batch: int = 20,
|
|
):
|
|
if not user:
|
|
user = await get_default_user()
|
|
|
|
# Get dataset object
|
|
db_engine = get_relational_engine()
|
|
async with db_engine.get_async_session() as session:
|
|
from cognee.modules.data.models import Dataset
|
|
|
|
dataset = await session.get(Dataset, dataset_id)
|
|
|
|
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
|
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
|
pipeline_run_id = pipeline_run.pipeline_run_id
|
|
|
|
yield PipelineRunStarted(
|
|
pipeline_run_id=pipeline_run_id,
|
|
dataset_id=dataset.id,
|
|
dataset_name=dataset.name,
|
|
payload=data,
|
|
)
|
|
|
|
try:
|
|
if not isinstance(data, list):
|
|
data = [data]
|
|
|
|
data = await resolve_data_directories(data)
|
|
|
|
number_of_data_items = len(data) if isinstance(data, list) else 1
|
|
|
|
data_item_tasks = [
|
|
data,
|
|
[dataset.id] * number_of_data_items,
|
|
[tasks] * number_of_data_items,
|
|
[pipeline_name] * number_of_data_items,
|
|
[pipeline_id] * number_of_data_items,
|
|
[pipeline_run_id] * number_of_data_items,
|
|
[context] * number_of_data_items,
|
|
[user] * number_of_data_items,
|
|
[incremental_loading] * number_of_data_items,
|
|
]
|
|
|
|
results = []
|
|
async for result in run_tasks_on_modal.map.aio(*data_item_tasks):
|
|
if not result:
|
|
continue
|
|
results.append(result)
|
|
|
|
# Remove skipped results
|
|
results = [r for r in results if r]
|
|
|
|
# If any data item failed, raise PipelineRunFailedError
|
|
errored = [
|
|
r
|
|
for r in results
|
|
if r and r.get("run_info") and isinstance(r["run_info"], PipelineRunErrored)
|
|
]
|
|
if errored:
|
|
raise PipelineRunFailedError("Pipeline run failed. Data item could not be processed.")
|
|
|
|
await log_pipeline_run_complete(
|
|
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data
|
|
)
|
|
|
|
yield PipelineRunCompleted(
|
|
pipeline_run_id=pipeline_run_id,
|
|
dataset_id=dataset.id,
|
|
dataset_name=dataset.name,
|
|
data_ingestion_info=results,
|
|
)
|
|
|
|
except Exception as error:
|
|
await log_pipeline_run_error(
|
|
pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data, error
|
|
)
|
|
|
|
yield PipelineRunErrored(
|
|
pipeline_run_id=pipeline_run_id,
|
|
payload=repr(error),
|
|
dataset_id=dataset.id,
|
|
dataset_name=dataset.name,
|
|
data_ingestion_info=locals().get("results"),
|
|
)
|
|
|
|
if not isinstance(error, PipelineRunFailedError):
|
|
raise
|