feat: pipeline tasks needs mapping (#690)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Boris 2025-04-03 10:52:59 +02:00 committed by GitHub
parent 567b45efa6
commit 0ce6fad24a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 803 additions and 618 deletions

View file

@ -1,7 +1,8 @@
from typing import Union, BinaryIO
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines import run_tasks, Task
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks import TaskConfig, Task
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
@ -36,7 +37,15 @@ async def add(
if user is None:
user = await get_default_user()
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
tasks = [
Task(resolve_data_directories),
Task(
ingest_data,
dataset_name,
user,
task_config=TaskConfig(needs=[resolve_data_directories]),
),
]
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
pipeline = run_tasks(

View file

@ -1,18 +1,19 @@
import os
import pathlib
import asyncio
from cognee.shared.logging_utils import get_logger
from uuid import NAMESPACE_OID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.api.v1.search import SearchType, search
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.base_config import get_base_config
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.tasks.Task import Task, TaskConfig
from cognee.modules.pipelines.operations.needs import merge_needs
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
from cognee.shared.utils import render_graph
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.ingestion import ingest_data
@ -45,25 +46,46 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
detailed_extraction = True
tasks = [
Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
Task(add_data_points, task_config={"batch_size": 500}),
Task(
get_repo_file_dependencies,
detailed_extraction=detailed_extraction,
task_config=TaskConfig(output_batch_size=500),
),
# Task(summarize_code, task_config=TaskConfig(output_batch_size=500)), # This task takes a long time to complete
Task(add_data_points, task_config=TaskConfig(needs=[get_repo_file_dependencies])),
]
if include_docs:
# This tasks take a long time to complete
non_code_tasks = [
Task(get_non_py_files, task_config={"batch_size": 50}),
Task(ingest_data, dataset_name="repo_docs", user=user),
Task(classify_documents),
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
Task(get_non_py_files),
Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
ingest_data,
dataset_name="repo_docs",
user=user,
task_config=TaskConfig(needs=[get_non_py_files]),
),
Task(classify_documents, task_config=TaskConfig(needs=[ingest_data])),
Task(
extract_chunks_from_documents,
max_chunk_size=get_max_chunk_tokens(),
task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),
),
Task(
extract_graph_from_data,
graph_model=KnowledgeGraph,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 50},
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
Task(
add_data_points,
task_config=TaskConfig(
needs=[merge_needs(summarize_text, extract_graph_from_data)]
),
),
]
@ -71,11 +93,11 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
if include_docs:
non_code_pipeline_run = run_tasks(non_code_tasks, dataset_id, repo_path, "cognify_pipeline")
async for run_status in non_code_pipeline_run:
yield run_status
async for run_info in non_code_pipeline_run:
yield run_info
async for run_status in run_tasks(tasks, dataset_id, repo_path, "cognify_code_pipeline"):
yield run_status
async for run_info in run_tasks(tasks, dataset_id, repo_path, "cognify_code_pipeline"):
yield run_info
if __name__ == "__main__":

View file

@ -10,10 +10,10 @@ from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.models import Data, Dataset
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines import run_tasks, merge_needs
from cognee.modules.pipelines.tasks import Task, TaskConfig
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph
@ -92,7 +92,9 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, tasks: list[Task]):
if not isinstance(task, Task):
raise ValueError(f"Task {task} is not an instance of Task")
pipeline_run = run_tasks(tasks, dataset.id, data_documents, "cognify_pipeline")
pipeline_run = run_tasks(
tasks, dataset.id, data_documents, "cognify_pipeline", context={"user": user}
)
pipeline_run_status = None
async for run_status in pipeline_run:
@ -121,24 +123,33 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(
check_permissions_on_documents,
user=user,
permissions=["write"],
task_config=TaskConfig(needs=[classify_documents]),
),
Task( # Extract text chunks based on the document type.
extract_chunks_from_documents,
max_chunk_size=chunk_size or get_max_chunk_tokens(),
chunker=chunker,
), # Extract text chunks based on the document type.
Task(
task_config=TaskConfig(needs=[check_permissions_on_documents], output_batch_size=10),
),
Task( # Generate knowledge graphs from the document chunks.
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=ontology_adapter,
task_config={"batch_size": 10},
), # Generate knowledge graphs from the document chunks.
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10},
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
Task(
add_data_points,
task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),
),
Task(add_data_points, task_config={"batch_size": 10}),
]
return default_tasks

View file

@ -2,11 +2,11 @@ from typing import List
from pydantic import BaseModel
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.pipelines.operations.needs import merge_needs
from cognee.modules.pipelines.tasks import Task, TaskConfig
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph
from cognee.shared.utils import send_telemetry
from cognee.tasks.documents import (
check_permissions_on_documents,
classify_documents,
@ -27,25 +27,32 @@ async def get_cascade_graph_tasks(
if user is None:
user = await get_default_user()
try:
cognee_config = get_cognify_config()
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
), # Extract text chunks based on the document type.
Task(
extract_graph_from_data, task_config={"batch_size": 10}
), # Generate knowledge graphs using cascade extraction
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
]
except Exception as error:
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)
raise error
cognee_config = get_cognify_config()
default_tasks = [
Task(classify_documents),
Task(
check_permissions_on_documents,
user=user,
permissions=["write"],
task_config=TaskConfig(needs=[classify_documents]),
),
Task( # Extract text chunks based on the document type.
extract_chunks_from_documents,
max_chunk_tokens=get_max_chunk_tokens(),
task_config=TaskConfig(needs=[check_permissions_on_documents], output_batch_size=50),
),
Task(
extract_graph_from_data,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
), # Generate knowledge graphs using cascade extraction
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
Task(
add_data_points,
task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),
),
]
return default_tasks

View file

@ -1,3 +1,3 @@
from .tasks.Task import Task
from .operations.run_tasks import run_tasks
from .operations.run_parallel import run_tasks_parallel
from .operations.needs import merge_needs, MergeNeeds

View file

@ -0,0 +1,18 @@
class WrongTaskOrderException(Exception):
message: str
def __init__(self, message: str):
self.message = message
super().__init__(message)
class TaskExecutionException(Exception):
type: str
message: str
traceback: str
def __init__(self, type: str, message: str, traceback: str):
self.message = message
self.type = type
self.traceback = traceback
super().__init__(message)

View file

@ -0,0 +1,60 @@
from typing import Any
from pydantic import BaseModel
from ..tasks import Task
class MergeNeeds(BaseModel):
needs: list[Any]
def merge_needs(*args):
return MergeNeeds(needs=args)
def get_task_needs(tasks: list[Task]):
input_tasks = []
for task in tasks:
if isinstance(task, MergeNeeds):
input_tasks.extend(task.needs)
else:
input_tasks.append(task)
return input_tasks
def get_need_task_results(results, task: Task):
input_results = []
for task_dependency in task.task_config.needs:
if isinstance(task_dependency, MergeNeeds):
task_results = []
max_result_length = 0
for task_need in task_dependency.needs:
result = results[task_need]
task_results.append(result)
if isinstance(result, tuple):
max_result_length = max(max_result_length, len(result))
final_results = [[] for _ in range(max_result_length)]
for result in task_results:
if isinstance(result, tuple):
for i, value in enumerate(result):
final_results[i].extend(value)
else:
final_results[0].extend(result)
input_results.extend(final_results)
else:
result = results[task_dependency]
if isinstance(result, tuple):
input_results.extend(result)
else:
input_results.append(result)
return input_results

View file

@ -1,227 +1,26 @@
import inspect
import json
from cognee.shared.logging_utils import get_logger
from uuid import UUID, uuid4
from typing import Any
from uuid import UUID, NAMESPACE_OID, uuid4, uuid5
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
log_pipeline_run_complete,
log_pipeline_run_error,
)
from cognee.modules.settings import get_current_settings
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.modules.settings import get_current_settings
from cognee.shared.utils import send_telemetry
from uuid import uuid5, NAMESPACE_OID
from cognee.shared.logging_utils import get_logger
from ..tasks.Task import Task
from ..tasks.Task import Task, TaskExecutionCompleted, TaskExecutionErrored, TaskExecutionStarted
from .run_tasks_base import run_tasks_base
logger = get_logger("run_tasks(tasks: [Task], data)")
async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
if len(tasks) == 0:
yield data
return
args = [data] if data is not None else []
running_task = tasks[0]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable):
logger.info("Async generator task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Async Generator Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
results = []
async_iterator = running_task.run(*args)
async for partial_result in async_iterator:
results.append(partial_result)
if len(results) == next_task_batch_size:
async for result in run_tasks_base(
leftover_tasks,
results[0] if next_task_batch_size == 1 else results,
user=user,
):
yield result
results = []
if len(results) > 0:
async for result in run_tasks_base(leftover_tasks, results, user):
yield result
results = []
logger.info("Async generator task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Async Generator Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Async generator task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Async Generator Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.isgeneratorfunction(running_task.executable):
logger.info("Generator task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Generator Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
results = []
for partial_result in running_task.run(*args):
results.append(partial_result)
if len(results) == next_task_batch_size:
async for result in run_tasks_base(
leftover_tasks, results[0] if next_task_batch_size == 1 else results, user
):
yield result
results = []
if len(results) > 0:
async for result in run_tasks_base(leftover_tasks, results, user):
yield result
results = []
logger.info("Generator task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Generator Task Completed",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Generator task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Generator Task Errored",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.iscoroutinefunction(running_task.executable):
logger.info("Coroutine task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Coroutine Task Started",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
try:
task_result = await running_task.run(*args)
async for result in run_tasks_base(leftover_tasks, task_result, user):
yield result
logger.info("Coroutine task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Coroutine Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Coroutine task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Coroutine Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.isfunction(running_task.executable):
logger.info("Function task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Function Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
task_result = running_task.run(*args)
async for result in run_tasks_base(leftover_tasks, task_result, user):
yield result
logger.info("Function task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Function Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Function task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Function Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
async def run_tasks_with_telemetry(
tasks: list[Task], data, pipeline_name: str, context: dict = None
):
config = get_current_settings()
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent=1))
@ -239,8 +38,45 @@ async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
| config,
)
async for result in run_tasks_base(tasks, data, user):
yield result
async for run_task_info in run_tasks_base(tasks, data, context):
if isinstance(run_task_info, TaskExecutionStarted):
send_telemetry(
"Task Run Started",
user.id,
additional_properties={
"task_name": run_task_info.task.__name__,
}
| config,
)
if isinstance(run_task_info, TaskExecutionCompleted):
send_telemetry(
"Task Run Completed",
user.id,
additional_properties={
"task_name": run_task_info.task.__name__,
}
| config,
)
if isinstance(run_task_info, TaskExecutionErrored):
send_telemetry(
"Task Run Errored",
user.id,
additional_properties={
"task_name": run_task_info.task.__name__,
"error": str(run_task_info.error),
}
| config,
)
logger.error(
"Task run errored: `%s`\n%s\n",
run_task_info.task.__name__,
str(run_task_info.error),
exc_info=True,
)
yield run_task_info
logger.info("Pipeline run completed: `%s`", pipeline_name)
send_telemetry(
@ -271,19 +107,22 @@ async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
async def run_tasks(
tasks: list[Task],
dataset_id: UUID = uuid4(),
dataset_id: UUID = None,
data: Any = None,
pipeline_name: str = "unknown_pipeline",
context: dict = None,
):
dataset_id = dataset_id or uuid4()
pipeline_id = uuid5(NAMESPACE_OID, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
yield pipeline_run
pipeline_run_id = pipeline_run.pipeline_run_id
try:
async for _ in run_tasks_with_telemetry(tasks, data, pipeline_id):
async for _ in run_tasks_with_telemetry(tasks, data, pipeline_id, context):
pass
yield await log_pipeline_run_complete(

View file

@ -0,0 +1,63 @@
from collections import deque
from cognee.shared.logging_utils import get_logger
from .needs import get_need_task_results, get_task_needs
from ..tasks.Task import Task, TaskExecutionCompleted, TaskExecutionInfo
from ..exceptions import WrongTaskOrderException
logger = get_logger("run_tasks_base(tasks: list[Task], data)")
async def run_tasks_base(tasks: list[Task], data=None, context=None):
if len(tasks) == 0:
return
pipeline_input = [data] if data is not None else []
"""Run tasks in dependency order and return results."""
task_graph = {} # Map task to its dependencies
dependents = {} # Reverse dependencies (who depends on whom)
results = {}
number_of_executed_tasks = 0
tasks_map = {task.executable: task for task in tasks} # Map task executable to task object
# Build task dependency graph
for task in tasks:
task_graph[task.executable] = get_task_needs(task.task_config.needs)
for dependent_task in task_graph[task.executable]:
dependents.setdefault(dependent_task, []).append(task.executable)
# Find tasks without dependencies
ready_queue = deque([task for task in tasks if not task_graph[task.executable]])
# Execute tasks in order
while ready_queue:
task = ready_queue.popleft()
task_inputs = (
get_need_task_results(results, task) if task.task_config.needs else pipeline_input
)
async for task_execution_info in task.run(*task_inputs): # Run task and store result
if isinstance(task_execution_info, TaskExecutionInfo): # Update result as it comes
results[task.executable] = task_execution_info.result
if isinstance(task_execution_info, TaskExecutionCompleted):
if task.executable not in results: # If result not already set, set it
results[task.executable] = task_execution_info.result
number_of_executed_tasks += 1
yield task_execution_info
# Process tasks depending on this task
for dependent_task in dependents.get(task.executable, []):
task_graph[dependent_task].remove(task.executable) # Mark dependency as resolved
if not task_graph[dependent_task]: # If all dependencies resolved, add to queue
ready_queue.append(tasks_map[dependent_task])
if number_of_executed_tasks != len(tasks):
raise WrongTaskOrderException(
f"{number_of_executed_tasks}/{len(tasks)} tasks executed. You likely have some disconnected tasks or circular dependency."
)

View file

@ -1,30 +1,136 @@
from typing import Union, Callable, Any, Coroutine, Generator, AsyncGenerator
import inspect
from typing import Callable, Any, Union
from pydantic import BaseModel
from ..tasks.types import TaskExecutable
from ..operations.needs import MergeNeeds
from ..exceptions import TaskExecutionException
class TaskExecutionStarted(BaseModel):
task: Callable
class TaskExecutionCompleted(BaseModel):
task: Callable
result: Any = None
class TaskExecutionErrored(BaseModel):
task: TaskExecutable
error: TaskExecutionException
model_config = {"arbitrary_types_allowed": True}
class TaskExecutionInfo(BaseModel):
result: Any = None
task: Callable
class TaskConfig(BaseModel):
output_batch_size: int = 1
needs: list[Union[Callable, MergeNeeds]] = []
class Task:
executable: Union[
Callable[..., Any],
Callable[..., Coroutine[Any, Any, Any]],
Generator[Any, Any, Any],
AsyncGenerator[Any, Any],
]
task_config: dict[str, Any] = {
"batch_size": 1,
}
task_config: TaskConfig
default_params: dict[str, Any] = {}
def __init__(self, executable, *args, task_config=None, **kwargs):
def __init__(self, executable, *args, task_config: TaskConfig = None, **kwargs):
self.executable = executable
self.default_params = {"args": args, "kwargs": kwargs}
self.result = None
if task_config is not None:
self.task_config = task_config
self.task_config = task_config or TaskConfig()
if "batch_size" not in task_config:
self.task_config["batch_size"] = 1
def run(self, *args, **kwargs):
async def run(self, *args, **kwargs):
combined_args = args + self.default_params["args"]
combined_kwargs = {**self.default_params["kwargs"], **kwargs}
combined_kwargs = {
**self.default_params["kwargs"],
**kwargs,
}
return self.executable(*combined_args, **combined_kwargs)
yield TaskExecutionStarted(
task=self.executable,
)
try:
if inspect.iscoroutinefunction(self.executable): # Async function
end_result = await self.executable(*combined_args, **combined_kwargs)
elif inspect.isgeneratorfunction(self.executable): # Generator
task_result = []
end_result = []
for value in self.executable(*combined_args, **combined_kwargs):
task_result.append(value) # Store the last yielded value
end_result.append(value)
if self.task_config.output_batch_size == 1:
yield TaskExecutionInfo(
result=value,
task=self.executable,
)
elif self.task_config.output_batch_size == len(task_result):
yield TaskExecutionInfo(
result=task_result,
task=self.executable,
)
task_result = [] # Reset for the next batch
# Yield any remaining items in the final batch if it's not empty
if task_result and self.task_config.output_batch_size > 1:
yield TaskExecutionInfo(
result=task_result,
task=self.executable,
)
elif inspect.isasyncgenfunction(self.executable): # Async Generator
task_result = []
end_result = []
async for value in self.executable(*combined_args, **combined_kwargs):
task_result.append(value) # Store the last yielded value
end_result.append(value)
if self.task_config.output_batch_size == 1:
yield TaskExecutionInfo(
result=value,
task=self.executable,
)
elif self.task_config.output_batch_size == len(task_result):
yield TaskExecutionInfo(
result=task_result,
task=self.executable,
)
task_result = [] # Reset for the next batch
# Yield any remaining items in the final batch if it's not empty
if task_result and self.task_config.output_batch_size > 1:
yield TaskExecutionInfo(
result=task_result,
task=self.executable,
)
else: # Regular function
end_result = self.executable(*combined_args, **combined_kwargs)
yield TaskExecutionCompleted(
task=self.executable,
result=end_result,
)
except Exception as error:
import traceback
error_details = TaskExecutionException(
type=type(error).__name__,
message=str(error),
traceback=traceback.format_exc(),
)
yield TaskExecutionErrored(
task=self.executable,
error=error_details,
)

View file

@ -0,0 +1,8 @@
from .Task import (
Task,
TaskConfig,
TaskExecutionInfo,
TaskExecutionCompleted,
TaskExecutionStarted,
TaskExecutionErrored,
)

View file

@ -0,0 +1,9 @@
from typing import Any, AsyncGenerator, Callable, Coroutine, Generator, Union
TaskExecutable = Union[
Callable[..., Any],
Callable[..., Coroutine[Any, Any, Any]],
AsyncGenerator[Any, Any],
Generator[Any, Any, Any],
]

View file

@ -12,7 +12,6 @@ from cognee.modules.graph.utils import (
retrieve_existing_edges,
)
from cognee.shared.data_models import KnowledgeGraph
from cognee.tasks.storage import add_data_points
async def integrate_chunk_graphs(
@ -28,7 +27,6 @@ async def integrate_chunk_graphs(
for chunk_index, chunk_graph in enumerate(chunk_graphs):
data_chunks[chunk_index].contains = chunk_graph
await add_data_points(chunk_graphs)
return data_chunks
existing_edges_map = await retrieve_existing_edges(
@ -41,13 +39,7 @@ async def integrate_chunk_graphs(
data_chunks, chunk_graphs, ontology_adapter, existing_edges_map
)
if len(graph_nodes) > 0:
await add_data_points(graph_nodes)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks
return graph_nodes, graph_edges
async def extract_graph_from_data(

View file

@ -41,4 +41,5 @@ async def resolve_data_directories(
resolved_data.append(item)
else: # If it's not a string add it directly
resolved_data.append(item)
return resolved_data

View file

@ -6,7 +6,8 @@ from .index_data_points import index_data_points
from .index_graph_edges import index_graph_edges
async def add_data_points(data_points: list[DataPoint]):
async def add_data_points(data_points: list[DataPoint], data_point_connections: list = None):
data_point_connections = data_point_connections or []
nodes = []
edges = []
@ -38,6 +39,8 @@ async def add_data_points(data_points: list[DataPoint]):
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)
if data_point_connections:
await graph_engine.add_edges(data_point_connections)
# This step has to happen after adding nodes and edges because we query the graph.
await index_graph_edges()

View file

@ -2,12 +2,12 @@ import asyncio
from typing import Type
from uuid import uuid5
from pydantic import BaseModel
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from .models import TextSummary
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
async def summarize_text(data_chunks: list[DataPoint], summarization_model: Type[BaseModel]):
if len(data_chunks) == 0:
return data_chunks

View file

@ -1,66 +0,0 @@
import asyncio
from queue import Queue
import cognee
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.relational import create_db_and_tables
async def pipeline(data_queue):
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
async def queue_consumer():
while not data_queue.is_closed:
if not data_queue.empty():
yield data_queue.get()
else:
await asyncio.sleep(0.3)
async def add_one(num):
yield num + 1
async def multiply_by_two(num):
yield num * 2
await create_db_and_tables()
user = await get_default_user()
tasks_run = run_tasks_base(
[
Task(queue_consumer),
Task(add_one),
Task(multiply_by_two),
],
data=None,
user=user,
)
results = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
index = 0
async for result in tasks_run:
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
index += 1
async def run_queue():
data_queue = Queue()
data_queue.is_closed = False
async def queue_producer():
for i in range(0, 10):
data_queue.put(i)
await asyncio.sleep(0.1)
data_queue.is_closed = True
await asyncio.gather(pipeline(data_queue), queue_producer())
def test_run_tasks_from_queue():
asyncio.run(run_queue())
if __name__ == "__main__":
asyncio.run(run_queue())

View file

@ -0,0 +1,41 @@
import asyncio
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
async def run_and_check_tasks():
def number_generator(num, context=None):
for i in range(num):
yield i + 1
async def add_one(num, context=None):
yield num + 1
async def multiply_by_two(num, context=None):
yield num * 2
index = 0
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20]
async for task_run_info in run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
Task(multiply_by_two, task_config=TaskConfig(needs=[number_generator])),
],
data=10,
):
if isinstance(task_run_info, TaskExecutionInfo):
assert task_run_info.result == expected_results[index], (
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
)
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -0,0 +1,47 @@
import pytest
import asyncio
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
from cognee.modules.pipelines.exceptions import WrongTaskOrderException
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
async def run_and_check_tasks():
def number_generator(num, context=None):
for i in range(num):
yield i + 1
async def add_one(num, context=None):
yield num + 1
async def multiply_by_two(num, context=None):
yield num * 2
index = 0
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22]
with pytest.raises(
WrongTaskOrderException,
match="1/3 tasks executed. You likely have some disconnected tasks or circular dependency.",
):
async for task_run_info in run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config=TaskConfig(needs=[number_generator, multiply_by_two])),
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one])),
],
data=10,
):
if isinstance(task_run_info, TaskExecutionInfo):
assert task_run_info.result == expected_results[index], (
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
)
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -0,0 +1,46 @@
import asyncio
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
async def run_and_check_tasks():
def number_generator(num, context=None):
for i in range(num):
yield i + 1
async def add_one(num, context=None):
yield num + 1
async def add_two(num, context=None):
yield num + 2
async def multiply_by_two(num1, num2, context=None):
yield num1 * 2
yield num2 * 2
index = 0
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 22, 24]
async for task_run_info in run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
Task(add_two, task_config=TaskConfig(needs=[number_generator])),
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one, add_two])),
],
data=10,
):
if isinstance(task_run_info, TaskExecutionInfo):
assert task_run_info.result == expected_results[index], (
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
)
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -0,0 +1,41 @@
import asyncio
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
async def run_and_check_tasks():
def number_generator(num, context=None):
for i in range(num):
yield i + 1
async def add_one(num, context=None):
yield num + 1
async def multiply_by_two(num, context=None):
yield num * 2
index = 0
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22]
async for task_run_info in run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one])),
],
data=10,
):
if isinstance(task_run_info, TaskExecutionInfo):
assert task_run_info.result == expected_results[index], (
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
)
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -1,50 +0,0 @@
import asyncio
import cognee
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.relational import create_db_and_tables
async def run_and_check_tasks():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
def number_generator(num):
for i in range(num):
yield i + 1
async def add_one(nums):
for num in nums:
yield num + 1
async def multiply_by_two(num):
yield num * 2
async def add_one_single(num):
yield num + 1
await create_db_and_tables()
user = await get_default_user()
pipeline = run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config={"batch_size": 5}),
Task(multiply_by_two, task_config={"batch_size": 1}),
Task(add_one_single),
],
data=10,
user=user,
)
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
index = 0
async for result in pipeline:
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())

View file

@ -1,6 +1,7 @@
import os
import pathlib
import cognee
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
@ -76,6 +77,12 @@ async def main():
assert len(history) == 6, "Search history is not correct."
# await render_graph()
graph_file_path = os.path.join(
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
)
await visualize_graph(graph_file_path)
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -74,13 +74,13 @@
" get_repo_file_dependencies,\n",
")\n",
"from cognee.tasks.storage import add_data_points\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
"\n",
"detailed_extraction = True\n",
"\n",
"tasks = [\n",
" Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),\n",
" Task(add_data_points, task_config={\"batch_size\": 100 if detailed_extraction else 500}),\n",
" Task(add_data_points, task_config=TaskConfig(needs=[get_repo_file_dependencies], output_batch_size=100 if detailed_extraction else 500)),\n",
"]"
]
},

View file

@ -518,11 +518,11 @@
"from cognee.modules.data.models import Dataset, Data\n",
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
"from cognee.modules.cognify.config import get_cognify_config\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"from cognee.modules.pipelines import run_tasks\n",
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
"from cognee.modules.pipelines.operations.needs import merge_needs\n",
"from cognee.modules.users.models import User\n",
"from cognee.tasks.documents import (\n",
" check_permissions_on_documents,\n",
" classify_documents,\n",
" extract_chunks_from_documents,\n",
")\n",
@ -540,19 +540,25 @@
"\n",
" tasks = [\n",
" Task(classify_documents),\n",
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
" Task(\n",
" extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n",
" ), # Extract text chunks based on the document type.\n",
" Task(\n",
" extract_graph_from_data, graph_model=KnowledgeGraph, task_config={\"batch_size\": 10}\n",
" ), # Generate knowledge graphs from the document chunks.\n",
" Task( # Extract text chunks based on the document type.\n",
" extract_chunks_from_documents,\n",
" max_chunk_size=get_max_chunk_tokens(),\n",
" task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),\n",
" ),\n",
" Task( # Generate knowledge graphs from the document chunks.\n",
" extract_graph_from_data,\n",
" graph_model=KnowledgeGraph,\n",
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
" ),\n",
" Task(\n",
" summarize_text,\n",
" summarization_model=cognee_config.summarization_model,\n",
" task_config={\"batch_size\": 10},\n",
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
" ),\n",
" Task(\n",
" add_data_points,\n",
" task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),\n",
" ),\n",
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
" ]\n",
"\n",
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, \"cognify_pipeline\")\n",
@ -1041,7 +1047,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "py312",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@ -1055,7 +1061,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.11.8"
}
},
"nbformat": 4,

View file

@ -397,14 +397,15 @@
"from cognee.modules.data.models import Dataset, Data\n",
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
"from cognee.modules.cognify.config import get_cognify_config\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"from cognee.modules.pipelines import run_tasks\n",
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
"from cognee.modules.pipelines.operations.needs import merge_needs\n",
"from cognee.modules.users.models import User\n",
"from cognee.tasks.documents import (\n",
" check_permissions_on_documents,\n",
" classify_documents,\n",
" extract_chunks_from_documents,\n",
")\n",
"from cognee.infrastructure.llm import get_max_chunk_tokens\n",
"from cognee.tasks.graph import extract_graph_from_data\n",
"from cognee.tasks.storage import add_data_points\n",
"from cognee.tasks.summarization import summarize_text\n",
@ -418,17 +419,25 @@
"\n",
" tasks = [\n",
" Task(classify_documents),\n",
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
" Task(\n",
" extract_graph_from_data, graph_model=KnowledgeGraph, task_config={\"batch_size\": 10}\n",
" ), # Generate knowledge graphs from the document chunks.\n",
" Task( # Extract text chunks based on the document type.\n",
" extract_chunks_from_documents,\n",
" max_chunk_size=get_max_chunk_tokens(),\n",
" task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),\n",
" ),\n",
" Task( # Generate knowledge graphs from the document chunks.\n",
" extract_graph_from_data,\n",
" graph_model=KnowledgeGraph,\n",
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
" ),\n",
" Task(\n",
" summarize_text,\n",
" summarization_model=cognee_config.summarization_model,\n",
" task_config={\"batch_size\": 10},\n",
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
" ),\n",
" Task(\n",
" add_data_points,\n",
" task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),\n",
" ),\n",
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
" ]\n",
"\n",
" pipeline = run_tasks(tasks, data_documents)\n",

File diff suppressed because one or more lines are too long