Revert "feat: pipeline tasks needs mapping" (#717)

Reverts topoteretes/cognee#690

I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Boris 2025-04-10 12:10:12 +02:00 committed by GitHub
parent c3d33e728e
commit 9536395468
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 620 additions and 822 deletions

View file

@ -1,8 +1,7 @@
from typing import Union, BinaryIO
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks import TaskConfig, Task
from cognee.modules.pipelines import run_tasks, Task
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
from cognee.infrastructure.databases.relational import (
create_db_and_tables as create_relational_db_and_tables,
@ -37,15 +36,7 @@ async def add(
if user is None:
user = await get_default_user()
tasks = [
Task(resolve_data_directories),
Task(
ingest_data,
dataset_name,
user,
task_config=TaskConfig(needs=[resolve_data_directories]),
),
]
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
pipeline = run_tasks(

View file

@ -1,19 +1,18 @@
import os
import pathlib
import asyncio
from cognee.shared.logging_utils import get_logger
from uuid import NAMESPACE_OID, uuid5
from cognee.shared.logging_utils import get_logger
from cognee.api.v1.search import SearchType, search
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.base_config import get_base_config
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task, TaskConfig
from cognee.modules.pipelines.operations.needs import merge_needs
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
from cognee.shared.utils import render_graph
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.ingestion import ingest_data
@ -46,46 +45,25 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
detailed_extraction = True
tasks = [
Task(
get_repo_file_dependencies,
detailed_extraction=detailed_extraction,
task_config=TaskConfig(output_batch_size=500),
),
# Task(summarize_code, task_config=TaskConfig(output_batch_size=500)), # This task takes a long time to complete
Task(add_data_points, task_config=TaskConfig(needs=[get_repo_file_dependencies])),
Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
Task(add_data_points, task_config={"batch_size": 500}),
]
if include_docs:
# This tasks take a long time to complete
non_code_tasks = [
Task(get_non_py_files),
Task(get_non_py_files, task_config={"batch_size": 50}),
Task(ingest_data, dataset_name="repo_docs", user=user),
Task(classify_documents),
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
Task(
ingest_data,
dataset_name="repo_docs",
user=user,
task_config=TaskConfig(needs=[get_non_py_files]),
),
Task(classify_documents, task_config=TaskConfig(needs=[ingest_data])),
Task(
extract_chunks_from_documents,
max_chunk_size=get_max_chunk_tokens(),
task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),
),
Task(
extract_graph_from_data,
graph_model=KnowledgeGraph,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
),
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
Task(
add_data_points,
task_config=TaskConfig(
needs=[merge_needs(summarize_text, extract_graph_from_data)]
),
task_config={"batch_size": 50},
),
]
@ -93,11 +71,11 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
if include_docs:
non_code_pipeline_run = run_tasks(non_code_tasks, dataset_id, repo_path, "cognify_pipeline")
async for run_info in non_code_pipeline_run:
yield run_info
async for run_status in non_code_pipeline_run:
yield run_status
async for run_info in run_tasks(tasks, dataset_id, repo_path, "cognify_code_pipeline"):
yield run_info
async for run_status in run_tasks(tasks, dataset_id, repo_path, "cognify_code_pipeline"):
yield run_status
if __name__ == "__main__":

View file

@ -10,10 +10,10 @@ from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
from cognee.modules.data.models import Data, Dataset
from cognee.modules.pipelines import run_tasks, merge_needs
from cognee.modules.pipelines.tasks import Task, TaskConfig
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.models import PipelineRunStatus
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph
@ -92,9 +92,7 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, tasks: list[Task]):
if not isinstance(task, Task):
raise ValueError(f"Task {task} is not an instance of Task")
pipeline_run = run_tasks(
tasks, dataset.id, data_documents, "cognify_pipeline", context={"user": user}
)
pipeline_run = run_tasks(tasks, dataset.id, data_documents, "cognify_pipeline")
pipeline_run_status = None
async for run_status in pipeline_run:
@ -123,33 +121,24 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(
check_permissions_on_documents,
user=user,
permissions=["write"],
task_config=TaskConfig(needs=[classify_documents]),
),
Task( # Extract text chunks based on the document type.
extract_chunks_from_documents,
max_chunk_size=chunk_size or get_max_chunk_tokens(),
chunker=chunker,
task_config=TaskConfig(needs=[check_permissions_on_documents], output_batch_size=10),
),
Task( # Generate knowledge graphs from the document chunks.
), # Extract text chunks based on the document type.
Task(
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=ontology_adapter,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
task_config={"batch_size": 10},
), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
Task(
add_data_points,
task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
]
return default_tasks

View file

@ -2,11 +2,11 @@ from typing import List
from pydantic import BaseModel
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines.operations.needs import merge_needs
from cognee.modules.pipelines.tasks import Task, TaskConfig
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.data_models import KnowledgeGraph
from cognee.shared.utils import send_telemetry
from cognee.tasks.documents import (
check_permissions_on_documents,
classify_documents,
@ -27,32 +27,25 @@ async def get_cascade_graph_tasks(
if user is None:
user = await get_default_user()
cognee_config = get_cognify_config()
default_tasks = [
Task(classify_documents),
Task(
check_permissions_on_documents,
user=user,
permissions=["write"],
task_config=TaskConfig(needs=[classify_documents]),
),
Task( # Extract text chunks based on the document type.
extract_chunks_from_documents,
max_chunk_tokens=get_max_chunk_tokens(),
task_config=TaskConfig(needs=[check_permissions_on_documents], output_batch_size=50),
),
Task(
extract_graph_from_data,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
), # Generate knowledge graphs using cascade extraction
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
),
Task(
add_data_points,
task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),
),
]
try:
cognee_config = get_cognify_config()
default_tasks = [
Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
), # Extract text chunks based on the document type.
Task(
extract_graph_from_data, task_config={"batch_size": 10}
), # Generate knowledge graphs using cascade extraction
Task(
summarize_text,
summarization_model=cognee_config.summarization_model,
task_config={"batch_size": 10},
),
Task(add_data_points, task_config={"batch_size": 10}),
]
except Exception as error:
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)
raise error
return default_tasks

View file

@ -2,20 +2,10 @@ from typing import List
from cognee.api.v1.cognify.cognify import get_default_tasks
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.chunking.TextChunker import TextChunker
from cognee.modules.pipelines.tasks import TaskConfig
from cognee.tasks.documents import (
classify_documents,
check_permissions_on_documents,
extract_chunks_from_documents,
)
from cognee.tasks.graph import extract_graph_from_data
from cognee.tasks.storage import add_data_points
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.pipelines import run_tasks, merge_needs
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
from cognee.infrastructure.llm import get_max_chunk_tokens
async def get_default_tasks_by_indices(
@ -49,13 +39,9 @@ async def get_no_summary_tasks(
extract_graph_from_data,
graph_model=graph_model,
ontology_adapter=ontology_adapter,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
)
add_data_points_task = Task(
add_data_points,
task_config=TaskConfig(needs=[extract_graph_from_data]),
)
add_data_points_task = Task(add_data_points)
return base_tasks + [graph_task, add_data_points_task]
@ -67,9 +53,6 @@ async def get_just_chunks_tasks(
# Get base tasks (0=classify, 1=check_permissions, 2=extract_chunks)
base_tasks = await get_default_tasks_by_indices([0, 1, 2], chunk_size, chunker)
add_data_points_task = Task(
add_data_points,
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
)
add_data_points_task = Task(add_data_points)
return base_tasks + [add_data_points_task]

View file

@ -1,3 +1,3 @@
from .tasks.Task import Task
from .operations.run_tasks import run_tasks
from .operations.needs import merge_needs, MergeNeeds
from .operations.run_parallel import run_tasks_parallel

View file

@ -1,18 +0,0 @@
class WrongTaskOrderException(Exception):
message: str
def __init__(self, message: str):
self.message = message
super().__init__(message)
class TaskExecutionException(Exception):
type: str
message: str
traceback: str
def __init__(self, type: str, message: str, traceback: str):
self.message = message
self.type = type
self.traceback = traceback
super().__init__(message)

View file

@ -1,60 +0,0 @@
from typing import Any
from pydantic import BaseModel
from ..tasks import Task
class MergeNeeds(BaseModel):
needs: list[Any]
def merge_needs(*args):
return MergeNeeds(needs=args)
def get_task_needs(tasks: list[Task]):
input_tasks = []
for task in tasks:
if isinstance(task, MergeNeeds):
input_tasks.extend(task.needs)
else:
input_tasks.append(task)
return input_tasks
def get_need_task_results(results, task: Task):
input_results = []
for task_dependency in task.task_config.needs:
if isinstance(task_dependency, MergeNeeds):
task_results = []
max_result_length = 0
for task_need in task_dependency.needs:
result = results[task_need]
task_results.append(result)
if isinstance(result, tuple):
max_result_length = max(max_result_length, len(result))
final_results = [[] for _ in range(max_result_length)]
for result in task_results:
if isinstance(result, tuple):
for i, value in enumerate(result):
final_results[i].extend(value)
else:
final_results[0].extend(result)
input_results.extend(final_results)
else:
result = results[task_dependency]
if isinstance(result, tuple):
input_results.extend(result)
else:
input_results.append(result)
return input_results

View file

@ -1,26 +1,227 @@
import inspect
import json
from typing import Any
from uuid import UUID, NAMESPACE_OID, uuid4, uuid5
from cognee.shared.logging_utils import get_logger
from uuid import UUID, uuid4
from typing import Any
from cognee.modules.pipelines.operations import (
log_pipeline_run_start,
log_pipeline_run_complete,
log_pipeline_run_error,
)
from cognee.modules.users.methods import get_default_user
from cognee.modules.settings import get_current_settings
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee.shared.logging_utils import get_logger
from uuid import uuid5, NAMESPACE_OID
from ..tasks.Task import Task, TaskExecutionCompleted, TaskExecutionErrored, TaskExecutionStarted
from .run_tasks_base import run_tasks_base
from ..tasks.Task import Task
logger = get_logger("run_tasks(tasks: [Task], data)")
async def run_tasks_with_telemetry(
tasks: list[Task], data, pipeline_name: str, context: dict = None
):
async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
if len(tasks) == 0:
yield data
return
args = [data] if data is not None else []
running_task = tasks[0]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable):
logger.info("Async generator task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Async Generator Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
results = []
async_iterator = running_task.run(*args)
async for partial_result in async_iterator:
results.append(partial_result)
if len(results) == next_task_batch_size:
async for result in run_tasks_base(
leftover_tasks,
results[0] if next_task_batch_size == 1 else results,
user=user,
):
yield result
results = []
if len(results) > 0:
async for result in run_tasks_base(leftover_tasks, results, user):
yield result
results = []
logger.info("Async generator task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Async Generator Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Async generator task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Async Generator Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.isgeneratorfunction(running_task.executable):
logger.info("Generator task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Generator Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
results = []
for partial_result in running_task.run(*args):
results.append(partial_result)
if len(results) == next_task_batch_size:
async for result in run_tasks_base(
leftover_tasks, results[0] if next_task_batch_size == 1 else results, user
):
yield result
results = []
if len(results) > 0:
async for result in run_tasks_base(leftover_tasks, results, user):
yield result
results = []
logger.info("Generator task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Generator Task Completed",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Generator task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Generator Task Errored",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.iscoroutinefunction(running_task.executable):
logger.info("Coroutine task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Coroutine Task Started",
user_id=user.id,
additional_properties={
"task_name": running_task.executable.__name__,
},
)
try:
task_result = await running_task.run(*args)
async for result in run_tasks_base(leftover_tasks, task_result, user):
yield result
logger.info("Coroutine task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Coroutine Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Coroutine task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Coroutine Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
elif inspect.isfunction(running_task.executable):
logger.info("Function task started: `%s`", running_task.executable.__name__)
send_telemetry(
"Function Task Started",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
try:
task_result = running_task.run(*args)
async for result in run_tasks_base(leftover_tasks, task_result, user):
yield result
logger.info("Function task completed: `%s`", running_task.executable.__name__)
send_telemetry(
"Function Task Completed",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
except Exception as error:
logger.error(
"Function task errored: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info=True,
)
send_telemetry(
"Function Task Errored",
user.id,
{
"task_name": running_task.executable.__name__,
},
)
raise error
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
config = get_current_settings()
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent=1))
@ -38,45 +239,8 @@ async def run_tasks_with_telemetry(
| config,
)
async for run_task_info in run_tasks_base(tasks, data, context):
if isinstance(run_task_info, TaskExecutionStarted):
send_telemetry(
"Task Run Started",
user.id,
additional_properties={
"task_name": run_task_info.task.__name__,
}
| config,
)
if isinstance(run_task_info, TaskExecutionCompleted):
send_telemetry(
"Task Run Completed",
user.id,
additional_properties={
"task_name": run_task_info.task.__name__,
}
| config,
)
if isinstance(run_task_info, TaskExecutionErrored):
send_telemetry(
"Task Run Errored",
user.id,
additional_properties={
"task_name": run_task_info.task.__name__,
"error": str(run_task_info.error),
}
| config,
)
logger.error(
"Task run errored: `%s`\n%s\n",
run_task_info.task.__name__,
str(run_task_info.error),
exc_info=True,
)
yield run_task_info
async for result in run_tasks_base(tasks, data, user):
yield result
logger.info("Pipeline run completed: `%s`", pipeline_name)
send_telemetry(
@ -107,22 +271,19 @@ async def run_tasks_with_telemetry(
async def run_tasks(
tasks: list[Task],
dataset_id: UUID = None,
dataset_id: UUID = uuid4(),
data: Any = None,
pipeline_name: str = "unknown_pipeline",
context: dict = None,
):
dataset_id = dataset_id or uuid4()
pipeline_id = uuid5(NAMESPACE_OID, pipeline_name)
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
yield pipeline_run
pipeline_run_id = pipeline_run.pipeline_run_id
try:
async for _ in run_tasks_with_telemetry(tasks, data, pipeline_id, context):
async for _ in run_tasks_with_telemetry(tasks, data, pipeline_id):
pass
yield await log_pipeline_run_complete(

View file

@ -1,63 +0,0 @@
from collections import deque
from cognee.shared.logging_utils import get_logger
from .needs import get_need_task_results, get_task_needs
from ..tasks.Task import Task, TaskExecutionCompleted, TaskExecutionInfo
from ..exceptions import WrongTaskOrderException
logger = get_logger("run_tasks_base(tasks: list[Task], data)")
async def run_tasks_base(tasks: list[Task], data=None, context=None):
if len(tasks) == 0:
return
pipeline_input = [data] if data is not None else []
"""Run tasks in dependency order and return results."""
task_graph = {} # Map task to its dependencies
dependents = {} # Reverse dependencies (who depends on whom)
results = {}
number_of_executed_tasks = 0
tasks_map = {task.executable: task for task in tasks} # Map task executable to task object
# Build task dependency graph
for task in tasks:
task_graph[task.executable] = get_task_needs(task.task_config.needs)
for dependent_task in task_graph[task.executable]:
dependents.setdefault(dependent_task, []).append(task.executable)
# Find tasks without dependencies
ready_queue = deque([task for task in tasks if not task_graph[task.executable]])
# Execute tasks in order
while ready_queue:
task = ready_queue.popleft()
task_inputs = (
get_need_task_results(results, task) if task.task_config.needs else pipeline_input
)
async for task_execution_info in task.run(*task_inputs): # Run task and store result
if isinstance(task_execution_info, TaskExecutionInfo): # Update result as it comes
results[task.executable] = task_execution_info.result
if isinstance(task_execution_info, TaskExecutionCompleted):
if task.executable not in results: # If result not already set, set it
results[task.executable] = task_execution_info.result
number_of_executed_tasks += 1
yield task_execution_info
# Process tasks depending on this task
for dependent_task in dependents.get(task.executable, []):
task_graph[dependent_task].remove(task.executable) # Mark dependency as resolved
if not task_graph[dependent_task]: # If all dependencies resolved, add to queue
ready_queue.append(tasks_map[dependent_task])
if number_of_executed_tasks != len(tasks):
raise WrongTaskOrderException(
f"{number_of_executed_tasks}/{len(tasks)} tasks executed. You likely have some disconnected tasks or circular dependency."
)

View file

@ -1,136 +1,30 @@
import inspect
from typing import Callable, Any, Union
from pydantic import BaseModel
from ..tasks.types import TaskExecutable
from ..operations.needs import MergeNeeds
from ..exceptions import TaskExecutionException
class TaskExecutionStarted(BaseModel):
task: Callable
class TaskExecutionCompleted(BaseModel):
task: Callable
result: Any = None
class TaskExecutionErrored(BaseModel):
task: TaskExecutable
error: TaskExecutionException
model_config = {"arbitrary_types_allowed": True}
class TaskExecutionInfo(BaseModel):
result: Any = None
task: Callable
class TaskConfig(BaseModel):
output_batch_size: int = 1
needs: list[Union[Callable, MergeNeeds]] = []
from typing import Union, Callable, Any, Coroutine, Generator, AsyncGenerator
class Task:
task_config: TaskConfig
executable: Union[
Callable[..., Any],
Callable[..., Coroutine[Any, Any, Any]],
Generator[Any, Any, Any],
AsyncGenerator[Any, Any],
]
task_config: dict[str, Any] = {
"batch_size": 1,
}
default_params: dict[str, Any] = {}
def __init__(self, executable, *args, task_config: TaskConfig = None, **kwargs):
def __init__(self, executable, *args, task_config=None, **kwargs):
self.executable = executable
self.default_params = {"args": args, "kwargs": kwargs}
self.result = None
self.task_config = task_config or TaskConfig()
if task_config is not None:
self.task_config = task_config
async def run(self, *args, **kwargs):
if "batch_size" not in task_config:
self.task_config["batch_size"] = 1
def run(self, *args, **kwargs):
combined_args = args + self.default_params["args"]
combined_kwargs = {
**self.default_params["kwargs"],
**kwargs,
}
combined_kwargs = {**self.default_params["kwargs"], **kwargs}
yield TaskExecutionStarted(
task=self.executable,
)
try:
if inspect.iscoroutinefunction(self.executable): # Async function
end_result = await self.executable(*combined_args, **combined_kwargs)
elif inspect.isgeneratorfunction(self.executable): # Generator
task_result = []
end_result = []
for value in self.executable(*combined_args, **combined_kwargs):
task_result.append(value) # Store the last yielded value
end_result.append(value)
if self.task_config.output_batch_size == 1:
yield TaskExecutionInfo(
result=value,
task=self.executable,
)
elif self.task_config.output_batch_size == len(task_result):
yield TaskExecutionInfo(
result=task_result,
task=self.executable,
)
task_result = [] # Reset for the next batch
# Yield any remaining items in the final batch if it's not empty
if task_result and self.task_config.output_batch_size > 1:
yield TaskExecutionInfo(
result=task_result,
task=self.executable,
)
elif inspect.isasyncgenfunction(self.executable): # Async Generator
task_result = []
end_result = []
async for value in self.executable(*combined_args, **combined_kwargs):
task_result.append(value) # Store the last yielded value
end_result.append(value)
if self.task_config.output_batch_size == 1:
yield TaskExecutionInfo(
result=value,
task=self.executable,
)
elif self.task_config.output_batch_size == len(task_result):
yield TaskExecutionInfo(
result=task_result,
task=self.executable,
)
task_result = [] # Reset for the next batch
# Yield any remaining items in the final batch if it's not empty
if task_result and self.task_config.output_batch_size > 1:
yield TaskExecutionInfo(
result=task_result,
task=self.executable,
)
else: # Regular function
end_result = self.executable(*combined_args, **combined_kwargs)
yield TaskExecutionCompleted(
task=self.executable,
result=end_result,
)
except Exception as error:
import traceback
error_details = TaskExecutionException(
type=type(error).__name__,
message=str(error),
traceback=traceback.format_exc(),
)
yield TaskExecutionErrored(
task=self.executable,
error=error_details,
)
return self.executable(*combined_args, **combined_kwargs)

View file

@ -1,8 +0,0 @@
from .Task import (
Task,
TaskConfig,
TaskExecutionInfo,
TaskExecutionCompleted,
TaskExecutionStarted,
TaskExecutionErrored,
)

View file

@ -1,9 +0,0 @@
from typing import Any, AsyncGenerator, Callable, Coroutine, Generator, Union
TaskExecutable = Union[
Callable[..., Any],
Callable[..., Coroutine[Any, Any, Any]],
AsyncGenerator[Any, Any],
Generator[Any, Any, Any],
]

View file

@ -12,6 +12,7 @@ from cognee.modules.graph.utils import (
retrieve_existing_edges,
)
from cognee.shared.data_models import KnowledgeGraph
from cognee.tasks.storage import add_data_points
async def integrate_chunk_graphs(
@ -27,6 +28,7 @@ async def integrate_chunk_graphs(
for chunk_index, chunk_graph in enumerate(chunk_graphs):
data_chunks[chunk_index].contains = chunk_graph
await add_data_points(chunk_graphs)
return data_chunks
existing_edges_map = await retrieve_existing_edges(
@ -39,7 +41,13 @@ async def integrate_chunk_graphs(
data_chunks, chunk_graphs, ontology_adapter, existing_edges_map
)
return graph_nodes, graph_edges
if len(graph_nodes) > 0:
await add_data_points(graph_nodes)
if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges)
return data_chunks
async def extract_graph_from_data(

View file

@ -41,5 +41,4 @@ async def resolve_data_directories(
resolved_data.append(item)
else: # If it's not a string add it directly
resolved_data.append(item)
return resolved_data

View file

@ -6,8 +6,7 @@ from .index_data_points import index_data_points
from .index_graph_edges import index_graph_edges
async def add_data_points(data_points: list[DataPoint], data_point_connections: list = None):
data_point_connections = data_point_connections or []
async def add_data_points(data_points: list[DataPoint]):
nodes = []
edges = []
@ -39,8 +38,6 @@ async def add_data_points(data_points: list[DataPoint], data_point_connections:
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)
if data_point_connections:
await graph_engine.add_edges(data_point_connections)
# This step has to happen after adding nodes and edges because we query the graph.
await index_graph_edges()

View file

@ -2,12 +2,12 @@ import asyncio
from typing import Type
from uuid import uuid5
from pydantic import BaseModel
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from .models import TextSummary
async def summarize_text(data_chunks: list[DataPoint], summarization_model: Type[BaseModel]):
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
if len(data_chunks) == 0:
return data_chunks

View file

@ -0,0 +1,66 @@
import asyncio
from queue import Queue
import cognee
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.relational import create_db_and_tables
async def pipeline(data_queue):
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
async def queue_consumer():
while not data_queue.is_closed:
if not data_queue.empty():
yield data_queue.get()
else:
await asyncio.sleep(0.3)
async def add_one(num):
yield num + 1
async def multiply_by_two(num):
yield num * 2
await create_db_and_tables()
user = await get_default_user()
tasks_run = run_tasks_base(
[
Task(queue_consumer),
Task(add_one),
Task(multiply_by_two),
],
data=None,
user=user,
)
results = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
index = 0
async for result in tasks_run:
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
index += 1
async def run_queue():
data_queue = Queue()
data_queue.is_closed = False
async def queue_producer():
for i in range(0, 10):
data_queue.put(i)
await asyncio.sleep(0.1)
data_queue.is_closed = True
await asyncio.gather(pipeline(data_queue), queue_producer())
def test_run_tasks_from_queue():
asyncio.run(run_queue())
if __name__ == "__main__":
asyncio.run(run_queue())

View file

@ -1,41 +0,0 @@
import asyncio
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
async def run_and_check_tasks():
def number_generator(num, context=None):
for i in range(num):
yield i + 1
async def add_one(num, context=None):
yield num + 1
async def multiply_by_two(num, context=None):
yield num * 2
index = 0
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20]
async for task_run_info in run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
Task(multiply_by_two, task_config=TaskConfig(needs=[number_generator])),
],
data=10,
):
if isinstance(task_run_info, TaskExecutionInfo):
assert task_run_info.result == expected_results[index], (
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
)
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -1,47 +0,0 @@
import pytest
import asyncio
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
from cognee.modules.pipelines.exceptions import WrongTaskOrderException
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
async def run_and_check_tasks():
def number_generator(num, context=None):
for i in range(num):
yield i + 1
async def add_one(num, context=None):
yield num + 1
async def multiply_by_two(num, context=None):
yield num * 2
index = 0
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22]
with pytest.raises(
WrongTaskOrderException,
match="1/3 tasks executed. You likely have some disconnected tasks or circular dependency.",
):
async for task_run_info in run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config=TaskConfig(needs=[number_generator, multiply_by_two])),
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one])),
],
data=10,
):
if isinstance(task_run_info, TaskExecutionInfo):
assert task_run_info.result == expected_results[index], (
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
)
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -1,46 +0,0 @@
import asyncio
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
async def run_and_check_tasks():
def number_generator(num, context=None):
for i in range(num):
yield i + 1
async def add_one(num, context=None):
yield num + 1
async def add_two(num, context=None):
yield num + 2
async def multiply_by_two(num1, num2, context=None):
yield num1 * 2
yield num2 * 2
index = 0
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 22, 24]
async for task_run_info in run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
Task(add_two, task_config=TaskConfig(needs=[number_generator])),
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one, add_two])),
],
data=10,
):
if isinstance(task_run_info, TaskExecutionInfo):
assert task_run_info.result == expected_results[index], (
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
)
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -1,41 +0,0 @@
import asyncio
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
async def run_and_check_tasks():
def number_generator(num, context=None):
for i in range(num):
yield i + 1
async def add_one(num, context=None):
yield num + 1
async def multiply_by_two(num, context=None):
yield num * 2
index = 0
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22]
async for task_run_info in run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one])),
],
data=10,
):
if isinstance(task_run_info, TaskExecutionInfo):
assert task_run_info.result == expected_results[index], (
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
)
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())
if __name__ == "__main__":
test_run_tasks()

View file

@ -0,0 +1,50 @@
import asyncio
import cognee
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
from cognee.modules.pipelines.tasks.Task import Task
from cognee.modules.users.methods import get_default_user
from cognee.infrastructure.databases.relational import create_db_and_tables
async def run_and_check_tasks():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
def number_generator(num):
for i in range(num):
yield i + 1
async def add_one(nums):
for num in nums:
yield num + 1
async def multiply_by_two(num):
yield num * 2
async def add_one_single(num):
yield num + 1
await create_db_and_tables()
user = await get_default_user()
pipeline = run_tasks_base(
[
Task(number_generator),
Task(add_one, task_config={"batch_size": 5}),
Task(multiply_by_two, task_config={"batch_size": 1}),
Task(add_one_single),
],
data=10,
user=user,
)
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
index = 0
async for result in pipeline:
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
index += 1
def test_run_tasks():
asyncio.run(run_and_check_tasks())

View file

@ -1,7 +1,6 @@
import os
import pathlib
import cognee
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
@ -77,12 +76,6 @@ async def main():
assert len(history) == 6, "Search history is not correct."
# await render_graph()
graph_file_path = os.path.join(
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
)
await visualize_graph(graph_file_path)
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -74,13 +74,13 @@
" get_repo_file_dependencies,\n",
")\n",
"from cognee.tasks.storage import add_data_points\n",
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"\n",
"detailed_extraction = True\n",
"\n",
"tasks = [\n",
" Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),\n",
" Task(add_data_points, task_config=TaskConfig(needs=[get_repo_file_dependencies], output_batch_size=100 if detailed_extraction else 500)),\n",
" Task(add_data_points, task_config={\"batch_size\": 100 if detailed_extraction else 500}),\n",
"]"
]
},

View file

@ -518,11 +518,11 @@
"from cognee.modules.data.models import Dataset, Data\n",
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
"from cognee.modules.cognify.config import get_cognify_config\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"from cognee.modules.pipelines import run_tasks\n",
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
"from cognee.modules.pipelines.operations.needs import merge_needs\n",
"from cognee.modules.users.models import User\n",
"from cognee.tasks.documents import (\n",
" check_permissions_on_documents,\n",
" classify_documents,\n",
" extract_chunks_from_documents,\n",
")\n",
@ -540,25 +540,19 @@
"\n",
" tasks = [\n",
" Task(classify_documents),\n",
" Task( # Extract text chunks based on the document type.\n",
" extract_chunks_from_documents,\n",
" max_chunk_size=get_max_chunk_tokens(),\n",
" task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),\n",
" ),\n",
" Task( # Generate knowledge graphs from the document chunks.\n",
" extract_graph_from_data,\n",
" graph_model=KnowledgeGraph,\n",
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
" ),\n",
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
" Task(\n",
" extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n",
" ), # Extract text chunks based on the document type.\n",
" Task(\n",
" extract_graph_from_data, graph_model=KnowledgeGraph, task_config={\"batch_size\": 10}\n",
" ), # Generate knowledge graphs from the document chunks.\n",
" Task(\n",
" summarize_text,\n",
" summarization_model=cognee_config.summarization_model,\n",
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
" ),\n",
" Task(\n",
" add_data_points,\n",
" task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),\n",
" task_config={\"batch_size\": 10},\n",
" ),\n",
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
" ]\n",
"\n",
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, \"cognify_pipeline\")\n",
@ -1047,7 +1041,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "py312",
"language": "python",
"name": "python3"
},
@ -1061,7 +1055,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.12.8"
}
},
"nbformat": 4,

View file

@ -397,15 +397,14 @@
"from cognee.modules.data.models import Dataset, Data\n",
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
"from cognee.modules.cognify.config import get_cognify_config\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"from cognee.modules.pipelines import run_tasks\n",
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
"from cognee.modules.pipelines.operations.needs import merge_needs\n",
"from cognee.modules.users.models import User\n",
"from cognee.tasks.documents import (\n",
" check_permissions_on_documents,\n",
" classify_documents,\n",
" extract_chunks_from_documents,\n",
")\n",
"from cognee.infrastructure.llm import get_max_chunk_tokens\n",
"from cognee.tasks.graph import extract_graph_from_data\n",
"from cognee.tasks.storage import add_data_points\n",
"from cognee.tasks.summarization import summarize_text\n",
@ -419,25 +418,17 @@
"\n",
" tasks = [\n",
" Task(classify_documents),\n",
" Task( # Extract text chunks based on the document type.\n",
" extract_chunks_from_documents,\n",
" max_chunk_size=get_max_chunk_tokens(),\n",
" task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),\n",
" ),\n",
" Task( # Generate knowledge graphs from the document chunks.\n",
" extract_graph_from_data,\n",
" graph_model=KnowledgeGraph,\n",
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
" ),\n",
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
" Task(\n",
" extract_graph_from_data, graph_model=KnowledgeGraph, task_config={\"batch_size\": 10}\n",
" ), # Generate knowledge graphs from the document chunks.\n",
" Task(\n",
" summarize_text,\n",
" summarization_model=cognee_config.summarization_model,\n",
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
" ),\n",
" Task(\n",
" add_data_points,\n",
" task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),\n",
" task_config={\"batch_size\": 10},\n",
" ),\n",
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
" ]\n",
"\n",
" pipeline = run_tasks(tasks, data_documents)\n",

File diff suppressed because one or more lines are too long