feat: pipeline tasks needs mapping (#690)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
567b45efa6
commit
0ce6fad24a
27 changed files with 803 additions and 618 deletions
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Union, BinaryIO
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.pipelines import run_tasks, Task
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks import TaskConfig, Task
|
||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
|
|
@ -36,7 +37,15 @@ async def add(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
|
||||
tasks = [
|
||||
Task(resolve_data_directories),
|
||||
Task(
|
||||
ingest_data,
|
||||
dataset_name,
|
||||
user,
|
||||
task_config=TaskConfig(needs=[resolve_data_directories]),
|
||||
),
|
||||
]
|
||||
|
||||
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
|
||||
pipeline = run_tasks(
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
import os
|
||||
import pathlib
|
||||
import asyncio
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.api.v1.search import SearchType, search
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines.tasks.Task import Task, TaskConfig
|
||||
from cognee.modules.pipelines.operations.needs import merge_needs
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
|
||||
from cognee.shared.utils import render_graph
|
||||
|
||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.ingestion import ingest_data
|
||||
|
|
@ -45,25 +46,46 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
|||
detailed_extraction = True
|
||||
|
||||
tasks = [
|
||||
Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),
|
||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
||||
Task(add_data_points, task_config={"batch_size": 500}),
|
||||
Task(
|
||||
get_repo_file_dependencies,
|
||||
detailed_extraction=detailed_extraction,
|
||||
task_config=TaskConfig(output_batch_size=500),
|
||||
),
|
||||
# Task(summarize_code, task_config=TaskConfig(output_batch_size=500)), # This task takes a long time to complete
|
||||
Task(add_data_points, task_config=TaskConfig(needs=[get_repo_file_dependencies])),
|
||||
]
|
||||
|
||||
if include_docs:
|
||||
# This tasks take a long time to complete
|
||||
non_code_tasks = [
|
||||
Task(get_non_py_files, task_config={"batch_size": 50}),
|
||||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
||||
Task(classify_documents),
|
||||
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
|
||||
Task(get_non_py_files),
|
||||
Task(
|
||||
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
|
||||
ingest_data,
|
||||
dataset_name="repo_docs",
|
||||
user=user,
|
||||
task_config=TaskConfig(needs=[get_non_py_files]),
|
||||
),
|
||||
Task(classify_documents, task_config=TaskConfig(needs=[ingest_data])),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=get_max_chunk_tokens(),
|
||||
task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),
|
||||
),
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
graph_model=KnowledgeGraph,
|
||||
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||
),
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 50},
|
||||
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||
),
|
||||
Task(
|
||||
add_data_points,
|
||||
task_config=TaskConfig(
|
||||
needs=[merge_needs(summarize_text, extract_graph_from_data)]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -71,11 +93,11 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
|||
|
||||
if include_docs:
|
||||
non_code_pipeline_run = run_tasks(non_code_tasks, dataset_id, repo_path, "cognify_pipeline")
|
||||
async for run_status in non_code_pipeline_run:
|
||||
yield run_status
|
||||
async for run_info in non_code_pipeline_run:
|
||||
yield run_info
|
||||
|
||||
async for run_status in run_tasks(tasks, dataset_id, repo_path, "cognify_code_pipeline"):
|
||||
yield run_status
|
||||
async for run_info in run_tasks(tasks, dataset_id, repo_path, "cognify_code_pipeline"):
|
||||
yield run_info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from cognee.modules.cognify.config import get_cognify_config
|
|||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.modules.pipelines import run_tasks
|
||||
from cognee.modules.pipelines import run_tasks, merge_needs
|
||||
from cognee.modules.pipelines.tasks import Task, TaskConfig
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
|
@ -92,7 +92,9 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, tasks: list[Task]):
|
|||
if not isinstance(task, Task):
|
||||
raise ValueError(f"Task {task} is not an instance of Task")
|
||||
|
||||
pipeline_run = run_tasks(tasks, dataset.id, data_documents, "cognify_pipeline")
|
||||
pipeline_run = run_tasks(
|
||||
tasks, dataset.id, data_documents, "cognify_pipeline", context={"user": user}
|
||||
)
|
||||
pipeline_run_status = None
|
||||
|
||||
async for run_status in pipeline_run:
|
||||
|
|
@ -121,24 +123,33 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
|||
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
||||
Task(
|
||||
check_permissions_on_documents,
|
||||
user=user,
|
||||
permissions=["write"],
|
||||
task_config=TaskConfig(needs=[classify_documents]),
|
||||
),
|
||||
Task( # Extract text chunks based on the document type.
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
chunker=chunker,
|
||||
), # Extract text chunks based on the document type.
|
||||
Task(
|
||||
task_config=TaskConfig(needs=[check_permissions_on_documents], output_batch_size=10),
|
||||
),
|
||||
Task( # Generate knowledge graphs from the document chunks.
|
||||
extract_graph_from_data,
|
||||
graph_model=graph_model,
|
||||
ontology_adapter=ontology_adapter,
|
||||
task_config={"batch_size": 10},
|
||||
), # Generate knowledge graphs from the document chunks.
|
||||
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||
),
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 10},
|
||||
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||
),
|
||||
Task(
|
||||
add_data_points,
|
||||
task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
]
|
||||
|
||||
return default_tasks
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from typing import List
|
|||
from pydantic import BaseModel
|
||||
|
||||
from cognee.modules.cognify.config import get_cognify_config
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.pipelines.operations.needs import merge_needs
|
||||
from cognee.modules.pipelines.tasks import Task, TaskConfig
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_documents,
|
||||
classify_documents,
|
||||
|
|
@ -27,25 +27,32 @@ async def get_cascade_graph_tasks(
|
|||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
try:
|
||||
cognee_config = get_cognify_config()
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
||||
), # Extract text chunks based on the document type.
|
||||
Task(
|
||||
extract_graph_from_data, task_config={"batch_size": 10}
|
||||
), # Generate knowledge graphs using cascade extraction
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config={"batch_size": 10},
|
||||
),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
]
|
||||
except Exception as error:
|
||||
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)
|
||||
raise error
|
||||
cognee_config = get_cognify_config()
|
||||
default_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(
|
||||
check_permissions_on_documents,
|
||||
user=user,
|
||||
permissions=["write"],
|
||||
task_config=TaskConfig(needs=[classify_documents]),
|
||||
),
|
||||
Task( # Extract text chunks based on the document type.
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_tokens=get_max_chunk_tokens(),
|
||||
task_config=TaskConfig(needs=[check_permissions_on_documents], output_batch_size=50),
|
||||
),
|
||||
Task(
|
||||
extract_graph_from_data,
|
||||
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||
), # Generate knowledge graphs using cascade extraction
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model=cognee_config.summarization_model,
|
||||
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||
),
|
||||
Task(
|
||||
add_data_points,
|
||||
task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),
|
||||
),
|
||||
]
|
||||
return default_tasks
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from .tasks.Task import Task
|
||||
from .operations.run_tasks import run_tasks
|
||||
from .operations.run_parallel import run_tasks_parallel
|
||||
from .operations.needs import merge_needs, MergeNeeds
|
||||
|
|
|
|||
18
cognee/modules/pipelines/exceptions.py
Normal file
18
cognee/modules/pipelines/exceptions.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
class WrongTaskOrderException(Exception):
|
||||
message: str
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TaskExecutionException(Exception):
|
||||
type: str
|
||||
message: str
|
||||
traceback: str
|
||||
|
||||
def __init__(self, type: str, message: str, traceback: str):
|
||||
self.message = message
|
||||
self.type = type
|
||||
self.traceback = traceback
|
||||
super().__init__(message)
|
||||
60
cognee/modules/pipelines/operations/needs.py
Normal file
60
cognee/modules/pipelines/operations/needs.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..tasks import Task
|
||||
|
||||
|
||||
class MergeNeeds(BaseModel):
|
||||
needs: list[Any]
|
||||
|
||||
|
||||
def merge_needs(*args):
|
||||
return MergeNeeds(needs=args)
|
||||
|
||||
|
||||
def get_task_needs(tasks: list[Task]):
|
||||
input_tasks = []
|
||||
|
||||
for task in tasks:
|
||||
if isinstance(task, MergeNeeds):
|
||||
input_tasks.extend(task.needs)
|
||||
else:
|
||||
input_tasks.append(task)
|
||||
|
||||
return input_tasks
|
||||
|
||||
|
||||
def get_need_task_results(results, task: Task):
|
||||
input_results = []
|
||||
|
||||
for task_dependency in task.task_config.needs:
|
||||
if isinstance(task_dependency, MergeNeeds):
|
||||
task_results = []
|
||||
max_result_length = 0
|
||||
|
||||
for task_need in task_dependency.needs:
|
||||
result = results[task_need]
|
||||
task_results.append(result)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
max_result_length = max(max_result_length, len(result))
|
||||
|
||||
final_results = [[] for _ in range(max_result_length)]
|
||||
|
||||
for result in task_results:
|
||||
if isinstance(result, tuple):
|
||||
for i, value in enumerate(result):
|
||||
final_results[i].extend(value)
|
||||
else:
|
||||
final_results[0].extend(result)
|
||||
|
||||
input_results.extend(final_results)
|
||||
else:
|
||||
result = results[task_dependency]
|
||||
|
||||
if isinstance(result, tuple):
|
||||
input_results.extend(result)
|
||||
else:
|
||||
input_results.append(result)
|
||||
|
||||
return input_results
|
||||
|
|
@ -1,227 +1,26 @@
|
|||
import inspect
|
||||
import json
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID, NAMESPACE_OID, uuid4, uuid5
|
||||
|
||||
from cognee.modules.pipelines.operations import (
|
||||
log_pipeline_run_start,
|
||||
log_pipeline_run_complete,
|
||||
log_pipeline_run_error,
|
||||
)
|
||||
from cognee.modules.settings import get_current_settings
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.settings import get_current_settings
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from ..tasks.Task import Task
|
||||
from ..tasks.Task import Task, TaskExecutionCompleted, TaskExecutionErrored, TaskExecutionStarted
|
||||
from .run_tasks_base import run_tasks_base
|
||||
|
||||
logger = get_logger("run_tasks(tasks: [Task], data)")
|
||||
|
||||
|
||||
async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
|
||||
if len(tasks) == 0:
|
||||
yield data
|
||||
return
|
||||
|
||||
args = [data] if data is not None else []
|
||||
|
||||
running_task = tasks[0]
|
||||
leftover_tasks = tasks[1:]
|
||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
|
||||
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||
|
||||
if inspect.isasyncgenfunction(running_task.executable):
|
||||
logger.info("Async generator task started: `%s`", running_task.executable.__name__)
|
||||
send_telemetry(
|
||||
"Async Generator Task Started",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
try:
|
||||
results = []
|
||||
|
||||
async_iterator = running_task.run(*args)
|
||||
|
||||
async for partial_result in async_iterator:
|
||||
results.append(partial_result)
|
||||
|
||||
if len(results) == next_task_batch_size:
|
||||
async for result in run_tasks_base(
|
||||
leftover_tasks,
|
||||
results[0] if next_task_batch_size == 1 else results,
|
||||
user=user,
|
||||
):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
||||
if len(results) > 0:
|
||||
async for result in run_tasks_base(leftover_tasks, results, user):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
||||
logger.info("Async generator task completed: `%s`", running_task.executable.__name__)
|
||||
send_telemetry(
|
||||
"Async Generator Task Completed",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Async generator task errored: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info=True,
|
||||
)
|
||||
send_telemetry(
|
||||
"Async Generator Task Errored",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
raise error
|
||||
|
||||
elif inspect.isgeneratorfunction(running_task.executable):
|
||||
logger.info("Generator task started: `%s`", running_task.executable.__name__)
|
||||
send_telemetry(
|
||||
"Generator Task Started",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
try:
|
||||
results = []
|
||||
|
||||
for partial_result in running_task.run(*args):
|
||||
results.append(partial_result)
|
||||
|
||||
if len(results) == next_task_batch_size:
|
||||
async for result in run_tasks_base(
|
||||
leftover_tasks, results[0] if next_task_batch_size == 1 else results, user
|
||||
):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
||||
if len(results) > 0:
|
||||
async for result in run_tasks_base(leftover_tasks, results, user):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
||||
logger.info("Generator task completed: `%s`", running_task.executable.__name__)
|
||||
send_telemetry(
|
||||
"Generator Task Completed",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Generator task errored: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info=True,
|
||||
)
|
||||
send_telemetry(
|
||||
"Generator Task Errored",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
raise error
|
||||
|
||||
elif inspect.iscoroutinefunction(running_task.executable):
|
||||
logger.info("Coroutine task started: `%s`", running_task.executable.__name__)
|
||||
send_telemetry(
|
||||
"Coroutine Task Started",
|
||||
user_id=user.id,
|
||||
additional_properties={
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
try:
|
||||
task_result = await running_task.run(*args)
|
||||
|
||||
async for result in run_tasks_base(leftover_tasks, task_result, user):
|
||||
yield result
|
||||
|
||||
logger.info("Coroutine task completed: `%s`", running_task.executable.__name__)
|
||||
send_telemetry(
|
||||
"Coroutine Task Completed",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Coroutine task errored: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info=True,
|
||||
)
|
||||
send_telemetry(
|
||||
"Coroutine Task Errored",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
raise error
|
||||
|
||||
elif inspect.isfunction(running_task.executable):
|
||||
logger.info("Function task started: `%s`", running_task.executable.__name__)
|
||||
send_telemetry(
|
||||
"Function Task Started",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
try:
|
||||
task_result = running_task.run(*args)
|
||||
|
||||
async for result in run_tasks_base(leftover_tasks, task_result, user):
|
||||
yield result
|
||||
|
||||
logger.info("Function task completed: `%s`", running_task.executable.__name__)
|
||||
send_telemetry(
|
||||
"Function Task Completed",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Function task errored: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info=True,
|
||||
)
|
||||
send_telemetry(
|
||||
"Function Task Errored",
|
||||
user.id,
|
||||
{
|
||||
"task_name": running_task.executable.__name__,
|
||||
},
|
||||
)
|
||||
raise error
|
||||
|
||||
|
||||
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
||||
async def run_tasks_with_telemetry(
|
||||
tasks: list[Task], data, pipeline_name: str, context: dict = None
|
||||
):
|
||||
config = get_current_settings()
|
||||
|
||||
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent=1))
|
||||
|
|
@ -239,8 +38,45 @@ async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
|||
| config,
|
||||
)
|
||||
|
||||
async for result in run_tasks_base(tasks, data, user):
|
||||
yield result
|
||||
async for run_task_info in run_tasks_base(tasks, data, context):
|
||||
if isinstance(run_task_info, TaskExecutionStarted):
|
||||
send_telemetry(
|
||||
"Task Run Started",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"task_name": run_task_info.task.__name__,
|
||||
}
|
||||
| config,
|
||||
)
|
||||
|
||||
if isinstance(run_task_info, TaskExecutionCompleted):
|
||||
send_telemetry(
|
||||
"Task Run Completed",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"task_name": run_task_info.task.__name__,
|
||||
}
|
||||
| config,
|
||||
)
|
||||
|
||||
if isinstance(run_task_info, TaskExecutionErrored):
|
||||
send_telemetry(
|
||||
"Task Run Errored",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"task_name": run_task_info.task.__name__,
|
||||
"error": str(run_task_info.error),
|
||||
}
|
||||
| config,
|
||||
)
|
||||
logger.error(
|
||||
"Task run errored: `%s`\n%s\n",
|
||||
run_task_info.task.__name__,
|
||||
str(run_task_info.error),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
yield run_task_info
|
||||
|
||||
logger.info("Pipeline run completed: `%s`", pipeline_name)
|
||||
send_telemetry(
|
||||
|
|
@ -271,19 +107,22 @@ async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
|||
|
||||
async def run_tasks(
|
||||
tasks: list[Task],
|
||||
dataset_id: UUID = uuid4(),
|
||||
dataset_id: UUID = None,
|
||||
data: Any = None,
|
||||
pipeline_name: str = "unknown_pipeline",
|
||||
context: dict = None,
|
||||
):
|
||||
dataset_id = dataset_id or uuid4()
|
||||
pipeline_id = uuid5(NAMESPACE_OID, pipeline_name)
|
||||
|
||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||
|
||||
yield pipeline_run
|
||||
|
||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||
|
||||
try:
|
||||
async for _ in run_tasks_with_telemetry(tasks, data, pipeline_id):
|
||||
async for _ in run_tasks_with_telemetry(tasks, data, pipeline_id, context):
|
||||
pass
|
||||
|
||||
yield await log_pipeline_run_complete(
|
||||
|
|
|
|||
63
cognee/modules/pipelines/operations/run_tasks_base.py
Normal file
63
cognee/modules/pipelines/operations/run_tasks_base.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from collections import deque
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from .needs import get_need_task_results, get_task_needs
|
||||
from ..tasks.Task import Task, TaskExecutionCompleted, TaskExecutionInfo
|
||||
from ..exceptions import WrongTaskOrderException
|
||||
|
||||
logger = get_logger("run_tasks_base(tasks: list[Task], data)")
|
||||
|
||||
|
||||
async def run_tasks_base(tasks: list[Task], data=None, context=None):
|
||||
if len(tasks) == 0:
|
||||
return
|
||||
|
||||
pipeline_input = [data] if data is not None else []
|
||||
|
||||
"""Run tasks in dependency order and return results."""
|
||||
task_graph = {} # Map task to its dependencies
|
||||
dependents = {} # Reverse dependencies (who depends on whom)
|
||||
results = {}
|
||||
number_of_executed_tasks = 0
|
||||
|
||||
tasks_map = {task.executable: task for task in tasks} # Map task executable to task object
|
||||
|
||||
# Build task dependency graph
|
||||
for task in tasks:
|
||||
task_graph[task.executable] = get_task_needs(task.task_config.needs)
|
||||
for dependent_task in task_graph[task.executable]:
|
||||
dependents.setdefault(dependent_task, []).append(task.executable)
|
||||
|
||||
# Find tasks without dependencies
|
||||
ready_queue = deque([task for task in tasks if not task_graph[task.executable]])
|
||||
|
||||
# Execute tasks in order
|
||||
while ready_queue:
|
||||
task = ready_queue.popleft()
|
||||
task_inputs = (
|
||||
get_need_task_results(results, task) if task.task_config.needs else pipeline_input
|
||||
)
|
||||
|
||||
async for task_execution_info in task.run(*task_inputs): # Run task and store result
|
||||
if isinstance(task_execution_info, TaskExecutionInfo): # Update result as it comes
|
||||
results[task.executable] = task_execution_info.result
|
||||
|
||||
if isinstance(task_execution_info, TaskExecutionCompleted):
|
||||
if task.executable not in results: # If result not already set, set it
|
||||
results[task.executable] = task_execution_info.result
|
||||
|
||||
number_of_executed_tasks += 1
|
||||
|
||||
yield task_execution_info
|
||||
|
||||
# Process tasks depending on this task
|
||||
for dependent_task in dependents.get(task.executable, []):
|
||||
task_graph[dependent_task].remove(task.executable) # Mark dependency as resolved
|
||||
if not task_graph[dependent_task]: # If all dependencies resolved, add to queue
|
||||
ready_queue.append(tasks_map[dependent_task])
|
||||
|
||||
if number_of_executed_tasks != len(tasks):
|
||||
raise WrongTaskOrderException(
|
||||
f"{number_of_executed_tasks}/{len(tasks)} tasks executed. You likely have some disconnected tasks or circular dependency."
|
||||
)
|
||||
|
|
@ -1,30 +1,136 @@
|
|||
from typing import Union, Callable, Any, Coroutine, Generator, AsyncGenerator
|
||||
import inspect
|
||||
from typing import Callable, Any, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..tasks.types import TaskExecutable
|
||||
from ..operations.needs import MergeNeeds
|
||||
from ..exceptions import TaskExecutionException
|
||||
|
||||
|
||||
class TaskExecutionStarted(BaseModel):
|
||||
task: Callable
|
||||
|
||||
|
||||
class TaskExecutionCompleted(BaseModel):
|
||||
task: Callable
|
||||
result: Any = None
|
||||
|
||||
|
||||
class TaskExecutionErrored(BaseModel):
|
||||
task: TaskExecutable
|
||||
error: TaskExecutionException
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
class TaskExecutionInfo(BaseModel):
|
||||
result: Any = None
|
||||
task: Callable
|
||||
|
||||
|
||||
class TaskConfig(BaseModel):
|
||||
output_batch_size: int = 1
|
||||
needs: list[Union[Callable, MergeNeeds]] = []
|
||||
|
||||
|
||||
class Task:
|
||||
executable: Union[
|
||||
Callable[..., Any],
|
||||
Callable[..., Coroutine[Any, Any, Any]],
|
||||
Generator[Any, Any, Any],
|
||||
AsyncGenerator[Any, Any],
|
||||
]
|
||||
task_config: dict[str, Any] = {
|
||||
"batch_size": 1,
|
||||
}
|
||||
task_config: TaskConfig
|
||||
default_params: dict[str, Any] = {}
|
||||
|
||||
def __init__(self, executable, *args, task_config=None, **kwargs):
|
||||
def __init__(self, executable, *args, task_config: TaskConfig = None, **kwargs):
|
||||
self.executable = executable
|
||||
self.default_params = {"args": args, "kwargs": kwargs}
|
||||
self.result = None
|
||||
|
||||
if task_config is not None:
|
||||
self.task_config = task_config
|
||||
self.task_config = task_config or TaskConfig()
|
||||
|
||||
if "batch_size" not in task_config:
|
||||
self.task_config["batch_size"] = 1
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
async def run(self, *args, **kwargs):
|
||||
combined_args = args + self.default_params["args"]
|
||||
combined_kwargs = {**self.default_params["kwargs"], **kwargs}
|
||||
combined_kwargs = {
|
||||
**self.default_params["kwargs"],
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return self.executable(*combined_args, **combined_kwargs)
|
||||
yield TaskExecutionStarted(
|
||||
task=self.executable,
|
||||
)
|
||||
|
||||
try:
|
||||
if inspect.iscoroutinefunction(self.executable): # Async function
|
||||
end_result = await self.executable(*combined_args, **combined_kwargs)
|
||||
|
||||
elif inspect.isgeneratorfunction(self.executable): # Generator
|
||||
task_result = []
|
||||
end_result = []
|
||||
|
||||
for value in self.executable(*combined_args, **combined_kwargs):
|
||||
task_result.append(value) # Store the last yielded value
|
||||
end_result.append(value)
|
||||
|
||||
if self.task_config.output_batch_size == 1:
|
||||
yield TaskExecutionInfo(
|
||||
result=value,
|
||||
task=self.executable,
|
||||
)
|
||||
elif self.task_config.output_batch_size == len(task_result):
|
||||
yield TaskExecutionInfo(
|
||||
result=task_result,
|
||||
task=self.executable,
|
||||
)
|
||||
task_result = [] # Reset for the next batch
|
||||
|
||||
# Yield any remaining items in the final batch if it's not empty
|
||||
if task_result and self.task_config.output_batch_size > 1:
|
||||
yield TaskExecutionInfo(
|
||||
result=task_result,
|
||||
task=self.executable,
|
||||
)
|
||||
|
||||
elif inspect.isasyncgenfunction(self.executable): # Async Generator
|
||||
task_result = []
|
||||
end_result = []
|
||||
|
||||
async for value in self.executable(*combined_args, **combined_kwargs):
|
||||
task_result.append(value) # Store the last yielded value
|
||||
end_result.append(value)
|
||||
|
||||
if self.task_config.output_batch_size == 1:
|
||||
yield TaskExecutionInfo(
|
||||
result=value,
|
||||
task=self.executable,
|
||||
)
|
||||
elif self.task_config.output_batch_size == len(task_result):
|
||||
yield TaskExecutionInfo(
|
||||
result=task_result,
|
||||
task=self.executable,
|
||||
)
|
||||
task_result = [] # Reset for the next batch
|
||||
|
||||
# Yield any remaining items in the final batch if it's not empty
|
||||
if task_result and self.task_config.output_batch_size > 1:
|
||||
yield TaskExecutionInfo(
|
||||
result=task_result,
|
||||
task=self.executable,
|
||||
)
|
||||
else: # Regular function
|
||||
end_result = self.executable(*combined_args, **combined_kwargs)
|
||||
|
||||
yield TaskExecutionCompleted(
|
||||
task=self.executable,
|
||||
result=end_result,
|
||||
)
|
||||
|
||||
except Exception as error:
|
||||
import traceback
|
||||
|
||||
error_details = TaskExecutionException(
|
||||
type=type(error).__name__,
|
||||
message=str(error),
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
|
||||
yield TaskExecutionErrored(
|
||||
task=self.executable,
|
||||
error=error_details,
|
||||
)
|
||||
|
|
|
|||
8
cognee/modules/pipelines/tasks/__init__.py
Normal file
8
cognee/modules/pipelines/tasks/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from .Task import (
|
||||
Task,
|
||||
TaskConfig,
|
||||
TaskExecutionInfo,
|
||||
TaskExecutionCompleted,
|
||||
TaskExecutionStarted,
|
||||
TaskExecutionErrored,
|
||||
)
|
||||
9
cognee/modules/pipelines/tasks/types.py
Normal file
9
cognee/modules/pipelines/tasks/types.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from typing import Any, AsyncGenerator, Callable, Coroutine, Generator, Union
|
||||
|
||||
|
||||
TaskExecutable = Union[
|
||||
Callable[..., Any],
|
||||
Callable[..., Coroutine[Any, Any, Any]],
|
||||
AsyncGenerator[Any, Any],
|
||||
Generator[Any, Any, Any],
|
||||
]
|
||||
|
|
@ -12,7 +12,6 @@ from cognee.modules.graph.utils import (
|
|||
retrieve_existing_edges,
|
||||
)
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
|
||||
async def integrate_chunk_graphs(
|
||||
|
|
@ -28,7 +27,6 @@ async def integrate_chunk_graphs(
|
|||
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
||||
data_chunks[chunk_index].contains = chunk_graph
|
||||
|
||||
await add_data_points(chunk_graphs)
|
||||
return data_chunks
|
||||
|
||||
existing_edges_map = await retrieve_existing_edges(
|
||||
|
|
@ -41,13 +39,7 @@ async def integrate_chunk_graphs(
|
|||
data_chunks, chunk_graphs, ontology_adapter, existing_edges_map
|
||||
)
|
||||
|
||||
if len(graph_nodes) > 0:
|
||||
await add_data_points(graph_nodes)
|
||||
|
||||
if len(graph_edges) > 0:
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
|
||||
return data_chunks
|
||||
return graph_nodes, graph_edges
|
||||
|
||||
|
||||
async def extract_graph_from_data(
|
||||
|
|
|
|||
|
|
@ -41,4 +41,5 @@ async def resolve_data_directories(
|
|||
resolved_data.append(item)
|
||||
else: # If it's not a string add it directly
|
||||
resolved_data.append(item)
|
||||
|
||||
return resolved_data
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from .index_data_points import index_data_points
|
|||
from .index_graph_edges import index_graph_edges
|
||||
|
||||
|
||||
async def add_data_points(data_points: list[DataPoint]):
|
||||
async def add_data_points(data_points: list[DataPoint], data_point_connections: list = None):
|
||||
data_point_connections = data_point_connections or []
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
|
|
@ -38,6 +39,8 @@ async def add_data_points(data_points: list[DataPoint]):
|
|||
|
||||
await graph_engine.add_nodes(nodes)
|
||||
await graph_engine.add_edges(edges)
|
||||
if data_point_connections:
|
||||
await graph_engine.add_edges(data_point_connections)
|
||||
|
||||
# This step has to happen after adding nodes and edges because we query the graph.
|
||||
await index_graph_edges()
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ import asyncio
|
|||
from typing import Type
|
||||
from uuid import uuid5
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.extraction.extract_summary import extract_summary
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from .models import TextSummary
|
||||
|
||||
|
||||
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
|
||||
async def summarize_text(data_chunks: list[DataPoint], summarization_model: Type[BaseModel]):
|
||||
if len(data_chunks) == 0:
|
||||
return data_chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -1,66 +0,0 @@
|
|||
import asyncio
|
||||
from queue import Queue
|
||||
|
||||
import cognee
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
||||
async def pipeline(data_queue):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
async def queue_consumer():
|
||||
while not data_queue.is_closed:
|
||||
if not data_queue.empty():
|
||||
yield data_queue.get()
|
||||
else:
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
async def add_one(num):
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(num):
|
||||
yield num * 2
|
||||
|
||||
await create_db_and_tables()
|
||||
user = await get_default_user()
|
||||
|
||||
tasks_run = run_tasks_base(
|
||||
[
|
||||
Task(queue_consumer),
|
||||
Task(add_one),
|
||||
Task(multiply_by_two),
|
||||
],
|
||||
data=None,
|
||||
user=user,
|
||||
)
|
||||
|
||||
results = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
|
||||
index = 0
|
||||
async for result in tasks_run:
|
||||
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
|
||||
index += 1
|
||||
|
||||
|
||||
async def run_queue():
|
||||
data_queue = Queue()
|
||||
data_queue.is_closed = False
|
||||
|
||||
async def queue_producer():
|
||||
for i in range(0, 10):
|
||||
data_queue.put(i)
|
||||
await asyncio.sleep(0.1)
|
||||
data_queue.is_closed = True
|
||||
|
||||
await asyncio.gather(pipeline(data_queue), queue_producer())
|
||||
|
||||
|
||||
def test_run_tasks_from_queue():
|
||||
asyncio.run(run_queue())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_queue())
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
import asyncio
|
||||
|
||||
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
|
||||
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||
|
||||
|
||||
async def run_and_check_tasks():
|
||||
def number_generator(num, context=None):
|
||||
for i in range(num):
|
||||
yield i + 1
|
||||
|
||||
async def add_one(num, context=None):
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(num, context=None):
|
||||
yield num * 2
|
||||
|
||||
index = 0
|
||||
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20]
|
||||
|
||||
async for task_run_info in run_tasks_base(
|
||||
[
|
||||
Task(number_generator),
|
||||
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
|
||||
Task(multiply_by_two, task_config=TaskConfig(needs=[number_generator])),
|
||||
],
|
||||
data=10,
|
||||
):
|
||||
if isinstance(task_run_info, TaskExecutionInfo):
|
||||
assert task_run_info.result == expected_results[index], (
|
||||
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
|
||||
)
|
||||
index += 1
|
||||
|
||||
|
||||
def test_run_tasks():
|
||||
asyncio.run(run_and_check_tasks())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_run_tasks()
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
|
||||
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
|
||||
from cognee.modules.pipelines.exceptions import WrongTaskOrderException
|
||||
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||
|
||||
|
||||
async def run_and_check_tasks():
|
||||
def number_generator(num, context=None):
|
||||
for i in range(num):
|
||||
yield i + 1
|
||||
|
||||
async def add_one(num, context=None):
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(num, context=None):
|
||||
yield num * 2
|
||||
|
||||
index = 0
|
||||
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22]
|
||||
|
||||
with pytest.raises(
|
||||
WrongTaskOrderException,
|
||||
match="1/3 tasks executed. You likely have some disconnected tasks or circular dependency.",
|
||||
):
|
||||
async for task_run_info in run_tasks_base(
|
||||
[
|
||||
Task(number_generator),
|
||||
Task(add_one, task_config=TaskConfig(needs=[number_generator, multiply_by_two])),
|
||||
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one])),
|
||||
],
|
||||
data=10,
|
||||
):
|
||||
if isinstance(task_run_info, TaskExecutionInfo):
|
||||
assert task_run_info.result == expected_results[index], (
|
||||
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
|
||||
)
|
||||
index += 1
|
||||
|
||||
|
||||
def test_run_tasks():
|
||||
asyncio.run(run_and_check_tasks())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_run_tasks()
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import asyncio
|
||||
|
||||
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
|
||||
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||
|
||||
|
||||
async def run_and_check_tasks():
|
||||
def number_generator(num, context=None):
|
||||
for i in range(num):
|
||||
yield i + 1
|
||||
|
||||
async def add_one(num, context=None):
|
||||
yield num + 1
|
||||
|
||||
async def add_two(num, context=None):
|
||||
yield num + 2
|
||||
|
||||
async def multiply_by_two(num1, num2, context=None):
|
||||
yield num1 * 2
|
||||
yield num2 * 2
|
||||
|
||||
index = 0
|
||||
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 22, 24]
|
||||
|
||||
async for task_run_info in run_tasks_base(
|
||||
[
|
||||
Task(number_generator),
|
||||
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
|
||||
Task(add_two, task_config=TaskConfig(needs=[number_generator])),
|
||||
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one, add_two])),
|
||||
],
|
||||
data=10,
|
||||
):
|
||||
if isinstance(task_run_info, TaskExecutionInfo):
|
||||
assert task_run_info.result == expected_results[index], (
|
||||
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
|
||||
)
|
||||
index += 1
|
||||
|
||||
|
||||
def test_run_tasks():
|
||||
asyncio.run(run_and_check_tasks())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_run_tasks()
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
import asyncio
|
||||
|
||||
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
|
||||
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||
|
||||
|
||||
async def run_and_check_tasks():
|
||||
def number_generator(num, context=None):
|
||||
for i in range(num):
|
||||
yield i + 1
|
||||
|
||||
async def add_one(num, context=None):
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(num, context=None):
|
||||
yield num * 2
|
||||
|
||||
index = 0
|
||||
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22]
|
||||
|
||||
async for task_run_info in run_tasks_base(
|
||||
[
|
||||
Task(number_generator),
|
||||
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
|
||||
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one])),
|
||||
],
|
||||
data=10,
|
||||
):
|
||||
if isinstance(task_run_info, TaskExecutionInfo):
|
||||
assert task_run_info.result == expected_results[index], (
|
||||
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
|
||||
)
|
||||
index += 1
|
||||
|
||||
|
||||
def test_run_tasks():
|
||||
asyncio.run(run_and_check_tasks())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_run_tasks()
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
import cognee
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
||||
async def run_and_check_tasks():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
def number_generator(num):
|
||||
for i in range(num):
|
||||
yield i + 1
|
||||
|
||||
async def add_one(nums):
|
||||
for num in nums:
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(num):
|
||||
yield num * 2
|
||||
|
||||
async def add_one_single(num):
|
||||
yield num + 1
|
||||
|
||||
await create_db_and_tables()
|
||||
user = await get_default_user()
|
||||
|
||||
pipeline = run_tasks_base(
|
||||
[
|
||||
Task(number_generator),
|
||||
Task(add_one, task_config={"batch_size": 5}),
|
||||
Task(multiply_by_two, task_config={"batch_size": 1}),
|
||||
Task(add_one_single),
|
||||
],
|
||||
data=10,
|
||||
user=user,
|
||||
)
|
||||
|
||||
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
|
||||
index = 0
|
||||
async for result in pipeline:
|
||||
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
|
||||
index += 1
|
||||
|
||||
|
||||
def test_run_tasks():
|
||||
asyncio.run(run_and_check_tasks())
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -76,6 +77,12 @@ async def main():
|
|||
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
# await render_graph()
|
||||
graph_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
||||
)
|
||||
await visualize_graph(graph_file_path)
|
||||
|
||||
# Assert local data files are cleaned properly
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
|
|
|||
|
|
@ -74,13 +74,13 @@
|
|||
" get_repo_file_dependencies,\n",
|
||||
")\n",
|
||||
"from cognee.tasks.storage import add_data_points\n",
|
||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
||||
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
|
||||
"\n",
|
||||
"detailed_extraction = True\n",
|
||||
"\n",
|
||||
"tasks = [\n",
|
||||
" Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),\n",
|
||||
" Task(add_data_points, task_config={\"batch_size\": 100 if detailed_extraction else 500}),\n",
|
||||
" Task(add_data_points, task_config=TaskConfig(needs=[get_repo_file_dependencies], output_batch_size=100 if detailed_extraction else 500)),\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
|
|
|
|||
|
|
@ -518,11 +518,11 @@
|
|||
"from cognee.modules.data.models import Dataset, Data\n",
|
||||
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
||||
"from cognee.modules.cognify.config import get_cognify_config\n",
|
||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
||||
"from cognee.modules.pipelines import run_tasks\n",
|
||||
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
|
||||
"from cognee.modules.pipelines.operations.needs import merge_needs\n",
|
||||
"from cognee.modules.users.models import User\n",
|
||||
"from cognee.tasks.documents import (\n",
|
||||
" check_permissions_on_documents,\n",
|
||||
" classify_documents,\n",
|
||||
" extract_chunks_from_documents,\n",
|
||||
")\n",
|
||||
|
|
@ -540,19 +540,25 @@
|
|||
"\n",
|
||||
" tasks = [\n",
|
||||
" Task(classify_documents),\n",
|
||||
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
|
||||
" Task(\n",
|
||||
" extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n",
|
||||
" ), # Extract text chunks based on the document type.\n",
|
||||
" Task(\n",
|
||||
" extract_graph_from_data, graph_model=KnowledgeGraph, task_config={\"batch_size\": 10}\n",
|
||||
" ), # Generate knowledge graphs from the document chunks.\n",
|
||||
" Task( # Extract text chunks based on the document type.\n",
|
||||
" extract_chunks_from_documents,\n",
|
||||
" max_chunk_size=get_max_chunk_tokens(),\n",
|
||||
" task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),\n",
|
||||
" ),\n",
|
||||
" Task( # Generate knowledge graphs from the document chunks.\n",
|
||||
" extract_graph_from_data,\n",
|
||||
" graph_model=KnowledgeGraph,\n",
|
||||
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
|
||||
" ),\n",
|
||||
" Task(\n",
|
||||
" summarize_text,\n",
|
||||
" summarization_model=cognee_config.summarization_model,\n",
|
||||
" task_config={\"batch_size\": 10},\n",
|
||||
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
|
||||
" ),\n",
|
||||
" Task(\n",
|
||||
" add_data_points,\n",
|
||||
" task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),\n",
|
||||
" ),\n",
|
||||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, \"cognify_pipeline\")\n",
|
||||
|
|
@ -1041,7 +1047,7 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "py312",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
|
@ -1055,7 +1061,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.8"
|
||||
"version": "3.11.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
|
|
@ -397,14 +397,15 @@
|
|||
"from cognee.modules.data.models import Dataset, Data\n",
|
||||
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
||||
"from cognee.modules.cognify.config import get_cognify_config\n",
|
||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
||||
"from cognee.modules.pipelines import run_tasks\n",
|
||||
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
|
||||
"from cognee.modules.pipelines.operations.needs import merge_needs\n",
|
||||
"from cognee.modules.users.models import User\n",
|
||||
"from cognee.tasks.documents import (\n",
|
||||
" check_permissions_on_documents,\n",
|
||||
" classify_documents,\n",
|
||||
" extract_chunks_from_documents,\n",
|
||||
")\n",
|
||||
"from cognee.infrastructure.llm import get_max_chunk_tokens\n",
|
||||
"from cognee.tasks.graph import extract_graph_from_data\n",
|
||||
"from cognee.tasks.storage import add_data_points\n",
|
||||
"from cognee.tasks.summarization import summarize_text\n",
|
||||
|
|
@ -418,17 +419,25 @@
|
|||
"\n",
|
||||
" tasks = [\n",
|
||||
" Task(classify_documents),\n",
|
||||
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
|
||||
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
|
||||
" Task(\n",
|
||||
" extract_graph_from_data, graph_model=KnowledgeGraph, task_config={\"batch_size\": 10}\n",
|
||||
" ), # Generate knowledge graphs from the document chunks.\n",
|
||||
" Task( # Extract text chunks based on the document type.\n",
|
||||
" extract_chunks_from_documents,\n",
|
||||
" max_chunk_size=get_max_chunk_tokens(),\n",
|
||||
" task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),\n",
|
||||
" ),\n",
|
||||
" Task( # Generate knowledge graphs from the document chunks.\n",
|
||||
" extract_graph_from_data,\n",
|
||||
" graph_model=KnowledgeGraph,\n",
|
||||
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
|
||||
" ),\n",
|
||||
" Task(\n",
|
||||
" summarize_text,\n",
|
||||
" summarization_model=cognee_config.summarization_model,\n",
|
||||
" task_config={\"batch_size\": 10},\n",
|
||||
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
|
||||
" ),\n",
|
||||
" Task(\n",
|
||||
" add_data_points,\n",
|
||||
" task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),\n",
|
||||
" ),\n",
|
||||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" pipeline = run_tasks(tasks, data_documents)\n",
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Add table
Reference in a new issue