Resolve lock leakage issue during user cancellation handling

• Change default log level to INFO
• Force enable error logging output
• Add lock cleanup rollback protection
• Handle LLM cache persistence errors
• Fix async task exception handling
This commit is contained in:
yangdx 2025-10-25 03:06:45 +08:00
parent 77336e50b6
commit a9ec15e669
3 changed files with 285 additions and 79 deletions

View file

@ -12,7 +12,7 @@ from lightrag.exceptions import PipelineNotInitializedError
# Define a direct print function for critical logs that must be visible in all processes
def direct_log(message, enable_output: bool = False, level: str = "DEBUG"):
def direct_log(message, enable_output: bool = False, level: str = "INFO"):
"""
Log a message directly to stderr to ensure visibility in all processes,
including the Gunicorn master process.
@ -44,7 +44,6 @@ def direct_log(message, enable_output: bool = False, level: str = "DEBUG"):
}
message_level = level_mapping.get(level.upper(), logging.DEBUG)
# print(f"Diret_log: {level.upper()} {message_level} ? {current_level}", file=sys.stderr, flush=True)
if message_level >= current_level:
print(f"{level}: {message}", file=sys.stderr, flush=True)
@ -168,7 +167,7 @@ class UnifiedLock(Generic[T]):
direct_log(
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
level="ERROR",
enable_output=self._enable_logging,
enable_output=True,
)
raise
@ -199,7 +198,7 @@ class UnifiedLock(Generic[T]):
direct_log(
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
level="ERROR",
enable_output=self._enable_logging,
enable_output=True,
)
# If main lock release failed but async lock hasn't been released, try to release it
@ -223,7 +222,7 @@ class UnifiedLock(Generic[T]):
direct_log(
f"== Lock == Process {self._pid}: Failed to release async lock after main lock failure: {inner_e}",
level="ERROR",
enable_output=self._enable_logging,
enable_output=True,
)
raise
@ -247,7 +246,7 @@ class UnifiedLock(Generic[T]):
direct_log(
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
level="ERROR",
enable_output=self._enable_logging,
enable_output=True,
)
raise
@ -269,7 +268,7 @@ class UnifiedLock(Generic[T]):
direct_log(
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
level="ERROR",
enable_output=self._enable_logging,
enable_output=True,
)
raise
@ -401,7 +400,7 @@ def _perform_lock_cleanup(
direct_log(
f"== {lock_type} Lock == Cleanup failed: {e}",
level="ERROR",
enable_output=False,
enable_output=True,
)
return 0, earliest_cleanup_time, last_cleanup_time
@ -689,7 +688,7 @@ class KeyedUnifiedLock:
direct_log(
f"Error during multiprocess lock cleanup: {e}",
level="ERROR",
enable_output=False,
enable_output=True,
)
# 2. Cleanup async locks using generic function
@ -718,7 +717,7 @@ class KeyedUnifiedLock:
direct_log(
f"Error during async lock cleanup: {e}",
level="ERROR",
enable_output=False,
enable_output=True,
)
# 3. Get current status after cleanup
@ -772,7 +771,7 @@ class KeyedUnifiedLock:
direct_log(
f"Error getting keyed lock status: {e}",
level="ERROR",
enable_output=False,
enable_output=True,
)
return status
@ -797,32 +796,239 @@ class _KeyedLockContext:
if enable_logging is not None
else parent._default_enable_logging
)
self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__
self._ul: Optional[List[Dict[str, Any]]] = None # set in __aenter__
# ----- enter -----
async def __aenter__(self):
if self._ul is not None:
raise RuntimeError("KeyedUnifiedLock already acquired in current context")
# acquire locks for all keys in the namespace
self._ul = []
for key in self._keys:
lock = self._parent._get_lock_for_key(
self._namespace, key, enable_logging=self._enable_logging
)
await lock.__aenter__()
inc_debug_n_locks_acquired()
self._ul.append(lock)
return self
try:
# Acquire locks for all keys in the namespace
for key in self._keys:
lock = None
entry = None
try:
# 1. Get lock object (reference count is incremented here)
lock = self._parent._get_lock_for_key(
self._namespace, key, enable_logging=self._enable_logging
)
# 2. Immediately create and add entry to list (critical for rollback to work)
entry = {
"key": key,
"lock": lock,
"entered": False,
"debug_inc": False,
"ref_incremented": True, # Mark that reference count has been incremented
}
self._ul.append(
entry
) # Add immediately after _get_lock_for_key for rollback to work
# 3. Try to acquire the lock
# Use try-finally to ensure state is updated atomically
lock_acquired = False
try:
await lock.__aenter__()
lock_acquired = True # Lock successfully acquired
finally:
if lock_acquired:
entry["entered"] = True
inc_debug_n_locks_acquired()
entry["debug_inc"] = True
except asyncio.CancelledError:
# Lock acquisition was cancelled
# The finally block above ensures entry["entered"] is correct
direct_log(
f"Lock acquisition cancelled for key {key}",
level="WARNING",
enable_output=self._enable_logging,
)
raise
except Exception as e:
# Other exceptions, log and re-raise
direct_log(
f"Lock acquisition failed for key {key}: {e}",
level="ERROR",
enable_output=True,
)
raise
return self
except BaseException:
# Critical: if any exception occurs (including CancelledError) during lock acquisition,
# we must rollback all already acquired locks to prevent lock leaks
# Use shield to ensure rollback completes
await asyncio.shield(self._rollback_acquired_locks())
raise
async def _rollback_acquired_locks(self):
"""Rollback all acquired locks in case of exception during __aenter__"""
if not self._ul:
return
async def rollback_single_entry(entry):
"""Rollback a single lock acquisition"""
key = entry["key"]
lock = entry["lock"]
debug_inc = entry["debug_inc"]
entered = entry["entered"]
ref_incremented = entry.get(
"ref_incremented", True
) # Default to True for safety
errors = []
# 1. If lock was acquired, release it
if entered:
try:
await lock.__aexit__(None, None, None)
except Exception as e:
errors.append(("lock_exit", e))
direct_log(
f"Lock rollback error for key {key}: {e}",
level="ERROR",
enable_output=True,
)
# 2. Release reference count (if it was incremented)
if ref_incremented:
try:
self._parent._release_lock_for_key(self._namespace, key)
except Exception as e:
errors.append(("ref_release", e))
direct_log(
f"Lock rollback reference release error for key {key}: {e}",
level="ERROR",
enable_output=True,
)
# 3. Decrement debug counter
if debug_inc:
try:
dec_debug_n_locks_acquired()
except Exception as e:
errors.append(("debug_dec", e))
direct_log(
f"Lock rollback counter decrementing error for key {key}: {e}",
level="ERROR",
enable_output=True,
)
return errors
# Release already acquired locks in reverse order
for entry in reversed(self._ul):
# Use shield to protect each lock's rollback
try:
await asyncio.shield(rollback_single_entry(entry))
except Exception as e:
# Log but continue rolling back other locks
direct_log(
f"Lock rollback unexpected error for {entry['key']}: {e}",
level="ERROR",
enable_output=True,
)
self._ul = None
# ----- exit -----
async def __aexit__(self, exc_type, exc, tb):
# The UnifiedLock takes care of proper release order
for ul, key in zip(reversed(self._ul), reversed(self._keys)):
await ul.__aexit__(exc_type, exc, tb)
self._parent._release_lock_for_key(self._namespace, key)
dec_debug_n_locks_acquired()
self._ul = None
if self._ul is None:
return
async def release_all_locks():
"""Release all locks with comprehensive error handling, protected from cancellation"""
async def release_single_entry(entry, exc_type, exc, tb):
"""Release a single lock with full protection"""
key = entry["key"]
lock = entry["lock"]
debug_inc = entry["debug_inc"]
entered = entry["entered"]
errors = []
# 1. Release the lock
if entered:
try:
await lock.__aexit__(exc_type, exc, tb)
except Exception as e:
errors.append(("lock_exit", e))
direct_log(
f"Lock release error for key {key}: {e}",
level="ERROR",
enable_output=True,
)
# 2. Release reference count
try:
self._parent._release_lock_for_key(self._namespace, key)
except Exception as e:
errors.append(("ref_release", e))
direct_log(
f"Lock release reference error for key {key}: {e}",
level="ERROR",
enable_output=True,
)
# 3. Decrement debug counter
if debug_inc:
try:
dec_debug_n_locks_acquired()
except Exception as e:
errors.append(("debug_dec", e))
direct_log(
f"Lock release counter decrementing error for key {key}: {e}",
level="ERROR",
enable_output=True,
)
return errors
all_errors = []
# Release locks in reverse order
# This entire loop is protected by the outer shield
for entry in reversed(self._ul):
try:
errors = await release_single_entry(entry, exc_type, exc, tb)
for error_type, error in errors:
all_errors.append((entry["key"], error_type, error))
except Exception as e:
all_errors.append((entry["key"], "unexpected", e))
direct_log(
f"Lock release unexpected error for {entry['key']}: {e}",
level="ERROR",
enable_output=True,
)
return all_errors
# CRITICAL: Protect the entire release process with shield
# This ensures that even if cancellation occurs, all locks are released
try:
all_errors = await asyncio.shield(release_all_locks())
except Exception as e:
direct_log(
f"Critical error during __aexit__ cleanup: {e}",
level="ERROR",
enable_output=True,
)
all_errors = []
finally:
# Always clear the lock list, even if shield was cancelled
self._ul = None
# If there were release errors and no other exception, raise the first release error
if all_errors and exc_type is None:
raise all_errors[0][2] # (key, error_type, error)
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:

View file

@ -1871,9 +1871,14 @@ class LightRAG:
if task and not task.done():
task.cancel()
# Persistent llm cache
# Persistent llm cache with error handling
if self.llm_response_cache:
await self.llm_response_cache.index_done_callback()
try:
await self.llm_response_cache.index_done_callback()
except Exception as persist_error:
logger.error(
f"Failed to persist LLM cache: {persist_error}"
)
# Record processing end time for failed case
processing_end_time = int(time.time())
@ -1994,9 +1999,14 @@ class LightRAG:
error_msg
)
# Persistent llm cache
# Persistent llm cache with error handling
if self.llm_response_cache:
await self.llm_response_cache.index_done_callback()
try:
await self.llm_response_cache.index_done_callback()
except Exception as persist_error:
logger.error(
f"Failed to persist LLM cache: {persist_error}"
)
# Record processing end time for failed case
processing_end_time = int(time.time())

View file

@ -2302,9 +2302,7 @@ async def merge_nodes_and_edges(
return entity_data
except Exception as e:
error_msg = (
f"Critical error in entity processing for `{entity_name}`: {e}"
)
error_msg = f"Error processing entity `{entity_name}`: {e}"
logger.error(error_msg)
# Try to update pipeline status, but don't let status update failure affect main exception
@ -2340,36 +2338,32 @@ async def merge_nodes_and_edges(
entity_tasks, return_when=asyncio.FIRST_EXCEPTION
)
# Check if any task raised an exception and ensure all exceptions are retrieved
first_exception = None
successful_results = []
processed_entities = []
for task in done:
try:
exception = task.exception()
if exception is not None:
if first_exception is None:
first_exception = exception
else:
successful_results.append(task.result())
except Exception as e:
result = task.result()
except BaseException as e:
if first_exception is None:
first_exception = e
else:
processed_entities.append(result)
if pending:
for task in pending:
task.cancel()
pending_results = await asyncio.gather(*pending, return_exceptions=True)
for result in pending_results:
if isinstance(result, BaseException):
if first_exception is None:
first_exception = result
else:
processed_entities.append(result)
# If any task failed, cancel all pending tasks and raise the first exception
if first_exception is not None:
# Cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the first exception to notify the caller
raise first_exception
# If all tasks completed successfully, collect results
processed_entities = [task.result() for task in entity_tasks]
# ===== Phase 2: Process all relationships concurrently =====
log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
logger.info(log_message)
@ -2421,7 +2415,7 @@ async def merge_nodes_and_edges(
return edge_data, added_entities
except Exception as e:
error_msg = f"Critical error in relationship processing for `{sorted_edge_key}`: {e}"
error_msg = f"Error processing relation `{sorted_edge_key}`: {e}"
logger.error(error_msg)
# Try to update pipeline status, but don't let status update failure affect main exception
@ -2459,40 +2453,36 @@ async def merge_nodes_and_edges(
edge_tasks, return_when=asyncio.FIRST_EXCEPTION
)
# Check if any task raised an exception and ensure all exceptions are retrieved
first_exception = None
successful_results = []
for task in done:
try:
exception = task.exception()
if exception is not None:
if first_exception is None:
first_exception = exception
else:
successful_results.append(task.result())
except Exception as e:
edge_data, added_entities = task.result()
except BaseException as e:
if first_exception is None:
first_exception = e
else:
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
if pending:
for task in pending:
task.cancel()
pending_results = await asyncio.gather(*pending, return_exceptions=True)
for result in pending_results:
if isinstance(result, BaseException):
if first_exception is None:
first_exception = result
else:
edge_data, added_entities = result
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
# If any task failed, cancel all pending tasks and raise the first exception
if first_exception is not None:
# Cancel all pending tasks
for pending_task in pending:
pending_task.cancel()
# Wait for cancellation to complete
if pending:
await asyncio.wait(pending)
# Re-raise the first exception to notify the caller
raise first_exception
# If all tasks completed successfully, collect results
for task in edge_tasks:
edge_data, added_entities = task.result()
if edge_data is not None:
processed_edges.append(edge_data)
all_added_entities.extend(added_entities)
# ===== Phase 3: Update full_entities and full_relations storage =====
if full_entities_storage and full_relations_storage and doc_id:
try: