diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 65394f1ec..0f14683f9 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -41,6 +41,7 @@ async def add( extraction_rules: Optional[Dict[str, Any]] = None, tavily_config: Optional[BaseModel] = None, soup_crawler_config: Optional[BaseModel] = None, + data_per_batch: Optional[int] = 20, ): """ Add data to Cognee for knowledge graph processing. @@ -235,6 +236,7 @@ async def add( vector_db_config=vector_db_config, graph_db_config=graph_db_config, incremental_loading=incremental_loading, + data_per_batch=data_per_batch, ): pipeline_run_info = run_info diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index c3045f00a..1eb266765 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -51,6 +51,7 @@ async def cognify( incremental_loading: bool = True, custom_prompt: Optional[str] = None, temporal_cognify: bool = False, + data_per_batch: int = 20, ): """ Transform ingested data into a structured knowledge graph. @@ -228,6 +229,7 @@ async def cognify( graph_db_config=graph_db_config, incremental_loading=incremental_loading, pipeline_name="cognify_pipeline", + data_per_batch=data_per_batch, ) diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index b59a171f7..e15e9e505 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -35,6 +35,7 @@ async def run_pipeline( vector_db_config: dict = None, graph_db_config: dict = None, incremental_loading: bool = False, + data_per_batch: int = 20, ): validate_pipeline_tasks(tasks) await setup_and_check_environment(vector_db_config, graph_db_config) @@ -50,6 +51,7 @@ async def run_pipeline( pipeline_name=pipeline_name, context={"dataset": dataset}, incremental_loading=incremental_loading, + data_per_batch=data_per_batch, ): yield run_info @@ -62,6 +64,7 @@ async def run_pipeline_per_dataset( pipeline_name: str = "custom_pipeline", context: dict = None, incremental_loading=False, + data_per_batch: int = 20, ): # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -77,7 +80,7 @@ async def run_pipeline_per_dataset( return pipeline_run = run_tasks( - tasks, dataset.id, data, user, pipeline_name, context, incremental_loading + tasks, dataset.id, data, user, pipeline_name, context, incremental_loading, data_per_batch ) async for pipeline_run_info in pipeline_run: diff --git a/cognee/modules/pipelines/operations/run_tasks.py b/cognee/modules/pipelines/operations/run_tasks.py index 4a0c77309..ecc2f647b 100644 --- a/cognee/modules/pipelines/operations/run_tasks.py +++ b/cognee/modules/pipelines/operations/run_tasks.py @@ -24,7 +24,6 @@ from cognee.modules.pipelines.operations import ( log_pipeline_run_complete, log_pipeline_run_error, ) -from .run_tasks_with_telemetry import run_tasks_with_telemetry from .run_tasks_data_item import run_tasks_data_item from ..tasks.task import Task @@ -60,6 +59,7 @@ async def run_tasks( pipeline_name: str = "unknown_pipeline", context: dict = None, incremental_loading: bool = False, + data_per_batch: int = 20, ): if not user: user = await get_default_user() @@ -89,24 +89,29 @@ async def run_tasks( if incremental_loading: data = await resolve_data_directories(data) - # Create async tasks per data item that will run the pipeline for the data item - data_item_tasks = [ - asyncio.create_task( - run_tasks_data_item( - data_item, - dataset, - tasks, - pipeline_name, - pipeline_id, - pipeline_run_id, - context, - user, - incremental_loading, + # Create and gather batches of async tasks of data items that will run the pipeline for the data item + results = [] + for start in range(0, len(data), data_per_batch): + data_batch = data[start : start + data_per_batch] + + data_item_tasks = [ + asyncio.create_task( + run_tasks_data_item( + data_item, + dataset, + tasks, + pipeline_name, + pipeline_id, + pipeline_run_id, + context, + user, + incremental_loading, + ) ) - ) - for data_item in data - ] - results = await asyncio.gather(*data_item_tasks) + for data_item in data_batch + ] + + results.extend(await asyncio.gather(*data_item_tasks)) # Remove skipped data items from results results = [result for result in results if result]