feat: Batch document handling (#1469)

<!-- .github/pull_request_template.md -->

## Description
Add a batch system for document processing to limit number of parallel
documents being processed in Cognee

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-10-18 09:48:52 +02:00 committed by GitHub
commit 559d5009f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 31 additions and 19 deletions

View file

@ -41,6 +41,7 @@ async def add(
extraction_rules: Optional[Dict[str, Any]] = None,
tavily_config: Optional[BaseModel] = None,
soup_crawler_config: Optional[BaseModel] = None,
data_per_batch: Optional[int] = 20,
):
"""
Add data to Cognee for knowledge graph processing.
@ -235,6 +236,7 @@ async def add(
vector_db_config=vector_db_config,
graph_db_config=graph_db_config,
incremental_loading=incremental_loading,
data_per_batch=data_per_batch,
):
pipeline_run_info = run_info

View file

@ -51,6 +51,7 @@ async def cognify(
incremental_loading: bool = True,
custom_prompt: Optional[str] = None,
temporal_cognify: bool = False,
data_per_batch: int = 20,
):
"""
Transform ingested data into a structured knowledge graph.
@ -228,6 +229,7 @@ async def cognify(
graph_db_config=graph_db_config,
incremental_loading=incremental_loading,
pipeline_name="cognify_pipeline",
data_per_batch=data_per_batch,
)

View file

@ -35,6 +35,7 @@ async def run_pipeline(
vector_db_config: dict = None,
graph_db_config: dict = None,
incremental_loading: bool = False,
data_per_batch: int = 20,
):
validate_pipeline_tasks(tasks)
await setup_and_check_environment(vector_db_config, graph_db_config)
@ -50,6 +51,7 @@ async def run_pipeline(
pipeline_name=pipeline_name,
context={"dataset": dataset},
incremental_loading=incremental_loading,
data_per_batch=data_per_batch,
):
yield run_info
@ -62,6 +64,7 @@ async def run_pipeline_per_dataset(
pipeline_name: str = "custom_pipeline",
context: dict = None,
incremental_loading=False,
data_per_batch: int = 20,
):
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
await set_database_global_context_variables(dataset.id, dataset.owner_id)
@ -77,7 +80,7 @@ async def run_pipeline_per_dataset(
return
pipeline_run = run_tasks(
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading, data_per_batch
)
async for pipeline_run_info in pipeline_run:

View file

@ -24,7 +24,6 @@ from cognee.modules.pipelines.operations import (
log_pipeline_run_complete,
log_pipeline_run_error,
)
from .run_tasks_with_telemetry import run_tasks_with_telemetry
from .run_tasks_data_item import run_tasks_data_item
from ..tasks.task import Task
@ -60,6 +59,7 @@ async def run_tasks(
pipeline_name: str = "unknown_pipeline",
context: dict = None,
incremental_loading: bool = False,
data_per_batch: int = 20,
):
if not user:
user = await get_default_user()
@ -89,24 +89,29 @@ async def run_tasks(
if incremental_loading:
data = await resolve_data_directories(data)
# Create async tasks per data item that will run the pipeline for the data item
data_item_tasks = [
asyncio.create_task(
run_tasks_data_item(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
incremental_loading,
# Create and gather batches of async tasks of data items that will run the pipeline for the data item
results = []
for start in range(0, len(data), data_per_batch):
data_batch = data[start : start + data_per_batch]
data_item_tasks = [
asyncio.create_task(
run_tasks_data_item(
data_item,
dataset,
tasks,
pipeline_name,
pipeline_id,
pipeline_run_id,
context,
user,
incremental_loading,
)
)
)
for data_item in data
]
results = await asyncio.gather(*data_item_tasks)
for data_item in data_batch
]
results.extend(await asyncio.gather(*data_item_tasks))
# Remove skipped data items from results
results = [result for result in results if result]