feature: tighten run_tasks_base (#730)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Extracted run_tasks_base function into a new file run_tasks_base.py.
- Extracted four executors that execute core logic based on the task
type.
- Extracted a task handler/wrapper that safely executes the core logic
with logging and telemetry.
- Fixed the inconsistency with the batches of size 1.

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
lxobr 2025-04-16 09:19:03 +02:00 committed by GitHub
parent 22b363b297
commit d1eab97102
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 195 additions and 256 deletions

View file

@ -9,7 +9,7 @@ from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.base_config import get_base_config
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
from cognee.shared.utils import render_graph

View file

@ -13,7 +13,7 @@ from cognee.modules.data.models import Data, Dataset
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph

View file

@ -4,7 +4,7 @@ from typing import Optional, Tuple, List, Dict, Union, Any, Callable, Awaitable
from cognee.eval_framework.benchmark_adapters.benchmark_adapters import BenchmarkAdapter
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
logger = get_logger(level=ERROR)

View file

@ -1,7 +1,7 @@
from enum import Enum
from typing import Callable, Awaitable, List
from cognee.api.v1.cognify.cognify import get_default_tasks
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.eval_framework.corpus_builder.task_getters.get_cascade_graph_tasks import (
get_cascade_graph_tasks,
)

View file

@ -2,7 +2,7 @@ from typing import List
from pydantic import BaseModel
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph

View file

@ -1,6 +1,6 @@
from typing import List
from cognee.api.v1.cognify.cognify import get_default_tasks
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
@ -39,9 +39,10 @@ async def get_no_summary_tasks(
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=ontology_adapter,
task_config={"batch_size": 10},
)
add_data_points_task = Task(add_data_points)
add_data_points_task = Task(add_data_points, task_config={"batch_size": 10})
return base_tasks + [graph_task, add_data_points_task]
@ -53,6 +54,6 @@ async def get_just_chunks_tasks(
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
add_data_points_task = Task(add_data_points)
add_data_points_task = Task(add_data_points, task_config={"batch_size": 10})
return base_tasks + [add_data_points_task]

View file

@ -1,3 +1,3 @@
from .tasks.Task import Task
from .tasks.task import Task
from .operations.run_tasks import run_tasks
from .operations.run_parallel import run_tasks_parallel

View file

@ -1,6 +1,6 @@
from typing import Any, Callable, Generator, List
import asyncio
from ..tasks.Task import Task
from ..tasks.task import Task
def run_tasks_parallel(tasks: List[Task]) -> Callable[[Any], Generator[Any, Any, Any]]:

View file

@ -1,4 +1,3 @@
import inspect
import json
from cognee.shared.logging_utils import get_logger
from uuid import UUID, uuid4
@ -15,212 +14,12 @@ from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from uuid import uuid5, NAMESPACE_OID
from ..tasks.Task import Task
from .run_tasks_base import run_tasks_base
from ..tasks.task import Task
logger = get_logger("run_tasks(tasks: [Task], data)")
async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
if len(tasks) == 0:
yield data
return
args = [data] if data is not None else []
running_task = tasks[0]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable):
logger.info("Async generator task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Async Generator Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
results = []
async_iterator = running_task.run(*args)
async for partial_result in async_iterator:
results.append(partial_result)
if len(results) == next_task_batch_size:
async for result in run_tasks_base(
leftover_tasks,
results[0] if next_task_batch_size == 1 else results,
user=user,
):
yield result
results = []
if len(results) > 0:
async for result in run_tasks_base(leftover_tasks, results, user):
yield result
results = []
logger.info("Async generator task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Async Generator Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Async generator task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Async Generator Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.isgeneratorfunction(running_task.executable):
logger.info("Generator task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Generator Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
results = []
for partial_result in running_task.run(*args):
results.append(partial_result)
if len(results) == next_task_batch_size:
async for result in run_tasks_base(
leftover_tasks, results[0] if next_task_batch_size == 1 else results, user
):
yield result
results = []
if len(results) > 0:
async for result in run_tasks_base(leftover_tasks, results, user):
yield result
results = []
logger.info("Generator task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Generator Task Completed",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Generator task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Generator Task Errored",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.iscoroutinefunction(running_task.executable):
logger.info("Coroutine task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Coroutine Task Started",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
try:
task_result = await running_task.run(*args)
async for result in run_tasks_base(leftover_tasks, task_result, user):
yield result
logger.info("Coroutine task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Coroutine Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Coroutine task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Coroutine Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.isfunction(running_task.executable):
logger.info("Function task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Function Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
task_result = running_task.run(*args)
async for result in run_tasks_base(leftover_tasks, task_result, user):
yield result
logger.info("Function task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Function Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Function task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Function Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
config = get_current_settings()

View file

@ -0,0 +1,72 @@
import inspect
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from ..tasks.task import Task
logger = get_logger("run_tasks_base")
async def handle_task(
running_task: Task,
args: list,
leftover_tasks: list[Task],
next_task_batch_size: int,
user: User,
):
"""Handle common task workflow with logging, telemetry, and error handling around the core execution logic."""
task_type = running_task.task_type
logger.info(f"{task_type} task started: `{running_task.executable.__name__}`")
send_telemetry(
f"{task_type} Task Started",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
try:
async for result_data in running_task.execute(args, next_task_batch_size):
async for result in run_tasks_base(leftover_tasks, result_data, user):
yield result
logger.info(f"{task_type} task completed: `{running_task.executable.__name__}`")
send_telemetry(
f"{task_type} Task Completed",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
f"{task_type} task errored: `{running_task.executable.__name__}`\n{str(error)}\n",
exc_info=True,
)
send_telemetry(
f"{task_type} Task Errored",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
raise error
async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
"""Base function to execute tasks in a pipeline, handling task type detection and execution."""
if len(tasks) == 0:
yield data
return
args = [data] if data is not None else []
running_task = tasks[0]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
async for result in handle_task(running_task, args, leftover_tasks, next_task_batch_size, user):
yield result

View file

@ -1,30 +0,0 @@
from typing import Union, Callable, Any, Coroutine, Generator, AsyncGenerator
class Task:
executable: Union[
Callable[..., Any],
Callable[..., Coroutine[Any, Any, Any]],
Generator[Any, Any, Any],
AsyncGenerator[Any, Any],
]
task_config: dict[str, Any] = {
"batch_size": 1,
}
default_params: dict[str, Any] = {}
def __init__(self, executable, *args, task_config=None, **kwargs):
self.executable = executable
self.default_params = {"args": args, "kwargs": kwargs}
if task_config is not None:
self.task_config = task_config
if "batch_size" not in task_config:
self.task_config["batch_size"] = 1
def run(self, *args, **kwargs):
combined_args = args + self.default_params["args"]
combined_kwargs = {**self.default_params["kwargs"], **kwargs}
return self.executable(*combined_args, **combined_kwargs)

View file

@ -0,0 +1,97 @@
from typing import Union, Callable, Any, Coroutine, Generator, AsyncGenerator
import inspect
class Task:
executable: Union[
Callable[..., Any],
Callable[..., Coroutine[Any, Any, Any]],
Generator[Any, Any, Any],
AsyncGenerator[Any, Any],
]
task_config: dict[str, Any] = {
"batch_size": 1,
}
default_params: dict[str, Any] = {}
task_type: str = None
_execute_method: Callable = None
_next_batch_size: int = 1
def __init__(self, executable, *args, task_config=None, **kwargs):
self.executable = executable
self.default_params = {"args": args, "kwargs": kwargs}
if inspect.isasyncgenfunction(executable):
self.task_type = "Async Generator"
self._execute_method = self.execute_async_generator
elif inspect.isgeneratorfunction(executable):
self.task_type = "Generator"
self._execute_method = self.execute_generator
elif inspect.iscoroutinefunction(executable):
self.task_type = "Coroutine"
self._execute_method = self.execute_coroutine
elif inspect.isfunction(executable):
self.task_type = "Function"
self._execute_method = self.execute_function
else:
raise ValueError(f"Unsupported task type: {executable}")
if task_config is not None:
self.task_config = task_config
if "batch_size" not in task_config:
self.task_config["batch_size"] = 1
def run(self, *args, **kwargs):
"""Execute the underlying task with given arguments."""
combined_args = args + self.default_params["args"]
combined_kwargs = {**self.default_params["kwargs"], **kwargs}
return self.executable(*combined_args, **combined_kwargs)
async def execute_async_generator(self, args):
"""Execute async generator task and collect results in batches."""
results = []
async_iterator = self.run(*args)
async for partial_result in async_iterator:
results.append(partial_result)
if len(results) == self._next_batch_size:
yield results
results = []
if results:
yield results
async def execute_generator(self, args):
"""Execute generator task and collect results in batches."""
results = []
for partial_result in self.run(*args):
results.append(partial_result)
if len(results) == self._next_batch_size:
yield results
results = []
if results:
yield results
async def execute_coroutine(self, args):
"""Execute coroutine task and yield the result."""
task_result = await self.run(*args)
yield task_result
async def execute_function(self, args):
"""Execute function task and yield the result."""
task_result = self.run(*args)
yield task_result
async def execute(self, args, next_batch_size=None):
"""Execute the task based on its type and yield results with the next task's batch size."""
if next_batch_size is not None:
self._next_batch_size = next_batch_size
async for result in self._execute_method(args):
yield result

View file

@ -3,7 +3,7 @@ from queue import Queue
import cognee
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.relational import create_db_and_tables
@ -19,11 +19,11 @@ async def pipeline(data_queue):
else:
await asyncio.sleep(0.3)
async def add_one(num):
yield num + 1
async def add_one(num_list):
yield num_list[0] + 1
async def multiply_by_two(num):
yield num * 2
async def multiply_by_two(num_list):
yield num_list[0] * 2
await create_db_and_tables()
user = await get_default_user()
@ -41,7 +41,7 @@ async def pipeline(data_queue):
results = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
index = 0
async for result in tasks_run:
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
assert result[0] == results[index], f"at {index = }: {result = } != {results[index] = }"
index += 1

View file

@ -2,7 +2,7 @@ import asyncio
import cognee
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.relational import create_db_and_tables
@ -19,11 +19,11 @@ async def run_and_check_tasks():
for num in nums:
yield num + 1
async def multiply_by_two(num):
yield num * 2
async def multiply_by_two(nums):
yield nums[0] * 2
async def add_one_single(num):
yield num + 1
async def add_one_single(nums):
yield nums[0] + 1
await create_db_and_tables()
user = await get_default_user()
@ -42,7 +42,7 @@ async def run_and_check_tasks():
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
index = 0
async for result in pipeline:
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
assert result[0] == results[index], f"at {index = }: {result = } != {results[index] = }"
index += 1

View file

@ -13,7 +13,7 @@ import cognee
from cognee.low_level import DataPoint, setup as cognee_setup
from cognee.api.v1.search import SearchType
from cognee.tasks.storage import add_data_points
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.pipelines import run_tasks