From 6b0c0ef81574d117e23ff8b9f8295376cb25473a Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 17 Nov 2025 04:07:37 +0800 Subject: [PATCH] Refactor namespace lock to support reusable async context manager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add NamespaceLock class wrapper • Fix lock re-entrance issues • Enable concurrent lock usage • Fresh context per async with block • Update get_namespace_lock API (cherry picked from commit 7deb9a64b9ae579f8f6fa4fc2e627d7d47e9eae3) --- lightrag/kg/shared_storage.py | 339 +++++++++++++++++++++++----------- 1 file changed, 231 insertions(+), 108 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index d9905f48..87f0f9a9 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -10,17 +10,19 @@ from typing import Any, Dict, List, Optional, Union, TypeVar, Generic from lightrag.exceptions import PipelineNotInitializedError +DEBUG_LOCKS = False + # Define a direct print function for critical logs that must be visible in all processes -def direct_log(message, enable_output: bool = False, level: str = "INFO"): +def direct_log(message, enable_output: bool = True, level: str = "DEBUG"): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. Args: message: The message to log - level: Log level (default: "DEBUG") - enable_output: Whether to actually output the log (default: True) + level: Log level for message (control the visibility of the message by comparing with the current logger level) + enable_output: Enable or disable log message (Force to turn off the message,) """ if not enable_output: return @@ -73,16 +75,16 @@ _last_mp_cleanup_time: Optional[float] = None _initialized = None +# Default workspace for backward compatibility +_default_workspace: Optional[str] = None + # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated # locks for mutex access -_storage_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None -_pipeline_status_lock: Optional[LockType] = None -_graph_db_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None # Manager for all keyed locks _storage_keyed_lock: Optional["KeyedUnifiedLock"] = None @@ -90,10 +92,25 @@ _storage_keyed_lock: Optional["KeyedUnifiedLock"] = None # async locks for coroutine synchronization in multiprocess mode _async_locks: Optional[Dict[str, asyncio.Lock]] = None -DEBUG_LOCKS = False _debug_n_locks_acquired: int = 0 +def get_final_namespace(namespace: str, workspace: str | None = None): + global _default_workspace + if workspace is None: + workspace = _default_workspace + + if workspace is None: + direct_log( + f"Error: Invoke namespace operation without workspace, pid={os.getpid()}", + level="ERROR", + ) + raise ValueError("Invoke namespace operation without workspace") + + final_namespace = f"{workspace}:{namespace}" if workspace else f"{namespace}" + return final_namespace + + def inc_debug_n_locks_acquired(): global _debug_n_locks_acquired if DEBUG_LOCKS: @@ -140,7 +157,8 @@ class UnifiedLock(Generic[T]): if not self._is_async and self._async_lock is not None: await self._async_lock.acquire() direct_log( - f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired", + f"== Lock == Process {self._pid}: Acquired async lock '{self._name}", + level="DEBUG", enable_output=self._enable_logging, ) @@ -151,7 +169,8 @@ class UnifiedLock(Generic[T]): self._lock.acquire() direct_log( - f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", + f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})", + level="INFO", enable_output=self._enable_logging, ) return self @@ -182,7 +201,8 @@ class UnifiedLock(Generic[T]): main_lock_released = True direct_log( - f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", + f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})", + level="INFO", enable_output=self._enable_logging, ) @@ -190,7 +210,8 @@ class UnifiedLock(Generic[T]): if not self._is_async and self._async_lock is not None: self._async_lock.release() direct_log( - f"== Lock == Process {self._pid}: Async lock '{self._name}' released", + f"== Lock == Process {self._pid}: Released async lock {self._name}", + level="DEBUG", enable_output=self._enable_logging, ) @@ -210,12 +231,13 @@ class UnifiedLock(Generic[T]): try: direct_log( f"== Lock == Process {self._pid}: Attempting to release async lock after main lock failure", - level="WARNING", + level="DEBUG", enable_output=self._enable_logging, ) self._async_lock.release() direct_log( f"== Lock == Process {self._pid}: Successfully released async lock after main lock failure", + level="INFO", enable_output=self._enable_logging, ) except Exception as inner_e: @@ -233,12 +255,14 @@ class UnifiedLock(Generic[T]): if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") direct_log( - f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", + f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)", + level="DEBUG", enable_output=self._enable_logging, ) self._lock.acquire() direct_log( - f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", + f"== Lock == Process {self._pid}: Acquired lock {self._name} (sync)", + level="INFO", enable_output=self._enable_logging, ) return self @@ -257,11 +281,13 @@ class UnifiedLock(Generic[T]): raise RuntimeError("Use 'async with' for shared_storage lock") direct_log( f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", + level="DEBUG", enable_output=self._enable_logging, ) self._lock.release() direct_log( - f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", + f"== Lock == Process {self._pid}: Released lock {self._name} (sync)", + level="INFO", enable_output=self._enable_logging, ) except Exception as e: @@ -1043,40 +1069,10 @@ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: ) -def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified storage lock for data consistency""" - async_lock = _async_locks.get("storage_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_storage_lock, - is_async=not _is_multiprocess, - name="storage_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) - - -def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified storage lock for data consistency""" - async_lock = _async_locks.get("pipeline_status_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_pipeline_status_lock, - is_async=not _is_multiprocess, - name="pipeline_status_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) - - -def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified graph database lock for ensuring atomic operations""" - async_lock = _async_locks.get("graph_db_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_graph_db_lock, - is_async=not _is_multiprocess, - name="graph_db_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) +# Workspace based storage_lock is implemented by get_storage_keyed_lock instead. +# Workspace based pipeline_status_lock is implemented by get_storage_keyed_lock instead. +# No need to implement graph_db_lock: +# data integrity is ensured by entity level keyed-lock and allowing only one process to hold pipeline at a time. def get_storage_keyed_lock( @@ -1180,14 +1176,11 @@ def initialize_share_data(workers: int = 1): _manager, \ _workers, \ _is_multiprocess, \ - _storage_lock, \ _lock_registry, \ _lock_registry_count, \ _lock_cleanup_data, \ _registry_guard, \ _internal_lock, \ - _pipeline_status_lock, \ - _graph_db_lock, \ _data_init_lock, \ _shared_dicts, \ _init_flags, \ @@ -1215,9 +1208,6 @@ def initialize_share_data(workers: int = 1): _lock_cleanup_data = _manager.dict() _registry_guard = _manager.RLock() _internal_lock = _manager.Lock() - _storage_lock = _manager.Lock() - _pipeline_status_lock = _manager.Lock() - _graph_db_lock = _manager.Lock() _data_init_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() @@ -1228,8 +1218,6 @@ def initialize_share_data(workers: int = 1): # Initialize async locks for multiprocess mode _async_locks = { "internal_lock": asyncio.Lock(), - "storage_lock": asyncio.Lock(), - "pipeline_status_lock": asyncio.Lock(), "graph_db_lock": asyncio.Lock(), "data_init_lock": asyncio.Lock(), } @@ -1240,9 +1228,6 @@ def initialize_share_data(workers: int = 1): else: _is_multiprocess = False _internal_lock = asyncio.Lock() - _storage_lock = asyncio.Lock() - _pipeline_status_lock = asyncio.Lock() - _graph_db_lock = asyncio.Lock() _data_init_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} @@ -1260,12 +1245,19 @@ def initialize_share_data(workers: int = 1): _initialized = True -async def initialize_pipeline_status(namespace: str = "pipeline_status"): +async def initialize_pipeline_status(workspace: str | None = None): """ - Initialize pipeline namespace with default values. - This function is called during FASTAPI lifespan for each worker. + Initialize pipeline_status share data with default values. + This function could be called before during FASTAPI lifespan for each worker. + + Args: + workspace: Optional workspace identifier for pipeline_status of specific workspace. + If None or empty string, uses the default workspace set by + set_default_workspace(). """ - pipeline_namespace = await get_namespace_data(namespace, first_init=True) + pipeline_namespace = await get_namespace_data( + "pipeline_status", first_init=True, workspace=workspace + ) async with get_internal_lock(): # Check if already initialized by checking for required fields @@ -1273,7 +1265,7 @@ async def initialize_pipeline_status(namespace: str = "pipeline_status"): return # Create a shared list object for history_messages - history_messages = _manager.list() if _is_multiprocess and _manager is not None else [] + history_messages = _manager.list() if _is_multiprocess else [] pipeline_namespace.update( { "autoscanned": False, # Auto-scan started @@ -1288,10 +1280,14 @@ async def initialize_pipeline_status(namespace: str = "pipeline_status"): "history_messages": history_messages, # 使用共享列表对象 } ) - direct_log(f"Process {os.getpid()} Pipeline namespace initialized: [{namespace}]") + + final_namespace = get_final_namespace("pipeline_status", workspace) + direct_log( + f"Process {os.getpid()} Pipeline namespace '{final_namespace}' initialized" + ) -async def get_update_flag(namespace: str): +async def get_update_flag(namespace: str, workspace: str | None = None): """ Create a namespace's update flag for a workers. Returen the update flag to caller for referencing or reset. @@ -1300,14 +1296,16 @@ async def get_update_flag(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: + if final_namespace not in _update_flags: if _is_multiprocess and _manager is not None: - _update_flags[namespace] = _manager.list() + _update_flags[final_namespace] = _manager.list() else: - _update_flags[namespace] = [] + _update_flags[final_namespace] = [] direct_log( - f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]" + f"Process {os.getpid()} initialized updated flags for namespace: [{final_namespace}]" ) if _is_multiprocess and _manager is not None: @@ -1320,39 +1318,43 @@ async def get_update_flag(namespace: str): new_update_flag = MutableBoolean(False) - _update_flags[namespace].append(new_update_flag) + _update_flags[final_namespace].append(new_update_flag) return new_update_flag -async def set_all_update_flags(namespace: str): +async def set_all_update_flags(namespace: str, workspace: str | None = None): """Set all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") + if final_namespace not in _update_flags: + raise ValueError(f"Namespace {final_namespace} not found in update flags") # Update flags for both modes - for i in range(len(_update_flags[namespace])): - _update_flags[namespace][i].value = True + for i in range(len(_update_flags[final_namespace])): + _update_flags[final_namespace][i].value = True -async def clear_all_update_flags(namespace: str): +async def clear_all_update_flags(namespace: str, workspace: str | None = None): """Clear all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") + if final_namespace not in _update_flags: + raise ValueError(f"Namespace {final_namespace} not found in update flags") # Update flags for both modes - for i in range(len(_update_flags[namespace])): - _update_flags[namespace][i].value = False + for i in range(len(_update_flags[final_namespace])): + _update_flags[final_namespace][i].value = False -async def get_all_update_flags_status() -> Dict[str, list]: +async def get_all_update_flags_status(workspace: str | None = None) -> Dict[str, list]: """ Get update flags status for all namespaces. @@ -1362,9 +1364,17 @@ async def get_all_update_flags_status() -> Dict[str, list]: if _update_flags is None: return {} + if workspace is None: + workspace = get_default_workspace + result = {} async with get_internal_lock(): for namespace, flags in _update_flags.items(): + namespace_split = namespace.split(":") + if workspace and not namespace_split[0] == workspace: + continue + if not workspace and namespace_split[0]: + continue worker_statuses = [] for flag in flags: if _is_multiprocess: @@ -1376,7 +1386,9 @@ async def get_all_update_flags_status() -> Dict[str, list]: return result -async def try_initialize_namespace(namespace: str) -> bool: +async def try_initialize_namespace( + namespace: str, workspace: str | None = None +) -> bool: """ Returns True if the current worker(process) gets initialization permission for loading data later. The worker does not get the permission is prohibited to load data from files. @@ -1386,52 +1398,139 @@ async def try_initialize_namespace(namespace: str) -> bool: if _init_flags is None: raise ValueError("Try to create nanmespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _init_flags: - _init_flags[namespace] = True + if final_namespace not in _init_flags: + _init_flags[final_namespace] = True direct_log( - f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + f"Process {os.getpid()} ready to initialize storage namespace: [{final_namespace}]" ) return True direct_log( - f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" + f"Process {os.getpid()} storage namespace already initialized: [{final_namespace}]" ) return False async def get_namespace_data( - namespace: str, first_init: bool = False + namespace: str, first_init: bool = False, workspace: str | None = None ) -> Dict[str, Any]: """get the shared data reference for specific namespace Args: namespace: The namespace to retrieve - allow_create: If True, allows creation of the namespace if it doesn't exist. - Used internally by initialize_pipeline_status(). + first_init: If True, allows pipeline_status namespace to create namespace if it doesn't exist. + Prevent getting pipeline_status namespace without initialize_pipeline_status(). + This parameter is used internally by initialize_pipeline_status(). + workspace: Workspace identifier (may be empty string for global namespace) """ if _shared_dicts is None: direct_log( - f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}", + f"Error: Try to getnanmespace before it is initialized, pid={os.getpid()}", level="ERROR", ) raise ValueError("Shared dictionaries not initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _shared_dicts: + if final_namespace not in _shared_dicts: # Special handling for pipeline_status namespace - if namespace == "pipeline_status" and not first_init: + if final_namespace.endswith(":pipeline_status") and not first_init: # Check if pipeline_status should have been initialized but wasn't - # This helps users understand they need to call initialize_pipeline_status() - raise PipelineNotInitializedError(namespace) + # This helps users to call initialize_pipeline_status() before get_namespace_data() + raise PipelineNotInitializedError(final_namespace) # For other namespaces or when allow_create=True, create them dynamically if _is_multiprocess and _manager is not None: - _shared_dicts[namespace] = _manager.dict() + _shared_dicts[final_namespace] = _manager.dict() else: - _shared_dicts[namespace] = {} + _shared_dicts[final_namespace] = {} - return _shared_dicts[namespace] + return _shared_dicts[final_namespace] + + +class NamespaceLock: + """ + Reusable namespace lock wrapper that creates a fresh context on each use. + + This class solves the lock re-entrance issue by implementing the async context + manager protocol. Each time it's used in an 'async with' statement, it creates + a new _KeyedLockContext internally, allowing the same NamespaceLock instance + to be used multiple times safely, even in concurrent scenarios. + + Example: + lock = NamespaceLock("my_namespace", "workspace1") + + # Can be used multiple times safely + async with lock: + await do_something() + + # Can even be used concurrently (each creates its own context) + await asyncio.gather( + use_lock_1(lock), + use_lock_2(lock) + ) + """ + + def __init__( + self, namespace: str, workspace: str | None = None, enable_logging: bool = False + ): + self._namespace = namespace + self._workspace = workspace + self._enable_logging = enable_logging + self._current_ctx = None + + async def __aenter__(self): + """Create a fresh context each time we enter""" + final_namespace = get_final_namespace(self._namespace, self._workspace) + self._current_ctx = get_storage_keyed_lock( + ["default_key"], + namespace=final_namespace, + enable_logging=self._enable_logging, + ) + return await self._current_ctx.__aenter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the current context and clean up""" + if self._current_ctx is None: + raise RuntimeError("NamespaceLock exited without being entered") + + result = await self._current_ctx.__aexit__(exc_type, exc_val, exc_tb) + self._current_ctx = None + return result + + +def get_namespace_lock( + namespace: str, workspace: str | None = None, enable_logging: bool = False +) -> NamespaceLock: + """Get a reusable namespace lock wrapper. + + This function returns a NamespaceLock instance that can be used multiple times + safely, even in concurrent scenarios. Each use creates a fresh lock context + internally, preventing lock re-entrance errors. + + Args: + namespace: The namespace to get the lock for. + workspace: Workspace identifier (may be empty string for global namespace) + enable_logging: Whether to enable lock operation logging + + Returns: + NamespaceLock: A reusable lock wrapper that can be used with 'async with' + + Example: + lock = get_namespace_lock("pipeline_status", workspace="space1") + + # Can be used multiple times + async with lock: + await do_something() + + async with lock: + await do_something_else() + """ + return NamespaceLock(namespace, workspace, enable_logging) def finalize_share_data(): @@ -1447,17 +1546,13 @@ def finalize_share_data(): global \ _manager, \ _is_multiprocess, \ - _storage_lock, \ _internal_lock, \ - _pipeline_status_lock, \ - _graph_db_lock, \ _data_init_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ _update_flags, \ - _async_locks, \ - _default_workspace + _async_locks # Check if already initialized if not _initialized: @@ -1516,13 +1611,41 @@ def finalize_share_data(): _is_multiprocess = None _shared_dicts = None _init_flags = None - _storage_lock = None _internal_lock = None - _pipeline_status_lock = None - _graph_db_lock = None _data_init_lock = None _update_flags = None _async_locks = None - _default_workspace = None direct_log(f"Process {os.getpid()} storage data finalization complete") + + +def set_default_workspace(workspace: str | None = None): + """ + Set default workspace for namespace operations for backward compatibility. + + This allows get_namespace_data(),get_namespace_lock() or initialize_pipeline_status() to + automatically use the correct workspace when called without workspace parameters, + maintaining compatibility with legacy code that doesn't pass workspace explicitly. + + Args: + workspace: Workspace identifier (may be empty string for global namespace) + """ + global _default_workspace + if workspace is None: + workspace = "" + _default_workspace = workspace + direct_log( + f"Default workspace set to: '{_default_workspace}' (empty means global)", + level="DEBUG", + ) + + +def get_default_workspace() -> str: + """ + Get default workspace for backward compatibility. + + Returns: + The default workspace string. Empty string means global namespace. None means not set. + """ + global _default_workspace + return _default_workspace