From 5182051168497f26836036c926afad734bef2072 Mon Sep 17 00:00:00 2001 From: Boris Date: Sat, 27 Jul 2024 10:01:44 +0200 Subject: [PATCH] feat: expose cognee.pipelines (#125) * fix: expose cognee.pipelines and fix batch task config * fix: neo4j neighbours --- cognee/__init__.py | 3 ++ cognee/api/client.py | 2 +- cognee/api/v1/cognify/cognify_v2.py | 4 +- .../databases/graph/neo4j_driver/adapter.py | 4 +- cognee/modules/pipelines/__init__.py | 1 + .../__tests__/{__index__.py => __init__.py} | 0 .../operations/__tests__/run_tasks.test.py | 22 ++++---- .../modules/pipelines/operations/run_tasks.py | 53 +++++++++++++------ cognee/pipelines.py | 5 ++ 9 files changed, 63 insertions(+), 31 deletions(-) rename cognee/modules/pipelines/operations/__tests__/{__index__.py => __init__.py} (100%) create mode 100644 cognee/pipelines.py diff --git a/cognee/__init__.py b/cognee/__init__.py index 49a5d7f21..04c13eae1 100644 --- a/cognee/__init__.py +++ b/cognee/__init__.py @@ -4,3 +4,6 @@ from .api.v1.cognify.cognify_v2 import cognify from .api.v1.datasets.datasets import datasets from .api.v1.search.search import search, SearchType from .api.v1.prune import prune + +# Pipelines +from .modules import pipelines diff --git a/cognee/api/client.py b/cognee/api/client.py index 196582b5d..bee1c2922 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -8,7 +8,7 @@ import logging import sentry_sdk from typing import Dict, Any, List, Union, Optional, Literal from typing_extensions import Annotated -from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query +from fastapi import FastAPI, HTTPException, Form, UploadFile, Query from fastapi.responses import JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel diff --git a/cognee/api/v1/cognify/cognify_v2.py b/cognee/api/v1/cognify/cognify_v2.py index e10680f19..4da5605d1 100644 --- a/cognee/api/v1/cognify/cognify_v2.py +++ b/cognee/api/v1/cognify/cognify_v2.py @@ -60,8 +60,8 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No root_node_id = "ROOT" tasks = [ - Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type - Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data + Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type + Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks Task( diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 51568973a..3cc93232f 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -321,9 +321,9 @@ class Neo4jAdapter(GraphDBInterface): return [result["id"] for result in results] async def get_neighbours(self, node_id: str) -> list[str]: - results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id)) + predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id)) - return [*results[0], *results[1]] + return [*predecessor_ids, *successor_ids] async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None: query = f""" diff --git a/cognee/modules/pipelines/__init__.py b/cognee/modules/pipelines/__init__.py index cba929c36..5005c25f0 100644 --- a/cognee/modules/pipelines/__init__.py +++ b/cognee/modules/pipelines/__init__.py @@ -1,2 +1,3 @@ +from .tasks.Task import Task from .operations.run_tasks import run_tasks from .operations.run_parallel import run_tasks_parallel diff --git a/cognee/modules/pipelines/operations/__tests__/__index__.py b/cognee/modules/pipelines/operations/__tests__/__init__.py similarity index 100% rename from cognee/modules/pipelines/operations/__tests__/__index__.py rename to cognee/modules/pipelines/operations/__tests__/__init__.py diff --git a/cognee/modules/pipelines/operations/__tests__/run_tasks.test.py b/cognee/modules/pipelines/operations/__tests__/run_tasks.test.py index 387b97274..2fef802fd 100644 --- a/cognee/modules/pipelines/operations/__tests__/run_tasks.test.py +++ b/cognee/modules/pipelines/operations/__tests__/run_tasks.test.py @@ -8,27 +8,29 @@ async def main(): for i in range(num): yield i + 1 - async def add_one(num): - yield num + 1 - - async def multiply_by_two(nums): + async def add_one(nums): for num in nums: - yield num * 2 + yield num + 1 - async def add_one_to_batched_data(num): + async def multiply_by_two(num): + yield num * 2 + + async def add_one_single(num): yield num + 1 pipeline = run_tasks([ - Task(number_generator, task_config = {"batch_size": 1}), + Task(number_generator), Task(add_one, task_config = {"batch_size": 5}), Task(multiply_by_two, task_config = {"batch_size": 1}), - Task(add_one_to_batched_data), + Task(add_one_single), ], 10) + results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23] + index = 0 async for result in pipeline: - print("\n") print(result) - print("\n") + assert result == results[index] + index += 1 if __name__ == "__main__": asyncio.run(main()) diff --git a/cognee/modules/pipelines/operations/run_tasks.py b/cognee/modules/pipelines/operations/run_tasks.py index 1000743da..e372b239f 100644 --- a/cognee/modules/pipelines/operations/run_tasks.py +++ b/cognee/modules/pipelines/operations/run_tasks.py @@ -10,13 +10,12 @@ async def run_tasks(tasks: [Task], data): return running_task = tasks[0] - batch_size = running_task.task_config["batch_size"] leftover_tasks = tasks[1:] next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None - # next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1 + next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1 if inspect.isasyncgenfunction(running_task.executable): - logger.info(f"Running async generator task: `{running_task.executable.__name__}`") + logger.info("Running async generator task: `%s`", running_task.executable.__name__) try: results = [] @@ -25,8 +24,8 @@ async def run_tasks(tasks: [Task], data): async for partial_result in async_iterator: results.append(partial_result) - if len(results) == batch_size: - async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results): + if len(results) == next_task_batch_size: + async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results): yield result results = [] @@ -37,7 +36,7 @@ async def run_tasks(tasks: [Task], data): results = [] - logger.info(f"Finished async generator task: `{running_task.executable.__name__}`") + logger.info("Finished async generator task: `%s`", running_task.executable.__name__) except Exception as error: logger.error( "Error occurred while running async generator task: `%s`\n%s\n", @@ -48,15 +47,15 @@ async def run_tasks(tasks: [Task], data): raise error elif inspect.isgeneratorfunction(running_task.executable): - logger.info(f"Running generator task: `{running_task.executable.__name__}`") + logger.info("Running generator task: `%s`", running_task.executable.__name__) try: results = [] for partial_result in running_task.run(data): results.append(partial_result) - if len(results) == batch_size: - async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results): + if len(results) == next_task_batch_size: + async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results): yield result results = [] @@ -67,7 +66,7 @@ async def run_tasks(tasks: [Task], data): results = [] - logger.info(f"Running generator task: `{running_task.executable.__name__}`") + logger.info("Finished generator task: `%s`", running_task.executable.__name__) except Exception as error: logger.error( "Error occurred while running generator task: `%s`\n%s\n", @@ -78,13 +77,35 @@ async def run_tasks(tasks: [Task], data): raise error elif inspect.iscoroutinefunction(running_task.executable): - task_result = await running_task.run(data) + logger.info("Running coroutine task: `%s`", running_task.executable.__name__) + try: + task_result = await running_task.run(data) - async for result in run_tasks(leftover_tasks, task_result): - yield result + async for result in run_tasks(leftover_tasks, task_result): + yield result + logger.info("Finished coroutine task: `%s`", running_task.executable.__name__) + except Exception as error: + logger.error( + "Error occurred while running coroutine task: `%s`\n%s\n", + running_task.executable.__name__, + str(error), + exc_info = True, + ) + elif inspect.isfunction(running_task.executable): - task_result = running_task.run(data) + logger.info("Running function task: `%s`", running_task.executable.__name__) + try: + task_result = running_task.run(data) - async for result in run_tasks(leftover_tasks, task_result): - yield result + async for result in run_tasks(leftover_tasks, task_result): + yield result + + logger.info("Finished function task: `%s`", running_task.executable.__name__) + except Exception as error: + logger.error( + "Error occurred while running function task: `%s`\n%s\n", + running_task.executable.__name__, + str(error), + exc_info = True, + ) diff --git a/cognee/pipelines.py b/cognee/pipelines.py new file mode 100644 index 000000000..7e5e791d4 --- /dev/null +++ b/cognee/pipelines.py @@ -0,0 +1,5 @@ +# Don't add any more code here, this file is used only for the purpose +# of enabling imports from `cognee.pipelines` module. +# `from cognee.pipelines import Task` for example. + +from .modules.pipelines import *