From a9ec15e669f45d2f9bb5c438a9a0ad96d64570b0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 25 Oct 2025 03:06:45 +0800 Subject: [PATCH] Resolve lock leakage issue during user cancellation handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Change default log level to INFO • Force enable error logging output • Add lock cleanup rollback protection • Handle LLM cache persistence errors • Fix async task exception handling --- lightrag/kg/shared_storage.py | 260 ++++++++++++++++++++++++++++++---- lightrag/lightrag.py | 18 ++- lightrag/operate.py | 86 +++++------ 3 files changed, 285 insertions(+), 79 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index e20dce52..26fc3832 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -12,7 +12,7 @@ from lightrag.exceptions import PipelineNotInitializedError # Define a direct print function for critical logs that must be visible in all processes -def direct_log(message, enable_output: bool = False, level: str = "DEBUG"): +def direct_log(message, enable_output: bool = False, level: str = "INFO"): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. @@ -44,7 +44,6 @@ def direct_log(message, enable_output: bool = False, level: str = "DEBUG"): } message_level = level_mapping.get(level.upper(), logging.DEBUG) - # print(f"Diret_log: {level.upper()} {message_level} ? {current_level}", file=sys.stderr, flush=True) if message_level >= current_level: print(f"{level}: {message}", file=sys.stderr, flush=True) @@ -168,7 +167,7 @@ class UnifiedLock(Generic[T]): direct_log( f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) raise @@ -199,7 +198,7 @@ class UnifiedLock(Generic[T]): direct_log( f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) # If main lock release failed but async lock hasn't been released, try to release it @@ -223,7 +222,7 @@ class UnifiedLock(Generic[T]): direct_log( f"== Lock == Process {self._pid}: Failed to release async lock after main lock failure: {inner_e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) raise @@ -247,7 +246,7 @@ class UnifiedLock(Generic[T]): direct_log( f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) raise @@ -269,7 +268,7 @@ class UnifiedLock(Generic[T]): direct_log( f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", level="ERROR", - enable_output=self._enable_logging, + enable_output=True, ) raise @@ -401,7 +400,7 @@ def _perform_lock_cleanup( direct_log( f"== {lock_type} Lock == Cleanup failed: {e}", level="ERROR", - enable_output=False, + enable_output=True, ) return 0, earliest_cleanup_time, last_cleanup_time @@ -689,7 +688,7 @@ class KeyedUnifiedLock: direct_log( f"Error during multiprocess lock cleanup: {e}", level="ERROR", - enable_output=False, + enable_output=True, ) # 2. Cleanup async locks using generic function @@ -718,7 +717,7 @@ class KeyedUnifiedLock: direct_log( f"Error during async lock cleanup: {e}", level="ERROR", - enable_output=False, + enable_output=True, ) # 3. Get current status after cleanup @@ -772,7 +771,7 @@ class KeyedUnifiedLock: direct_log( f"Error getting keyed lock status: {e}", level="ERROR", - enable_output=False, + enable_output=True, ) return status @@ -797,32 +796,239 @@ class _KeyedLockContext: if enable_logging is not None else parent._default_enable_logging ) - self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__ + self._ul: Optional[List[Dict[str, Any]]] = None # set in __aenter__ # ----- enter ----- async def __aenter__(self): if self._ul is not None: raise RuntimeError("KeyedUnifiedLock already acquired in current context") - # acquire locks for all keys in the namespace self._ul = [] - for key in self._keys: - lock = self._parent._get_lock_for_key( - self._namespace, key, enable_logging=self._enable_logging - ) - await lock.__aenter__() - inc_debug_n_locks_acquired() - self._ul.append(lock) - return self + + try: + # Acquire locks for all keys in the namespace + for key in self._keys: + lock = None + entry = None + + try: + # 1. Get lock object (reference count is incremented here) + lock = self._parent._get_lock_for_key( + self._namespace, key, enable_logging=self._enable_logging + ) + + # 2. Immediately create and add entry to list (critical for rollback to work) + entry = { + "key": key, + "lock": lock, + "entered": False, + "debug_inc": False, + "ref_incremented": True, # Mark that reference count has been incremented + } + self._ul.append( + entry + ) # Add immediately after _get_lock_for_key for rollback to work + + # 3. Try to acquire the lock + # Use try-finally to ensure state is updated atomically + lock_acquired = False + try: + await lock.__aenter__() + lock_acquired = True # Lock successfully acquired + finally: + if lock_acquired: + entry["entered"] = True + inc_debug_n_locks_acquired() + entry["debug_inc"] = True + + except asyncio.CancelledError: + # Lock acquisition was cancelled + # The finally block above ensures entry["entered"] is correct + direct_log( + f"Lock acquisition cancelled for key {key}", + level="WARNING", + enable_output=self._enable_logging, + ) + raise + except Exception as e: + # Other exceptions, log and re-raise + direct_log( + f"Lock acquisition failed for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + raise + + return self + + except BaseException: + # Critical: if any exception occurs (including CancelledError) during lock acquisition, + # we must rollback all already acquired locks to prevent lock leaks + # Use shield to ensure rollback completes + await asyncio.shield(self._rollback_acquired_locks()) + raise + + async def _rollback_acquired_locks(self): + """Rollback all acquired locks in case of exception during __aenter__""" + if not self._ul: + return + + async def rollback_single_entry(entry): + """Rollback a single lock acquisition""" + key = entry["key"] + lock = entry["lock"] + debug_inc = entry["debug_inc"] + entered = entry["entered"] + ref_incremented = entry.get( + "ref_incremented", True + ) # Default to True for safety + + errors = [] + + # 1. If lock was acquired, release it + if entered: + try: + await lock.__aexit__(None, None, None) + except Exception as e: + errors.append(("lock_exit", e)) + direct_log( + f"Lock rollback error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 2. Release reference count (if it was incremented) + if ref_incremented: + try: + self._parent._release_lock_for_key(self._namespace, key) + except Exception as e: + errors.append(("ref_release", e)) + direct_log( + f"Lock rollback reference release error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 3. Decrement debug counter + if debug_inc: + try: + dec_debug_n_locks_acquired() + except Exception as e: + errors.append(("debug_dec", e)) + direct_log( + f"Lock rollback counter decrementing error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + return errors + + # Release already acquired locks in reverse order + for entry in reversed(self._ul): + # Use shield to protect each lock's rollback + try: + await asyncio.shield(rollback_single_entry(entry)) + except Exception as e: + # Log but continue rolling back other locks + direct_log( + f"Lock rollback unexpected error for {entry['key']}: {e}", + level="ERROR", + enable_output=True, + ) + + self._ul = None # ----- exit ----- async def __aexit__(self, exc_type, exc, tb): - # The UnifiedLock takes care of proper release order - for ul, key in zip(reversed(self._ul), reversed(self._keys)): - await ul.__aexit__(exc_type, exc, tb) - self._parent._release_lock_for_key(self._namespace, key) - dec_debug_n_locks_acquired() - self._ul = None + if self._ul is None: + return + + async def release_all_locks(): + """Release all locks with comprehensive error handling, protected from cancellation""" + + async def release_single_entry(entry, exc_type, exc, tb): + """Release a single lock with full protection""" + key = entry["key"] + lock = entry["lock"] + debug_inc = entry["debug_inc"] + entered = entry["entered"] + + errors = [] + + # 1. Release the lock + if entered: + try: + await lock.__aexit__(exc_type, exc, tb) + except Exception as e: + errors.append(("lock_exit", e)) + direct_log( + f"Lock release error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 2. Release reference count + try: + self._parent._release_lock_for_key(self._namespace, key) + except Exception as e: + errors.append(("ref_release", e)) + direct_log( + f"Lock release reference error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 3. Decrement debug counter + if debug_inc: + try: + dec_debug_n_locks_acquired() + except Exception as e: + errors.append(("debug_dec", e)) + direct_log( + f"Lock release counter decrementing error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + return errors + + all_errors = [] + + # Release locks in reverse order + # This entire loop is protected by the outer shield + for entry in reversed(self._ul): + try: + errors = await release_single_entry(entry, exc_type, exc, tb) + for error_type, error in errors: + all_errors.append((entry["key"], error_type, error)) + except Exception as e: + all_errors.append((entry["key"], "unexpected", e)) + direct_log( + f"Lock release unexpected error for {entry['key']}: {e}", + level="ERROR", + enable_output=True, + ) + + return all_errors + + # CRITICAL: Protect the entire release process with shield + # This ensures that even if cancellation occurs, all locks are released + try: + all_errors = await asyncio.shield(release_all_locks()) + except Exception as e: + direct_log( + f"Critical error during __aexit__ cleanup: {e}", + level="ERROR", + enable_output=True, + ) + all_errors = [] + finally: + # Always clear the lock list, even if shield was cancelled + self._ul = None + + # If there were release errors and no other exception, raise the first release error + if all_errors and exc_type is None: + raise all_errors[0][2] # (key, error_type, error) def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index dff637f6..24ea0209 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1871,9 +1871,14 @@ class LightRAG: if task and not task.done(): task.cancel() - # Persistent llm cache + # Persistent llm cache with error handling if self.llm_response_cache: - await self.llm_response_cache.index_done_callback() + try: + await self.llm_response_cache.index_done_callback() + except Exception as persist_error: + logger.error( + f"Failed to persist LLM cache: {persist_error}" + ) # Record processing end time for failed case processing_end_time = int(time.time()) @@ -1994,9 +1999,14 @@ class LightRAG: error_msg ) - # Persistent llm cache + # Persistent llm cache with error handling if self.llm_response_cache: - await self.llm_response_cache.index_done_callback() + try: + await self.llm_response_cache.index_done_callback() + except Exception as persist_error: + logger.error( + f"Failed to persist LLM cache: {persist_error}" + ) # Record processing end time for failed case processing_end_time = int(time.time()) diff --git a/lightrag/operate.py b/lightrag/operate.py index 496c000c..36c8251d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2302,9 +2302,7 @@ async def merge_nodes_and_edges( return entity_data except Exception as e: - error_msg = ( - f"Critical error in entity processing for `{entity_name}`: {e}" - ) + error_msg = f"Error processing entity `{entity_name}`: {e}" logger.error(error_msg) # Try to update pipeline status, but don't let status update failure affect main exception @@ -2340,36 +2338,32 @@ async def merge_nodes_and_edges( entity_tasks, return_when=asyncio.FIRST_EXCEPTION ) - # Check if any task raised an exception and ensure all exceptions are retrieved first_exception = None - successful_results = [] + processed_entities = [] for task in done: try: - exception = task.exception() - if exception is not None: - if first_exception is None: - first_exception = exception - else: - successful_results.append(task.result()) - except Exception as e: + result = task.result() + except BaseException as e: if first_exception is None: first_exception = e + else: + processed_entities.append(result) + + if pending: + for task in pending: + task.cancel() + pending_results = await asyncio.gather(*pending, return_exceptions=True) + for result in pending_results: + if isinstance(result, BaseException): + if first_exception is None: + first_exception = result + else: + processed_entities.append(result) - # If any task failed, cancel all pending tasks and raise the first exception if first_exception is not None: - # Cancel all pending tasks - for pending_task in pending: - pending_task.cancel() - # Wait for cancellation to complete - if pending: - await asyncio.wait(pending) - # Re-raise the first exception to notify the caller raise first_exception - # If all tasks completed successfully, collect results - processed_entities = [task.result() for task in entity_tasks] - # ===== Phase 2: Process all relationships concurrently ===== log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" logger.info(log_message) @@ -2421,7 +2415,7 @@ async def merge_nodes_and_edges( return edge_data, added_entities except Exception as e: - error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}" + error_msg = f"Error processing relation `{sorted_edge_key}`: {e}" logger.error(error_msg) # Try to update pipeline status, but don't let status update failure affect main exception @@ -2459,40 +2453,36 @@ async def merge_nodes_and_edges( edge_tasks, return_when=asyncio.FIRST_EXCEPTION ) - # Check if any task raised an exception and ensure all exceptions are retrieved first_exception = None - successful_results = [] for task in done: try: - exception = task.exception() - if exception is not None: - if first_exception is None: - first_exception = exception - else: - successful_results.append(task.result()) - except Exception as e: + edge_data, added_entities = task.result() + except BaseException as e: if first_exception is None: first_exception = e + else: + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) + + if pending: + for task in pending: + task.cancel() + pending_results = await asyncio.gather(*pending, return_exceptions=True) + for result in pending_results: + if isinstance(result, BaseException): + if first_exception is None: + first_exception = result + else: + edge_data, added_entities = result + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) - # If any task failed, cancel all pending tasks and raise the first exception if first_exception is not None: - # Cancel all pending tasks - for pending_task in pending: - pending_task.cancel() - # Wait for cancellation to complete - if pending: - await asyncio.wait(pending) - # Re-raise the first exception to notify the caller raise first_exception - # If all tasks completed successfully, collect results - for task in edge_tasks: - edge_data, added_entities = task.result() - if edge_data is not None: - processed_edges.append(edge_data) - all_added_entities.extend(added_entities) - # ===== Phase 3: Update full_entities and full_relations storage ===== if full_entities_storage and full_relations_storage and doc_id: try: