feat: Batch document handling (#1469)
<!-- .github/pull_request_template.md --> ## Description Add a batch system for document processing to limit number of parallel documents being processed in Cognee ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
559d5009f7
4 changed files with 31 additions and 19 deletions
|
|
@ -41,6 +41,7 @@ async def add(
|
||||||
extraction_rules: Optional[Dict[str, Any]] = None,
|
extraction_rules: Optional[Dict[str, Any]] = None,
|
||||||
tavily_config: Optional[BaseModel] = None,
|
tavily_config: Optional[BaseModel] = None,
|
||||||
soup_crawler_config: Optional[BaseModel] = None,
|
soup_crawler_config: Optional[BaseModel] = None,
|
||||||
|
data_per_batch: Optional[int] = 20,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add data to Cognee for knowledge graph processing.
|
Add data to Cognee for knowledge graph processing.
|
||||||
|
|
@ -235,6 +236,7 @@ async def add(
|
||||||
vector_db_config=vector_db_config,
|
vector_db_config=vector_db_config,
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
|
data_per_batch=data_per_batch,
|
||||||
):
|
):
|
||||||
pipeline_run_info = run_info
|
pipeline_run_info = run_info
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ async def cognify(
|
||||||
incremental_loading: bool = True,
|
incremental_loading: bool = True,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
temporal_cognify: bool = False,
|
temporal_cognify: bool = False,
|
||||||
|
data_per_batch: int = 20,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Transform ingested data into a structured knowledge graph.
|
Transform ingested data into a structured knowledge graph.
|
||||||
|
|
@ -228,6 +229,7 @@ async def cognify(
|
||||||
graph_db_config=graph_db_config,
|
graph_db_config=graph_db_config,
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
pipeline_name="cognify_pipeline",
|
pipeline_name="cognify_pipeline",
|
||||||
|
data_per_batch=data_per_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ async def run_pipeline(
|
||||||
vector_db_config: dict = None,
|
vector_db_config: dict = None,
|
||||||
graph_db_config: dict = None,
|
graph_db_config: dict = None,
|
||||||
incremental_loading: bool = False,
|
incremental_loading: bool = False,
|
||||||
|
data_per_batch: int = 20,
|
||||||
):
|
):
|
||||||
validate_pipeline_tasks(tasks)
|
validate_pipeline_tasks(tasks)
|
||||||
await setup_and_check_environment(vector_db_config, graph_db_config)
|
await setup_and_check_environment(vector_db_config, graph_db_config)
|
||||||
|
|
@ -50,6 +51,7 @@ async def run_pipeline(
|
||||||
pipeline_name=pipeline_name,
|
pipeline_name=pipeline_name,
|
||||||
context={"dataset": dataset},
|
context={"dataset": dataset},
|
||||||
incremental_loading=incremental_loading,
|
incremental_loading=incremental_loading,
|
||||||
|
data_per_batch=data_per_batch,
|
||||||
):
|
):
|
||||||
yield run_info
|
yield run_info
|
||||||
|
|
||||||
|
|
@ -62,6 +64,7 @@ async def run_pipeline_per_dataset(
|
||||||
pipeline_name: str = "custom_pipeline",
|
pipeline_name: str = "custom_pipeline",
|
||||||
context: dict = None,
|
context: dict = None,
|
||||||
incremental_loading=False,
|
incremental_loading=False,
|
||||||
|
data_per_batch: int = 20,
|
||||||
):
|
):
|
||||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||||
|
|
@ -77,7 +80,7 @@ async def run_pipeline_per_dataset(
|
||||||
return
|
return
|
||||||
|
|
||||||
pipeline_run = run_tasks(
|
pipeline_run = run_tasks(
|
||||||
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
|
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading, data_per_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
async for pipeline_run_info in pipeline_run:
|
async for pipeline_run_info in pipeline_run:
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@ from cognee.modules.pipelines.operations import (
|
||||||
log_pipeline_run_complete,
|
log_pipeline_run_complete,
|
||||||
log_pipeline_run_error,
|
log_pipeline_run_error,
|
||||||
)
|
)
|
||||||
from .run_tasks_with_telemetry import run_tasks_with_telemetry
|
|
||||||
from .run_tasks_data_item import run_tasks_data_item
|
from .run_tasks_data_item import run_tasks_data_item
|
||||||
from ..tasks.task import Task
|
from ..tasks.task import Task
|
||||||
|
|
||||||
|
|
@ -60,6 +59,7 @@ async def run_tasks(
|
||||||
pipeline_name: str = "unknown_pipeline",
|
pipeline_name: str = "unknown_pipeline",
|
||||||
context: dict = None,
|
context: dict = None,
|
||||||
incremental_loading: bool = False,
|
incremental_loading: bool = False,
|
||||||
|
data_per_batch: int = 20,
|
||||||
):
|
):
|
||||||
if not user:
|
if not user:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
@ -89,24 +89,29 @@ async def run_tasks(
|
||||||
if incremental_loading:
|
if incremental_loading:
|
||||||
data = await resolve_data_directories(data)
|
data = await resolve_data_directories(data)
|
||||||
|
|
||||||
# Create async tasks per data item that will run the pipeline for the data item
|
# Create and gather batches of async tasks of data items that will run the pipeline for the data item
|
||||||
data_item_tasks = [
|
results = []
|
||||||
asyncio.create_task(
|
for start in range(0, len(data), data_per_batch):
|
||||||
run_tasks_data_item(
|
data_batch = data[start : start + data_per_batch]
|
||||||
data_item,
|
|
||||||
dataset,
|
data_item_tasks = [
|
||||||
tasks,
|
asyncio.create_task(
|
||||||
pipeline_name,
|
run_tasks_data_item(
|
||||||
pipeline_id,
|
data_item,
|
||||||
pipeline_run_id,
|
dataset,
|
||||||
context,
|
tasks,
|
||||||
user,
|
pipeline_name,
|
||||||
incremental_loading,
|
pipeline_id,
|
||||||
|
pipeline_run_id,
|
||||||
|
context,
|
||||||
|
user,
|
||||||
|
incremental_loading,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
for data_item in data_batch
|
||||||
for data_item in data
|
]
|
||||||
]
|
|
||||||
results = await asyncio.gather(*data_item_tasks)
|
results.extend(await asyncio.gather(*data_item_tasks))
|
||||||
|
|
||||||
# Remove skipped data items from results
|
# Remove skipped data items from results
|
||||||
results = [result for result in results if result]
|
results = [result for result in results if result]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue