diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index c9e26614..fdd4adcd 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,4 +1,3 @@ -from collections import defaultdict import os import sys import asyncio @@ -6,7 +5,7 @@ import multiprocessing as mp from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing import Manager import time -from typing import Any, Callable, Dict, List, Optional, Union, TypeVar, Generic +from typing import Any, Dict, List, Optional, Union, TypeVar, Generic # Define a direct print function for critical logs that must be visible in all processes @@ -56,25 +55,38 @@ _async_locks: Optional[Dict[str, asyncio.Lock]] = None DEBUG_LOCKS = False _debug_n_locks_acquired: int = 0 + + def inc_debug_n_locks_acquired(): global _debug_n_locks_acquired if DEBUG_LOCKS: _debug_n_locks_acquired += 1 - print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True) + print( + f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}", + end="\r", + flush=True, + ) + def dec_debug_n_locks_acquired(): global _debug_n_locks_acquired if DEBUG_LOCKS: if _debug_n_locks_acquired > 0: _debug_n_locks_acquired -= 1 - print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True) + print( + f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}", + end="\r", + flush=True, + ) else: raise RuntimeError("Attempting to release lock when no locks are acquired") + def get_debug_n_locks_acquired(): global _debug_n_locks_acquired return _debug_n_locks_acquired + class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" @@ -246,6 +258,7 @@ class UnifiedLock(Generic[T]): else: return self._lock.locked() + # ───────────────────────────────────────────────────────────────────────────── # 2. CROSS‑PROCESS FACTORY (one manager.Lock shared by *all* processes) # ───────────────────────────────────────────────────────────────────────────── @@ -253,7 +266,10 @@ def _get_combined_key(factory_name: str, key: str) -> str: """Return the combined key for the factory and key.""" return f"{factory_name}:{key}" -def _get_or_create_shared_raw_mp_lock(factory_name: str, key: str) -> Optional[mp.synchronize.Lock]: + +def _get_or_create_shared_raw_mp_lock( + factory_name: str, key: str +) -> Optional[mp.synchronize.Lock]: """Return the *singleton* manager.Lock() proxy for *key*, creating if needed.""" if not _is_multiprocess: return None @@ -265,13 +281,19 @@ def _get_or_create_shared_raw_mp_lock(factory_name: str, key: str) -> Optional[m if raw is None: raw = _manager.Lock() _lock_registry[combined_key] = raw - _lock_registry_count[combined_key] = 1 # 修复:新锁初始化为1,与释放逻辑保持一致 + _lock_registry_count[combined_key] = ( + 1 # 修复:新锁初始化为1,与释放逻辑保持一致 + ) else: if count is None: - raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}") + raise RuntimeError( + f"Shared-Data lock registry for {factory_name} is corrupted for key {key}" + ) count += 1 _lock_registry_count[combined_key] = count - if count == 1 and combined_key in _lock_cleanup_data: # 把再次使用的锁添剔除出待清理字典 + if ( + count == 1 and combined_key in _lock_cleanup_data + ): # 把再次使用的锁添剔除出待清理字典 _lock_cleanup_data.pop(combined_key) return raw @@ -288,25 +310,29 @@ def _release_shared_raw_mp_lock(factory_name: str, key: str): if raw is None and count is None: return elif raw is None or count is None: - raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}") + raise RuntimeError( + f"Shared-Data lock registry for {factory_name} is corrupted for key {key}" + ) count -= 1 if count < 0: - raise RuntimeError(f"Attempting to release lock for {key} more times than it was acquired") - + raise RuntimeError( + f"Attempting to release lock for {key} more times than it was acquired" + ) + _lock_registry_count[combined_key] = count - + current_time = time.time() if count == 0: _lock_cleanup_data[combined_key] = current_time - + # 清理过期的锁 for cleanup_key, cleanup_time in list(_lock_cleanup_data.items()): if current_time - cleanup_time > CLEANUP_KEYED_LOCKS_AFTER_SECONDS: _lock_registry.pop(cleanup_key, None) _lock_registry_count.pop(cleanup_key, None) _lock_cleanup_data.pop(cleanup_key, None) - + # ───────────────────────────────────────────────────────────────────────────── # 3. PARAMETER‑KEYED WRAPPER (unchanged except it *accepts a factory*) @@ -322,7 +348,9 @@ class KeyedUnifiedLock: """ # ---------------- construction ---------------- - def __init__(self, factory_name: str, *, default_enable_logging: bool = True) -> None: + def __init__( + self, factory_name: str, *, default_enable_logging: bool = True + ) -> None: self._factory_name = factory_name self._default_enable_logging = default_enable_logging self._async_lock: Dict[str, asyncio.Lock] = {} # key → asyncio.Lock @@ -340,7 +368,12 @@ class KeyedUnifiedLock: """ if enable_logging is None: enable_logging = self._default_enable_logging - return _KeyedLockContext(self, factory_name=self._factory_name, keys=keys, enable_logging=enable_logging) + return _KeyedLockContext( + self, + factory_name=self._factory_name, + keys=keys, + enable_logging=enable_logging, + ) def _get_or_create_async_lock(self, key: str) -> asyncio.Lock: async_lock = self._async_lock.get(key) @@ -357,7 +390,7 @@ class KeyedUnifiedLock: def _release_async_lock(self, key: str): count = self._async_lock_count.get(key, 0) count -= 1 - + current_time = time.time() # 优化:只调用一次 time.time() if count == 0: self._async_lock_cleanup_data[key] = current_time @@ -370,7 +403,6 @@ class KeyedUnifiedLock: self._async_lock_count.pop(cleanup_key) self._async_lock_cleanup_data.pop(cleanup_key) - def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock: # 1. get (or create) the per‑process async gate for this key # Is synchronous, so no need to acquire a lock @@ -381,15 +413,15 @@ class KeyedUnifiedLock: is_multiprocess = raw_lock is not None if not is_multiprocess: raw_lock = async_lock - + # 3. build a *fresh* UnifiedLock with the chosen logging flag if is_multiprocess: return UnifiedLock( lock=raw_lock, - is_async=False, # manager.Lock is synchronous + is_async=False, # manager.Lock is synchronous name=f"key:{self._factory_name}:{key}", enable_logging=enable_logging, - async_lock=async_lock, # prevents event‑loop blocking + async_lock=async_lock, # prevents event‑loop blocking ) else: return UnifiedLock( @@ -397,13 +429,14 @@ class KeyedUnifiedLock: is_async=True, name=f"key:{self._factory_name}:{key}", enable_logging=enable_logging, - async_lock=None, # No need for async lock in single process mode + async_lock=None, # No need for async lock in single process mode ) - + def _release_lock_for_key(self, key: str): self._release_async_lock(key) _release_shared_raw_mp_lock(self._factory_name, key) + class _KeyedLockContext: def __init__( self, @@ -419,7 +452,8 @@ class _KeyedLockContext: # to avoid deadlocks self._keys = sorted(keys) self._enable_logging = ( - enable_logging if enable_logging is not None + enable_logging + if enable_logging is not None else parent._default_enable_logging ) self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__ @@ -432,7 +466,9 @@ class _KeyedLockContext: # 4. acquire it self._ul = [] for key in self._keys: - lock = self._parent._get_lock_for_key(key, enable_logging=self._enable_logging) + lock = self._parent._get_lock_for_key( + key, enable_logging=self._enable_logging + ) await lock.__aenter__() inc_debug_n_locks_acquired() self._ul.append(lock) @@ -447,6 +483,7 @@ class _KeyedLockContext: dec_debug_n_locks_acquired() self._ul = None + def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None @@ -494,7 +531,10 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: async_lock=async_lock, ) -def get_graph_db_lock_keyed(keys: str | list[str], enable_logging: bool = False) -> KeyedUnifiedLock: + +def get_graph_db_lock_keyed( + keys: str | list[str], enable_logging: bool = False +) -> KeyedUnifiedLock: """return unified graph database lock for ensuring atomic operations""" global _graph_db_lock_keyed if _graph_db_lock_keyed is None: @@ -503,6 +543,7 @@ def get_graph_db_lock_keyed(keys: str | list[str], enable_logging: bool = False) keys = [keys] return _graph_db_lock_keyed(keys, enable_logging=enable_logging) + def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: """return unified data initialization lock for ensuring atomic data initialization""" async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None @@ -577,7 +618,7 @@ def initialize_share_data(workers: int = 1): _shared_dicts = _manager.dict() _init_flags = _manager.dict() _update_flags = _manager.dict() - + _graph_db_lock_keyed = KeyedUnifiedLock( factory_name="graph_db_lock", ) @@ -605,7 +646,7 @@ def initialize_share_data(workers: int = 1): _init_flags = {} _update_flags = {} _async_locks = None # No need for async locks in single process mode - + _graph_db_lock_keyed = KeyedUnifiedLock( factory_name="graph_db_lock", )