diff --git a/api/db/db_models.py b/api/db/db_models.py index e60afbef5..ffff792a5 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -837,6 +837,58 @@ class Task(DataBaseModel): retry_count = IntegerField(default=0) digest = TextField(null=True, help_text="task digest", default="") chunk_ids = LongTextField(null=True, help_text="chunk ids", default="") + + # Checkpoint/Resume support + checkpoint_id = CharField(max_length=32, null=True, index=True, help_text="Associated checkpoint ID") + can_pause = BooleanField(default=False, help_text="Whether task supports pause/resume") + is_paused = BooleanField(default=False, index=True, help_text="Whether task is currently paused") + + +class TaskCheckpoint(DataBaseModel): + """Checkpoint data for long-running tasks (RAPTOR, GraphRAG)""" + id = CharField(max_length=32, primary_key=True) + task_id = CharField(max_length=32, null=False, index=True, help_text="Associated task ID") + task_type = CharField(max_length=32, null=False, help_text="Task type: raptor, graphrag") + + # Overall task state + status = CharField(max_length=16, null=False, default="pending", index=True, + help_text="Status: pending, running, paused, completed, failed, cancelled") + + # Document tracking + total_documents = IntegerField(default=0, help_text="Total number of documents to process") + completed_documents = IntegerField(default=0, help_text="Number of completed documents") + failed_documents = IntegerField(default=0, help_text="Number of failed documents") + pending_documents = IntegerField(default=0, help_text="Number of pending documents") + + # Progress tracking + overall_progress = FloatField(default=0.0, help_text="Overall progress (0.0 to 1.0)") + token_count = IntegerField(default=0, help_text="Total tokens consumed") + + # Checkpoint data (JSON) + checkpoint_data = JSONField(null=False, default={}, help_text="Detailed checkpoint state") + # Structure: { + # "doc_states": { + # "doc_id_1": {"status": "completed", "token_count": 1500, "chunks": 45, "completed_at": "..."}, + # "doc_id_2": {"status": "failed", "error": "API timeout", "retry_count": 3, "last_attempt": "..."}, + # "doc_id_3": {"status": "pending"}, + # }, + # "config": {...}, + # "metadata": {...} + # } + + # Timestamps + started_at = DateTimeField(null=True, help_text="When task started") + paused_at = DateTimeField(null=True, help_text="When task was paused") + resumed_at = DateTimeField(null=True, help_text="When task was resumed") + completed_at = DateTimeField(null=True, help_text="When task completed") + last_checkpoint_at = DateTimeField(null=True, index=True, help_text="Last checkpoint save time") + + # Error tracking + error_message = TextField(null=True, help_text="Error message if failed") + retry_count = IntegerField(default=0, help_text="Number of retries attempted") + + class Meta: + db_table = "task_checkpoint" class Dialog(DataBaseModel): @@ -1293,4 +1345,19 @@ def migrate_db(): migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False))) except Exception: pass + + # Checkpoint/Resume support migrations + try: + migrate(migrator.add_column("task", "checkpoint_id", CharField(max_length=32, null=True, index=True, help_text="Associated checkpoint ID"))) + except Exception: + pass + try: + migrate(migrator.add_column("task", "can_pause", BooleanField(default=False, help_text="Whether task supports pause/resume"))) + except Exception: + pass + try: + migrate(migrator.add_column("task", "is_paused", BooleanField(default=False, index=True, help_text="Whether task is currently paused"))) + except Exception: + pass + logging.disable(logging.NOTSET) diff --git a/api/db/services/checkpoint_service.py b/api/db/services/checkpoint_service.py new file mode 100644 index 000000000..0a8f8e469 --- /dev/null +++ b/api/db/services/checkpoint_service.py @@ -0,0 +1,379 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Checkpoint service for managing task checkpoints and resume functionality. +""" + +import logging +from datetime import datetime +from typing import Optional, Dict, List, Any +from api.db.db_models import TaskCheckpoint +from api.db.services.common_service import CommonService +from api.utils import get_uuid + + +class CheckpointService(CommonService): + """Service for managing task checkpoints""" + + model = TaskCheckpoint + + @classmethod + def create_checkpoint( + cls, + task_id: str, + task_type: str, + doc_ids: List[str], + config: Dict[str, Any] + ) -> TaskCheckpoint: + """ + Create a new checkpoint for a task. + + Args: + task_id: Task ID + task_type: Type of task ("raptor" or "graphrag") + doc_ids: List of document IDs to process + config: Task configuration + + Returns: + Created TaskCheckpoint instance + """ + checkpoint_id = get_uuid() + + # Initialize document states + doc_states = {} + for doc_id in doc_ids: + doc_states[doc_id] = { + "status": "pending", + "token_count": 0, + "chunks": 0, + "retry_count": 0 + } + + checkpoint_data = { + "doc_states": doc_states, + "config": config, + "metadata": { + "created_at": datetime.now().isoformat() + } + } + + checkpoint = cls.model( + id=checkpoint_id, + task_id=task_id, + task_type=task_type, + status="pending", + total_documents=len(doc_ids), + completed_documents=0, + failed_documents=0, + pending_documents=len(doc_ids), + overall_progress=0.0, + token_count=0, + checkpoint_data=checkpoint_data, + started_at=datetime.now(), + last_checkpoint_at=datetime.now() + ) + checkpoint.save() + + logging.info(f"Created checkpoint {checkpoint_id} for task {task_id} with {len(doc_ids)} documents") + return checkpoint + + @classmethod + def get_by_task_id(cls, task_id: str) -> Optional[TaskCheckpoint]: + """Get checkpoint by task ID""" + try: + return cls.model.get(cls.model.task_id == task_id) + except Exception: + return None + + @classmethod + def save_document_completion( + cls, + checkpoint_id: str, + doc_id: str, + token_count: int = 0, + chunks: int = 0 + ) -> bool: + """ + Save completion of a single document. + + Args: + checkpoint_id: Checkpoint ID + doc_id: Document ID + token_count: Tokens consumed for this document + chunks: Number of chunks generated + + Returns: + True if successful + """ + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + + # Update document state + doc_states = checkpoint.checkpoint_data.get("doc_states", {}) + if doc_id in doc_states: + doc_states[doc_id] = { + "status": "completed", + "token_count": token_count, + "chunks": chunks, + "completed_at": datetime.now().isoformat(), + "retry_count": doc_states[doc_id].get("retry_count", 0) + } + + # Update counters + completed = sum(1 for s in doc_states.values() if s["status"] == "completed") + failed = sum(1 for s in doc_states.values() if s["status"] == "failed") + pending = sum(1 for s in doc_states.values() if s["status"] == "pending") + total_tokens = sum(s.get("token_count", 0) for s in doc_states.values()) + + progress = completed / checkpoint.total_documents if checkpoint.total_documents > 0 else 0.0 + + # Update checkpoint + checkpoint.checkpoint_data["doc_states"] = doc_states + checkpoint.completed_documents = completed + checkpoint.failed_documents = failed + checkpoint.pending_documents = pending + checkpoint.overall_progress = progress + checkpoint.token_count = total_tokens + checkpoint.last_checkpoint_at = datetime.now() + + # Check if all documents are done + if pending == 0: + checkpoint.status = "completed" + checkpoint.completed_at = datetime.now() + + checkpoint.save() + + logging.info(f"Checkpoint {checkpoint_id}: Document {doc_id} completed ({completed}/{checkpoint.total_documents})") + return True + + except Exception as e: + logging.error(f"Failed to save document completion: {e}") + return False + + @classmethod + def save_document_failure( + cls, + checkpoint_id: str, + doc_id: str, + error: str + ) -> bool: + """ + Save failure of a single document. + + Args: + checkpoint_id: Checkpoint ID + doc_id: Document ID + error: Error message + + Returns: + True if successful + """ + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + + # Update document state + doc_states = checkpoint.checkpoint_data.get("doc_states", {}) + if doc_id in doc_states: + retry_count = doc_states[doc_id].get("retry_count", 0) + 1 + doc_states[doc_id] = { + "status": "failed", + "error": error, + "retry_count": retry_count, + "last_attempt": datetime.now().isoformat() + } + + # Update counters + completed = sum(1 for s in doc_states.values() if s["status"] == "completed") + failed = sum(1 for s in doc_states.values() if s["status"] == "failed") + pending = sum(1 for s in doc_states.values() if s["status"] == "pending") + + # Update checkpoint + checkpoint.checkpoint_data["doc_states"] = doc_states + checkpoint.completed_documents = completed + checkpoint.failed_documents = failed + checkpoint.pending_documents = pending + checkpoint.last_checkpoint_at = datetime.now() + checkpoint.save() + + logging.warning(f"Checkpoint {checkpoint_id}: Document {doc_id} failed: {error}") + return True + + except Exception as e: + logging.error(f"Failed to save document failure: {e}") + return False + + @classmethod + def get_pending_documents(cls, checkpoint_id: str) -> List[str]: + """Get list of pending document IDs""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + doc_states = checkpoint.checkpoint_data.get("doc_states", {}) + return [doc_id for doc_id, state in doc_states.items() if state["status"] == "pending"] + except Exception as e: + logging.error(f"Failed to get pending documents: {e}") + return [] + + @classmethod + def get_failed_documents(cls, checkpoint_id: str) -> List[Dict[str, Any]]: + """Get list of failed documents with details""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + doc_states = checkpoint.checkpoint_data.get("doc_states", {}) + failed = [] + for doc_id, state in doc_states.items(): + if state["status"] == "failed": + failed.append({ + "doc_id": doc_id, + "error": state.get("error", "Unknown error"), + "retry_count": state.get("retry_count", 0), + "last_attempt": state.get("last_attempt") + }) + return failed + except Exception as e: + logging.error(f"Failed to get failed documents: {e}") + return [] + + @classmethod + def pause_checkpoint(cls, checkpoint_id: str) -> bool: + """Mark checkpoint as paused""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + checkpoint.status = "paused" + checkpoint.paused_at = datetime.now() + checkpoint.save() + logging.info(f"Checkpoint {checkpoint_id} paused") + return True + except Exception as e: + logging.error(f"Failed to pause checkpoint: {e}") + return False + + @classmethod + def resume_checkpoint(cls, checkpoint_id: str) -> bool: + """Mark checkpoint as resumed""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + checkpoint.status = "running" + checkpoint.resumed_at = datetime.now() + checkpoint.save() + logging.info(f"Checkpoint {checkpoint_id} resumed") + return True + except Exception as e: + logging.error(f"Failed to resume checkpoint: {e}") + return False + + @classmethod + def cancel_checkpoint(cls, checkpoint_id: str) -> bool: + """Mark checkpoint as cancelled""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + checkpoint.status = "cancelled" + checkpoint.save() + logging.info(f"Checkpoint {checkpoint_id} cancelled") + return True + except Exception as e: + logging.error(f"Failed to cancel checkpoint: {e}") + return False + + @classmethod + def is_paused(cls, checkpoint_id: str) -> bool: + """Check if checkpoint is paused""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + return checkpoint.status == "paused" + except Exception: + return False + + @classmethod + def is_cancelled(cls, checkpoint_id: str) -> bool: + """Check if checkpoint is cancelled""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + return checkpoint.status == "cancelled" + except Exception: + return False + + @classmethod + def should_retry(cls, checkpoint_id: str, doc_id: str, max_retries: int = 3) -> bool: + """Check if document should be retried""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + doc_states = checkpoint.checkpoint_data.get("doc_states", {}) + if doc_id in doc_states: + retry_count = doc_states[doc_id].get("retry_count", 0) + return retry_count < max_retries + return False + except Exception: + return False + + @classmethod + def reset_document_for_retry(cls, checkpoint_id: str, doc_id: str) -> bool: + """Reset a failed document to pending for retry""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + doc_states = checkpoint.checkpoint_data.get("doc_states", {}) + + if doc_id in doc_states and doc_states[doc_id]["status"] == "failed": + retry_count = doc_states[doc_id].get("retry_count", 0) + doc_states[doc_id] = { + "status": "pending", + "token_count": 0, + "chunks": 0, + "retry_count": retry_count # Keep retry count + } + + # Update counters + failed = sum(1 for s in doc_states.values() if s["status"] == "failed") + pending = sum(1 for s in doc_states.values() if s["status"] == "pending") + + checkpoint.checkpoint_data["doc_states"] = doc_states + checkpoint.failed_documents = failed + checkpoint.pending_documents = pending + checkpoint.save() + + logging.info(f"Reset document {doc_id} for retry (attempt {retry_count + 1})") + return True + return False + except Exception as e: + logging.error(f"Failed to reset document for retry: {e}") + return False + + @classmethod + def get_checkpoint_status(cls, checkpoint_id: str) -> Optional[Dict[str, Any]]: + """Get detailed checkpoint status""" + try: + checkpoint = cls.model.get_by_id(checkpoint_id) + return { + "checkpoint_id": checkpoint.id, + "task_id": checkpoint.task_id, + "task_type": checkpoint.task_type, + "status": checkpoint.status, + "progress": checkpoint.overall_progress, + "total_documents": checkpoint.total_documents, + "completed_documents": checkpoint.completed_documents, + "failed_documents": checkpoint.failed_documents, + "pending_documents": checkpoint.pending_documents, + "token_count": checkpoint.token_count, + "started_at": checkpoint.started_at.isoformat() if checkpoint.started_at else None, + "paused_at": checkpoint.paused_at.isoformat() if checkpoint.paused_at else None, + "resumed_at": checkpoint.resumed_at.isoformat() if checkpoint.resumed_at else None, + "completed_at": checkpoint.completed_at.isoformat() if checkpoint.completed_at else None, + "last_checkpoint_at": checkpoint.last_checkpoint_at.isoformat() if checkpoint.last_checkpoint_at else None, + "error_message": checkpoint.error_message + } + except Exception as e: + logging.error(f"Failed to get checkpoint status: {e}") + return None diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 6c426f6f8..b9193c7e6 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -331,6 +331,7 @@ class RaptorConfig(Base): threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)] max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] random_seed: Annotated[int, Field(default=0, ge=0)] + use_checkpoints: Annotated[bool, Field(default=True, description="Enable checkpoint/resume for fault tolerance")] class GraphragConfig(Base): diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 714b886eb..476f62df1 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -27,6 +27,7 @@ import json_repair from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService +from api.db.services.checkpoint_service import CheckpointService from common.connection_utils import timeout from rag.utils.base64_image import image2id from common.log_utils import init_root_logger @@ -639,6 +640,121 @@ async def run_dataflow(task: dict): PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline)) +async def run_raptor_with_checkpoint(task, row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]): + """ + Checkpoint-aware RAPTOR execution that processes documents individually. + + This wrapper enables: + - Per-document checkpointing + - Pause/resume capability + - Failure isolation + - Automatic retry + """ + task_id = task["id"] + raptor_config = kb_parser_config.get("raptor", {}) + + # Create or load checkpoint + checkpoint = CheckpointService.get_by_task_id(task_id) + if not checkpoint: + checkpoint = CheckpointService.create_checkpoint( + task_id=task_id, + task_type="raptor", + doc_ids=doc_ids, + config=raptor_config + ) + logging.info(f"Created new checkpoint for RAPTOR task {task_id}") + else: + logging.info(f"Resuming RAPTOR task {task_id} from checkpoint {checkpoint.id}") + + # Get pending documents (skip already completed ones) + pending_docs = CheckpointService.get_pending_documents(checkpoint.id) + total_docs = len(doc_ids) + + if not pending_docs: + logging.info(f"All documents already processed for task {task_id}") + callback(prog=1.0, msg="All documents completed") + return + + logging.info(f"Processing {len(pending_docs)}/{total_docs} pending documents") + + # Process each document individually + all_results = [] + total_tokens = 0 + + for idx, doc_id in enumerate(pending_docs): + # Check for pause/cancel + if CheckpointService.is_paused(checkpoint.id): + logging.info(f"Task {task_id} paused at document {doc_id}") + callback(prog=0.0, msg="Task paused") + return + + if CheckpointService.is_cancelled(checkpoint.id): + logging.info(f"Task {task_id} cancelled at document {doc_id}") + callback(prog=0.0, msg="Task cancelled") + return + + try: + # Process single document + logging.info(f"Processing document {doc_id} ({idx+1}/{len(pending_docs)})") + + # Call original RAPTOR function for single document + results, token_count = await run_raptor_for_kb( + row, kb_parser_config, chat_mdl, embd_mdl, vector_size, + callback=None, # Don't use callback for individual docs + doc_ids=[doc_id] + ) + + # Save results + all_results.extend(results) + total_tokens += token_count + + # Save checkpoint + CheckpointService.save_document_completion( + checkpoint.id, + doc_id, + token_count=token_count, + chunks=len(results) + ) + + # Update progress + completed = total_docs - len(pending_docs) + idx + 1 + progress = completed / total_docs + callback(prog=progress, msg=f"Completed {completed}/{total_docs} documents") + + logging.info(f"Document {doc_id} completed: {len(results)} chunks, {token_count} tokens") + + except Exception as e: + error_msg = str(e) + logging.error(f"Failed to process document {doc_id}: {error_msg}") + + # Save failure + CheckpointService.save_document_failure( + checkpoint.id, + doc_id, + error=error_msg + ) + + # Check if we should retry + if CheckpointService.should_retry(checkpoint.id, doc_id, max_retries=3): + logging.info(f"Document {doc_id} will be retried later") + else: + logging.warning(f"Document {doc_id} exceeded max retries, skipping") + + # Continue with other documents (fault tolerance) + continue + + # Final status + failed_docs = CheckpointService.get_failed_documents(checkpoint.id) + if failed_docs: + logging.warning(f"Task {task_id} completed with {len(failed_docs)} failed documents") + callback(prog=1.0, msg=f"Completed with {len(failed_docs)} failures") + else: + logging.info(f"Task {task_id} completed successfully") + callback(prog=1.0, msg="All documents completed successfully") + + return all_results, total_tokens + + @timeout(3600) async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]): fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID @@ -854,17 +970,35 @@ async def do_handle_task(task): # bind LLM for raptor chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) - # run RAPTOR + + # Check if checkpointing is enabled (default: True for RAPTOR) + use_checkpoints = kb_parser_config.get("raptor", {}).get("use_checkpoints", True) + + # run RAPTOR with or without checkpoints async with kg_limiter: - chunks, token_count = await run_raptor_for_kb( - row=task, - kb_parser_config=kb_parser_config, - chat_mdl=chat_model, - embd_mdl=embedding_model, - vector_size=vector_size, - callback=progress_callback, - doc_ids=task.get("doc_ids", []), - ) + if use_checkpoints: + # Use checkpoint-aware version for fault tolerance + chunks, token_count = await run_raptor_with_checkpoint( + task=task, + row=task, + kb_parser_config=kb_parser_config, + chat_mdl=chat_model, + embd_mdl=embedding_model, + vector_size=vector_size, + callback=progress_callback, + doc_ids=task.get("doc_ids", []), + ) + else: + # Use original version (legacy mode) + chunks, token_count = await run_raptor_for_kb( + row=task, + kb_parser_config=kb_parser_config, + chat_mdl=chat_model, + embd_mdl=embedding_model, + vector_size=vector_size, + callback=progress_callback, + doc_ids=task.get("doc_ids", []), + ) if fake_doc_ids := task.get("doc_ids", []): task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes # Either using graphrag or Standard chunking methods