Add pipeline cancellation feature for graceful processing termination
• Add cancel_pipeline API endpoint • Implement PipelineCancelledException • Add cancellation checks in main loop • Handle task cancellation gracefully • Mark cancelled docs as FAILED
This commit is contained in:
parent
6a29b5daa0
commit
743aefc655
4 changed files with 183 additions and 3 deletions
|
|
@ -161,6 +161,28 @@ class ReprocessResponse(BaseModel):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CancelPipelineResponse(BaseModel):
|
||||||
|
"""Response model for pipeline cancellation operation
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
status: Status of the cancellation request
|
||||||
|
message: Message describing the operation result
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: Literal["cancellation_requested", "not_busy"] = Field(
|
||||||
|
description="Status of the cancellation request"
|
||||||
|
)
|
||||||
|
message: str = Field(description="Human-readable message describing the operation")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"status": "cancellation_requested",
|
||||||
|
"message": "Pipeline cancellation has been requested. Documents will be marked as FAILED.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class InsertTextRequest(BaseModel):
|
class InsertTextRequest(BaseModel):
|
||||||
"""Request model for inserting a single text document
|
"""Request model for inserting a single text document
|
||||||
|
|
||||||
|
|
@ -2754,4 +2776,63 @@ def create_document_routes(
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/cancel_pipeline",
|
||||||
|
response_model=CancelPipelineResponse,
|
||||||
|
dependencies=[Depends(combined_auth)],
|
||||||
|
)
|
||||||
|
async def cancel_pipeline():
|
||||||
|
"""
|
||||||
|
Request cancellation of the currently running pipeline.
|
||||||
|
|
||||||
|
This endpoint sets a cancellation flag in the pipeline status. The pipeline will:
|
||||||
|
1. Check this flag at key processing points
|
||||||
|
2. Stop processing new documents
|
||||||
|
3. Cancel all running document processing tasks
|
||||||
|
4. Mark all PROCESSING documents as FAILED with reason "User cancelled"
|
||||||
|
|
||||||
|
The cancellation is graceful and ensures data consistency. Documents that have
|
||||||
|
completed processing will remain in PROCESSED status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CancelPipelineResponse: Response with status and message
|
||||||
|
- status="cancellation_requested": Cancellation flag has been set
|
||||||
|
- status="not_busy": Pipeline is not currently running
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If an error occurs while setting cancellation flag (500).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from lightrag.kg.shared_storage import (
|
||||||
|
get_namespace_data,
|
||||||
|
get_pipeline_status_lock,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
|
pipeline_status_lock = get_pipeline_status_lock()
|
||||||
|
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if not pipeline_status.get("busy", False):
|
||||||
|
return CancelPipelineResponse(
|
||||||
|
status="not_busy",
|
||||||
|
message="Pipeline is not currently running. No cancellation needed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set cancellation flag
|
||||||
|
pipeline_status["cancellation_requested"] = True
|
||||||
|
cancel_msg = "Pipeline cancellation requested by user"
|
||||||
|
logger.info(cancel_msg)
|
||||||
|
pipeline_status["latest_message"] = cancel_msg
|
||||||
|
pipeline_status["history_messages"].append(cancel_msg)
|
||||||
|
|
||||||
|
return CancelPipelineResponse(
|
||||||
|
status="cancellation_requested",
|
||||||
|
message="Pipeline cancellation has been requested. Documents will be marked as FAILED.",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error requesting pipeline cancellation: {str(e)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
|
||||||
|
|
@ -96,3 +96,11 @@ class PipelineNotInitializedError(KeyError):
|
||||||
f" await initialize_pipeline_status()"
|
f" await initialize_pipeline_status()"
|
||||||
)
|
)
|
||||||
super().__init__(msg)
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineCancelledException(Exception):
|
||||||
|
"""Raised when pipeline processing is cancelled by user request."""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "User cancelled"):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
)
|
)
|
||||||
from lightrag.prompt import PROMPTS
|
from lightrag.prompt import PROMPTS
|
||||||
|
from lightrag.exceptions import PipelineCancelledException
|
||||||
from lightrag.constants import (
|
from lightrag.constants import (
|
||||||
DEFAULT_MAX_GLEANING,
|
DEFAULT_MAX_GLEANING,
|
||||||
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
||||||
|
|
@ -1603,6 +1604,7 @@ class LightRAG:
|
||||||
"batchs": 0, # Total number of files to be processed
|
"batchs": 0, # Total number of files to be processed
|
||||||
"cur_batch": 0, # Number of files already processed
|
"cur_batch": 0, # Number of files already processed
|
||||||
"request_pending": False, # Clear any previous request
|
"request_pending": False, # Clear any previous request
|
||||||
|
"cancellation_requested": False, # Initialize cancellation flag
|
||||||
"latest_message": "",
|
"latest_message": "",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -1619,6 +1621,22 @@ class LightRAG:
|
||||||
try:
|
try:
|
||||||
# Process documents until no more documents or requests
|
# Process documents until no more documents or requests
|
||||||
while True:
|
while True:
|
||||||
|
# Check for cancellation request at the start of main loop
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
# Clear pending request
|
||||||
|
pipeline_status["request_pending"] = False
|
||||||
|
# Celar cancellation flag
|
||||||
|
pipeline_status["cancellation_requested"] = False
|
||||||
|
|
||||||
|
log_message = "Pipeline cancelled by user"
|
||||||
|
logger.info(log_message)
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
# Exit directly, skipping request_pending check
|
||||||
|
return
|
||||||
|
|
||||||
if not to_process_docs:
|
if not to_process_docs:
|
||||||
log_message = "All enqueued documents have been processed"
|
log_message = "All enqueued documents have been processed"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
|
|
@ -1689,6 +1707,11 @@ class LightRAG:
|
||||||
first_stage_tasks = []
|
first_stage_tasks = []
|
||||||
entity_relation_task = None
|
entity_relation_task = None
|
||||||
try:
|
try:
|
||||||
|
# Check for cancellation before starting document processing
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException("User cancelled")
|
||||||
|
|
||||||
# Get file path from status document
|
# Get file path from status document
|
||||||
file_path = getattr(
|
file_path = getattr(
|
||||||
status_doc, "file_path", "unknown_source"
|
status_doc, "file_path", "unknown_source"
|
||||||
|
|
@ -1751,6 +1774,11 @@ class LightRAG:
|
||||||
# Record processing start time
|
# Record processing start time
|
||||||
processing_start_time = int(time.time())
|
processing_start_time = int(time.time())
|
||||||
|
|
||||||
|
# Check for cancellation before entity extraction
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException("User cancelled")
|
||||||
|
|
||||||
# Process document in two stages
|
# Process document in two stages
|
||||||
# Stage 1: Process text chunks and docs (parallel execution)
|
# Stage 1: Process text chunks and docs (parallel execution)
|
||||||
doc_status_task = asyncio.create_task(
|
doc_status_task = asyncio.create_task(
|
||||||
|
|
@ -1856,6 +1884,15 @@ class LightRAG:
|
||||||
# Concurrency is controlled by keyed lock for individual entities and relationships
|
# Concurrency is controlled by keyed lock for individual entities and relationships
|
||||||
if file_extraction_stage_ok:
|
if file_extraction_stage_ok:
|
||||||
try:
|
try:
|
||||||
|
# Check for cancellation before merge
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get(
|
||||||
|
"cancellation_requested", False
|
||||||
|
):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled"
|
||||||
|
)
|
||||||
|
|
||||||
# Get chunk_results from entity_relation_task
|
# Get chunk_results from entity_relation_task
|
||||||
chunk_results = await entity_relation_task
|
chunk_results = await entity_relation_task
|
||||||
await merge_nodes_and_edges(
|
await merge_nodes_and_edges(
|
||||||
|
|
@ -1970,7 +2007,19 @@ class LightRAG:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for all document processing to complete
|
# Wait for all document processing to complete
|
||||||
await asyncio.gather(*doc_tasks)
|
try:
|
||||||
|
await asyncio.gather(*doc_tasks)
|
||||||
|
except PipelineCancelledException:
|
||||||
|
# Cancel all remaining tasks
|
||||||
|
for task in doc_tasks:
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Wait for all tasks to complete cancellation
|
||||||
|
await asyncio.wait(doc_tasks, return_when=asyncio.ALL_COMPLETED)
|
||||||
|
|
||||||
|
# Exit directly (document statuses already updated in process_document)
|
||||||
|
return
|
||||||
|
|
||||||
# Check if there's a pending request to process more documents (with lock)
|
# Check if there's a pending request to process more documents (with lock)
|
||||||
has_pending_request = False
|
has_pending_request = False
|
||||||
|
|
@ -2001,11 +2050,14 @@ class LightRAG:
|
||||||
to_process_docs.update(pending_docs)
|
to_process_docs.update(pending_docs)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
log_message = "Enqueued document processing pipeline stoped"
|
log_message = "Enqueued document processing pipeline stopped"
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
# Always reset busy status when done or if an exception occurs (with lock)
|
# Always reset busy status and cancellation flag when done or if an exception occurs (with lock)
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
pipeline_status["busy"] = False
|
pipeline_status["busy"] = False
|
||||||
|
pipeline_status["cancellation_requested"] = (
|
||||||
|
False # Always reset cancellation flag
|
||||||
|
)
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import json_repair
|
||||||
from typing import Any, AsyncIterator, overload, Literal
|
from typing import Any, AsyncIterator, overload, Literal
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
from lightrag.exceptions import PipelineCancelledException
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
|
|
@ -2214,6 +2215,12 @@ async def merge_nodes_and_edges(
|
||||||
file_path: File path for logging
|
file_path: File path for logging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Check for cancellation at the start of merge
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException("User cancelled during merge phase")
|
||||||
|
|
||||||
# Collect all nodes and edges from all chunks
|
# Collect all nodes and edges from all chunks
|
||||||
all_nodes = defaultdict(list)
|
all_nodes = defaultdict(list)
|
||||||
all_edges = defaultdict(list)
|
all_edges = defaultdict(list)
|
||||||
|
|
@ -2250,6 +2257,14 @@ async def merge_nodes_and_edges(
|
||||||
|
|
||||||
async def _locked_process_entity_name(entity_name, entities):
|
async def _locked_process_entity_name(entity_name, entities):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
|
# Check for cancellation before processing entity
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during entity merge"
|
||||||
|
)
|
||||||
|
|
||||||
workspace = global_config.get("workspace", "")
|
workspace = global_config.get("workspace", "")
|
||||||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||||
async with get_storage_keyed_lock(
|
async with get_storage_keyed_lock(
|
||||||
|
|
@ -2349,6 +2364,14 @@ async def merge_nodes_and_edges(
|
||||||
|
|
||||||
async def _locked_process_edges(edge_key, edges):
|
async def _locked_process_edges(edge_key, edges):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
|
# Check for cancellation before processing edges
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during relation merge"
|
||||||
|
)
|
||||||
|
|
||||||
workspace = global_config.get("workspace", "")
|
workspace = global_config.get("workspace", "")
|
||||||
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
|
||||||
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
sorted_edge_key = sorted([edge_key[0], edge_key[1]])
|
||||||
|
|
@ -2535,6 +2558,14 @@ async def extract_entities(
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
text_chunks_storage: BaseKVStorage | None = None,
|
text_chunks_storage: BaseKVStorage | None = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
|
# Check for cancellation at the start of entity extraction
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during entity extraction"
|
||||||
|
)
|
||||||
|
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
|
|
||||||
|
|
@ -2702,6 +2733,14 @@ async def extract_entities(
|
||||||
|
|
||||||
async def _process_with_semaphore(chunk):
|
async def _process_with_semaphore(chunk):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
|
# Check for cancellation before processing chunk
|
||||||
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
if pipeline_status.get("cancellation_requested", False):
|
||||||
|
raise PipelineCancelledException(
|
||||||
|
"User cancelled during chunk processing"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await _process_single_content(chunk)
|
return await _process_single_content(chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue