feat: expose cognee.pipelines (#125)
* fix: expose cognee.pipelines and fix batch task config * fix: neo4j neighbours
This commit is contained in:
parent
86c7aa23a8
commit
5182051168
9 changed files with 63 additions and 31 deletions
|
|
@ -4,3 +4,6 @@ from .api.v1.cognify.cognify_v2 import cognify
|
|||
from .api.v1.datasets.datasets import datasets
|
||||
from .api.v1.search.search import search, SearchType
|
||||
from .api.v1.prune import prune
|
||||
|
||||
# Pipelines
|
||||
from .modules import pipelines
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import sentry_sdk
|
||||
from typing import Dict, Any, List, Union, Optional, Literal
|
||||
from typing_extensions import Annotated
|
||||
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
|
||||
from fastapi import FastAPI, HTTPException, Form, UploadFile, Query
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
|
|
|||
|
|
@ -60,8 +60,8 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No
|
|||
root_node_id = "ROOT"
|
||||
|
||||
tasks = [
|
||||
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
|
||||
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
|
||||
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
|
||||
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
|
||||
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
|
||||
Task(
|
||||
|
|
|
|||
|
|
@ -321,9 +321,9 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
return [result["id"] for result in results]
|
||||
|
||||
async def get_neighbours(self, node_id: str) -> list[str]:
|
||||
results = await asyncio.gather(*[self.get_predecessor_ids(node_id)], self.get_successor_ids(node_id))
|
||||
predecessor_ids, successor_ids = await asyncio.gather(self.get_predecessor_ids(node_id), self.get_successor_ids(node_id))
|
||||
|
||||
return [*results[0], *results[1]]
|
||||
return [*predecessor_ids, *successor_ids]
|
||||
|
||||
async def remove_connection_to_predecessors_of(self, node_ids: list[str], edge_label: str) -> None:
|
||||
query = f"""
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from .tasks.Task import Task
|
||||
from .operations.run_tasks import run_tasks
|
||||
from .operations.run_parallel import run_tasks_parallel
|
||||
|
|
|
|||
|
|
@ -8,27 +8,29 @@ async def main():
|
|||
for i in range(num):
|
||||
yield i + 1
|
||||
|
||||
async def add_one(num):
|
||||
yield num + 1
|
||||
|
||||
async def multiply_by_two(nums):
|
||||
async def add_one(nums):
|
||||
for num in nums:
|
||||
yield num * 2
|
||||
yield num + 1
|
||||
|
||||
async def add_one_to_batched_data(num):
|
||||
async def multiply_by_two(num):
|
||||
yield num * 2
|
||||
|
||||
async def add_one_single(num):
|
||||
yield num + 1
|
||||
|
||||
pipeline = run_tasks([
|
||||
Task(number_generator, task_config = {"batch_size": 1}),
|
||||
Task(number_generator),
|
||||
Task(add_one, task_config = {"batch_size": 5}),
|
||||
Task(multiply_by_two, task_config = {"batch_size": 1}),
|
||||
Task(add_one_to_batched_data),
|
||||
Task(add_one_single),
|
||||
], 10)
|
||||
|
||||
results = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23]
|
||||
index = 0
|
||||
async for result in pipeline:
|
||||
print("\n")
|
||||
print(result)
|
||||
print("\n")
|
||||
assert result == results[index]
|
||||
index += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -10,13 +10,12 @@ async def run_tasks(tasks: [Task], data):
|
|||
return
|
||||
|
||||
running_task = tasks[0]
|
||||
batch_size = running_task.task_config["batch_size"]
|
||||
leftover_tasks = tasks[1:]
|
||||
next_task = leftover_tasks[0] if len(leftover_tasks) > 1 else None
|
||||
# next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||
next_task_batch_size = next_task.task_config["batch_size"] if next_task else 1
|
||||
|
||||
if inspect.isasyncgenfunction(running_task.executable):
|
||||
logger.info(f"Running async generator task: `{running_task.executable.__name__}`")
|
||||
logger.info("Running async generator task: `%s`", running_task.executable.__name__)
|
||||
try:
|
||||
results = []
|
||||
|
||||
|
|
@ -25,8 +24,8 @@ async def run_tasks(tasks: [Task], data):
|
|||
async for partial_result in async_iterator:
|
||||
results.append(partial_result)
|
||||
|
||||
if len(results) == batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
|
||||
if len(results) == next_task_batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
|
@ -37,7 +36,7 @@ async def run_tasks(tasks: [Task], data):
|
|||
|
||||
results = []
|
||||
|
||||
logger.info(f"Finished async generator task: `{running_task.executable.__name__}`")
|
||||
logger.info("Finished async generator task: `%s`", running_task.executable.__name__)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running async generator task: `%s`\n%s\n",
|
||||
|
|
@ -48,15 +47,15 @@ async def run_tasks(tasks: [Task], data):
|
|||
raise error
|
||||
|
||||
elif inspect.isgeneratorfunction(running_task.executable):
|
||||
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
|
||||
logger.info("Running generator task: `%s`", running_task.executable.__name__)
|
||||
try:
|
||||
results = []
|
||||
|
||||
for partial_result in running_task.run(data):
|
||||
results.append(partial_result)
|
||||
|
||||
if len(results) == batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if batch_size == 1 else results):
|
||||
if len(results) == next_task_batch_size:
|
||||
async for result in run_tasks(leftover_tasks, results[0] if next_task_batch_size == 1 else results):
|
||||
yield result
|
||||
|
||||
results = []
|
||||
|
|
@ -67,7 +66,7 @@ async def run_tasks(tasks: [Task], data):
|
|||
|
||||
results = []
|
||||
|
||||
logger.info(f"Running generator task: `{running_task.executable.__name__}`")
|
||||
logger.info("Finished generator task: `%s`", running_task.executable.__name__)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running generator task: `%s`\n%s\n",
|
||||
|
|
@ -78,13 +77,35 @@ async def run_tasks(tasks: [Task], data):
|
|||
raise error
|
||||
|
||||
elif inspect.iscoroutinefunction(running_task.executable):
|
||||
task_result = await running_task.run(data)
|
||||
logger.info("Running coroutine task: `%s`", running_task.executable.__name__)
|
||||
try:
|
||||
task_result = await running_task.run(data)
|
||||
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
|
||||
logger.info("Finished coroutine task: `%s`", running_task.executable.__name__)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running coroutine task: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info = True,
|
||||
)
|
||||
|
||||
elif inspect.isfunction(running_task.executable):
|
||||
task_result = running_task.run(data)
|
||||
logger.info("Running function task: `%s`", running_task.executable.__name__)
|
||||
try:
|
||||
task_result = running_task.run(data)
|
||||
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
async for result in run_tasks(leftover_tasks, task_result):
|
||||
yield result
|
||||
|
||||
logger.info("Finished function task: `%s`", running_task.executable.__name__)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Error occurred while running function task: `%s`\n%s\n",
|
||||
running_task.executable.__name__,
|
||||
str(error),
|
||||
exc_info = True,
|
||||
)
|
||||
|
|
|
|||
5
cognee/pipelines.py
Normal file
5
cognee/pipelines.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Don't add any more code here, this file is used only for the purpose
|
||||
# of enabling imports from `cognee.pipelines` module.
|
||||
# `from cognee.pipelines import Task` for example.
|
||||
|
||||
from .modules.pipelines import *
|
||||
Loading…
Add table
Reference in a new issue