feature: adds pipeline execution layer to cognify (#1291)
<!-- .github/pull_request_template.md --> ## Description feature: adds pipeline execution layer to cognify ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
6f230c5a38
commit
d91b0f6aa3
2 changed files with 135 additions and 104 deletions
|
|
@ -11,8 +11,6 @@ from cognee.modules.pipelines import cognee_pipeline
|
|||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunErrored
|
||||
from cognee.modules.pipelines.queues.pipeline_run_info_queues import push_to_queue
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
from cognee.tasks.documents import (
|
||||
|
|
@ -23,6 +21,7 @@ from cognee.tasks.documents import (
|
|||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
||||
|
||||
logger = get_logger("cognify")
|
||||
|
||||
|
|
@ -180,113 +179,18 @@ async def cognify(
|
|||
"""
|
||||
tasks = await get_default_tasks(user, graph_model, chunker, chunk_size, ontology_file_path)
|
||||
|
||||
if run_in_background:
|
||||
return await run_cognify_as_background_process(
|
||||
tasks=tasks,
|
||||
user=user,
|
||||
datasets=datasets,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
else:
|
||||
return await run_cognify_blocking(
|
||||
tasks=tasks,
|
||||
user=user,
|
||||
datasets=datasets,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||
|
||||
|
||||
async def run_cognify_blocking(
|
||||
tasks,
|
||||
user,
|
||||
datasets,
|
||||
graph_db_config: dict = None,
|
||||
vector_db_config: dict = False,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
total_run_info = {}
|
||||
|
||||
async for run_info in cognee_pipeline(
|
||||
return await pipeline_executor_func(
|
||||
pipeline=cognee_pipeline,
|
||||
tasks=tasks,
|
||||
datasets=datasets,
|
||||
user=user,
|
||||
pipeline_name="cognify_pipeline",
|
||||
graph_db_config=graph_db_config,
|
||||
datasets=datasets,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
):
|
||||
if run_info.dataset_id:
|
||||
total_run_info[run_info.dataset_id] = run_info
|
||||
else:
|
||||
total_run_info = run_info
|
||||
|
||||
return total_run_info
|
||||
|
||||
|
||||
async def run_cognify_as_background_process(
|
||||
tasks,
|
||||
user,
|
||||
datasets,
|
||||
graph_db_config: dict = None,
|
||||
vector_db_config: dict = False,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
# Convert dataset to list if it's a string
|
||||
if isinstance(datasets, str):
|
||||
datasets = [datasets]
|
||||
|
||||
# Store pipeline status for all pipelines
|
||||
pipeline_run_started_info = {}
|
||||
|
||||
async def handle_rest_of_the_run(pipeline_list):
|
||||
# Execute all provided pipelines one by one to avoid database write conflicts
|
||||
# TODO: Convert to async gather task instead of for loop when Queue mechanism for database is created
|
||||
for pipeline in pipeline_list:
|
||||
while True:
|
||||
try:
|
||||
pipeline_run_info = await anext(pipeline)
|
||||
|
||||
push_to_queue(pipeline_run_info.pipeline_run_id, pipeline_run_info)
|
||||
|
||||
if isinstance(pipeline_run_info, PipelineRunCompleted) or isinstance(
|
||||
pipeline_run_info, PipelineRunErrored
|
||||
):
|
||||
break
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
# Start all pipelines to get started status
|
||||
pipeline_list = []
|
||||
for dataset in datasets:
|
||||
pipeline_run = cognee_pipeline(
|
||||
tasks=tasks,
|
||||
user=user,
|
||||
datasets=dataset,
|
||||
pipeline_name="cognify_pipeline",
|
||||
graph_db_config=graph_db_config,
|
||||
vector_db_config=vector_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
|
||||
# Save dataset Pipeline run started info
|
||||
run_info = await anext(pipeline_run)
|
||||
pipeline_run_started_info[run_info.dataset_id] = run_info
|
||||
|
||||
if pipeline_run_started_info[run_info.dataset_id].payload:
|
||||
# Remove payload info to avoid serialization
|
||||
# TODO: Handle payload serialization
|
||||
pipeline_run_started_info[run_info.dataset_id].payload = []
|
||||
|
||||
pipeline_list.append(pipeline_run)
|
||||
|
||||
# Send all started pipelines to execute one by one in background
|
||||
asyncio.create_task(handle_rest_of_the_run(pipeline_list=pipeline_list))
|
||||
|
||||
return pipeline_run_started_info
|
||||
pipeline_name="cognify_pipeline",
|
||||
)
|
||||
|
||||
|
||||
async def get_default_tasks( # TODO: Find out a better way to do this (Boris's comment)
|
||||
|
|
|
|||
127
cognee/modules/pipelines/layers/pipeline_execution_mode.py
Normal file
127
cognee/modules/pipelines/layers/pipeline_execution_mode.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import asyncio
|
||||
from typing import Any, AsyncIterable, AsyncGenerator, Callable, Dict, Union, Awaitable
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunErrored
|
||||
from cognee.modules.pipelines.queues.pipeline_run_info_queues import push_to_queue
|
||||
|
||||
AsyncGenLike = Union[
|
||||
AsyncIterable[Any],
|
||||
AsyncGenerator[Any, None],
|
||||
Callable[..., AsyncIterable[Any]],
|
||||
Callable[..., AsyncGenerator[Any, None]],
|
||||
]
|
||||
|
||||
|
||||
async def run_pipeline_blocking(pipeline: AsyncGenLike, **params) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a pipeline synchronously (blocking until all results are consumed).
|
||||
|
||||
This function iterates through the given pipeline (an async generator/iterable)
|
||||
until completion, aggregating the run information for each dataset.
|
||||
|
||||
Args:
|
||||
pipeline (AsyncGenLike): The pipeline generator or callable producing async run information.
|
||||
**params: Arbitrary keyword arguments to be passed to the pipeline if it is callable.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
- If multiple datasets are processed, a mapping of dataset_id -> last run_info.
|
||||
- If no dataset_id is present in run_info, the run_info itself is returned.
|
||||
"""
|
||||
agen = pipeline(**params) if callable(pipeline) else pipeline
|
||||
|
||||
total_run_info: Dict[str, Any] = {}
|
||||
|
||||
async for run_info in agen:
|
||||
dataset_id = getattr(run_info, "dataset_id", None)
|
||||
if dataset_id:
|
||||
total_run_info[dataset_id] = run_info
|
||||
else:
|
||||
total_run_info = run_info
|
||||
|
||||
return total_run_info
|
||||
|
||||
|
||||
async def run_pipeline_as_background_process(
|
||||
pipeline: AsyncGenLike,
|
||||
**params,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute one or more pipelines as background tasks.
|
||||
|
||||
This function:
|
||||
1. Starts pipelines for each dataset (if multiple datasets are provided).
|
||||
2. Returns the initial "started" run information immediately.
|
||||
3. Continues executing the pipelines in the background,
|
||||
pushing run updates to a queue until each completes.
|
||||
|
||||
Args:
|
||||
pipeline (AsyncGenLike): The pipeline generator or callable producing async run information.
|
||||
**params: Arbitrary keyword arguments to be passed to the pipeline if it is callable.
|
||||
Expected to include "datasets", which may be a single dataset ID (str)
|
||||
or a list of dataset IDs.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A mapping of dataset_id -> initial run_info (with payload removed for serialization).
|
||||
"""
|
||||
|
||||
datasets = params.get("datasets", None)
|
||||
|
||||
if isinstance(datasets, str):
|
||||
datasets = [datasets]
|
||||
|
||||
pipeline_run_started_info = {}
|
||||
|
||||
async def handle_rest_of_the_run(pipeline_list):
|
||||
# Execute all provided pipelines one by one to avoid database write conflicts
|
||||
# TODO: Convert to async gather task instead of for loop when Queue mechanism for database is created
|
||||
for pipeline in pipeline_list:
|
||||
while True:
|
||||
try:
|
||||
pipeline_run_info = await anext(pipeline)
|
||||
|
||||
push_to_queue(pipeline_run_info.pipeline_run_id, pipeline_run_info)
|
||||
|
||||
if isinstance(pipeline_run_info, PipelineRunCompleted) or isinstance(
|
||||
pipeline_run_info, PipelineRunErrored
|
||||
):
|
||||
break
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
# Start all pipelines to get started status
|
||||
pipeline_list = []
|
||||
for dataset in datasets:
|
||||
call_params = dict(params)
|
||||
if "datasets" in call_params:
|
||||
call_params["datasets"] = dataset
|
||||
|
||||
pipeline_run = pipeline(**call_params) if callable(pipeline) else pipeline
|
||||
|
||||
# Save dataset Pipeline run started info
|
||||
run_info = await anext(pipeline_run)
|
||||
pipeline_run_started_info[run_info.dataset_id] = run_info
|
||||
|
||||
if pipeline_run_started_info[run_info.dataset_id].payload:
|
||||
# Remove payload info to avoid serialization
|
||||
# TODO: Handle payload serialization
|
||||
pipeline_run_started_info[run_info.dataset_id].payload = []
|
||||
|
||||
pipeline_list.append(pipeline_run)
|
||||
|
||||
# Send all started pipelines to execute one by one in background
|
||||
asyncio.create_task(handle_rest_of_the_run(pipeline_list=pipeline_list))
|
||||
|
||||
return pipeline_run_started_info
|
||||
|
||||
|
||||
def get_pipeline_executor(
|
||||
run_in_background: bool = False,
|
||||
) -> Callable[..., Awaitable[Dict[str, Any]]]:
|
||||
"""
|
||||
Return the appropriate pipeline runner.
|
||||
|
||||
Usage:
|
||||
run_fn = get_run_pipeline_fn(run_in_background=True)
|
||||
result = await run_fn(pipeline, **params)
|
||||
"""
|
||||
return run_pipeline_as_background_process if run_in_background else run_pipeline_blocking
|
||||
Loading…
Add table
Reference in a new issue