116 lines
3.4 KiB
Python
116 lines
3.4 KiB
Python
try:
|
|
import modal
|
|
except ModuleNotFoundError:
|
|
modal = None
|
|
|
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
from cognee.modules.pipelines.models import (
|
|
PipelineRunStarted,
|
|
PipelineRunYield,
|
|
PipelineRunCompleted,
|
|
)
|
|
from cognee.modules.pipelines.operations import log_pipeline_run_start, log_pipeline_run_complete
|
|
from cognee.modules.pipelines.utils.generate_pipeline_id import generate_pipeline_id
|
|
from cognee.modules.users.methods import get_default_user
|
|
from cognee.shared.logging_utils import get_logger
|
|
|
|
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
|
|
|
|
|
logger = get_logger("run_tasks_distributed()")
|
|
|
|
|
|
if modal:
|
|
from distributed.app import app
|
|
from distributed.modal_image import image
|
|
|
|
@app.function(
|
|
image=image,
|
|
timeout=86400,
|
|
max_containers=50,
|
|
secrets=[modal.Secret.from_name("distributed_cognee")],
|
|
)
|
|
async def run_tasks_on_modal(tasks, data_item, user, pipeline_name, context):
|
|
pipeline_run = run_tasks_with_telemetry(tasks, data_item, user, pipeline_name, context)
|
|
|
|
run_info = None
|
|
|
|
async for pipeline_run_info in pipeline_run:
|
|
run_info = pipeline_run_info
|
|
|
|
return run_info
|
|
|
|
|
|
async def run_tasks_distributed(tasks, dataset_id, data, user, pipeline_name, context):
|
|
if not user:
|
|
user = get_default_user()
|
|
|
|
db_engine = get_relational_engine()
|
|
async with db_engine.get_async_session() as session:
|
|
from cognee.modules.data.models import Dataset
|
|
|
|
dataset = await session.get(Dataset, dataset_id)
|
|
|
|
pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name)
|
|
|
|
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
|
|
|
pipeline_run_id = pipeline_run.pipeline_run_id
|
|
|
|
yield PipelineRunStarted(
|
|
pipeline_run_id=pipeline_run_id,
|
|
dataset_id=dataset.id,
|
|
dataset_name=dataset.name,
|
|
payload=data,
|
|
)
|
|
|
|
data_count = len(data) if isinstance(data, list) else 1
|
|
|
|
arguments = [
|
|
[tasks] * data_count,
|
|
[[data_item] for data_item in data[:data_count]] if data_count > 1 else [data],
|
|
[user] * data_count,
|
|
[pipeline_name] * data_count,
|
|
[context] * data_count,
|
|
]
|
|
|
|
async for result in run_tasks_on_modal.map.aio(*arguments):
|
|
logger.info(f"Received result: {result}")
|
|
|
|
yield PipelineRunYield(
|
|
pipeline_run_id=pipeline_run_id,
|
|
dataset_id=dataset.id,
|
|
dataset_name=dataset.name,
|
|
payload=result,
|
|
)
|
|
|
|
# producer_futures = []
|
|
|
|
# for data_item in data[:5]:
|
|
# producer_future = run_tasks_distributed(
|
|
# run_tasks_with_telemetry, tasks, [data_item], user, pipeline_name, context
|
|
# )
|
|
# producer_futures.append(producer_future)
|
|
|
|
# batch_results = []
|
|
# for producer_future in producer_futures:
|
|
# try:
|
|
# result = producer_future.get()
|
|
# except Exception as e:
|
|
# result = e
|
|
# batch_results.append(result)
|
|
|
|
# yield PipelineRunYield(
|
|
# pipeline_run_id=pipeline_run_id,
|
|
# dataset_id=dataset.id,
|
|
# dataset_name=dataset.name,
|
|
# payload=result,
|
|
# )
|
|
|
|
await log_pipeline_run_complete(pipeline_run_id, pipeline_id, pipeline_name, dataset_id, data)
|
|
|
|
yield PipelineRunCompleted(
|
|
pipeline_run_id=pipeline_run_id,
|
|
dataset_id=dataset.id,
|
|
dataset_name=dataset.name,
|
|
)
|