Refac: Generalize keyed lock with namespace support

Refactored the `KeyedUnifiedLock` to be generic and support dynamic namespaces. This decouples the locking mechanism from a specific "GraphDB" implementation, allowing it to be reused across different components and workspaces safely.

Key changes:
- `KeyedUnifiedLock` now takes a `namespace` parameter on lock acquisition.
- Renamed `_graph_db_lock_keyed` to a more generic _storage_keyed_lock`
- Replaced `get_graph_db_lock_keyed` with get_storage_keyed_lock` to support namespaces
This commit is contained in:
yangdx 2025-07-12 12:10:12 +08:00
parent f2d875f8ab
commit 2ade3067f8
2 changed files with 73 additions and 59 deletions

View file

@ -81,7 +81,7 @@ _pipeline_status_lock: Optional[LockType] = None
_graph_db_lock: Optional[LockType] = None _graph_db_lock: Optional[LockType] = None
_data_init_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None
# Manager for all keyed locks # Manager for all keyed locks
_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None _storage_keyed_lock: Optional["KeyedUnifiedLock"] = None
# async locks for coroutine synchronization in multiprocess mode # async locks for coroutine synchronization in multiprocess mode
_async_locks: Optional[Dict[str, asyncio.Lock]] = None _async_locks: Optional[Dict[str, asyncio.Lock]] = None
@ -379,12 +379,12 @@ def _release_shared_raw_mp_lock(factory_name: str, key: str):
new_earliest_time = None new_earliest_time = None
# Perform cleanup while maintaining the new earliest time # Perform cleanup while maintaining the new earliest time
# Clean expired locks from all namespaces
for cleanup_key, cleanup_time in list(_lock_cleanup_data.items()): for cleanup_key, cleanup_time in list(_lock_cleanup_data.items()):
if ( if (
current_time - cleanup_time current_time - cleanup_time
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS > CLEANUP_KEYED_LOCKS_AFTER_SECONDS
): ):
# Clean expired locks
_lock_registry.pop(cleanup_key, None) _lock_registry.pop(cleanup_key, None)
_lock_registry_count.pop(cleanup_key, None) _lock_registry_count.pop(cleanup_key, None)
_lock_cleanup_data.pop(cleanup_key, None) _lock_cleanup_data.pop(cleanup_key, None)
@ -433,15 +433,13 @@ class KeyedUnifiedLock:
Manager for unified keyed locks, supporting both single and multi-process Manager for unified keyed locks, supporting both single and multi-process
Keeps only a table of async keyed locks locally Keeps only a table of async keyed locks locally
Fetches the multi-process keyed lockon every acquire Fetches the multi-process keyed lock on every acquire
Builds a fresh `UnifiedLock` each time, so `enable_logging` Builds a fresh `UnifiedLock` each time, so `enable_logging`
(or future options) can vary per call. (or future options) can vary per call.
Supports dynamic namespaces specified at lock usage time
""" """
def __init__( def __init__(self, *, default_enable_logging: bool = True) -> None:
self, factory_name: str, *, default_enable_logging: bool = True
) -> None:
self._factory_name = factory_name
self._default_enable_logging = default_enable_logging self._default_enable_logging = default_enable_logging
self._async_lock: Dict[str, asyncio.Lock] = {} # local keyed locks self._async_lock: Dict[str, asyncio.Lock] = {} # local keyed locks
self._async_lock_count: Dict[ self._async_lock_count: Dict[
@ -460,41 +458,43 @@ class KeyedUnifiedLock:
None # track last async cleanup time for minimum interval None # track last async cleanup time for minimum interval
) )
def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None): def __call__(
self, namespace: str, keys: list[str], *, enable_logging: Optional[bool] = None
):
""" """
Ergonomic helper so you can write: Ergonomic helper so you can write:
async with keyed_locks("alpha"): async with storage_keyed_lock("namespace", ["key1", "key2"]):
... ...
""" """
if enable_logging is None: if enable_logging is None:
enable_logging = self._default_enable_logging enable_logging = self._default_enable_logging
return _KeyedLockContext( return _KeyedLockContext(
self, self,
factory_name=self._factory_name, namespace=namespace,
keys=keys, keys=keys,
enable_logging=enable_logging, enable_logging=enable_logging,
) )
def _get_or_create_async_lock(self, key: str) -> asyncio.Lock: def _get_or_create_async_lock(self, combined_key: str) -> asyncio.Lock:
async_lock = self._async_lock.get(key) async_lock = self._async_lock.get(combined_key)
count = self._async_lock_count.get(key, 0) count = self._async_lock_count.get(combined_key, 0)
if async_lock is None: if async_lock is None:
async_lock = asyncio.Lock() async_lock = asyncio.Lock()
self._async_lock[key] = async_lock self._async_lock[combined_key] = async_lock
elif count == 0 and key in self._async_lock_cleanup_data: elif count == 0 and combined_key in self._async_lock_cleanup_data:
self._async_lock_cleanup_data.pop(key) self._async_lock_cleanup_data.pop(combined_key)
count += 1 count += 1
self._async_lock_count[key] = count self._async_lock_count[combined_key] = count
return async_lock return async_lock
def _release_async_lock(self, key: str): def _release_async_lock(self, combined_key: str):
count = self._async_lock_count.get(key, 0) count = self._async_lock_count.get(combined_key, 0)
count -= 1 count -= 1
current_time = time.time() current_time = time.time()
if count == 0: if count == 0:
self._async_lock_cleanup_data[key] = current_time self._async_lock_cleanup_data[combined_key] = current_time
# Update earliest async cleanup time (only when earlier) # Update earliest async cleanup time (only when earlier)
if ( if (
@ -502,7 +502,7 @@ class KeyedUnifiedLock:
or current_time < self._earliest_async_cleanup_time or current_time < self._earliest_async_cleanup_time
): ):
self._earliest_async_cleanup_time = current_time self._earliest_async_cleanup_time = current_time
self._async_lock_count[key] = count self._async_lock_count[combined_key] = count
# Efficient cleanup triggering with minimum interval control # Efficient cleanup triggering with minimum interval control
total_cleanup_len = len(self._async_lock_cleanup_data) total_cleanup_len = len(self._async_lock_cleanup_data)
@ -538,6 +538,7 @@ class KeyedUnifiedLock:
new_earliest_time = None new_earliest_time = None
# Perform cleanup while maintaining the new earliest time # Perform cleanup while maintaining the new earliest time
# Clean expired async locks from all namespaces
for cleanup_key, cleanup_time in list( for cleanup_key, cleanup_time in list(
self._async_lock_cleanup_data.items() self._async_lock_cleanup_data.items()
): ):
@ -545,7 +546,6 @@ class KeyedUnifiedLock:
current_time - cleanup_time current_time - cleanup_time
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS > CLEANUP_KEYED_LOCKS_AFTER_SECONDS
): ):
# Clean expired async locks
self._async_lock.pop(cleanup_key) self._async_lock.pop(cleanup_key)
self._async_lock_count.pop(cleanup_key) self._async_lock_count.pop(cleanup_key)
self._async_lock_cleanup_data.pop(cleanup_key) self._async_lock_cleanup_data.pop(cleanup_key)
@ -588,23 +588,28 @@ class KeyedUnifiedLock:
) )
# Don't update _last_async_cleanup_time to allow retry # Don't update _last_async_cleanup_time to allow retry
def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock: def _get_lock_for_key(
# 1. get (or create) the perprocess async gate for this key self, namespace: str, key: str, enable_logging: bool = False
# Is synchronous, so no need to acquire a lock ) -> UnifiedLock:
async_lock = self._get_or_create_async_lock(key) # 1. Create combined key for this namespace:key combination
combined_key = _get_combined_key(namespace, key)
# 2. fetch the shared raw lock # 2. get (or create) the perprocess async gate for this combined key
raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key) # Is synchronous, so no need to acquire a lock
async_lock = self._get_or_create_async_lock(combined_key)
# 3. fetch the shared raw lock
raw_lock = _get_or_create_shared_raw_mp_lock(namespace, key)
is_multiprocess = raw_lock is not None is_multiprocess = raw_lock is not None
if not is_multiprocess: if not is_multiprocess:
raw_lock = async_lock raw_lock = async_lock
# 3. build a *fresh* UnifiedLock with the chosen logging flag # 4. build a *fresh* UnifiedLock with the chosen logging flag
if is_multiprocess: if is_multiprocess:
return UnifiedLock( return UnifiedLock(
lock=raw_lock, lock=raw_lock,
is_async=False, # manager.Lock is synchronous is_async=False, # manager.Lock is synchronous
name=_get_combined_key(self._factory_name, key), name=combined_key,
enable_logging=enable_logging, enable_logging=enable_logging,
async_lock=async_lock, # prevents eventloop blocking async_lock=async_lock, # prevents eventloop blocking
) )
@ -612,26 +617,27 @@ class KeyedUnifiedLock:
return UnifiedLock( return UnifiedLock(
lock=raw_lock, lock=raw_lock,
is_async=True, is_async=True,
name=_get_combined_key(self._factory_name, key), name=combined_key,
enable_logging=enable_logging, enable_logging=enable_logging,
async_lock=None, # No need for async lock in single process mode async_lock=None, # No need for async lock in single process mode
) )
def _release_lock_for_key(self, key: str): def _release_lock_for_key(self, namespace: str, key: str):
self._release_async_lock(key) combined_key = _get_combined_key(namespace, key)
_release_shared_raw_mp_lock(self._factory_name, key) self._release_async_lock(combined_key)
_release_shared_raw_mp_lock(namespace, key)
class _KeyedLockContext: class _KeyedLockContext:
def __init__( def __init__(
self, self,
parent: KeyedUnifiedLock, parent: KeyedUnifiedLock,
factory_name: str, namespace: str,
keys: list[str], keys: list[str],
enable_logging: bool, enable_logging: bool,
) -> None: ) -> None:
self._parent = parent self._parent = parent
self._factory_name = factory_name self._namespace = namespace
# The sorting is critical to ensure proper lock and release order # The sorting is critical to ensure proper lock and release order
# to avoid deadlocks # to avoid deadlocks
@ -648,23 +654,23 @@ class _KeyedLockContext:
if self._ul is not None: if self._ul is not None:
raise RuntimeError("KeyedUnifiedLock already acquired in current context") raise RuntimeError("KeyedUnifiedLock already acquired in current context")
# 4. acquire it # acquire locks for all keys in the namespace
self._ul = [] self._ul = []
for key in self._keys: for key in self._keys:
lock = self._parent._get_lock_for_key( lock = self._parent._get_lock_for_key(
key, enable_logging=self._enable_logging self._namespace, key, enable_logging=self._enable_logging
) )
await lock.__aenter__() await lock.__aenter__()
inc_debug_n_locks_acquired() inc_debug_n_locks_acquired()
self._ul.append(lock) self._ul.append(lock)
return self # or return self._key if you prefer return self
# ----- exit ----- # ----- exit -----
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
# The UnifiedLock takes care of proper release order # The UnifiedLock takes care of proper release order
for ul, key in zip(reversed(self._ul), reversed(self._keys)): for ul, key in zip(reversed(self._ul), reversed(self._keys)):
await ul.__aexit__(exc_type, exc, tb) await ul.__aexit__(exc_type, exc, tb)
self._parent._release_lock_for_key(key) self._parent._release_lock_for_key(self._namespace, key)
dec_debug_n_locks_acquired() dec_debug_n_locks_acquired()
self._ul = None self._ul = None
@ -717,16 +723,16 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
) )
def get_graph_db_lock_keyed( def get_storage_keyed_lock(
keys: str | list[str], enable_logging: bool = False keys: str | list[str], namespace: str = "default", enable_logging: bool = False
) -> KeyedUnifiedLock: ) -> _KeyedLockContext:
"""return unified graph database lock for ensuring atomic operations""" """Return unified storage keyed lock for ensuring atomic operations across different namespaces"""
global _graph_db_lock_keyed global _storage_keyed_lock
if _graph_db_lock_keyed is None: if _storage_keyed_lock is None:
raise RuntimeError("Shared-Data is not initialized") raise RuntimeError("Shared-Data is not initialized")
if isinstance(keys, str): if isinstance(keys, str):
keys = [keys] keys = [keys]
return _graph_db_lock_keyed(keys, enable_logging=enable_logging) return _storage_keyed_lock(namespace, keys, enable_logging=enable_logging)
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
@ -777,7 +783,7 @@ def initialize_share_data(workers: int = 1):
_initialized, \ _initialized, \
_update_flags, \ _update_flags, \
_async_locks, \ _async_locks, \
_graph_db_lock_keyed, \ _storage_keyed_lock, \
_earliest_mp_cleanup_time, \ _earliest_mp_cleanup_time, \
_last_mp_cleanup_time _last_mp_cleanup_time
@ -806,9 +812,7 @@ def initialize_share_data(workers: int = 1):
_init_flags = _manager.dict() _init_flags = _manager.dict()
_update_flags = _manager.dict() _update_flags = _manager.dict()
_graph_db_lock_keyed = KeyedUnifiedLock( _storage_keyed_lock = KeyedUnifiedLock()
factory_name="GraphDB",
)
# Initialize async locks for multiprocess mode # Initialize async locks for multiprocess mode
_async_locks = { _async_locks = {
@ -834,9 +838,7 @@ def initialize_share_data(workers: int = 1):
_update_flags = {} _update_flags = {}
_async_locks = None # No need for async locks in single process mode _async_locks = None # No need for async locks in single process mode
_graph_db_lock_keyed = KeyedUnifiedLock( _storage_keyed_lock = KeyedUnifiedLock()
factory_name="GraphDB",
)
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
# Initialize multiprocess cleanup times # Initialize multiprocess cleanup times

View file

@ -37,7 +37,7 @@ from .base import (
) )
from .prompt import PROMPTS from .prompt import PROMPTS
from .constants import GRAPH_FIELD_SEP from .constants import GRAPH_FIELD_SEP
from .kg.shared_storage import get_graph_db_lock_keyed from .kg.shared_storage import get_storage_keyed_lock
import time import time
from dotenv import load_dotenv from dotenv import load_dotenv
@ -1019,7 +1019,11 @@ async def _merge_edges_then_upsert(
if await knowledge_graph_inst.has_node(need_insert_id): if await knowledge_graph_inst.has_node(need_insert_id):
# This is so that the initial check for the existence of the node need not be locked # This is so that the initial check for the existence of the node need not be locked
continue continue
async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False): workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
async with get_storage_keyed_lock(
[need_insert_id], namespace=namespace, enable_logging=False
):
if not (await knowledge_graph_inst.has_node(need_insert_id)): if not (await knowledge_graph_inst.has_node(need_insert_id)):
# # Discard this edge if the node does not exist # # Discard this edge if the node does not exist
# if need_insert_id == src_id: # if need_insert_id == src_id:
@ -1162,7 +1166,11 @@ async def merge_nodes_and_edges(
async def _locked_process_entity_name(entity_name, entities): async def _locked_process_entity_name(entity_name, entities):
async with semaphore: async with semaphore:
async with get_graph_db_lock_keyed([entity_name], enable_logging=False): workspace = global_config.get("workspace", "")
namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
async with get_storage_keyed_lock(
[entity_name], namespace=namespace, enable_logging=False
):
entity_data = await _merge_nodes_then_upsert( entity_data = await _merge_nodes_then_upsert(
entity_name, entity_name,
entities, entities,
@ -1187,8 +1195,12 @@ async def merge_nodes_and_edges(
async def _locked_process_edges(edge_key, edges): async def _locked_process_edges(edge_key, edges):
async with semaphore: async with semaphore:
async with get_graph_db_lock_keyed( workspace = global_config.get("workspace", "")
f"{edge_key[0]}-{edge_key[1]}", enable_logging=False namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
async with get_storage_keyed_lock(
f"{edge_key[0]}-{edge_key[1]}",
namespace=namespace,
enable_logging=False,
): ):
edge_data = await _merge_edges_then_upsert( edge_data = await _merge_edges_then_upsert(
edge_key[0], edge_key[0],