refactor: forwarding of data batch size rework

This commit is contained in:
Igor Ilic 2025-10-15 20:18:48 +02:00
parent ad4a732e28
commit 2fb06e0729
4 changed files with 13 additions and 8 deletions

View file

@ -41,6 +41,7 @@ async def add(
extraction_rules: Optional[Dict[str, Any]] = None, extraction_rules: Optional[Dict[str, Any]] = None,
tavily_config: Optional[BaseModel] = None, tavily_config: Optional[BaseModel] = None,
soup_crawler_config: Optional[BaseModel] = None, soup_crawler_config: Optional[BaseModel] = None,
data_batch_size: Optional[int] = 20,
): ):
""" """
Add data to Cognee for knowledge graph processing. Add data to Cognee for knowledge graph processing.
@ -235,6 +236,7 @@ async def add(
vector_db_config=vector_db_config, vector_db_config=vector_db_config,
graph_db_config=graph_db_config, graph_db_config=graph_db_config,
incremental_loading=incremental_loading, incremental_loading=incremental_loading,
data_batch_size=data_batch_size,
): ):
pipeline_run_info = run_info pipeline_run_info = run_info

View file

@ -51,6 +51,7 @@ async def cognify(
incremental_loading: bool = True, incremental_loading: bool = True,
custom_prompt: Optional[str] = None, custom_prompt: Optional[str] = None,
temporal_cognify: bool = False, temporal_cognify: bool = False,
data_batch_size: int = 20,
): ):
""" """
Transform ingested data into a structured knowledge graph. Transform ingested data into a structured knowledge graph.
@ -228,6 +229,7 @@ async def cognify(
graph_db_config=graph_db_config, graph_db_config=graph_db_config,
incremental_loading=incremental_loading, incremental_loading=incremental_loading,
pipeline_name="cognify_pipeline", pipeline_name="cognify_pipeline",
data_batch_size=data_batch_size,
) )

View file

@ -35,6 +35,7 @@ async def run_pipeline(
vector_db_config: dict = None, vector_db_config: dict = None,
graph_db_config: dict = None, graph_db_config: dict = None,
incremental_loading: bool = False, incremental_loading: bool = False,
data_batch_size: int = 20,
): ):
validate_pipeline_tasks(tasks) validate_pipeline_tasks(tasks)
await setup_and_check_environment(vector_db_config, graph_db_config) await setup_and_check_environment(vector_db_config, graph_db_config)
@ -50,6 +51,7 @@ async def run_pipeline(
pipeline_name=pipeline_name, pipeline_name=pipeline_name,
context={"dataset": dataset}, context={"dataset": dataset},
incremental_loading=incremental_loading, incremental_loading=incremental_loading,
data_batch_size=data_batch_size,
): ):
yield run_info yield run_info
@ -62,6 +64,7 @@ async def run_pipeline_per_dataset(
pipeline_name: str = "custom_pipeline", pipeline_name: str = "custom_pipeline",
context: dict = None, context: dict = None,
incremental_loading=False, incremental_loading=False,
data_batch_size: int = 20,
): ):
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
await set_database_global_context_variables(dataset.id, dataset.owner_id) await set_database_global_context_variables(dataset.id, dataset.owner_id)
@ -77,7 +80,7 @@ async def run_pipeline_per_dataset(
return return
pipeline_run = run_tasks( pipeline_run = run_tasks(
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading tasks, dataset.id, data, user, pipeline_name, context, incremental_loading, data_batch_size
) )
async for pipeline_run_info in pipeline_run: async for pipeline_run_info in pipeline_run:

View file

@ -24,14 +24,11 @@ from cognee.modules.pipelines.operations import (
log_pipeline_run_complete, log_pipeline_run_complete,
log_pipeline_run_error, log_pipeline_run_error,
) )
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from .run_tasks_data_item import run_tasks_data_item from .run_tasks_data_item import run_tasks_data_item
from ..tasks.task import Task from ..tasks.task import Task
logger = get_logger("run_tasks(tasks: [Task], data)") logger = get_logger("run_tasks(tasks: [Task], data)")
# TODO: See if this parameter should be configurable as input for run_tasks itself
DOCUMENT_BATCH_SIZE = 10
def override_run_tasks(new_gen): def override_run_tasks(new_gen):
@ -62,6 +59,7 @@ async def run_tasks(
pipeline_name: str = "unknown_pipeline", pipeline_name: str = "unknown_pipeline",
context: dict = None, context: dict = None,
incremental_loading: bool = False, incremental_loading: bool = False,
data_batch_size: int = 20,
): ):
if not user: if not user:
user = await get_default_user() user = await get_default_user()
@ -93,12 +91,12 @@ async def run_tasks(
# Create and gather batches of async tasks of data items that will run the pipeline for the data item # Create and gather batches of async tasks of data items that will run the pipeline for the data item
results = [] results = []
for start in range(0, len(data), DOCUMENT_BATCH_SIZE): for start in range(0, len(data), data_batch_size):
document_batch = data[start : start + DOCUMENT_BATCH_SIZE] data_batch = data[start : start + data_batch_size]
data_item_tasks = [ data_item_tasks = [
asyncio.create_task( asyncio.create_task(
_run_tasks_data_item( run_tasks_data_item(
data_item, data_item,
dataset, dataset,
tasks, tasks,
@ -110,7 +108,7 @@ async def run_tasks(
incremental_loading, incremental_loading,
) )
) )
for data_item in document_batch for data_item in data_batch
] ]
results.extend(await asyncio.gather(*data_item_tasks)) results.extend(await asyncio.gather(*data_item_tasks))