feat: pipeline tasks needs mapping (#690)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
567b45efa6
commit
0ce6fad24a
27 changed files with 803 additions and 618 deletions
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Union, BinaryIO
|
from typing import Union, BinaryIO
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.pipelines import run_tasks, Task
|
from cognee.modules.pipelines import run_tasks
|
||||||
|
from cognee.modules.pipelines.tasks import TaskConfig, Task
|
||||||
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
||||||
from cognee.infrastructure.databases.relational import (
|
from cognee.infrastructure.databases.relational import (
|
||||||
create_db_and_tables as create_relational_db_and_tables,
|
create_db_and_tables as create_relational_db_and_tables,
|
||||||
|
|
@ -36,7 +37,15 @@ async def add(
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
tasks = [Task(resolve_data_directories), Task(ingest_data, dataset_name, user)]
|
tasks = [
|
||||||
|
Task(resolve_data_directories),
|
||||||
|
Task(
|
||||||
|
ingest_data,
|
||||||
|
dataset_name,
|
||||||
|
user,
|
||||||
|
task_config=TaskConfig(needs=[resolve_data_directories]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
|
dataset_id = uuid5(NAMESPACE_OID, dataset_name)
|
||||||
pipeline = run_tasks(
|
pipeline = run_tasks(
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,19 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import asyncio
|
import asyncio
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
|
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.api.v1.search import SearchType, search
|
from cognee.api.v1.search import SearchType, search
|
||||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.modules.pipelines import run_tasks
|
from cognee.modules.pipelines import run_tasks
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
from cognee.modules.pipelines.tasks.Task import Task, TaskConfig
|
||||||
|
from cognee.modules.pipelines.operations.needs import merge_needs
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
|
from cognee.shared.data_models import KnowledgeGraph, MonitoringTool
|
||||||
from cognee.shared.utils import render_graph
|
|
||||||
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
from cognee.tasks.documents import classify_documents, extract_chunks_from_documents
|
||||||
from cognee.tasks.graph import extract_graph_from_data
|
from cognee.tasks.graph import extract_graph_from_data
|
||||||
from cognee.tasks.ingestion import ingest_data
|
from cognee.tasks.ingestion import ingest_data
|
||||||
|
|
@ -45,25 +46,46 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
||||||
detailed_extraction = True
|
detailed_extraction = True
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),
|
Task(
|
||||||
# Task(summarize_code, task_config={"batch_size": 500}), # This task takes a long time to complete
|
get_repo_file_dependencies,
|
||||||
Task(add_data_points, task_config={"batch_size": 500}),
|
detailed_extraction=detailed_extraction,
|
||||||
|
task_config=TaskConfig(output_batch_size=500),
|
||||||
|
),
|
||||||
|
# Task(summarize_code, task_config=TaskConfig(output_batch_size=500)), # This task takes a long time to complete
|
||||||
|
Task(add_data_points, task_config=TaskConfig(needs=[get_repo_file_dependencies])),
|
||||||
]
|
]
|
||||||
|
|
||||||
if include_docs:
|
if include_docs:
|
||||||
# This tasks take a long time to complete
|
# This tasks take a long time to complete
|
||||||
non_code_tasks = [
|
non_code_tasks = [
|
||||||
Task(get_non_py_files, task_config={"batch_size": 50}),
|
Task(get_non_py_files),
|
||||||
Task(ingest_data, dataset_name="repo_docs", user=user),
|
|
||||||
Task(classify_documents),
|
|
||||||
Task(extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()),
|
|
||||||
Task(
|
Task(
|
||||||
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
|
ingest_data,
|
||||||
|
dataset_name="repo_docs",
|
||||||
|
user=user,
|
||||||
|
task_config=TaskConfig(needs=[get_non_py_files]),
|
||||||
|
),
|
||||||
|
Task(classify_documents, task_config=TaskConfig(needs=[ingest_data])),
|
||||||
|
Task(
|
||||||
|
extract_chunks_from_documents,
|
||||||
|
max_chunk_size=get_max_chunk_tokens(),
|
||||||
|
task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),
|
||||||
|
),
|
||||||
|
Task(
|
||||||
|
extract_graph_from_data,
|
||||||
|
graph_model=KnowledgeGraph,
|
||||||
|
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||||
),
|
),
|
||||||
Task(
|
Task(
|
||||||
summarize_text,
|
summarize_text,
|
||||||
summarization_model=cognee_config.summarization_model,
|
summarization_model=cognee_config.summarization_model,
|
||||||
task_config={"batch_size": 50},
|
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||||
|
),
|
||||||
|
Task(
|
||||||
|
add_data_points,
|
||||||
|
task_config=TaskConfig(
|
||||||
|
needs=[merge_needs(summarize_text, extract_graph_from_data)]
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -71,11 +93,11 @@ async def run_code_graph_pipeline(repo_path, include_docs=False):
|
||||||
|
|
||||||
if include_docs:
|
if include_docs:
|
||||||
non_code_pipeline_run = run_tasks(non_code_tasks, dataset_id, repo_path, "cognify_pipeline")
|
non_code_pipeline_run = run_tasks(non_code_tasks, dataset_id, repo_path, "cognify_pipeline")
|
||||||
async for run_status in non_code_pipeline_run:
|
async for run_info in non_code_pipeline_run:
|
||||||
yield run_status
|
yield run_info
|
||||||
|
|
||||||
async for run_status in run_tasks(tasks, dataset_id, repo_path, "cognify_code_pipeline"):
|
async for run_info in run_tasks(tasks, dataset_id, repo_path, "cognify_code_pipeline"):
|
||||||
yield run_status
|
yield run_info
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
from cognee.modules.data.methods import get_datasets, get_datasets_by_name
|
||||||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||||
from cognee.modules.data.models import Data, Dataset
|
from cognee.modules.data.models import Data, Dataset
|
||||||
from cognee.modules.pipelines import run_tasks
|
from cognee.modules.pipelines import run_tasks, merge_needs
|
||||||
|
from cognee.modules.pipelines.tasks import Task, TaskConfig
|
||||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
|
|
@ -92,7 +92,9 @@ async def run_cognify_pipeline(dataset: Dataset, user: User, tasks: list[Task]):
|
||||||
if not isinstance(task, Task):
|
if not isinstance(task, Task):
|
||||||
raise ValueError(f"Task {task} is not an instance of Task")
|
raise ValueError(f"Task {task} is not an instance of Task")
|
||||||
|
|
||||||
pipeline_run = run_tasks(tasks, dataset.id, data_documents, "cognify_pipeline")
|
pipeline_run = run_tasks(
|
||||||
|
tasks, dataset.id, data_documents, "cognify_pipeline", context={"user": user}
|
||||||
|
)
|
||||||
pipeline_run_status = None
|
pipeline_run_status = None
|
||||||
|
|
||||||
async for run_status in pipeline_run:
|
async for run_status in pipeline_run:
|
||||||
|
|
@ -121,24 +123,33 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
|
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
|
||||||
Task(
|
Task(
|
||||||
|
check_permissions_on_documents,
|
||||||
|
user=user,
|
||||||
|
permissions=["write"],
|
||||||
|
task_config=TaskConfig(needs=[classify_documents]),
|
||||||
|
),
|
||||||
|
Task( # Extract text chunks based on the document type.
|
||||||
extract_chunks_from_documents,
|
extract_chunks_from_documents,
|
||||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||||
chunker=chunker,
|
chunker=chunker,
|
||||||
), # Extract text chunks based on the document type.
|
task_config=TaskConfig(needs=[check_permissions_on_documents], output_batch_size=10),
|
||||||
Task(
|
),
|
||||||
|
Task( # Generate knowledge graphs from the document chunks.
|
||||||
extract_graph_from_data,
|
extract_graph_from_data,
|
||||||
graph_model=graph_model,
|
graph_model=graph_model,
|
||||||
ontology_adapter=ontology_adapter,
|
ontology_adapter=ontology_adapter,
|
||||||
task_config={"batch_size": 10},
|
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||||
), # Generate knowledge graphs from the document chunks.
|
),
|
||||||
Task(
|
Task(
|
||||||
summarize_text,
|
summarize_text,
|
||||||
summarization_model=cognee_config.summarization_model,
|
summarization_model=cognee_config.summarization_model,
|
||||||
task_config={"batch_size": 10},
|
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||||
|
),
|
||||||
|
Task(
|
||||||
|
add_data_points,
|
||||||
|
task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),
|
||||||
),
|
),
|
||||||
Task(add_data_points, task_config={"batch_size": 10}),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return default_tasks
|
return default_tasks
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,11 @@ from typing import List
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.modules.cognify.config import get_cognify_config
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
from cognee.modules.pipelines.operations.needs import merge_needs
|
||||||
|
from cognee.modules.pipelines.tasks import Task, TaskConfig
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.shared.utils import send_telemetry
|
|
||||||
from cognee.tasks.documents import (
|
from cognee.tasks.documents import (
|
||||||
check_permissions_on_documents,
|
check_permissions_on_documents,
|
||||||
classify_documents,
|
classify_documents,
|
||||||
|
|
@ -27,25 +27,32 @@ async def get_cascade_graph_tasks(
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
try:
|
|
||||||
cognee_config = get_cognify_config()
|
cognee_config = get_cognify_config()
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(check_permissions_on_documents, user=user, permissions=["write"]),
|
|
||||||
Task(
|
Task(
|
||||||
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
|
check_permissions_on_documents,
|
||||||
), # Extract text chunks based on the document type.
|
user=user,
|
||||||
|
permissions=["write"],
|
||||||
|
task_config=TaskConfig(needs=[classify_documents]),
|
||||||
|
),
|
||||||
|
Task( # Extract text chunks based on the document type.
|
||||||
|
extract_chunks_from_documents,
|
||||||
|
max_chunk_tokens=get_max_chunk_tokens(),
|
||||||
|
task_config=TaskConfig(needs=[check_permissions_on_documents], output_batch_size=50),
|
||||||
|
),
|
||||||
Task(
|
Task(
|
||||||
extract_graph_from_data, task_config={"batch_size": 10}
|
extract_graph_from_data,
|
||||||
|
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||||
), # Generate knowledge graphs using cascade extraction
|
), # Generate knowledge graphs using cascade extraction
|
||||||
Task(
|
Task(
|
||||||
summarize_text,
|
summarize_text,
|
||||||
summarization_model=cognee_config.summarization_model,
|
summarization_model=cognee_config.summarization_model,
|
||||||
task_config={"batch_size": 10},
|
task_config=TaskConfig(needs=[extract_chunks_from_documents]),
|
||||||
|
),
|
||||||
|
Task(
|
||||||
|
add_data_points,
|
||||||
|
task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),
|
||||||
),
|
),
|
||||||
Task(add_data_points, task_config={"batch_size": 10}),
|
|
||||||
]
|
]
|
||||||
except Exception as error:
|
|
||||||
send_telemetry("cognee.cognify DEFAULT TASKS CREATION ERRORED", user.id)
|
|
||||||
raise error
|
|
||||||
return default_tasks
|
return default_tasks
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
from .tasks.Task import Task
|
from .tasks.Task import Task
|
||||||
from .operations.run_tasks import run_tasks
|
from .operations.run_tasks import run_tasks
|
||||||
from .operations.run_parallel import run_tasks_parallel
|
from .operations.needs import merge_needs, MergeNeeds
|
||||||
|
|
|
||||||
18
cognee/modules/pipelines/exceptions.py
Normal file
18
cognee/modules/pipelines/exceptions.py
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
class WrongTaskOrderException(Exception):
|
||||||
|
message: str
|
||||||
|
|
||||||
|
def __init__(self, message: str):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutionException(Exception):
|
||||||
|
type: str
|
||||||
|
message: str
|
||||||
|
traceback: str
|
||||||
|
|
||||||
|
def __init__(self, type: str, message: str, traceback: str):
|
||||||
|
self.message = message
|
||||||
|
self.type = type
|
||||||
|
self.traceback = traceback
|
||||||
|
super().__init__(message)
|
||||||
60
cognee/modules/pipelines/operations/needs.py
Normal file
60
cognee/modules/pipelines/operations/needs.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
from typing import Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..tasks import Task
|
||||||
|
|
||||||
|
|
||||||
|
class MergeNeeds(BaseModel):
|
||||||
|
needs: list[Any]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_needs(*args):
|
||||||
|
return MergeNeeds(needs=args)
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_needs(tasks: list[Task]):
|
||||||
|
input_tasks = []
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
if isinstance(task, MergeNeeds):
|
||||||
|
input_tasks.extend(task.needs)
|
||||||
|
else:
|
||||||
|
input_tasks.append(task)
|
||||||
|
|
||||||
|
return input_tasks
|
||||||
|
|
||||||
|
|
||||||
|
def get_need_task_results(results, task: Task):
|
||||||
|
input_results = []
|
||||||
|
|
||||||
|
for task_dependency in task.task_config.needs:
|
||||||
|
if isinstance(task_dependency, MergeNeeds):
|
||||||
|
task_results = []
|
||||||
|
max_result_length = 0
|
||||||
|
|
||||||
|
for task_need in task_dependency.needs:
|
||||||
|
result = results[task_need]
|
||||||
|
task_results.append(result)
|
||||||
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
max_result_length = max(max_result_length, len(result))
|
||||||
|
|
||||||
|
final_results = [[] for _ in range(max_result_length)]
|
||||||
|
|
||||||
|
for result in task_results:
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
for i, value in enumerate(result):
|
||||||
|
final_results[i].extend(value)
|
||||||
|
else:
|
||||||
|
final_results[0].extend(result)
|
||||||
|
|
||||||
|
input_results.extend(final_results)
|
||||||
|
else:
|
||||||
|
result = results[task_dependency]
|
||||||
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
input_results.extend(result)
|
||||||
|
else:
|
||||||
|
input_results.append(result)
|
||||||
|
|
||||||
|
return input_results
|
||||||
|
|
@ -1,227 +1,26 @@
|
||||||
import inspect
|
|
||||||
import json
|
import json
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from uuid import UUID, NAMESPACE_OID, uuid4, uuid5
|
||||||
|
|
||||||
from cognee.modules.pipelines.operations import (
|
from cognee.modules.pipelines.operations import (
|
||||||
log_pipeline_run_start,
|
log_pipeline_run_start,
|
||||||
log_pipeline_run_complete,
|
log_pipeline_run_complete,
|
||||||
log_pipeline_run_error,
|
log_pipeline_run_error,
|
||||||
)
|
)
|
||||||
from cognee.modules.settings import get_current_settings
|
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.settings import get_current_settings
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from uuid import uuid5, NAMESPACE_OID
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from ..tasks.Task import Task
|
from ..tasks.Task import Task, TaskExecutionCompleted, TaskExecutionErrored, TaskExecutionStarted
|
||||||
|
from .run_tasks_base import run_tasks_base
|
||||||
|
|
||||||
logger = get_logger("run_tasks(tasks: [Task], data)")
|
logger = get_logger("run_tasks(tasks: [Task], data)")
|
||||||
|
|
||||||
|
|
||||||
async def run_tasks_base(tasks: list[Task], data=None, user: User = None):
|
async def run_tasks_with_telemetry(
|
||||||
if len(tasks) == 0:
|
tasks: list[Task], data, pipeline_name: str, context: dict = None
|
||||||
yield data
|
|
||||||
return
|
|
||||||
|
|
||||||
args = [data] if data is not None else []
|
|
||||||
|
|
||||||
running_task = tasks[0]
|
|
||||||
leftover_tasks = tasks[1:]
|
|
||||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 0 else None
|
|
||||||
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
|
||||||
|
|
||||||
if inspect.isasyncgenfunction(running_task.executable):
|
|
||||||
logger.info("Async generator task started: `%s`", running_task.executable.__name__)
|
|
||||||
send_telemetry(
|
|
||||||
"Async Generator Task Started",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
results = []
|
|
||||||
|
|
||||||
async_iterator = running_task.run(*args)
|
|
||||||
|
|
||||||
async for partial_result in async_iterator:
|
|
||||||
results.append(partial_result)
|
|
||||||
|
|
||||||
if len(results) == next_task_batch_size:
|
|
||||||
async for result in run_tasks_base(
|
|
||||||
leftover_tasks,
|
|
||||||
results[0] if next_task_batch_size == 1 else results,
|
|
||||||
user=user,
|
|
||||||
):
|
):
|
||||||
yield result
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
if len(results) > 0:
|
|
||||||
async for result in run_tasks_base(leftover_tasks, results, user):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
logger.info("Async generator task completed: `%s`", running_task.executable.__name__)
|
|
||||||
send_telemetry(
|
|
||||||
"Async Generator Task Completed",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
"Async generator task errored: `%s`\n%s\n",
|
|
||||||
running_task.executable.__name__,
|
|
||||||
str(error),
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
send_telemetry(
|
|
||||||
"Async Generator Task Errored",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
elif inspect.isgeneratorfunction(running_task.executable):
|
|
||||||
logger.info("Generator task started: `%s`", running_task.executable.__name__)
|
|
||||||
send_telemetry(
|
|
||||||
"Generator Task Started",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for partial_result in running_task.run(*args):
|
|
||||||
results.append(partial_result)
|
|
||||||
|
|
||||||
if len(results) == next_task_batch_size:
|
|
||||||
async for result in run_tasks_base(
|
|
||||||
leftover_tasks, results[0] if next_task_batch_size == 1 else results, user
|
|
||||||
):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
if len(results) > 0:
|
|
||||||
async for result in run_tasks_base(leftover_tasks, results, user):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
logger.info("Generator task completed: `%s`", running_task.executable.__name__)
|
|
||||||
send_telemetry(
|
|
||||||
"Generator Task Completed",
|
|
||||||
user_id=user.id,
|
|
||||||
additional_properties={
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
"Generator task errored: `%s`\n%s\n",
|
|
||||||
running_task.executable.__name__,
|
|
||||||
str(error),
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
send_telemetry(
|
|
||||||
"Generator Task Errored",
|
|
||||||
user_id=user.id,
|
|
||||||
additional_properties={
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
elif inspect.iscoroutinefunction(running_task.executable):
|
|
||||||
logger.info("Coroutine task started: `%s`", running_task.executable.__name__)
|
|
||||||
send_telemetry(
|
|
||||||
"Coroutine Task Started",
|
|
||||||
user_id=user.id,
|
|
||||||
additional_properties={
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
task_result = await running_task.run(*args)
|
|
||||||
|
|
||||||
async for result in run_tasks_base(leftover_tasks, task_result, user):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
logger.info("Coroutine task completed: `%s`", running_task.executable.__name__)
|
|
||||||
send_telemetry(
|
|
||||||
"Coroutine Task Completed",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
"Coroutine task errored: `%s`\n%s\n",
|
|
||||||
running_task.executable.__name__,
|
|
||||||
str(error),
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
send_telemetry(
|
|
||||||
"Coroutine Task Errored",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
elif inspect.isfunction(running_task.executable):
|
|
||||||
logger.info("Function task started: `%s`", running_task.executable.__name__)
|
|
||||||
send_telemetry(
|
|
||||||
"Function Task Started",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
task_result = running_task.run(*args)
|
|
||||||
|
|
||||||
async for result in run_tasks_base(leftover_tasks, task_result, user):
|
|
||||||
yield result
|
|
||||||
|
|
||||||
logger.info("Function task completed: `%s`", running_task.executable.__name__)
|
|
||||||
send_telemetry(
|
|
||||||
"Function Task Completed",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
"Function task errored: `%s`\n%s\n",
|
|
||||||
running_task.executable.__name__,
|
|
||||||
str(error),
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
send_telemetry(
|
|
||||||
"Function Task Errored",
|
|
||||||
user.id,
|
|
||||||
{
|
|
||||||
"task_name": running_task.executable.__name__,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
|
|
||||||
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
|
||||||
config = get_current_settings()
|
config = get_current_settings()
|
||||||
|
|
||||||
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent=1))
|
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent=1))
|
||||||
|
|
@ -239,8 +38,45 @@ async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
||||||
| config,
|
| config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in run_tasks_base(tasks, data, user):
|
async for run_task_info in run_tasks_base(tasks, data, context):
|
||||||
yield result
|
if isinstance(run_task_info, TaskExecutionStarted):
|
||||||
|
send_telemetry(
|
||||||
|
"Task Run Started",
|
||||||
|
user.id,
|
||||||
|
additional_properties={
|
||||||
|
"task_name": run_task_info.task.__name__,
|
||||||
|
}
|
||||||
|
| config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(run_task_info, TaskExecutionCompleted):
|
||||||
|
send_telemetry(
|
||||||
|
"Task Run Completed",
|
||||||
|
user.id,
|
||||||
|
additional_properties={
|
||||||
|
"task_name": run_task_info.task.__name__,
|
||||||
|
}
|
||||||
|
| config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(run_task_info, TaskExecutionErrored):
|
||||||
|
send_telemetry(
|
||||||
|
"Task Run Errored",
|
||||||
|
user.id,
|
||||||
|
additional_properties={
|
||||||
|
"task_name": run_task_info.task.__name__,
|
||||||
|
"error": str(run_task_info.error),
|
||||||
|
}
|
||||||
|
| config,
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
"Task run errored: `%s`\n%s\n",
|
||||||
|
run_task_info.task.__name__,
|
||||||
|
str(run_task_info.error),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield run_task_info
|
||||||
|
|
||||||
logger.info("Pipeline run completed: `%s`", pipeline_name)
|
logger.info("Pipeline run completed: `%s`", pipeline_name)
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
|
|
@ -271,19 +107,22 @@ async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
||||||
|
|
||||||
async def run_tasks(
|
async def run_tasks(
|
||||||
tasks: list[Task],
|
tasks: list[Task],
|
||||||
dataset_id: UUID = uuid4(),
|
dataset_id: UUID = None,
|
||||||
data: Any = None,
|
data: Any = None,
|
||||||
pipeline_name: str = "unknown_pipeline",
|
pipeline_name: str = "unknown_pipeline",
|
||||||
|
context: dict = None,
|
||||||
):
|
):
|
||||||
|
dataset_id = dataset_id or uuid4()
|
||||||
pipeline_id = uuid5(NAMESPACE_OID, pipeline_name)
|
pipeline_id = uuid5(NAMESPACE_OID, pipeline_name)
|
||||||
|
|
||||||
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
pipeline_run = await log_pipeline_run_start(pipeline_id, pipeline_name, dataset_id, data)
|
||||||
|
|
||||||
yield pipeline_run
|
yield pipeline_run
|
||||||
|
|
||||||
pipeline_run_id = pipeline_run.pipeline_run_id
|
pipeline_run_id = pipeline_run.pipeline_run_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for _ in run_tasks_with_telemetry(tasks, data, pipeline_id):
|
async for _ in run_tasks_with_telemetry(tasks, data, pipeline_id, context):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
yield await log_pipeline_run_complete(
|
yield await log_pipeline_run_complete(
|
||||||
|
|
|
||||||
63
cognee/modules/pipelines/operations/run_tasks_base.py
Normal file
63
cognee/modules/pipelines/operations/run_tasks_base.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
from .needs import get_need_task_results, get_task_needs
|
||||||
|
from ..tasks.Task import Task, TaskExecutionCompleted, TaskExecutionInfo
|
||||||
|
from ..exceptions import WrongTaskOrderException
|
||||||
|
|
||||||
|
logger = get_logger("run_tasks_base(tasks: list[Task], data)")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_tasks_base(tasks: list[Task], data=None, context=None):
|
||||||
|
if len(tasks) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
pipeline_input = [data] if data is not None else []
|
||||||
|
|
||||||
|
"""Run tasks in dependency order and return results."""
|
||||||
|
task_graph = {} # Map task to its dependencies
|
||||||
|
dependents = {} # Reverse dependencies (who depends on whom)
|
||||||
|
results = {}
|
||||||
|
number_of_executed_tasks = 0
|
||||||
|
|
||||||
|
tasks_map = {task.executable: task for task in tasks} # Map task executable to task object
|
||||||
|
|
||||||
|
# Build task dependency graph
|
||||||
|
for task in tasks:
|
||||||
|
task_graph[task.executable] = get_task_needs(task.task_config.needs)
|
||||||
|
for dependent_task in task_graph[task.executable]:
|
||||||
|
dependents.setdefault(dependent_task, []).append(task.executable)
|
||||||
|
|
||||||
|
# Find tasks without dependencies
|
||||||
|
ready_queue = deque([task for task in tasks if not task_graph[task.executable]])
|
||||||
|
|
||||||
|
# Execute tasks in order
|
||||||
|
while ready_queue:
|
||||||
|
task = ready_queue.popleft()
|
||||||
|
task_inputs = (
|
||||||
|
get_need_task_results(results, task) if task.task_config.needs else pipeline_input
|
||||||
|
)
|
||||||
|
|
||||||
|
async for task_execution_info in task.run(*task_inputs): # Run task and store result
|
||||||
|
if isinstance(task_execution_info, TaskExecutionInfo): # Update result as it comes
|
||||||
|
results[task.executable] = task_execution_info.result
|
||||||
|
|
||||||
|
if isinstance(task_execution_info, TaskExecutionCompleted):
|
||||||
|
if task.executable not in results: # If result not already set, set it
|
||||||
|
results[task.executable] = task_execution_info.result
|
||||||
|
|
||||||
|
number_of_executed_tasks += 1
|
||||||
|
|
||||||
|
yield task_execution_info
|
||||||
|
|
||||||
|
# Process tasks depending on this task
|
||||||
|
for dependent_task in dependents.get(task.executable, []):
|
||||||
|
task_graph[dependent_task].remove(task.executable) # Mark dependency as resolved
|
||||||
|
if not task_graph[dependent_task]: # If all dependencies resolved, add to queue
|
||||||
|
ready_queue.append(tasks_map[dependent_task])
|
||||||
|
|
||||||
|
if number_of_executed_tasks != len(tasks):
|
||||||
|
raise WrongTaskOrderException(
|
||||||
|
f"{number_of_executed_tasks}/{len(tasks)} tasks executed. You likely have some disconnected tasks or circular dependency."
|
||||||
|
)
|
||||||
|
|
@ -1,30 +1,136 @@
|
||||||
from typing import Union, Callable, Any, Coroutine, Generator, AsyncGenerator
|
import inspect
|
||||||
|
from typing import Callable, Any, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..tasks.types import TaskExecutable
|
||||||
|
from ..operations.needs import MergeNeeds
|
||||||
|
from ..exceptions import TaskExecutionException
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutionStarted(BaseModel):
|
||||||
|
task: Callable
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutionCompleted(BaseModel):
|
||||||
|
task: Callable
|
||||||
|
result: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutionErrored(BaseModel):
|
||||||
|
task: TaskExecutable
|
||||||
|
error: TaskExecutionException
|
||||||
|
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutionInfo(BaseModel):
|
||||||
|
result: Any = None
|
||||||
|
task: Callable
|
||||||
|
|
||||||
|
|
||||||
|
class TaskConfig(BaseModel):
|
||||||
|
output_batch_size: int = 1
|
||||||
|
needs: list[Union[Callable, MergeNeeds]] = []
|
||||||
|
|
||||||
|
|
||||||
class Task:
|
class Task:
|
||||||
executable: Union[
|
task_config: TaskConfig
|
||||||
Callable[..., Any],
|
|
||||||
Callable[..., Coroutine[Any, Any, Any]],
|
|
||||||
Generator[Any, Any, Any],
|
|
||||||
AsyncGenerator[Any, Any],
|
|
||||||
]
|
|
||||||
task_config: dict[str, Any] = {
|
|
||||||
"batch_size": 1,
|
|
||||||
}
|
|
||||||
default_params: dict[str, Any] = {}
|
default_params: dict[str, Any] = {}
|
||||||
|
|
||||||
def __init__(self, executable, *args, task_config=None, **kwargs):
|
def __init__(self, executable, *args, task_config: TaskConfig = None, **kwargs):
|
||||||
self.executable = executable
|
self.executable = executable
|
||||||
self.default_params = {"args": args, "kwargs": kwargs}
|
self.default_params = {"args": args, "kwargs": kwargs}
|
||||||
|
self.result = None
|
||||||
|
|
||||||
if task_config is not None:
|
self.task_config = task_config or TaskConfig()
|
||||||
self.task_config = task_config
|
|
||||||
|
|
||||||
if "batch_size" not in task_config:
|
async def run(self, *args, **kwargs):
|
||||||
self.task_config["batch_size"] = 1
|
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
|
||||||
combined_args = args + self.default_params["args"]
|
combined_args = args + self.default_params["args"]
|
||||||
combined_kwargs = {**self.default_params["kwargs"], **kwargs}
|
combined_kwargs = {
|
||||||
|
**self.default_params["kwargs"],
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
return self.executable(*combined_args, **combined_kwargs)
|
yield TaskExecutionStarted(
|
||||||
|
task=self.executable,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if inspect.iscoroutinefunction(self.executable): # Async function
|
||||||
|
end_result = await self.executable(*combined_args, **combined_kwargs)
|
||||||
|
|
||||||
|
elif inspect.isgeneratorfunction(self.executable): # Generator
|
||||||
|
task_result = []
|
||||||
|
end_result = []
|
||||||
|
|
||||||
|
for value in self.executable(*combined_args, **combined_kwargs):
|
||||||
|
task_result.append(value) # Store the last yielded value
|
||||||
|
end_result.append(value)
|
||||||
|
|
||||||
|
if self.task_config.output_batch_size == 1:
|
||||||
|
yield TaskExecutionInfo(
|
||||||
|
result=value,
|
||||||
|
task=self.executable,
|
||||||
|
)
|
||||||
|
elif self.task_config.output_batch_size == len(task_result):
|
||||||
|
yield TaskExecutionInfo(
|
||||||
|
result=task_result,
|
||||||
|
task=self.executable,
|
||||||
|
)
|
||||||
|
task_result = [] # Reset for the next batch
|
||||||
|
|
||||||
|
# Yield any remaining items in the final batch if it's not empty
|
||||||
|
if task_result and self.task_config.output_batch_size > 1:
|
||||||
|
yield TaskExecutionInfo(
|
||||||
|
result=task_result,
|
||||||
|
task=self.executable,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif inspect.isasyncgenfunction(self.executable): # Async Generator
|
||||||
|
task_result = []
|
||||||
|
end_result = []
|
||||||
|
|
||||||
|
async for value in self.executable(*combined_args, **combined_kwargs):
|
||||||
|
task_result.append(value) # Store the last yielded value
|
||||||
|
end_result.append(value)
|
||||||
|
|
||||||
|
if self.task_config.output_batch_size == 1:
|
||||||
|
yield TaskExecutionInfo(
|
||||||
|
result=value,
|
||||||
|
task=self.executable,
|
||||||
|
)
|
||||||
|
elif self.task_config.output_batch_size == len(task_result):
|
||||||
|
yield TaskExecutionInfo(
|
||||||
|
result=task_result,
|
||||||
|
task=self.executable,
|
||||||
|
)
|
||||||
|
task_result = [] # Reset for the next batch
|
||||||
|
|
||||||
|
# Yield any remaining items in the final batch if it's not empty
|
||||||
|
if task_result and self.task_config.output_batch_size > 1:
|
||||||
|
yield TaskExecutionInfo(
|
||||||
|
result=task_result,
|
||||||
|
task=self.executable,
|
||||||
|
)
|
||||||
|
else: # Regular function
|
||||||
|
end_result = self.executable(*combined_args, **combined_kwargs)
|
||||||
|
|
||||||
|
yield TaskExecutionCompleted(
|
||||||
|
task=self.executable,
|
||||||
|
result=end_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
error_details = TaskExecutionException(
|
||||||
|
type=type(error).__name__,
|
||||||
|
message=str(error),
|
||||||
|
traceback=traceback.format_exc(),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield TaskExecutionErrored(
|
||||||
|
task=self.executable,
|
||||||
|
error=error_details,
|
||||||
|
)
|
||||||
|
|
|
||||||
8
cognee/modules/pipelines/tasks/__init__.py
Normal file
8
cognee/modules/pipelines/tasks/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
from .Task import (
|
||||||
|
Task,
|
||||||
|
TaskConfig,
|
||||||
|
TaskExecutionInfo,
|
||||||
|
TaskExecutionCompleted,
|
||||||
|
TaskExecutionStarted,
|
||||||
|
TaskExecutionErrored,
|
||||||
|
)
|
||||||
9
cognee/modules/pipelines/tasks/types.py
Normal file
9
cognee/modules/pipelines/tasks/types.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
from typing import Any, AsyncGenerator, Callable, Coroutine, Generator, Union
|
||||||
|
|
||||||
|
|
||||||
|
TaskExecutable = Union[
|
||||||
|
Callable[..., Any],
|
||||||
|
Callable[..., Coroutine[Any, Any, Any]],
|
||||||
|
AsyncGenerator[Any, Any],
|
||||||
|
Generator[Any, Any, Any],
|
||||||
|
]
|
||||||
|
|
@ -12,7 +12,6 @@ from cognee.modules.graph.utils import (
|
||||||
retrieve_existing_edges,
|
retrieve_existing_edges,
|
||||||
)
|
)
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.tasks.storage import add_data_points
|
|
||||||
|
|
||||||
|
|
||||||
async def integrate_chunk_graphs(
|
async def integrate_chunk_graphs(
|
||||||
|
|
@ -28,7 +27,6 @@ async def integrate_chunk_graphs(
|
||||||
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
||||||
data_chunks[chunk_index].contains = chunk_graph
|
data_chunks[chunk_index].contains = chunk_graph
|
||||||
|
|
||||||
await add_data_points(chunk_graphs)
|
|
||||||
return data_chunks
|
return data_chunks
|
||||||
|
|
||||||
existing_edges_map = await retrieve_existing_edges(
|
existing_edges_map = await retrieve_existing_edges(
|
||||||
|
|
@ -41,13 +39,7 @@ async def integrate_chunk_graphs(
|
||||||
data_chunks, chunk_graphs, ontology_adapter, existing_edges_map
|
data_chunks, chunk_graphs, ontology_adapter, existing_edges_map
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(graph_nodes) > 0:
|
return graph_nodes, graph_edges
|
||||||
await add_data_points(graph_nodes)
|
|
||||||
|
|
||||||
if len(graph_edges) > 0:
|
|
||||||
await graph_engine.add_edges(graph_edges)
|
|
||||||
|
|
||||||
return data_chunks
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_graph_from_data(
|
async def extract_graph_from_data(
|
||||||
|
|
|
||||||
|
|
@ -41,4 +41,5 @@ async def resolve_data_directories(
|
||||||
resolved_data.append(item)
|
resolved_data.append(item)
|
||||||
else: # If it's not a string add it directly
|
else: # If it's not a string add it directly
|
||||||
resolved_data.append(item)
|
resolved_data.append(item)
|
||||||
|
|
||||||
return resolved_data
|
return resolved_data
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ from .index_data_points import index_data_points
|
||||||
from .index_graph_edges import index_graph_edges
|
from .index_graph_edges import index_graph_edges
|
||||||
|
|
||||||
|
|
||||||
async def add_data_points(data_points: list[DataPoint]):
|
async def add_data_points(data_points: list[DataPoint], data_point_connections: list = None):
|
||||||
|
data_point_connections = data_point_connections or []
|
||||||
nodes = []
|
nodes = []
|
||||||
edges = []
|
edges = []
|
||||||
|
|
||||||
|
|
@ -38,6 +39,8 @@ async def add_data_points(data_points: list[DataPoint]):
|
||||||
|
|
||||||
await graph_engine.add_nodes(nodes)
|
await graph_engine.add_nodes(nodes)
|
||||||
await graph_engine.add_edges(edges)
|
await graph_engine.add_edges(edges)
|
||||||
|
if data_point_connections:
|
||||||
|
await graph_engine.add_edges(data_point_connections)
|
||||||
|
|
||||||
# This step has to happen after adding nodes and edges because we query the graph.
|
# This step has to happen after adding nodes and edges because we query the graph.
|
||||||
await index_graph_edges()
|
await index_graph_edges()
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@ import asyncio
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from uuid import uuid5
|
from uuid import uuid5
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.data.extraction.extract_summary import extract_summary
|
from cognee.modules.data.extraction.extract_summary import extract_summary
|
||||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
|
||||||
from .models import TextSummary
|
from .models import TextSummary
|
||||||
|
|
||||||
|
|
||||||
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
|
async def summarize_text(data_chunks: list[DataPoint], summarization_model: Type[BaseModel]):
|
||||||
if len(data_chunks) == 0:
|
if len(data_chunks) == 0:
|
||||||
return data_chunks
|
return data_chunks
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,66 +0,0 @@
|
||||||
import asyncio
|
|
||||||
from queue import Queue
|
|
||||||
|
|
||||||
import cognee
|
|
||||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
|
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
|
||||||
from cognee.modules.users.methods import get_default_user
|
|
||||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
|
||||||
|
|
||||||
|
|
||||||
async def pipeline(data_queue):
|
|
||||||
await cognee.prune.prune_data()
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
|
||||||
|
|
||||||
async def queue_consumer():
|
|
||||||
while not data_queue.is_closed:
|
|
||||||
if not data_queue.empty():
|
|
||||||
yield data_queue.get()
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(0.3)
|
|
||||||
|
|
||||||
async def add_one(num):
|
|
||||||
yield num + 1
|
|
||||||
|
|
||||||
async def multiply_by_two(num):
|
|
||||||
yield num * 2
|
|
||||||
|
|
||||||
await create_db_and_tables()
|
|
||||||
user = await get_default_user()
|
|
||||||
|
|
||||||
tasks_run = run_tasks_base(
|
|
||||||
[
|
|
||||||
Task(queue_consumer),
|
|
||||||
Task(add_one),
|
|
||||||
Task(multiply_by_two),
|
|
||||||
],
|
|
||||||
data=None,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
|
|
||||||
index = 0
|
|
||||||
async for result in tasks_run:
|
|
||||||
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
|
|
||||||
async def run_queue():
|
|
||||||
data_queue = Queue()
|
|
||||||
data_queue.is_closed = False
|
|
||||||
|
|
||||||
async def queue_producer():
|
|
||||||
for i in range(0, 10):
|
|
||||||
data_queue.put(i)
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
data_queue.is_closed = True
|
|
||||||
|
|
||||||
await asyncio.gather(pipeline(data_queue), queue_producer())
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_tasks_from_queue():
|
|
||||||
asyncio.run(run_queue())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(run_queue())
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
|
||||||
|
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||||
|
|
||||||
|
|
||||||
|
async def run_and_check_tasks():
|
||||||
|
def number_generator(num, context=None):
|
||||||
|
for i in range(num):
|
||||||
|
yield i + 1
|
||||||
|
|
||||||
|
async def add_one(num, context=None):
|
||||||
|
yield num + 1
|
||||||
|
|
||||||
|
async def multiply_by_two(num, context=None):
|
||||||
|
yield num * 2
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20]
|
||||||
|
|
||||||
|
async for task_run_info in run_tasks_base(
|
||||||
|
[
|
||||||
|
Task(number_generator),
|
||||||
|
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
|
||||||
|
Task(multiply_by_two, task_config=TaskConfig(needs=[number_generator])),
|
||||||
|
],
|
||||||
|
data=10,
|
||||||
|
):
|
||||||
|
if isinstance(task_run_info, TaskExecutionInfo):
|
||||||
|
assert task_run_info.result == expected_results[index], (
|
||||||
|
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
|
||||||
|
)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_tasks():
|
||||||
|
asyncio.run(run_and_check_tasks())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_run_tasks()
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
|
||||||
|
from cognee.modules.pipelines.exceptions import WrongTaskOrderException
|
||||||
|
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||||
|
|
||||||
|
|
||||||
|
async def run_and_check_tasks():
|
||||||
|
def number_generator(num, context=None):
|
||||||
|
for i in range(num):
|
||||||
|
yield i + 1
|
||||||
|
|
||||||
|
async def add_one(num, context=None):
|
||||||
|
yield num + 1
|
||||||
|
|
||||||
|
async def multiply_by_two(num, context=None):
|
||||||
|
yield num * 2
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22]
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
WrongTaskOrderException,
|
||||||
|
match="1/3 tasks executed. You likely have some disconnected tasks or circular dependency.",
|
||||||
|
):
|
||||||
|
async for task_run_info in run_tasks_base(
|
||||||
|
[
|
||||||
|
Task(number_generator),
|
||||||
|
Task(add_one, task_config=TaskConfig(needs=[number_generator, multiply_by_two])),
|
||||||
|
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one])),
|
||||||
|
],
|
||||||
|
data=10,
|
||||||
|
):
|
||||||
|
if isinstance(task_run_info, TaskExecutionInfo):
|
||||||
|
assert task_run_info.result == expected_results[index], (
|
||||||
|
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
|
||||||
|
)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_tasks():
|
||||||
|
asyncio.run(run_and_check_tasks())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_run_tasks()
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
|
||||||
|
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||||
|
|
||||||
|
|
||||||
|
async def run_and_check_tasks():
|
||||||
|
def number_generator(num, context=None):
|
||||||
|
for i in range(num):
|
||||||
|
yield i + 1
|
||||||
|
|
||||||
|
async def add_one(num, context=None):
|
||||||
|
yield num + 1
|
||||||
|
|
||||||
|
async def add_two(num, context=None):
|
||||||
|
yield num + 2
|
||||||
|
|
||||||
|
async def multiply_by_two(num1, num2, context=None):
|
||||||
|
yield num1 * 2
|
||||||
|
yield num2 * 2
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 22, 24]
|
||||||
|
|
||||||
|
async for task_run_info in run_tasks_base(
|
||||||
|
[
|
||||||
|
Task(number_generator),
|
||||||
|
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
|
||||||
|
Task(add_two, task_config=TaskConfig(needs=[number_generator])),
|
||||||
|
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one, add_two])),
|
||||||
|
],
|
||||||
|
data=10,
|
||||||
|
):
|
||||||
|
if isinstance(task_run_info, TaskExecutionInfo):
|
||||||
|
assert task_run_info.result == expected_results[index], (
|
||||||
|
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
|
||||||
|
)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_tasks():
|
||||||
|
asyncio.run(run_and_check_tasks())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_run_tasks()
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from cognee.modules.pipelines.tasks import Task, TaskConfig, TaskExecutionInfo
|
||||||
|
from cognee.modules.pipelines.operations.run_tasks_base import run_tasks_base
|
||||||
|
|
||||||
|
|
||||||
|
async def run_and_check_tasks():
|
||||||
|
def number_generator(num, context=None):
|
||||||
|
for i in range(num):
|
||||||
|
yield i + 1
|
||||||
|
|
||||||
|
async def add_one(num, context=None):
|
||||||
|
yield num + 1
|
||||||
|
|
||||||
|
async def multiply_by_two(num, context=None):
|
||||||
|
yield num * 2
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
expected_results = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22]
|
||||||
|
|
||||||
|
async for task_run_info in run_tasks_base(
|
||||||
|
[
|
||||||
|
Task(number_generator),
|
||||||
|
Task(add_one, task_config=TaskConfig(needs=[number_generator])),
|
||||||
|
Task(multiply_by_two, task_config=TaskConfig(needs=[add_one])),
|
||||||
|
],
|
||||||
|
data=10,
|
||||||
|
):
|
||||||
|
if isinstance(task_run_info, TaskExecutionInfo):
|
||||||
|
assert task_run_info.result == expected_results[index], (
|
||||||
|
f"at {index = }: {task_run_info.result = } != {expected_results[index] = }"
|
||||||
|
)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_tasks():
|
||||||
|
asyncio.run(run_and_check_tasks())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_run_tasks()
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import cognee
|
|
||||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
|
|
||||||
from cognee.modules.pipelines.tasks.Task import Task
|
|
||||||
from cognee.modules.users.methods import get_default_user
|
|
||||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
|
||||||
|
|
||||||
|
|
||||||
async def run_and_check_tasks():
|
|
||||||
await cognee.prune.prune_data()
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
|
||||||
|
|
||||||
def number_generator(num):
|
|
||||||
for i in range(num):
|
|
||||||
yield i + 1
|
|
||||||
|
|
||||||
async def add_one(nums):
|
|
||||||
for num in nums:
|
|
||||||
yield num + 1
|
|
||||||
|
|
||||||
async def multiply_by_two(num):
|
|
||||||
yield num * 2
|
|
||||||
|
|
||||||
async def add_one_single(num):
|
|
||||||
yield num + 1
|
|
||||||
|
|
||||||
await create_db_and_tables()
|
|
||||||
user = await get_default_user()
|
|
||||||
|
|
||||||
pipeline = run_tasks_base(
|
|
||||||
[
|
|
||||||
Task(number_generator),
|
|
||||||
Task(add_one, task_config={"batch_size": 5}),
|
|
||||||
Task(multiply_by_two, task_config={"batch_size": 1}),
|
|
||||||
Task(add_one_single),
|
|
||||||
],
|
|
||||||
data=10,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
|
|
||||||
index = 0
|
|
||||||
async for result in pipeline:
|
|
||||||
assert result == results[index], f"at {index = }: {result = } != {results[index] = }"
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_tasks():
|
|
||||||
asyncio.run(run_and_check_tasks())
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import cognee
|
import cognee
|
||||||
|
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||||
from cognee.modules.search.operations import get_history
|
from cognee.modules.search.operations import get_history
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
@ -76,6 +77,12 @@ async def main():
|
||||||
|
|
||||||
assert len(history) == 6, "Search history is not correct."
|
assert len(history) == 6, "Search history is not correct."
|
||||||
|
|
||||||
|
# await render_graph()
|
||||||
|
graph_file_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization.html"
|
||||||
|
)
|
||||||
|
await visualize_graph(graph_file_path)
|
||||||
|
|
||||||
# Assert local data files are cleaned properly
|
# Assert local data files are cleaned properly
|
||||||
await cognee.prune.prune_data()
|
await cognee.prune.prune_data()
|
||||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||||
|
|
|
||||||
|
|
@ -74,13 +74,13 @@
|
||||||
" get_repo_file_dependencies,\n",
|
" get_repo_file_dependencies,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"from cognee.tasks.storage import add_data_points\n",
|
"from cognee.tasks.storage import add_data_points\n",
|
||||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
|
||||||
"\n",
|
"\n",
|
||||||
"detailed_extraction = True\n",
|
"detailed_extraction = True\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tasks = [\n",
|
"tasks = [\n",
|
||||||
" Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),\n",
|
" Task(get_repo_file_dependencies, detailed_extraction=detailed_extraction),\n",
|
||||||
" Task(add_data_points, task_config={\"batch_size\": 100 if detailed_extraction else 500}),\n",
|
" Task(add_data_points, task_config=TaskConfig(needs=[get_repo_file_dependencies], output_batch_size=100 if detailed_extraction else 500)),\n",
|
||||||
"]"
|
"]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -518,11 +518,11 @@
|
||||||
"from cognee.modules.data.models import Dataset, Data\n",
|
"from cognee.modules.data.models import Dataset, Data\n",
|
||||||
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
||||||
"from cognee.modules.cognify.config import get_cognify_config\n",
|
"from cognee.modules.cognify.config import get_cognify_config\n",
|
||||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
|
||||||
"from cognee.modules.pipelines import run_tasks\n",
|
"from cognee.modules.pipelines import run_tasks\n",
|
||||||
|
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
|
||||||
|
"from cognee.modules.pipelines.operations.needs import merge_needs\n",
|
||||||
"from cognee.modules.users.models import User\n",
|
"from cognee.modules.users.models import User\n",
|
||||||
"from cognee.tasks.documents import (\n",
|
"from cognee.tasks.documents import (\n",
|
||||||
" check_permissions_on_documents,\n",
|
|
||||||
" classify_documents,\n",
|
" classify_documents,\n",
|
||||||
" extract_chunks_from_documents,\n",
|
" extract_chunks_from_documents,\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
|
@ -540,19 +540,25 @@
|
||||||
"\n",
|
"\n",
|
||||||
" tasks = [\n",
|
" tasks = [\n",
|
||||||
" Task(classify_documents),\n",
|
" Task(classify_documents),\n",
|
||||||
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
|
" Task( # Extract text chunks based on the document type.\n",
|
||||||
" Task(\n",
|
" extract_chunks_from_documents,\n",
|
||||||
" extract_chunks_from_documents, max_chunk_size=get_max_chunk_tokens()\n",
|
" max_chunk_size=get_max_chunk_tokens(),\n",
|
||||||
" ), # Extract text chunks based on the document type.\n",
|
" task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),\n",
|
||||||
" Task(\n",
|
" ),\n",
|
||||||
" extract_graph_from_data, graph_model=KnowledgeGraph, task_config={\"batch_size\": 10}\n",
|
" Task( # Generate knowledge graphs from the document chunks.\n",
|
||||||
" ), # Generate knowledge graphs from the document chunks.\n",
|
" extract_graph_from_data,\n",
|
||||||
|
" graph_model=KnowledgeGraph,\n",
|
||||||
|
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
|
||||||
|
" ),\n",
|
||||||
" Task(\n",
|
" Task(\n",
|
||||||
" summarize_text,\n",
|
" summarize_text,\n",
|
||||||
" summarization_model=cognee_config.summarization_model,\n",
|
" summarization_model=cognee_config.summarization_model,\n",
|
||||||
" task_config={\"batch_size\": 10},\n",
|
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
|
||||||
|
" ),\n",
|
||||||
|
" Task(\n",
|
||||||
|
" add_data_points,\n",
|
||||||
|
" task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
|
||||||
" ]\n",
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, \"cognify_pipeline\")\n",
|
" pipeline_run = run_tasks(tasks, dataset.id, data_documents, \"cognify_pipeline\")\n",
|
||||||
|
|
@ -1041,7 +1047,7 @@
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "py312",
|
"display_name": ".venv",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
|
@ -1055,7 +1061,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.8"
|
"version": "3.11.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
||||||
|
|
@ -397,14 +397,15 @@
|
||||||
"from cognee.modules.data.models import Dataset, Data\n",
|
"from cognee.modules.data.models import Dataset, Data\n",
|
||||||
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
|
||||||
"from cognee.modules.cognify.config import get_cognify_config\n",
|
"from cognee.modules.cognify.config import get_cognify_config\n",
|
||||||
"from cognee.modules.pipelines.tasks.Task import Task\n",
|
|
||||||
"from cognee.modules.pipelines import run_tasks\n",
|
"from cognee.modules.pipelines import run_tasks\n",
|
||||||
|
"from cognee.modules.pipelines.tasks import Task, TaskConfig\n",
|
||||||
|
"from cognee.modules.pipelines.operations.needs import merge_needs\n",
|
||||||
"from cognee.modules.users.models import User\n",
|
"from cognee.modules.users.models import User\n",
|
||||||
"from cognee.tasks.documents import (\n",
|
"from cognee.tasks.documents import (\n",
|
||||||
" check_permissions_on_documents,\n",
|
|
||||||
" classify_documents,\n",
|
" classify_documents,\n",
|
||||||
" extract_chunks_from_documents,\n",
|
" extract_chunks_from_documents,\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
"from cognee.infrastructure.llm import get_max_chunk_tokens\n",
|
||||||
"from cognee.tasks.graph import extract_graph_from_data\n",
|
"from cognee.tasks.graph import extract_graph_from_data\n",
|
||||||
"from cognee.tasks.storage import add_data_points\n",
|
"from cognee.tasks.storage import add_data_points\n",
|
||||||
"from cognee.tasks.summarization import summarize_text\n",
|
"from cognee.tasks.summarization import summarize_text\n",
|
||||||
|
|
@ -418,17 +419,25 @@
|
||||||
"\n",
|
"\n",
|
||||||
" tasks = [\n",
|
" tasks = [\n",
|
||||||
" Task(classify_documents),\n",
|
" Task(classify_documents),\n",
|
||||||
" Task(check_permissions_on_documents, user=user, permissions=[\"write\"]),\n",
|
" Task( # Extract text chunks based on the document type.\n",
|
||||||
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
|
" extract_chunks_from_documents,\n",
|
||||||
" Task(\n",
|
" max_chunk_size=get_max_chunk_tokens(),\n",
|
||||||
" extract_graph_from_data, graph_model=KnowledgeGraph, task_config={\"batch_size\": 10}\n",
|
" task_config=TaskConfig(needs=[classify_documents], output_batch_size=10),\n",
|
||||||
" ), # Generate knowledge graphs from the document chunks.\n",
|
" ),\n",
|
||||||
|
" Task( # Generate knowledge graphs from the document chunks.\n",
|
||||||
|
" extract_graph_from_data,\n",
|
||||||
|
" graph_model=KnowledgeGraph,\n",
|
||||||
|
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
|
||||||
|
" ),\n",
|
||||||
" Task(\n",
|
" Task(\n",
|
||||||
" summarize_text,\n",
|
" summarize_text,\n",
|
||||||
" summarization_model=cognee_config.summarization_model,\n",
|
" summarization_model=cognee_config.summarization_model,\n",
|
||||||
" task_config={\"batch_size\": 10},\n",
|
" task_config=TaskConfig(needs=[extract_chunks_from_documents]),\n",
|
||||||
|
" ),\n",
|
||||||
|
" Task(\n",
|
||||||
|
" add_data_points,\n",
|
||||||
|
" task_config=TaskConfig(needs=[merge_needs(summarize_text, extract_graph_from_data)]),\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" Task(add_data_points, task_config={\"batch_size\": 10}),\n",
|
|
||||||
" ]\n",
|
" ]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" pipeline = run_tasks(tasks, data_documents)\n",
|
" pipeline = run_tasks(tasks, data_documents)\n",
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
Loading…
Add table
Reference in a new issue