diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 3032bd4e8..d29d8c939 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -44,6 +44,7 @@ async def cognify( graph_model: BaseModel = KnowledgeGraph, chunker=TextChunker, chunk_size: int = None, + batch_size: int = None, config: Config = None, vector_db_config: dict = None, graph_db_config: dict = None, @@ -105,6 +106,7 @@ async def cognify( Formula: min(embedding_max_completion_tokens, llm_max_completion_tokens // 2) Default limits: ~512-8192 tokens depending on models. Smaller chunks = more granular but potentially fragmented knowledge. + batch_size: Number of chunks to be processed in a single batch in Cognify tasks. vector_db_config: Custom vector database configuration for embeddings storage. graph_db_config: Custom graph database configuration for relationship storage. run_in_background: If True, starts processing asynchronously and returns immediately. @@ -209,10 +211,18 @@ async def cognify( } if temporal_cognify: - tasks = await get_temporal_tasks(user, chunker, chunk_size) + tasks = await get_temporal_tasks( + user=user, chunker=chunker, chunk_size=chunk_size, batch_size=batch_size + ) else: tasks = await get_default_tasks( - user, graph_model, chunker, chunk_size, config, custom_prompt + user=user, + graph_model=graph_model, + chunker=chunker, + chunk_size=chunk_size, + config=config, + custom_prompt=custom_prompt, + batch_size=batch_size, ) # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for @@ -238,6 +248,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's chunk_size: int = None, config: Config = None, custom_prompt: Optional[str] = None, + batch_size: int = 100, ) -> list[Task]: if config is None: ontology_config = get_ontology_env_config() @@ -256,6 +267,9 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's "ontology_config": {"ontology_resolver": get_default_ontology_resolver()} } + if batch_size is None: + batch_size = 100 + default_tasks = [ Task(classify_documents), Task(check_permissions_on_dataset, user=user, permissions=["write"]), @@ -269,20 +283,20 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's graph_model=graph_model, config=config, custom_prompt=custom_prompt, - task_config={"batch_size": 100}, + task_config={"batch_size": batch_size}, ), # Generate knowledge graphs from the document chunks. Task( summarize_text, - task_config={"batch_size": 100}, + task_config={"batch_size": batch_size}, ), - Task(add_data_points, task_config={"batch_size": 100}), + Task(add_data_points, task_config={"batch_size": batch_size}), ] return default_tasks async def get_temporal_tasks( - user: User = None, chunker=TextChunker, chunk_size: int = None + user: User = None, chunker=TextChunker, chunk_size: int = None, batch_size: int = 10 ) -> list[Task]: """ Builds and returns a list of temporal processing tasks to be executed in sequence. @@ -299,10 +313,14 @@ async def get_temporal_tasks( user (User, optional): The user requesting task execution, used for permission checks. chunker (Callable, optional): A text chunking function/class to split documents. Defaults to TextChunker. chunk_size (int, optional): Maximum token size per chunk. If not provided, uses system default. + batch_size (int, optional): Number of chunks to process in a single batch in Cognify Returns: list[Task]: A list of Task objects representing the temporal processing pipeline. """ + if batch_size is None: + batch_size = 10 + temporal_tasks = [ Task(classify_documents), Task(check_permissions_on_dataset, user=user, permissions=["write"]), @@ -311,9 +329,9 @@ async def get_temporal_tasks( max_chunk_size=chunk_size or get_max_chunk_tokens(), chunker=chunker, ), - Task(extract_events_and_timestamps, task_config={"batch_size": 10}), + Task(extract_events_and_timestamps, task_config={"batch_size": batch_size}), Task(extract_knowledge_graph_from_events), - Task(add_data_points, task_config={"batch_size": 10}), + Task(add_data_points, task_config={"batch_size": batch_size}), ] return temporal_tasks