diff --git a/env.example b/env.example index 828f962e..3c5113ff 100644 --- a/env.example +++ b/env.example @@ -208,6 +208,7 @@ OPENAI_LLM_MAX_COMPLETION_TOKENS=9000 # OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}' ### use the following command to see all support options for Ollama LLM +### If LightRAG deployed in Docker uses host.docker.internal instead of localhost in LLM_BINDING_HOST ### lightrag-server --llm-binding ollama --help ### Ollama Server Specific Parameters ### OLLAMA_LLM_NUM_CTX must be provided, and should at least larger than MAX_TOTAL_TOKENS + 2000 @@ -229,7 +230,7 @@ EMBEDDING_BINDING=ollama EMBEDDING_MODEL=bge-m3:latest EMBEDDING_DIM=1024 EMBEDDING_BINDING_API_KEY=your_api_key -# If the embedding service is deployed within the same Docker stack, use host.docker.internal instead of localhost +# If LightRAG deployed in Docker uses host.docker.internal instead of localhost EMBEDDING_BINDING_HOST=http://localhost:11434 ### OpenAI compatible (VoyageAI embedding openai compatible) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index c38b09e5..1d044256 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.4.9.4" +__version__ = "1.4.9.5" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index 6268052f..de364382 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "0243" +__api_version__ = "0245" diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index c33296c0..51b911ec 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -161,6 +161,28 @@ class ReprocessResponse(BaseModel): } +class CancelPipelineResponse(BaseModel): + """Response model for pipeline cancellation operation + + Attributes: + status: Status of the cancellation request + message: Message describing the operation result + """ + + status: Literal["cancellation_requested", "not_busy"] = Field( + description="Status of the cancellation request" + ) + message: str = Field(description="Human-readable message describing the operation") + + class Config: + json_schema_extra = { + "example": { + "status": "cancellation_requested", + "message": "Pipeline cancellation has been requested. Documents will be marked as FAILED.", + } + } + + class InsertTextRequest(BaseModel): """Request model for inserting a single text document @@ -458,7 +480,7 @@ class DocsStatusesResponse(BaseModel): "id": "doc_789", "content_summary": "Document pending final indexing", "content_length": 7200, - "status": "multimodal_processed", + "status": "preprocessed", "created_at": "2025-03-31T09:30:00", "updated_at": "2025-03-31T09:35:00", "track_id": "upload_20250331_093000_xyz789", @@ -1525,7 +1547,19 @@ async def background_delete_documents( try: # Loop through each document ID and delete them one by one for i, doc_id in enumerate(doc_ids, 1): + # Check for cancellation at the start of each document deletion async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + cancel_msg = f"Deletion cancelled by user at document {i}/{total_docs}. {len(successful_deletions)} deleted, {total_docs - i + 1} remaining." + logger.info(cancel_msg) + pipeline_status["latest_message"] = cancel_msg + pipeline_status["history_messages"].append(cancel_msg) + # Add remaining documents to failed list with cancellation reason + failed_deletions.extend( + doc_ids[i - 1 :] + ) # i-1 because enumerate starts at 1 + break # Exit the loop, remaining documents unchanged + start_msg = f"Deleting document {i}/{total_docs}: {doc_id}" logger.info(start_msg) pipeline_status["cur_batch"] = i @@ -1688,6 +1722,10 @@ async def background_delete_documents( # Final summary and check for pending requests async with pipeline_status_lock: pipeline_status["busy"] = False + pipeline_status["pending_requests"] = False # Reset pending requests flag + pipeline_status["cancellation_requested"] = ( + False # Always reset cancellation flag + ) completion_msg = f"Deletion completed: {len(successful_deletions)} successful, {len(failed_deletions)} failed" pipeline_status["latest_message"] = completion_msg pipeline_status["history_messages"].append(completion_msg) @@ -2221,7 +2259,7 @@ def create_document_routes( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - # TODO: Deprecated + # TODO: Deprecated, use /documents/paginated instead @router.get( "", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)] ) @@ -2745,4 +2783,63 @@ def create_document_routes( logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + @router.post( + "/cancel_pipeline", + response_model=CancelPipelineResponse, + dependencies=[Depends(combined_auth)], + ) + async def cancel_pipeline(): + """ + Request cancellation of the currently running pipeline. + + This endpoint sets a cancellation flag in the pipeline status. The pipeline will: + 1. Check this flag at key processing points + 2. Stop processing new documents + 3. Cancel all running document processing tasks + 4. Mark all PROCESSING documents as FAILED with reason "User cancelled" + + The cancellation is graceful and ensures data consistency. Documents that have + completed processing will remain in PROCESSED status. + + Returns: + CancelPipelineResponse: Response with status and message + - status="cancellation_requested": Cancellation flag has been set + - status="not_busy": Pipeline is not currently running + + Raises: + HTTPException: If an error occurs while setting cancellation flag (500). + """ + try: + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_pipeline_status_lock, + ) + + pipeline_status = await get_namespace_data("pipeline_status") + pipeline_status_lock = get_pipeline_status_lock() + + async with pipeline_status_lock: + if not pipeline_status.get("busy", False): + return CancelPipelineResponse( + status="not_busy", + message="Pipeline is not currently running. No cancellation needed.", + ) + + # Set cancellation flag + pipeline_status["cancellation_requested"] = True + cancel_msg = "Pipeline cancellation requested by user" + logger.info(cancel_msg) + pipeline_status["latest_message"] = cancel_msg + pipeline_status["history_messages"].append(cancel_msg) + + return CancelPipelineResponse( + status="cancellation_requested", + message="Pipeline cancellation has been requested. Documents will be marked as FAILED.", + ) + + except Exception as e: + logger.error(f"Error requesting pipeline cancellation: {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + return router diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 53cc41c0..f0ee0e98 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -73,6 +73,16 @@ class QueryRequest(BaseModel): ge=1, ) + hl_keywords: list[str] = Field( + default_factory=list, + description="List of high-level keywords to prioritize in retrieval. Leave empty to use the LLM to generate the keywords.", + ) + + ll_keywords: list[str] = Field( + default_factory=list, + description="List of low-level keywords to refine retrieval focus. Leave empty to use the LLM to generate the keywords.", + ) + conversation_history: Optional[List[Dict[str, Any]]] = Field( default=None, description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", @@ -294,6 +304,16 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): } ``` + Bypass initial LLM call by providing high-level and low-level keywords: + ```json + { + "query": "What is Retrieval-Augmented-Generation?", + "hl_keywords": ["machine learning", "information retrieval", "natural language processing"], + "ll_keywords": ["retrieval augmented generation", "RAG", "knowledge base"], + "mode": "mix" + } + ``` + Advanced query with references: ```json { @@ -482,6 +502,16 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): } ``` + Bypass initial LLM call by providing high-level and low-level keywords: + ```json + { + "query": "What is Retrieval-Augmented-Generation?", + "hl_keywords": ["machine learning", "information retrieval", "natural language processing"], + "ll_keywords": ["retrieval augmented generation", "RAG", "knowledge base"], + "mode": "mix" + } + ``` + Complete response query: ```json { @@ -968,6 +998,16 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): } ``` + Bypass initial LLM call by providing high-level and low-level keywords: + ```json + { + "query": "What is Retrieval-Augmented-Generation?", + "hl_keywords": ["machine learning", "information retrieval", "natural language processing"], + "ll_keywords": ["retrieval augmented generation", "RAG", "knowledge base"], + "mode": "mix" + } + ``` + **Response Analysis:** - **Empty arrays**: Normal for certain modes (e.g., naive mode has no entities/relationships) - **Processing info**: Shows retrieval statistics and token usage diff --git a/lightrag/base.py b/lightrag/base.py index e569de2a..3cf40136 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -720,7 +720,7 @@ class DocStatus(str, Enum): PENDING = "pending" PROCESSING = "processing" - PREPROCESSED = "multimodal_processed" + PREPROCESSED = "preprocessed" PROCESSED = "processed" FAILED = "failed" @@ -751,6 +751,25 @@ class DocProcessingStatus: """Error message if failed""" metadata: dict[str, Any] = field(default_factory=dict) """Additional metadata""" + multimodal_processed: bool | None = field(default=None, repr=False) + """Internal field: indicates if multimodal processing is complete. Not shown in repr() but accessible for debugging.""" + + def __post_init__(self): + """ + Handle status conversion based on multimodal_processed field. + + Business rules: + - If multimodal_processed is False and status is PROCESSED, + then change status to PREPROCESSED + - The multimodal_processed field is kept (with repr=False) for internal use and debugging + """ + # Apply status conversion logic + if self.multimodal_processed is not None: + if ( + self.multimodal_processed is False + and self.status == DocStatus.PROCESSED + ): + self.status = DocStatus.PREPROCESSED @dataclass diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py index d57df1ac..09e1d0e7 100644 --- a/lightrag/exceptions.py +++ b/lightrag/exceptions.py @@ -96,3 +96,11 @@ class PipelineNotInitializedError(KeyError): f" await initialize_pipeline_status()" ) super().__init__(msg) + + +class PipelineCancelledException(Exception): + """Raised when pipeline processing is cancelled by user request.""" + + def __init__(self, message: str = "User cancelled"): + super().__init__(message) + self.message = message diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index cbe2cf82..edb7983c 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -462,14 +462,37 @@ class MilvusVectorDBStorage(BaseVectorStorage): if type_name in ["FloatVector", "FLOAT_VECTOR"]: existing_dimension = field.get("params", {}).get("dim") - if existing_dimension != current_dimension: + # Convert both to int for comparison to handle type mismatches + # (Milvus API may return string "1024" vs int 1024) + try: + existing_dim_int = ( + int(existing_dimension) + if existing_dimension is not None + else None + ) + current_dim_int = ( + int(current_dimension) + if current_dimension is not None + else None + ) + except (TypeError, ValueError) as e: + logger.error( + f"[{self.workspace}] Failed to parse dimensions: existing={existing_dimension} (type={type(existing_dimension)}), " + f"current={current_dimension} (type={type(current_dimension)}), error={e}" + ) + raise ValueError( + f"Invalid dimension values for collection '{self.final_namespace}': " + f"existing={existing_dimension}, current={current_dimension}" + ) from e + + if existing_dim_int != current_dim_int: raise ValueError( f"Vector dimension mismatch for collection '{self.final_namespace}': " - f"existing={existing_dimension}, current={current_dimension}" + f"existing={existing_dim_int}, current={current_dim_int}" ) logger.debug( - f"[{self.workspace}] Vector dimension check passed: {current_dimension}" + f"[{self.workspace}] Vector dimension check passed: {current_dim_int}" ) return diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index db26d1c1..723de69f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4613,16 +4613,19 @@ class PGGraphStorage(BaseGraphStorage): Returns: A list of all nodes, where each node is a dictionary of its properties """ - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:base) - RETURN n - $$) AS (n agtype)""" + # Use native SQL to avoid Cypher wrapper overhead + # Original: SELECT * FROM cypher(...) with MATCH (n:base) + # Optimized: Direct table access for better performance + query = f""" + SELECT properties + FROM {self.graph_name}.base + """ results = await self._query(query) nodes = [] for result in results: - if result["n"]: - node_dict = result["n"]["properties"] + if result.get("properties"): + node_dict = result["properties"] # Process string result, parse it to JSON dictionary if isinstance(node_dict, str): @@ -4632,6 +4635,7 @@ class PGGraphStorage(BaseGraphStorage): logger.warning( f"[{self.workspace}] Failed to parse node string: {node_dict}" ) + continue # Add node id (entity_id) to the dictionary for easier access node_dict["id"] = node_dict.get("entity_id") @@ -4643,12 +4647,21 @@ class PGGraphStorage(BaseGraphStorage): Returns: A list of all edges, where each edge is a dictionary of its properties - (The edge is bidirectional; deduplication must be handled by the caller) + (If 2 directional edges exist between the same pair of nodes, deduplication must be handled by the caller) + """ + # Use native SQL to avoid Cartesian product (N×N) in Cypher MATCH + # Original Cypher: MATCH (a:base)-[r]-(b:base) creates ~50 billion row combinations + # Optimized: Start from edges table, join to nodes only to get entity_id + # Performance: O(E) instead of O(N²), ~50,000x faster for large graphs + query = f""" + SELECT DISTINCT + (ag_catalog.agtype_access_operator(VARIADIC ARRAY[a.properties, '"entity_id"'::agtype]))::text AS source, + (ag_catalog.agtype_access_operator(VARIADIC ARRAY[b.properties, '"entity_id"'::agtype]))::text AS target, + r.properties + FROM {self.graph_name}."DIRECTED" r + JOIN {self.graph_name}.base a ON r.start_id = a.id + JOIN {self.graph_name}.base b ON r.end_id = b.id """ - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (a:base)-[r]-(b:base) - RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties - $$) AS (source text, target text, properties agtype)""" results = await self._query(query) edges = [] diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index e20dce52..33d43bfa 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -12,15 +12,15 @@ from lightrag.exceptions import PipelineNotInitializedError # Define a direct print function for critical logs that must be visible in all processes -def direct_log(message, enable_output: bool = False, level: str = "DEBUG"): +def direct_log(message, enable_output: bool = True, level: str = "DEBUG"): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. Args: message: The message to log - level: Log level (default: "DEBUG") - enable_output: Whether to actually output the log (default: True) + level: Log level for message (control the visibility of the message by comparing with the current logger level) + enable_output: Enable or disable log message (Force to turn off the message,) """ if not enable_output: return @@ -44,7 +44,6 @@ def direct_log(message, enable_output: bool = False, level: str = "DEBUG"): } message_level = level_mapping.get(level.upper(), logging.DEBUG) - # print(f"Diret_log: {level.upper()} {message_level} ? {current_level}", file=sys.stderr, flush=True) if message_level >= current_level: print(f"{level}: {message}", file=sys.stderr, flush=True) @@ -141,7 +140,8 @@ class UnifiedLock(Generic[T]): if not self._is_async and self._async_lock is not None: await self._async_lock.acquire() direct_log( - f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired", + f"== Lock == Process {self._pid}: Acquired async lock '{self._name}", + level="DEBUG", enable_output=self._enable_logging, ) @@ -152,7 +152,8 @@ class UnifiedLock(Generic[T]): self._lock.acquire() direct_log( - f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", + f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})", + level="INFO", enable_output=self._enable_logging, ) return self @@ -168,7 +169,7 @@ class UnifiedLock(Generic[T]): direct_log( f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) raise @@ -183,7 +184,8 @@ class UnifiedLock(Generic[T]): main_lock_released = True direct_log( - f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", + f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})", + level="INFO", enable_output=self._enable_logging, ) @@ -191,7 +193,8 @@ class UnifiedLock(Generic[T]): if not self._is_async and self._async_lock is not None: self._async_lock.release() direct_log( - f"== Lock == Process {self._pid}: Async lock '{self._name}' released", + f"== Lock == Process {self._pid}: Released async lock {self._name}", + level="DEBUG", enable_output=self._enable_logging, ) @@ -199,7 +202,7 @@ class UnifiedLock(Generic[T]): direct_log( f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) # If main lock release failed but async lock hasn't been released, try to release it @@ -211,19 +214,20 @@ class UnifiedLock(Generic[T]): try: direct_log( f"== Lock == Process {self._pid}: Attempting to release async lock after main lock failure", - level="WARNING", + level="DEBUG", enable_output=self._enable_logging, ) self._async_lock.release() direct_log( f"== Lock == Process {self._pid}: Successfully released async lock after main lock failure", + level="INFO", enable_output=self._enable_logging, ) except Exception as inner_e: direct_log( f"== Lock == Process {self._pid}: Failed to release async lock after main lock failure: {inner_e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) raise @@ -234,12 +238,14 @@ class UnifiedLock(Generic[T]): if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") direct_log( - f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", + f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)", + level="DEBUG", enable_output=self._enable_logging, ) self._lock.acquire() direct_log( - f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", + f"== Lock == Process {self._pid}: Acquired lock {self._name} (sync)", + level="INFO", enable_output=self._enable_logging, ) return self @@ -247,7 +253,7 @@ class UnifiedLock(Generic[T]): direct_log( f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) raise @@ -258,18 +264,20 @@ class UnifiedLock(Generic[T]): raise RuntimeError("Use 'async with' for shared_storage lock") direct_log( f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", + level="DEBUG", enable_output=self._enable_logging, ) self._lock.release() direct_log( - f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", + f"== Lock == Process {self._pid}: Released lock {self._name} (sync)", + level="INFO", enable_output=self._enable_logging, ) except Exception as e: direct_log( f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) raise @@ -401,7 +409,7 @@ def _perform_lock_cleanup( direct_log( f"== {lock_type} Lock == Cleanup failed: {e}", level="ERROR", - enable_output=False, + enable_output=True, ) return 0, earliest_cleanup_time, last_cleanup_time @@ -689,7 +697,7 @@ class KeyedUnifiedLock: direct_log( f"Error during multiprocess lock cleanup: {e}", level="ERROR", - enable_output=False, + enable_output=True, ) # 2. Cleanup async locks using generic function @@ -718,7 +726,7 @@ class KeyedUnifiedLock: direct_log( f"Error during async lock cleanup: {e}", level="ERROR", - enable_output=False, + enable_output=True, ) # 3. Get current status after cleanup @@ -772,7 +780,7 @@ class KeyedUnifiedLock: direct_log( f"Error getting keyed lock status: {e}", level="ERROR", - enable_output=False, + enable_output=True, ) return status @@ -797,32 +805,239 @@ class _KeyedLockContext: if enable_logging is not None else parent._default_enable_logging ) - self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__ + self._ul: Optional[List[Dict[str, Any]]] = None # set in __aenter__ # ----- enter ----- async def __aenter__(self): if self._ul is not None: raise RuntimeError("KeyedUnifiedLock already acquired in current context") - # acquire locks for all keys in the namespace self._ul = [] - for key in self._keys: - lock = self._parent._get_lock_for_key( - self._namespace, key, enable_logging=self._enable_logging - ) - await lock.__aenter__() - inc_debug_n_locks_acquired() - self._ul.append(lock) - return self + + try: + # Acquire locks for all keys in the namespace + for key in self._keys: + lock = None + entry = None + + try: + # 1. Get lock object (reference count is incremented here) + lock = self._parent._get_lock_for_key( + self._namespace, key, enable_logging=self._enable_logging + ) + + # 2. Immediately create and add entry to list (critical for rollback to work) + entry = { + "key": key, + "lock": lock, + "entered": False, + "debug_inc": False, + "ref_incremented": True, # Mark that reference count has been incremented + } + self._ul.append( + entry + ) # Add immediately after _get_lock_for_key for rollback to work + + # 3. Try to acquire the lock + # Use try-finally to ensure state is updated atomically + lock_acquired = False + try: + await lock.__aenter__() + lock_acquired = True # Lock successfully acquired + finally: + if lock_acquired: + entry["entered"] = True + inc_debug_n_locks_acquired() + entry["debug_inc"] = True + + except asyncio.CancelledError: + # Lock acquisition was cancelled + # The finally block above ensures entry["entered"] is correct + direct_log( + f"Lock acquisition cancelled for key {key}", + level="WARNING", + enable_output=self._enable_logging, + ) + raise + except Exception as e: + # Other exceptions, log and re-raise + direct_log( + f"Lock acquisition failed for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + raise + + return self + + except BaseException: + # Critical: if any exception occurs (including CancelledError) during lock acquisition, + # we must rollback all already acquired locks to prevent lock leaks + # Use shield to ensure rollback completes + await asyncio.shield(self._rollback_acquired_locks()) + raise + + async def _rollback_acquired_locks(self): + """Rollback all acquired locks in case of exception during __aenter__""" + if not self._ul: + return + + async def rollback_single_entry(entry): + """Rollback a single lock acquisition""" + key = entry["key"] + lock = entry["lock"] + debug_inc = entry["debug_inc"] + entered = entry["entered"] + ref_incremented = entry.get( + "ref_incremented", True + ) # Default to True for safety + + errors = [] + + # 1. If lock was acquired, release it + if entered: + try: + await lock.__aexit__(None, None, None) + except Exception as e: + errors.append(("lock_exit", e)) + direct_log( + f"Lock rollback error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 2. Release reference count (if it was incremented) + if ref_incremented: + try: + self._parent._release_lock_for_key(self._namespace, key) + except Exception as e: + errors.append(("ref_release", e)) + direct_log( + f"Lock rollback reference release error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 3. Decrement debug counter + if debug_inc: + try: + dec_debug_n_locks_acquired() + except Exception as e: + errors.append(("debug_dec", e)) + direct_log( + f"Lock rollback counter decrementing error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + return errors + + # Release already acquired locks in reverse order + for entry in reversed(self._ul): + # Use shield to protect each lock's rollback + try: + await asyncio.shield(rollback_single_entry(entry)) + except Exception as e: + # Log but continue rolling back other locks + direct_log( + f"Lock rollback unexpected error for {entry['key']}: {e}", + level="ERROR", + enable_output=True, + ) + + self._ul = None # ----- exit ----- async def __aexit__(self, exc_type, exc, tb): - # The UnifiedLock takes care of proper release order - for ul, key in zip(reversed(self._ul), reversed(self._keys)): - await ul.__aexit__(exc_type, exc, tb) - self._parent._release_lock_for_key(self._namespace, key) - dec_debug_n_locks_acquired() - self._ul = None + if self._ul is None: + return + + async def release_all_locks(): + """Release all locks with comprehensive error handling, protected from cancellation""" + + async def release_single_entry(entry, exc_type, exc, tb): + """Release a single lock with full protection""" + key = entry["key"] + lock = entry["lock"] + debug_inc = entry["debug_inc"] + entered = entry["entered"] + + errors = [] + + # 1. Release the lock + if entered: + try: + await lock.__aexit__(exc_type, exc, tb) + except Exception as e: + errors.append(("lock_exit", e)) + direct_log( + f"Lock release error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 2. Release reference count + try: + self._parent._release_lock_for_key(self._namespace, key) + except Exception as e: + errors.append(("ref_release", e)) + direct_log( + f"Lock release reference error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 3. Decrement debug counter + if debug_inc: + try: + dec_debug_n_locks_acquired() + except Exception as e: + errors.append(("debug_dec", e)) + direct_log( + f"Lock release counter decrementing error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + return errors + + all_errors = [] + + # Release locks in reverse order + # This entire loop is protected by the outer shield + for entry in reversed(self._ul): + try: + errors = await release_single_entry(entry, exc_type, exc, tb) + for error_type, error in errors: + all_errors.append((entry["key"], error_type, error)) + except Exception as e: + all_errors.append((entry["key"], "unexpected", e)) + direct_log( + f"Lock release unexpected error for {entry['key']}: {e}", + level="ERROR", + enable_output=True, + ) + + return all_errors + + # CRITICAL: Protect the entire release process with shield + # This ensures that even if cancellation occurs, all locks are released + try: + all_errors = await asyncio.shield(release_all_locks()) + except Exception as e: + direct_log( + f"Critical error during __aexit__ cleanup: {e}", + level="ERROR", + enable_output=True, + ) + all_errors = [] + finally: + # Always clear the lock list, even if shield was cancelled + self._ul = None + + # If there were release errors and no other exception, raise the first release error + if all_errors and exc_type is None: + raise all_errors[0][2] # (key, error_type, error) def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 3b82d718..a5326b5b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -22,6 +22,7 @@ from typing import ( Dict, ) from lightrag.prompt import PROMPTS +from lightrag.exceptions import PipelineCancelledException from lightrag.constants import ( DEFAULT_MAX_GLEANING, DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, @@ -86,7 +87,7 @@ from lightrag.operate import ( merge_nodes_and_edges, kg_query, naive_query, - _rebuild_knowledge_from_chunks, + rebuild_knowledge_from_chunks, ) from lightrag.constants import GRAPH_FIELD_SEP from lightrag.utils import ( @@ -711,7 +712,7 @@ class LightRAG: async def check_and_migrate_data(self): """Check if data migration is needed and perform migration if necessary""" - async with get_data_init_lock(enable_logging=True): + async with get_data_init_lock(): try: # Check if migration is needed: # 1. chunk_entity_relation_graph has entities and relations (count > 0) @@ -1605,6 +1606,7 @@ class LightRAG: "batchs": 0, # Total number of files to be processed "cur_batch": 0, # Number of files already processed "request_pending": False, # Clear any previous request + "cancellation_requested": False, # Initialize cancellation flag "latest_message": "", } ) @@ -1621,6 +1623,22 @@ class LightRAG: try: # Process documents until no more documents or requests while True: + # Check for cancellation request at the start of main loop + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + # Clear pending request + pipeline_status["request_pending"] = False + # Celar cancellation flag + pipeline_status["cancellation_requested"] = False + + log_message = "Pipeline cancelled by user" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Exit directly, skipping request_pending check + return + if not to_process_docs: log_message = "All enqueued documents have been processed" logger.info(log_message) @@ -1683,14 +1701,25 @@ class LightRAG: semaphore: asyncio.Semaphore, ) -> None: """Process single document""" + # Initialize variables at the start to prevent UnboundLocalError in error handling + file_path = "unknown_source" + current_file_number = 0 file_extraction_stage_ok = False + processing_start_time = int(time.time()) + first_stage_tasks = [] + entity_relation_task = None + async with semaphore: nonlocal processed_count - current_file_number = 0 # Initialize to prevent UnboundLocalError in error handling first_stage_tasks = [] entity_relation_task = None try: + # Check for cancellation before starting document processing + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled") + # Get file path from status document file_path = getattr( status_doc, "file_path", "unknown_source" @@ -1753,6 +1782,11 @@ class LightRAG: # Record processing start time processing_start_time = int(time.time()) + # Check for cancellation before entity extraction + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled") + # Process document in two stages # Stage 1: Process text chunks and docs (parallel execution) doc_status_task = asyncio.create_task( @@ -1807,16 +1841,29 @@ class LightRAG: file_extraction_stage_ok = True except Exception as e: - # Log error and update pipeline status - logger.error(traceback.format_exc()) - error_msg = f"Failed to extract document {current_file_number}/{total_files}: {file_path}" - logger.error(error_msg) - async with pipeline_status_lock: - pipeline_status["latest_message"] = error_msg - pipeline_status["history_messages"].append( - traceback.format_exc() - ) - pipeline_status["history_messages"].append(error_msg) + # Check if this is a user cancellation + if isinstance(e, PipelineCancelledException): + # User cancellation - log brief message only, no traceback + error_msg = f"User cancelled {current_file_number}/{total_files}: {file_path}" + logger.warning(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + error_msg + ) + else: + # Other exceptions - log with traceback + logger.error(traceback.format_exc()) + error_msg = f"Failed to extract document {current_file_number}/{total_files}: {file_path}" + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + traceback.format_exc() + ) + pipeline_status["history_messages"].append( + error_msg + ) # Cancel tasks that are not yet completed all_tasks = first_stage_tasks + ( @@ -1826,9 +1873,14 @@ class LightRAG: if task and not task.done(): task.cancel() - # Persistent llm cache + # Persistent llm cache with error handling if self.llm_response_cache: - await self.llm_response_cache.index_done_callback() + try: + await self.llm_response_cache.index_done_callback() + except Exception as persist_error: + logger.error( + f"Failed to persist LLM cache: {persist_error}" + ) # Record processing end time for failed case processing_end_time = int(time.time()) @@ -1858,6 +1910,15 @@ class LightRAG: # Concurrency is controlled by keyed lock for individual entities and relationships if file_extraction_stage_ok: try: + # Check for cancellation before merge + async with pipeline_status_lock: + if pipeline_status.get( + "cancellation_requested", False + ): + raise PipelineCancelledException( + "User cancelled" + ) + # Get chunk_results from entity_relation_task chunk_results = await entity_relation_task await merge_nodes_and_edges( @@ -1916,22 +1977,38 @@ class LightRAG: ) except Exception as e: - # Log error and update pipeline status - logger.error(traceback.format_exc()) - error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}" - logger.error(error_msg) - async with pipeline_status_lock: - pipeline_status["latest_message"] = error_msg - pipeline_status["history_messages"].append( - traceback.format_exc() - ) - pipeline_status["history_messages"].append( - error_msg - ) + # Check if this is a user cancellation + if isinstance(e, PipelineCancelledException): + # User cancellation - log brief message only, no traceback + error_msg = f"User cancelled during merge {current_file_number}/{total_files}: {file_path}" + logger.warning(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + error_msg + ) + else: + # Other exceptions - log with traceback + logger.error(traceback.format_exc()) + error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}" + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + traceback.format_exc() + ) + pipeline_status["history_messages"].append( + error_msg + ) - # Persistent llm cache + # Persistent llm cache with error handling if self.llm_response_cache: - await self.llm_response_cache.index_done_callback() + try: + await self.llm_response_cache.index_done_callback() + except Exception as persist_error: + logger.error( + f"Failed to persist LLM cache: {persist_error}" + ) # Record processing end time for failed case processing_end_time = int(time.time()) @@ -1972,7 +2049,19 @@ class LightRAG: ) # Wait for all document processing to complete - await asyncio.gather(*doc_tasks) + try: + await asyncio.gather(*doc_tasks) + except PipelineCancelledException: + # Cancel all remaining tasks + for task in doc_tasks: + if not task.done(): + task.cancel() + + # Wait for all tasks to complete cancellation + await asyncio.wait(doc_tasks, return_when=asyncio.ALL_COMPLETED) + + # Exit directly (document statuses already updated in process_document) + return # Check if there's a pending request to process more documents (with lock) has_pending_request = False @@ -2003,11 +2092,14 @@ class LightRAG: to_process_docs.update(pending_docs) finally: - log_message = "Enqueued document processing pipeline stoped" + log_message = "Enqueued document processing pipeline stopped" logger.info(log_message) - # Always reset busy status when done or if an exception occurs (with lock) + # Always reset busy status and cancellation flag when done or if an exception occurs (with lock) async with pipeline_status_lock: pipeline_status["busy"] = False + pipeline_status["cancellation_requested"] = ( + False # Always reset cancellation flag + ) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) @@ -3264,6 +3356,10 @@ class LightRAG: list(entities_to_delete) ) + # Delete from entity_chunks storage + if self.entity_chunks: + await self.entity_chunks.delete(list(entities_to_delete)) + async with pipeline_status_lock: log_message = f"Successfully deleted {len(entities_to_delete)} entities" logger.info(log_message) @@ -3293,6 +3389,14 @@ class LightRAG: list(relationships_to_delete) ) + # Delete from relation_chunks storage + if self.relation_chunks: + relation_storage_keys = [ + make_relation_chunk_key(src, tgt) + for src, tgt in relationships_to_delete + ] + await self.relation_chunks.delete(relation_storage_keys) + async with pipeline_status_lock: log_message = f"Successfully deleted {len(relationships_to_delete)} relations" logger.info(log_message) @@ -3309,7 +3413,7 @@ class LightRAG: # 8. Rebuild entities and relationships from remaining chunks if entities_to_rebuild or relationships_to_rebuild: try: - await _rebuild_knowledge_from_chunks( + await rebuild_knowledge_from_chunks( entities_to_rebuild=entities_to_rebuild, relationships_to_rebuild=relationships_to_rebuild, knowledge_graph_inst=self.chunk_entity_relation_graph, diff --git a/lightrag/operate.py b/lightrag/operate.py index 7be89c90..a1a0063f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,5 +1,6 @@ from __future__ import annotations from functools import partial +from pathlib import Path import asyncio import json @@ -7,6 +8,7 @@ import json_repair from typing import Any, AsyncIterator, overload, Literal from collections import Counter, defaultdict +from lightrag.exceptions import PipelineCancelledException from lightrag.utils import ( logger, compute_mdhash_id, @@ -69,7 +71,7 @@ from dotenv import load_dotenv # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file -load_dotenv(dotenv_path=".env", override=False) +load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env", override=False) def _truncate_entity_identifier( @@ -502,7 +504,7 @@ async def _handle_single_relationship_extraction( return None -async def _rebuild_knowledge_from_chunks( +async def rebuild_knowledge_from_chunks( entities_to_rebuild: dict[str, list[str]], relationships_to_rebuild: dict[tuple[str, str], list[str]], knowledge_graph_inst: BaseGraphStorage, @@ -710,6 +712,7 @@ async def _rebuild_knowledge_from_chunks( await _rebuild_single_relationship( knowledge_graph_inst=knowledge_graph_inst, relationships_vdb=relationships_vdb, + entities_vdb=entities_vdb, src=src, tgt=tgt, chunk_ids=chunk_ids, @@ -717,13 +720,14 @@ async def _rebuild_knowledge_from_chunks( llm_response_cache=llm_response_cache, global_config=global_config, relation_chunks_storage=relation_chunks_storage, + entity_chunks_storage=entity_chunks_storage, pipeline_status=pipeline_status, pipeline_status_lock=pipeline_status_lock, ) rebuilt_relationships_count += 1 except Exception as e: failed_relationships_count += 1 - status_message = f"Failed to rebuild `{src} - {tgt}`: {e}" + status_message = f"Failed to rebuild `{src}`~`{tgt}`: {e}" logger.info(status_message) # Per requirement, change to info if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: @@ -1292,6 +1296,7 @@ async def _rebuild_single_entity( async def _rebuild_single_relationship( knowledge_graph_inst: BaseGraphStorage, relationships_vdb: BaseVectorStorage, + entities_vdb: BaseVectorStorage, src: str, tgt: str, chunk_ids: list[str], @@ -1299,6 +1304,7 @@ async def _rebuild_single_relationship( llm_response_cache: BaseKVStorage, global_config: dict[str, str], relation_chunks_storage: BaseKVStorage | None = None, + entity_chunks_storage: BaseKVStorage | None = None, pipeline_status: dict | None = None, pipeline_status_lock=None, ) -> None: @@ -1428,6 +1434,10 @@ async def _rebuild_single_relationship( else: truncation_info = "" + # Sort src and tgt to ensure consistent ordering (smaller string first) + if src > tgt: + src, tgt = tgt, src + # Update relationship in graph storage updated_relationship_data = { **current_relationship, @@ -1442,6 +1452,63 @@ async def _rebuild_single_relationship( else current_relationship.get("file_path", "unknown_source"), "truncate": truncation_info, } + + # Ensure both endpoint nodes exist before writing the edge back + # (certain storage backends require pre-existing nodes). + node_description = ( + updated_relationship_data["description"] + if updated_relationship_data.get("description") + else current_relationship.get("description", "") + ) + node_source_id = updated_relationship_data.get("source_id", "") + node_file_path = updated_relationship_data.get("file_path", "unknown_source") + + for node_id in {src, tgt}: + if not (await knowledge_graph_inst.has_node(node_id)): + node_created_at = int(time.time()) + node_data = { + "entity_id": node_id, + "source_id": node_source_id, + "description": node_description, + "entity_type": "UNKNOWN", + "file_path": node_file_path, + "created_at": node_created_at, + "truncate": "", + } + await knowledge_graph_inst.upsert_node(node_id, node_data=node_data) + + # Update entity_chunks_storage for the newly created entity + if entity_chunks_storage is not None and limited_chunk_ids: + await entity_chunks_storage.upsert( + { + node_id: { + "chunk_ids": limited_chunk_ids, + "count": len(limited_chunk_ids), + } + } + ) + + # Update entity_vdb for the newly created entity + if entities_vdb is not None: + entity_vdb_id = compute_mdhash_id(node_id, prefix="ent-") + entity_content = f"{node_id}\n{node_description}" + vdb_data = { + entity_vdb_id: { + "content": entity_content, + "entity_name": node_id, + "source_id": node_source_id, + "entity_type": "UNKNOWN", + "file_path": node_file_path, + } + } + await safe_vdb_operation_with_exception( + operation=lambda payload=vdb_data: entities_vdb.upsert(payload), + operation_name="rebuild_added_entity_upsert", + entity_name=node_id, + max_retries=3, + retry_delay=0.1, + ) + await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data) # Update relationship in vector database @@ -1487,7 +1554,7 @@ async def _rebuild_single_relationship( raise # Re-raise exception # Log rebuild completion with truncation info - status_message = f"Rebuild `{src} - {tgt}` from {len(chunk_ids)} chunks" + status_message = f"Rebuild `{src}`~`{tgt}` from {len(chunk_ids)} chunks" if truncation_info: status_message += f" ({truncation_info})" # Add truncation info from apply_source_ids_limit if truncation occurred @@ -1639,6 +1706,12 @@ async def _merge_nodes_then_upsert( logger.error(f"Entity {entity_name} has no description") raise ValueError(f"Entity {entity_name} has no description") + # Check for cancellation before LLM summary + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled during entity summary") + # 8. Get summary description an LLM usage status description, llm_was_used = await _handle_entity_relation_summary( "Entity", @@ -1791,6 +1864,7 @@ async def _merge_edges_then_upsert( llm_response_cache: BaseKVStorage | None = None, added_entities: list = None, # New parameter to track entities added during edge processing relation_chunks_storage: BaseKVStorage | None = None, + entity_chunks_storage: BaseKVStorage | None = None, ): if src_id == tgt_id: return None @@ -1959,6 +2033,14 @@ async def _merge_edges_then_upsert( logger.error(f"Relation {src_id}~{tgt_id} has no description") raise ValueError(f"Relation {src_id}~{tgt_id} has no description") + # Check for cancellation before LLM summary + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during relation summary" + ) + # 8. Get summary description an LLM usage status description, llm_was_used = await _handle_entity_relation_summary( "Relation", @@ -2065,6 +2147,10 @@ async def _merge_edges_then_upsert( else: logger.debug(status_message) + # Sort src_id and tgt_id to ensure consistent ordering (smaller string first) + if src_id > tgt_id: + src_id, tgt_id = tgt_id, src_id + # 11. Update both graph and vector db for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): @@ -2080,6 +2166,19 @@ async def _merge_edges_then_upsert( } await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) + # Update entity_chunks_storage for the newly created entity + if entity_chunks_storage is not None: + chunk_ids = [chunk_id for chunk_id in full_source_ids if chunk_id] + if chunk_ids: + await entity_chunks_storage.upsert( + { + need_insert_id: { + "chunk_ids": chunk_ids, + "count": len(chunk_ids), + } + } + ) + if entity_vdb is not None: entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-") entity_content = f"{need_insert_id}\n{description}" @@ -2216,6 +2315,12 @@ async def merge_nodes_and_edges( file_path: File path for logging """ + # Check for cancellation at the start of merge + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled during merge phase") + # Collect all nodes and edges from all chunks all_nodes = defaultdict(list) all_edges = defaultdict(list) @@ -2252,6 +2357,14 @@ async def merge_nodes_and_edges( async def _locked_process_entity_name(entity_name, entities): async with semaphore: + # Check for cancellation before processing entity + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during entity merge" + ) + workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" async with get_storage_keyed_lock( @@ -2274,9 +2387,7 @@ async def merge_nodes_and_edges( return entity_data except Exception as e: - error_msg = ( - f"Critical error in entity processing for `{entity_name}`: {e}" - ) + error_msg = f"Error processing entity `{entity_name}`: {e}" logger.error(error_msg) # Try to update pipeline status, but don't let status update failure affect main exception @@ -2312,36 +2423,32 @@ async def merge_nodes_and_edges( entity_tasks, return_when=asyncio.FIRST_EXCEPTION ) - # Check if any task raised an exception and ensure all exceptions are retrieved first_exception = None - successful_results = [] + processed_entities = [] for task in done: try: - exception = task.exception() - if exception is not None: - if first_exception is None: - first_exception = exception - else: - successful_results.append(task.result()) - except Exception as e: + result = task.result() + except BaseException as e: if first_exception is None: first_exception = e + else: + processed_entities.append(result) + + if pending: + for task in pending: + task.cancel() + pending_results = await asyncio.gather(*pending, return_exceptions=True) + for result in pending_results: + if isinstance(result, BaseException): + if first_exception is None: + first_exception = result + else: + processed_entities.append(result) - # If any task failed, cancel all pending tasks and raise the first exception if first_exception is not None: - # Cancel all pending tasks - for pending_task in pending: - pending_task.cancel() - # Wait for cancellation to complete - if pending: - await asyncio.wait(pending) - # Re-raise the first exception to notify the caller raise first_exception - # If all tasks completed successfully, collect results - processed_entities = [task.result() for task in entity_tasks] - # ===== Phase 2: Process all relationships concurrently ===== log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" logger.info(log_message) @@ -2351,6 +2458,14 @@ async def merge_nodes_and_edges( async def _locked_process_edges(edge_key, edges): async with semaphore: + # Check for cancellation before processing edges + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during relation merge" + ) + workspace = global_config.get("workspace", "") namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" sorted_edge_key = sorted([edge_key[0], edge_key[1]]) @@ -2377,6 +2492,7 @@ async def merge_nodes_and_edges( llm_response_cache, added_entities, # Pass list to collect added entities relation_chunks_storage, + entity_chunks_storage, # Add entity_chunks_storage parameter ) if edge_data is None: @@ -2385,7 +2501,7 @@ async def merge_nodes_and_edges( return edge_data, added_entities except Exception as e: - error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}" + error_msg = f"Error processing relation `{sorted_edge_key}`: {e}" logger.error(error_msg) # Try to update pipeline status, but don't let status update failure affect main exception @@ -2423,40 +2539,36 @@ async def merge_nodes_and_edges( edge_tasks, return_when=asyncio.FIRST_EXCEPTION ) - # Check if any task raised an exception and ensure all exceptions are retrieved first_exception = None - successful_results = [] for task in done: try: - exception = task.exception() - if exception is not None: - if first_exception is None: - first_exception = exception - else: - successful_results.append(task.result()) - except Exception as e: + edge_data, added_entities = task.result() + except BaseException as e: if first_exception is None: first_exception = e + else: + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) + + if pending: + for task in pending: + task.cancel() + pending_results = await asyncio.gather(*pending, return_exceptions=True) + for result in pending_results: + if isinstance(result, BaseException): + if first_exception is None: + first_exception = result + else: + edge_data, added_entities = result + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) - # If any task failed, cancel all pending tasks and raise the first exception if first_exception is not None: - # Cancel all pending tasks - for pending_task in pending: - pending_task.cancel() - # Wait for cancellation to complete - if pending: - await asyncio.wait(pending) - # Re-raise the first exception to notify the caller raise first_exception - # If all tasks completed successfully, collect results - for task in edge_tasks: - edge_data, added_entities = task.result() - if edge_data is not None: - processed_edges.append(edge_data) - all_added_entities.extend(added_entities) - # ===== Phase 3: Update full_entities and full_relations storage ===== if full_entities_storage and full_relations_storage and doc_id: try: @@ -2537,6 +2649,14 @@ async def extract_entities( llm_response_cache: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None, ) -> list: + # Check for cancellation at the start of entity extraction + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during entity extraction" + ) + use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -2704,6 +2824,14 @@ async def extract_entities( async def _process_with_semaphore(chunk): async with semaphore: + # Check for cancellation before processing chunk + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during chunk processing" + ) + try: return await _process_single_content(chunk) except Exception as e: diff --git a/lightrag_webui/src/api/lightrag.ts b/lightrag_webui/src/api/lightrag.ts index cf9a7e7a..7a268642 100644 --- a/lightrag_webui/src/api/lightrag.ts +++ b/lightrag_webui/src/api/lightrag.ts @@ -167,7 +167,7 @@ export type DeleteDocResponse = { doc_id: string } -export type DocStatus = 'pending' | 'processing' | 'multimodal_processed' | 'processed' | 'failed' +export type DocStatus = 'pending' | 'processing' | 'preprocessed' | 'processed' | 'failed' export type DocStatusResponse = { id: string @@ -242,6 +242,7 @@ export type PipelineStatusResponse = { batchs: number cur_batch: number request_pending: boolean + cancellation_requested?: boolean latest_message: string history_messages?: string[] update_status?: Record @@ -691,6 +692,14 @@ export const getPipelineStatus = async (): Promise => { return response.data } +export const cancelPipeline = async (): Promise<{ + status: 'cancellation_requested' | 'not_busy' + message: string +}> => { + const response = await axiosInstance.post('/documents/cancel_pipeline') + return response.data +} + export const loginToServer = async (username: string, password: string): Promise => { const formData = new FormData(); formData.append('username', username); diff --git a/lightrag_webui/src/components/documents/PipelineStatusDialog.tsx b/lightrag_webui/src/components/documents/PipelineStatusDialog.tsx index 2a2c5d93..c368d69c 100644 --- a/lightrag_webui/src/components/documents/PipelineStatusDialog.tsx +++ b/lightrag_webui/src/components/documents/PipelineStatusDialog.tsx @@ -11,7 +11,7 @@ import { DialogDescription } from '@/components/ui/Dialog' import Button from '@/components/ui/Button' -import { getPipelineStatus, PipelineStatusResponse } from '@/api/lightrag' +import { getPipelineStatus, cancelPipeline, PipelineStatusResponse } from '@/api/lightrag' import { errorMessage } from '@/lib/utils' import { cn } from '@/lib/utils' @@ -30,6 +30,7 @@ export default function PipelineStatusDialog({ const [status, setStatus] = useState(null) const [position, setPosition] = useState('center') const [isUserScrolled, setIsUserScrolled] = useState(false) + const [showCancelConfirm, setShowCancelConfirm] = useState(false) const historyRef = useRef(null) // Reset position when dialog opens @@ -37,6 +38,9 @@ export default function PipelineStatusDialog({ if (open) { setPosition('center') setIsUserScrolled(false) + } else { + // Reset confirmation dialog state when main dialog closes + setShowCancelConfirm(false) } }, [open]) @@ -81,6 +85,24 @@ export default function PipelineStatusDialog({ return () => clearInterval(interval) }, [open, t]) + // Handle cancel pipeline confirmation + const handleConfirmCancel = async () => { + setShowCancelConfirm(false) + try { + const result = await cancelPipeline() + if (result.status === 'cancellation_requested') { + toast.success(t('documentPanel.pipelineStatus.cancelSuccess')) + } else if (result.status === 'not_busy') { + toast.info(t('documentPanel.pipelineStatus.cancelNotBusy')) + } + } catch (err) { + toast.error(t('documentPanel.pipelineStatus.cancelFailed', { error: errorMessage(err) })) + } + } + + // Determine if cancel button should be enabled + const canCancel = status?.busy === true && !status?.cancellation_requested + return ( - {/* Pipeline Status */} -
-
-
{t('documentPanel.pipelineStatus.busy')}:
-
-
-
-
{t('documentPanel.pipelineStatus.requestPending')}:
-
+ {/* Pipeline Status - with cancel button */} +
+ {/* Left side: Status indicators */} +
+
+
{t('documentPanel.pipelineStatus.busy')}:
+
+
+
+
{t('documentPanel.pipelineStatus.requestPending')}:
+
+
+ {/* Only show cancellation status when it's requested */} + {status?.cancellation_requested && ( +
+
{t('documentPanel.pipelineStatus.cancellationRequested')}:
+
+
+ )}
+ + {/* Right side: Cancel button - only show when pipeline is busy */} + {status?.busy && ( + + )}
{/* Job Information */} @@ -172,31 +221,49 @@ export default function PipelineStatusDialog({
- {/* Latest Message */} -
-
{t('documentPanel.pipelineStatus.latestMessage')}:
-
- {status?.latest_message || '-'} -
-
- {/* History Messages */}
-
{t('documentPanel.pipelineStatus.historyMessages')}:
+
{t('documentPanel.pipelineStatus.pipelineMessages')}:
{status?.history_messages?.length ? ( status.history_messages.map((msg, idx) => ( -
{msg}
+
{msg}
)) ) : '-'}
+ + {/* Cancel Confirmation Dialog */} + + + + {t('documentPanel.pipelineStatus.cancelConfirmTitle')} + + {t('documentPanel.pipelineStatus.cancelConfirmDescription')} + + +
+ + +
+
+
) } diff --git a/lightrag_webui/src/features/DocumentManager.tsx b/lightrag_webui/src/features/DocumentManager.tsx index 530e98c7..406faf2b 100644 --- a/lightrag_webui/src/features/DocumentManager.tsx +++ b/lightrag_webui/src/features/DocumentManager.tsx @@ -21,7 +21,6 @@ import PaginationControls from '@/components/ui/PaginationControls' import { scanNewDocuments, - reprocessFailedDocuments, getDocumentsPaginated, DocsStatusesResponse, DocStatus, @@ -52,7 +51,7 @@ const getCountValue = (counts: Record, ...keys: string[]): numbe const hasActiveDocumentsStatus = (counts: Record): boolean => getCountValue(counts, 'PROCESSING', 'processing') > 0 || getCountValue(counts, 'PENDING', 'pending') > 0 || - getCountValue(counts, 'PREPROCESSED', 'preprocessed', 'multimodal_processed') > 0 + getCountValue(counts, 'PREPROCESSED', 'preprocessed') > 0 const getDisplayFileName = (doc: DocStatusResponse, maxLength: number = 20): string => { // Check if file_path exists and is a non-empty string @@ -257,7 +256,7 @@ export default function DocumentManager() { const [pageByStatus, setPageByStatus] = useState>({ all: 1, processed: 1, - multimodal_processed: 1, + preprocessed: 1, processing: 1, pending: 1, failed: 1, @@ -324,7 +323,7 @@ export default function DocumentManager() { setPageByStatus({ all: 1, processed: 1, - 'multimodal_processed': 1, + preprocessed: 1, processing: 1, pending: 1, failed: 1, @@ -471,8 +470,8 @@ export default function DocumentManager() { const processedCount = getCountValue(statusCounts, 'PROCESSED', 'processed') || documentCounts.processed || 0; const preprocessedCount = - getCountValue(statusCounts, 'PREPROCESSED', 'preprocessed', 'multimodal_processed') || - documentCounts.multimodal_processed || + getCountValue(statusCounts, 'PREPROCESSED', 'preprocessed') || + documentCounts.preprocessed || 0; const processingCount = getCountValue(statusCounts, 'PROCESSING', 'processing') || documentCounts.processing || 0; const pendingCount = getCountValue(statusCounts, 'PENDING', 'pending') || documentCounts.pending || 0; @@ -481,7 +480,7 @@ export default function DocumentManager() { // Store previous status counts const prevStatusCounts = useRef({ processed: 0, - multimodal_processed: 0, + preprocessed: 0, processing: 0, pending: 0, failed: 0 @@ -572,7 +571,7 @@ export default function DocumentManager() { const legacyDocs: DocsStatusesResponse = { statuses: { processed: response.documents.filter((doc: DocStatusResponse) => doc.status === 'processed'), - multimodal_processed: response.documents.filter((doc: DocStatusResponse) => doc.status === 'multimodal_processed'), + preprocessed: response.documents.filter((doc: DocStatusResponse) => doc.status === 'preprocessed'), processing: response.documents.filter((doc: DocStatusResponse) => doc.status === 'processing'), pending: response.documents.filter((doc: DocStatusResponse) => doc.status === 'pending'), failed: response.documents.filter((doc: DocStatusResponse) => doc.status === 'failed') @@ -868,42 +867,6 @@ export default function DocumentManager() { } }, [t, startPollingInterval, currentTab, health, statusCounts]) - const retryFailedDocuments = useCallback(async () => { - try { - // Check if component is still mounted before starting the request - if (!isMountedRef.current) return; - - const { status, message, track_id: _track_id } = await reprocessFailedDocuments(); // eslint-disable-line @typescript-eslint/no-unused-vars - - // Check again if component is still mounted after the request completes - if (!isMountedRef.current) return; - - // Note: _track_id is available for future use (e.g., progress tracking) - toast.message(message || status); - - // Reset health check timer with 1 second delay to avoid race condition - useBackendState.getState().resetHealthCheckTimerDelayed(1000); - - // Start fast refresh with 2-second interval immediately after retry - startPollingInterval(2000); - - // Set recovery timer to restore normal polling interval after 15 seconds - setTimeout(() => { - if (isMountedRef.current && currentTab === 'documents' && health) { - // Restore intelligent polling interval based on document status - const hasActiveDocuments = hasActiveDocumentsStatus(statusCounts); - const normalInterval = hasActiveDocuments ? 5000 : 30000; - startPollingInterval(normalInterval); - } - }, 15000); // Restore after 15 seconds - } catch (err) { - // Only show error if component is still mounted - if (isMountedRef.current) { - toast.error(errorMessage(err)); - } - } - }, [startPollingInterval, currentTab, health, statusCounts]) - // Handle page size change - update state and save to store const handlePageSizeChange = useCallback((newPageSize: number) => { if (newPageSize === pagination.page_size) return; @@ -915,7 +878,7 @@ export default function DocumentManager() { setPageByStatus({ all: 1, processed: 1, - multimodal_processed: 1, + preprocessed: 1, processing: 1, pending: 1, failed: 1, @@ -956,7 +919,7 @@ export default function DocumentManager() { const legacyDocs: DocsStatusesResponse = { statuses: { processed: response.documents.filter(doc => doc.status === 'processed'), - multimodal_processed: response.documents.filter(doc => doc.status === 'multimodal_processed'), + preprocessed: response.documents.filter(doc => doc.status === 'preprocessed'), processing: response.documents.filter(doc => doc.status === 'processing'), pending: response.documents.filter(doc => doc.status === 'pending'), failed: response.documents.filter(doc => doc.status === 'failed') @@ -1032,7 +995,7 @@ export default function DocumentManager() { // Get new status counts const newStatusCounts = { processed: docs?.statuses?.processed?.length || 0, - multimodal_processed: docs?.statuses?.multimodal_processed?.length || 0, + preprocessed: docs?.statuses?.preprocessed?.length || 0, processing: docs?.statuses?.processing?.length || 0, pending: docs?.statuses?.pending?.length || 0, failed: docs?.statuses?.failed?.length || 0 @@ -1166,16 +1129,6 @@ export default function DocumentManager() { > {t('documentPanel.documentManager.scanButton')} -