feat: expose cognee.pipelines (#125)

* fix: expose cognee.pipelines and fix batch task config

* fix: neo4j neighbours
This commit is contained in:
Boris 2024-07-27 10:01:44 +02:00 committed by GitHub
parent 86c7aa23a8
commit 5182051168
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 63 additions and 31 deletions

View file

@ -4,3 +4,6 @@ from .api.v1.cognify.cognify_v2 import cognify
from .api.v1.datasets.datasets import datasets
from .api.v1.search.search import search, SearchType
from .api.v1.prune import prune
# Pipelines
from .modules import pipelines

View file

@ -8,7 +8,7 @@ import logging
import sentry_sdk
from typing import Dict, Any, List, Union, Optional, Literal
from typing_extensions import Annotated
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

View file

@ -60,8 +60,8 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
root_node_id = "ROOT"
tasks = [
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(

View file

@ -321,9 +321,9 @@ class Neo4jAdapter(GraphDBInterface):
return [result["id"] for result in results]
async def get_neighbours(self, node_id: str) -> list[str]:
results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id))
predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id))
return [*results[0], *results[1]]
return [*predecessor_ids, *successor_ids]
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
query = f"""

View file

@ -1,2 +1,3 @@
from .tasks.Task import Task
from .operations.run_tasks import run_tasks
from .operations.run_parallel import run_tasks_parallel

View file

@ -8,27 +8,29 @@ async def main():
for i in range(num):
yield i + 1
async def add_one(num):
yield num + 1
async def multiply_by_two(nums):
async def add_one(nums):
for num in nums:
yield num * 2
yield num + 1
async def add_one_to_batched_data(num):
async def multiply_by_two(num):
yield num * 2
async def add_one_single(num):
yield num + 1
pipeline = run_tasks([
Task(number_generator, task_config = {"batch_size": 1}),
Task(number_generator),
Task(add_one, task_config = {"batch_size": 5}),
Task(multiply_by_two, task_config = {"batch_size": 1}),
Task(add_one_to_batched_data),
Task(add_one_single),
], 10)
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
index = 0
async for result in pipeline:
print("\n")
print(result)
print("\n")
assert result == results[index]
index += 1
if __name__ == "__main__":
asyncio.run(main())

View file

@ -10,13 +10,12 @@ async def run_tasks(tasks: [Task], data):
return
running_task = tasks[0]
batch_size = running_task.task_config["batch_size"]
leftover_tasks = tasks[1:]
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
# next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
if inspect.isasyncgenfunction(running_task.executable):
logger.info(f"Running async generator task: `{running_task.executable.__name__}`")
logger.info("Running async generator task: `%s`", running_task.executable.__name__)
try:
results = []
@ -25,8 +24,8 @@ async def run_tasks(tasks: [Task], data):
async for partial_result in async_iterator:
results.append(partial_result)
if len(results) == batch_size:
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
if len(results) == next_task_batch_size:
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
yield result
results = []
@ -37,7 +36,7 @@ async def run_tasks(tasks: [Task], data):
results = []
logger.info(f"Finished async generator task: `{running_task.executable.__name__}`")
logger.info("Finished async generator task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running async generator task: `%s`\n%s\n",
@ -48,15 +47,15 @@ async def run_tasks(tasks: [Task], data):
raise error
elif inspect.isgeneratorfunction(running_task.executable):
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
logger.info("Running generator task: `%s`", running_task.executable.__name__)
try:
results = []
for partial_result in running_task.run(data):
results.append(partial_result)
if len(results) == batch_size:
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
if len(results) == next_task_batch_size:
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
yield result
results = []
@ -67,7 +66,7 @@ async def run_tasks(tasks: [Task], data):
results = []
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
logger.info("Finished generator task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running generator task: `%s`\n%s\n",
@ -78,13 +77,35 @@ async def run_tasks(tasks: [Task], data):
raise error
elif inspect.iscoroutinefunction(running_task.executable):
task_result = await running_task.run(data)
logger.info("Running coroutine task: `%s`", running_task.executable.__name__)
try:
task_result = await running_task.run(data)
async for result in run_tasks(leftover_tasks, task_result):
yield result
async for result in run_tasks(leftover_tasks, task_result):
yield result
logger.info("Finished coroutine task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running coroutine task: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info = True,
)
elif inspect.isfunction(running_task.executable):
task_result = running_task.run(data)
logger.info("Running function task: `%s`", running_task.executable.__name__)
try:
task_result = running_task.run(data)
async for result in run_tasks(leftover_tasks, task_result):
yield result
async for result in run_tasks(leftover_tasks, task_result):
yield result
logger.info("Finished function task: `%s`", running_task.executable.__name__)
except Exception as error:
logger.error(
"Error occurred while running function task: `%s`\n%s\n",
running_task.executable.__name__,
str(error),
exc_info = True,
)

5
cognee/pipelines.py Normal file
View file

@ -0,0 +1,5 @@
# Don't add any more code here, this file is used only for the purpose
# of enabling imports from `cognee.pipelines` module.
# `from cognee.pipelines import Task` for example.
from .modules.pipelines import *