Merge pull request #1770 from danielaskdd/merge_lock_with_key
Refac: Optimize keyed lock cleanup logic with time and size tracking
This commit is contained in:
commit
ad7d7d0854
3 changed files with 700 additions and 194 deletions
|
|
@ -1,23 +1,46 @@
|
|||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.synchronize import Lock as ProcessLock
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
||||
import time
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic
|
||||
|
||||
|
||||
# Define a direct print function for critical logs that must be visible in all processes
|
||||
def direct_log(message, level="INFO", enable_output: bool = True):
|
||||
def direct_log(message, enable_output: bool = True, level: str = "DEBUG"):
|
||||
"""
|
||||
Log a message directly to stderr to ensure visibility in all processes,
|
||||
including the Gunicorn master process.
|
||||
|
||||
Args:
|
||||
message: The message to log
|
||||
level: Log level (default: "INFO")
|
||||
level: Log level (default: "DEBUG")
|
||||
enable_output: Whether to actually output the log (default: True)
|
||||
"""
|
||||
if enable_output:
|
||||
# Get the current logger level from the lightrag logger
|
||||
try:
|
||||
from lightrag.utils import logger
|
||||
|
||||
current_level = logger.getEffectiveLevel()
|
||||
except ImportError:
|
||||
# Fallback if lightrag.utils is not available
|
||||
current_level = logging.INFO
|
||||
|
||||
# Convert string level to numeric level for comparison
|
||||
level_mapping = {
|
||||
"DEBUG": logging.DEBUG, # 10
|
||||
"INFO": logging.INFO, # 20
|
||||
"WARNING": logging.WARNING, # 30
|
||||
"ERROR": logging.ERROR, # 40
|
||||
"CRITICAL": logging.CRITICAL, # 50
|
||||
}
|
||||
message_level = level_mapping.get(level.upper(), logging.DEBUG)
|
||||
|
||||
# print(f"Diret_log: {level.upper()} {message_level} ? {current_level}", file=sys.stderr, flush=True)
|
||||
if enable_output or (message_level >= current_level):
|
||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
|
|
@ -27,6 +50,23 @@ LockType = Union[ProcessLock, asyncio.Lock]
|
|||
_is_multiprocess = None
|
||||
_workers = None
|
||||
_manager = None
|
||||
|
||||
# Global singleton data for multi-process keyed locks
|
||||
_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None
|
||||
_lock_registry_count: Optional[Dict[str, int]] = None
|
||||
_lock_cleanup_data: Optional[Dict[str, time.time]] = None
|
||||
_registry_guard = None
|
||||
# Timeout for keyed locks in seconds
|
||||
CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300
|
||||
# Threshold for triggering cleanup - only clean when pending list exceeds this size
|
||||
CLEANUP_THRESHOLD = 500
|
||||
# Minimum interval between cleanup operations in seconds
|
||||
MIN_CLEANUP_INTERVAL_SECONDS = 30
|
||||
# Track the earliest cleanup time for efficient cleanup triggering (multiprocess locks only)
|
||||
_earliest_mp_cleanup_time: Optional[float] = None
|
||||
# Track the last cleanup time to enforce minimum interval (multiprocess locks only)
|
||||
_last_mp_cleanup_time: Optional[float] = None
|
||||
|
||||
_initialized = None
|
||||
|
||||
# shared data for storage across processes
|
||||
|
|
@ -40,10 +80,37 @@ _internal_lock: Optional[LockType] = None
|
|||
_pipeline_status_lock: Optional[LockType] = None
|
||||
_graph_db_lock: Optional[LockType] = None
|
||||
_data_init_lock: Optional[LockType] = None
|
||||
# Manager for all keyed locks
|
||||
_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None
|
||||
|
||||
# async locks for coroutine synchronization in multiprocess mode
|
||||
_async_locks: Optional[Dict[str, asyncio.Lock]] = None
|
||||
|
||||
DEBUG_LOCKS = False
|
||||
_debug_n_locks_acquired: int = 0
|
||||
|
||||
|
||||
def inc_debug_n_locks_acquired():
|
||||
global _debug_n_locks_acquired
|
||||
if DEBUG_LOCKS:
|
||||
_debug_n_locks_acquired += 1
|
||||
print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}")
|
||||
|
||||
|
||||
def dec_debug_n_locks_acquired():
|
||||
global _debug_n_locks_acquired
|
||||
if DEBUG_LOCKS:
|
||||
if _debug_n_locks_acquired > 0:
|
||||
_debug_n_locks_acquired -= 1
|
||||
print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}")
|
||||
else:
|
||||
raise RuntimeError("Attempting to release lock when no locks are acquired")
|
||||
|
||||
|
||||
def get_debug_n_locks_acquired():
|
||||
global _debug_n_locks_acquired
|
||||
return _debug_n_locks_acquired
|
||||
|
||||
|
||||
class UnifiedLock(Generic[T]):
|
||||
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||
|
|
@ -65,17 +132,8 @@ class UnifiedLock(Generic[T]):
|
|||
|
||||
async def __aenter__(self) -> "UnifiedLock[T]":
|
||||
try:
|
||||
# direct_log(
|
||||
# f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
|
||||
# enable_output=self._enable_logging,
|
||||
# )
|
||||
|
||||
# If in multiprocess mode and async lock exists, acquire it first
|
||||
if not self._is_async and self._async_lock is not None:
|
||||
# direct_log(
|
||||
# f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'",
|
||||
# enable_output=self._enable_logging,
|
||||
# )
|
||||
await self._async_lock.acquire()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired",
|
||||
|
|
@ -210,6 +268,406 @@ class UnifiedLock(Generic[T]):
|
|||
)
|
||||
raise
|
||||
|
||||
def locked(self) -> bool:
|
||||
if self._is_async:
|
||||
return self._lock.locked()
|
||||
else:
|
||||
return self._lock.locked()
|
||||
|
||||
|
||||
def _get_combined_key(factory_name: str, key: str) -> str:
|
||||
"""Return the combined key for the factory and key."""
|
||||
return f"{factory_name}:{key}"
|
||||
|
||||
|
||||
def _get_or_create_shared_raw_mp_lock(
|
||||
factory_name: str, key: str
|
||||
) -> Optional[mp.synchronize.Lock]:
|
||||
"""Return the *singleton* manager.Lock() proxy for keyed lock, creating if needed."""
|
||||
if not _is_multiprocess:
|
||||
return None
|
||||
|
||||
with _registry_guard:
|
||||
combined_key = _get_combined_key(factory_name, key)
|
||||
raw = _lock_registry.get(combined_key)
|
||||
count = _lock_registry_count.get(combined_key)
|
||||
if raw is None:
|
||||
raw = _manager.Lock()
|
||||
_lock_registry[combined_key] = raw
|
||||
count = 0
|
||||
else:
|
||||
if count is None:
|
||||
raise RuntimeError(
|
||||
f"Shared-Data lock registry for {factory_name} is corrupted for key {key}"
|
||||
)
|
||||
if (
|
||||
count == 1 and combined_key in _lock_cleanup_data
|
||||
): # Reusing an key waiting for cleanup, remove it from cleanup list
|
||||
_lock_cleanup_data.pop(combined_key)
|
||||
count += 1
|
||||
_lock_registry_count[combined_key] = count
|
||||
return raw
|
||||
|
||||
|
||||
def _release_shared_raw_mp_lock(factory_name: str, key: str):
|
||||
"""Release the *singleton* manager.Lock() proxy for *key*."""
|
||||
if not _is_multiprocess:
|
||||
return
|
||||
|
||||
global _earliest_mp_cleanup_time, _last_mp_cleanup_time
|
||||
|
||||
with _registry_guard:
|
||||
combined_key = _get_combined_key(factory_name, key)
|
||||
raw = _lock_registry.get(combined_key)
|
||||
count = _lock_registry_count.get(combined_key)
|
||||
if raw is None and count is None:
|
||||
return
|
||||
elif raw is None or count is None:
|
||||
raise RuntimeError(
|
||||
f"Shared-Data lock registry for {factory_name} is corrupted for key {key}"
|
||||
)
|
||||
|
||||
count -= 1
|
||||
if count < 0:
|
||||
raise RuntimeError(
|
||||
f"Attempting to release lock for {key} more times than it was acquired"
|
||||
)
|
||||
|
||||
_lock_registry_count[combined_key] = count
|
||||
|
||||
current_time = time.time()
|
||||
if count == 0:
|
||||
_lock_cleanup_data[combined_key] = current_time
|
||||
|
||||
# Update earliest multiprocess cleanup time (only when earlier)
|
||||
if (
|
||||
_earliest_mp_cleanup_time is None
|
||||
or current_time < _earliest_mp_cleanup_time
|
||||
):
|
||||
_earliest_mp_cleanup_time = current_time
|
||||
|
||||
# Efficient cleanup triggering with minimum interval control
|
||||
total_cleanup_len = len(_lock_cleanup_data)
|
||||
if total_cleanup_len >= CLEANUP_THRESHOLD:
|
||||
# Time rollback detection
|
||||
if (
|
||||
_last_mp_cleanup_time is not None
|
||||
and current_time < _last_mp_cleanup_time
|
||||
):
|
||||
direct_log(
|
||||
"== mp Lock == Time rollback detected, resetting cleanup time",
|
||||
level="WARNING",
|
||||
enable_output=False,
|
||||
)
|
||||
_last_mp_cleanup_time = None
|
||||
|
||||
# Check cleanup conditions
|
||||
has_expired_locks = (
|
||||
_earliest_mp_cleanup_time is not None
|
||||
and current_time - _earliest_mp_cleanup_time
|
||||
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS
|
||||
)
|
||||
|
||||
interval_satisfied = (
|
||||
_last_mp_cleanup_time is None
|
||||
or current_time - _last_mp_cleanup_time > MIN_CLEANUP_INTERVAL_SECONDS
|
||||
)
|
||||
|
||||
if has_expired_locks and interval_satisfied:
|
||||
try:
|
||||
cleaned_count = 0
|
||||
new_earliest_time = None
|
||||
|
||||
# Perform cleanup while maintaining the new earliest time
|
||||
for cleanup_key, cleanup_time in list(_lock_cleanup_data.items()):
|
||||
if (
|
||||
current_time - cleanup_time
|
||||
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS
|
||||
):
|
||||
# Clean expired locks
|
||||
_lock_registry.pop(cleanup_key, None)
|
||||
_lock_registry_count.pop(cleanup_key, None)
|
||||
_lock_cleanup_data.pop(cleanup_key, None)
|
||||
cleaned_count += 1
|
||||
else:
|
||||
# Track the earliest time among remaining locks
|
||||
if (
|
||||
new_earliest_time is None
|
||||
or cleanup_time < new_earliest_time
|
||||
):
|
||||
new_earliest_time = cleanup_time
|
||||
|
||||
# Update state only after successful cleanup
|
||||
_earliest_mp_cleanup_time = new_earliest_time
|
||||
_last_mp_cleanup_time = current_time
|
||||
|
||||
if cleaned_count > 0:
|
||||
next_cleanup_in = max(
|
||||
(
|
||||
new_earliest_time
|
||||
+ CLEANUP_KEYED_LOCKS_AFTER_SECONDS
|
||||
- current_time
|
||||
)
|
||||
if new_earliest_time
|
||||
else float("inf"),
|
||||
MIN_CLEANUP_INTERVAL_SECONDS,
|
||||
)
|
||||
direct_log(
|
||||
f"== mp Lock == Cleaned up {cleaned_count}/{total_cleanup_len} expired locks, "
|
||||
f"next cleanup in {next_cleanup_in:.1f}s",
|
||||
enable_output=False,
|
||||
level="INFO",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== mp Lock == Cleanup failed: {e}",
|
||||
level="ERROR",
|
||||
enable_output=False,
|
||||
)
|
||||
# Don't update _last_mp_cleanup_time to allow retry
|
||||
|
||||
|
||||
class KeyedUnifiedLock:
|
||||
"""
|
||||
Manager for unified keyed locks, supporting both single and multi-process
|
||||
|
||||
• Keeps only a table of async keyed locks locally
|
||||
• Fetches the multi-process keyed lockon every acquire
|
||||
• Builds a fresh `UnifiedLock` each time, so `enable_logging`
|
||||
(or future options) can vary per call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, factory_name: str, *, default_enable_logging: bool = True
|
||||
) -> None:
|
||||
self._factory_name = factory_name
|
||||
self._default_enable_logging = default_enable_logging
|
||||
self._async_lock: Dict[str, asyncio.Lock] = {} # local keyed locks
|
||||
self._async_lock_count: Dict[
|
||||
str, int
|
||||
] = {} # local keyed locks referenced count
|
||||
self._async_lock_cleanup_data: Dict[
|
||||
str, time.time
|
||||
] = {} # local keyed locks timeout
|
||||
self._mp_locks: Dict[
|
||||
str, mp.synchronize.Lock
|
||||
] = {} # multi-process lock proxies
|
||||
self._earliest_async_cleanup_time: Optional[float] = (
|
||||
None # track earliest async cleanup time
|
||||
)
|
||||
self._last_async_cleanup_time: Optional[float] = (
|
||||
None # track last async cleanup time for minimum interval
|
||||
)
|
||||
|
||||
def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None):
|
||||
"""
|
||||
Ergonomic helper so you can write:
|
||||
|
||||
async with keyed_locks("alpha"):
|
||||
...
|
||||
"""
|
||||
if enable_logging is None:
|
||||
enable_logging = self._default_enable_logging
|
||||
return _KeyedLockContext(
|
||||
self,
|
||||
factory_name=self._factory_name,
|
||||
keys=keys,
|
||||
enable_logging=enable_logging,
|
||||
)
|
||||
|
||||
def _get_or_create_async_lock(self, key: str) -> asyncio.Lock:
|
||||
async_lock = self._async_lock.get(key)
|
||||
count = self._async_lock_count.get(key, 0)
|
||||
if async_lock is None:
|
||||
async_lock = asyncio.Lock()
|
||||
self._async_lock[key] = async_lock
|
||||
elif count == 0 and key in self._async_lock_cleanup_data:
|
||||
self._async_lock_cleanup_data.pop(key)
|
||||
count += 1
|
||||
self._async_lock_count[key] = count
|
||||
return async_lock
|
||||
|
||||
def _release_async_lock(self, key: str):
|
||||
count = self._async_lock_count.get(key, 0)
|
||||
count -= 1
|
||||
|
||||
current_time = time.time()
|
||||
if count == 0:
|
||||
self._async_lock_cleanup_data[key] = current_time
|
||||
|
||||
# Update earliest async cleanup time (only when earlier)
|
||||
if (
|
||||
self._earliest_async_cleanup_time is None
|
||||
or current_time < self._earliest_async_cleanup_time
|
||||
):
|
||||
self._earliest_async_cleanup_time = current_time
|
||||
self._async_lock_count[key] = count
|
||||
|
||||
# Efficient cleanup triggering with minimum interval control
|
||||
total_cleanup_len = len(self._async_lock_cleanup_data)
|
||||
if total_cleanup_len >= CLEANUP_THRESHOLD:
|
||||
# Time rollback detection
|
||||
if (
|
||||
self._last_async_cleanup_time is not None
|
||||
and current_time < self._last_async_cleanup_time
|
||||
):
|
||||
direct_log(
|
||||
"== async Lock == Time rollback detected, resetting cleanup time",
|
||||
level="WARNING",
|
||||
enable_output=False,
|
||||
)
|
||||
self._last_async_cleanup_time = None
|
||||
|
||||
# Check cleanup conditions
|
||||
has_expired_locks = (
|
||||
self._earliest_async_cleanup_time is not None
|
||||
and current_time - self._earliest_async_cleanup_time
|
||||
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS
|
||||
)
|
||||
|
||||
interval_satisfied = (
|
||||
self._last_async_cleanup_time is None
|
||||
or current_time - self._last_async_cleanup_time
|
||||
> MIN_CLEANUP_INTERVAL_SECONDS
|
||||
)
|
||||
|
||||
if has_expired_locks and interval_satisfied:
|
||||
try:
|
||||
cleaned_count = 0
|
||||
new_earliest_time = None
|
||||
|
||||
# Perform cleanup while maintaining the new earliest time
|
||||
for cleanup_key, cleanup_time in list(
|
||||
self._async_lock_cleanup_data.items()
|
||||
):
|
||||
if (
|
||||
current_time - cleanup_time
|
||||
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS
|
||||
):
|
||||
# Clean expired async locks
|
||||
self._async_lock.pop(cleanup_key)
|
||||
self._async_lock_count.pop(cleanup_key)
|
||||
self._async_lock_cleanup_data.pop(cleanup_key)
|
||||
cleaned_count += 1
|
||||
else:
|
||||
# Track the earliest time among remaining locks
|
||||
if (
|
||||
new_earliest_time is None
|
||||
or cleanup_time < new_earliest_time
|
||||
):
|
||||
new_earliest_time = cleanup_time
|
||||
|
||||
# Update state only after successful cleanup
|
||||
self._earliest_async_cleanup_time = new_earliest_time
|
||||
self._last_async_cleanup_time = current_time
|
||||
|
||||
if cleaned_count > 0:
|
||||
next_cleanup_in = max(
|
||||
(
|
||||
new_earliest_time
|
||||
+ CLEANUP_KEYED_LOCKS_AFTER_SECONDS
|
||||
- current_time
|
||||
)
|
||||
if new_earliest_time
|
||||
else float("inf"),
|
||||
MIN_CLEANUP_INTERVAL_SECONDS,
|
||||
)
|
||||
direct_log(
|
||||
f"== async Lock == Cleaned up {cleaned_count}/{total_cleanup_len} expired async locks, "
|
||||
f"next cleanup in {next_cleanup_in:.1f}s",
|
||||
enable_output=False,
|
||||
level="INFO",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"== async Lock == Cleanup failed: {e}",
|
||||
level="ERROR",
|
||||
enable_output=False,
|
||||
)
|
||||
# Don't update _last_async_cleanup_time to allow retry
|
||||
|
||||
def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock:
|
||||
# 1. get (or create) the per‑process async gate for this key
|
||||
# Is synchronous, so no need to acquire a lock
|
||||
async_lock = self._get_or_create_async_lock(key)
|
||||
|
||||
# 2. fetch the shared raw lock
|
||||
raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key)
|
||||
is_multiprocess = raw_lock is not None
|
||||
if not is_multiprocess:
|
||||
raw_lock = async_lock
|
||||
|
||||
# 3. build a *fresh* UnifiedLock with the chosen logging flag
|
||||
if is_multiprocess:
|
||||
return UnifiedLock(
|
||||
lock=raw_lock,
|
||||
is_async=False, # manager.Lock is synchronous
|
||||
name=_get_combined_key(self._factory_name, key),
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock, # prevents event‑loop blocking
|
||||
)
|
||||
else:
|
||||
return UnifiedLock(
|
||||
lock=raw_lock,
|
||||
is_async=True,
|
||||
name=_get_combined_key(self._factory_name, key),
|
||||
enable_logging=enable_logging,
|
||||
async_lock=None, # No need for async lock in single process mode
|
||||
)
|
||||
|
||||
def _release_lock_for_key(self, key: str):
|
||||
self._release_async_lock(key)
|
||||
_release_shared_raw_mp_lock(self._factory_name, key)
|
||||
|
||||
|
||||
class _KeyedLockContext:
|
||||
def __init__(
|
||||
self,
|
||||
parent: KeyedUnifiedLock,
|
||||
factory_name: str,
|
||||
keys: list[str],
|
||||
enable_logging: bool,
|
||||
) -> None:
|
||||
self._parent = parent
|
||||
self._factory_name = factory_name
|
||||
|
||||
# The sorting is critical to ensure proper lock and release order
|
||||
# to avoid deadlocks
|
||||
self._keys = sorted(keys)
|
||||
self._enable_logging = (
|
||||
enable_logging
|
||||
if enable_logging is not None
|
||||
else parent._default_enable_logging
|
||||
)
|
||||
self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__
|
||||
|
||||
# ----- enter -----
|
||||
async def __aenter__(self):
|
||||
if self._ul is not None:
|
||||
raise RuntimeError("KeyedUnifiedLock already acquired in current context")
|
||||
|
||||
# 4. acquire it
|
||||
self._ul = []
|
||||
for key in self._keys:
|
||||
lock = self._parent._get_lock_for_key(
|
||||
key, enable_logging=self._enable_logging
|
||||
)
|
||||
await lock.__aenter__()
|
||||
inc_debug_n_locks_acquired()
|
||||
self._ul.append(lock)
|
||||
return self # or return self._key if you prefer
|
||||
|
||||
# ----- exit -----
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
# The UnifiedLock takes care of proper release order
|
||||
for ul, key in zip(reversed(self._ul), reversed(self._keys)):
|
||||
await ul.__aexit__(exc_type, exc, tb)
|
||||
self._parent._release_lock_for_key(key)
|
||||
dec_debug_n_locks_acquired()
|
||||
self._ul = None
|
||||
|
||||
|
||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
|
|
@ -259,6 +717,18 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
|||
)
|
||||
|
||||
|
||||
def get_graph_db_lock_keyed(
|
||||
keys: str | list[str], enable_logging: bool = False
|
||||
) -> KeyedUnifiedLock:
|
||||
"""return unified graph database lock for ensuring atomic operations"""
|
||||
global _graph_db_lock_keyed
|
||||
if _graph_db_lock_keyed is None:
|
||||
raise RuntimeError("Shared-Data is not initialized")
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
return _graph_db_lock_keyed(keys, enable_logging=enable_logging)
|
||||
|
||||
|
||||
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
|
||||
|
|
@ -294,6 +764,10 @@ def initialize_share_data(workers: int = 1):
|
|||
_workers, \
|
||||
_is_multiprocess, \
|
||||
_storage_lock, \
|
||||
_lock_registry, \
|
||||
_lock_registry_count, \
|
||||
_lock_cleanup_data, \
|
||||
_registry_guard, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_graph_db_lock, \
|
||||
|
|
@ -302,7 +776,10 @@ def initialize_share_data(workers: int = 1):
|
|||
_init_flags, \
|
||||
_initialized, \
|
||||
_update_flags, \
|
||||
_async_locks
|
||||
_async_locks, \
|
||||
_graph_db_lock_keyed, \
|
||||
_earliest_mp_cleanup_time, \
|
||||
_last_mp_cleanup_time
|
||||
|
||||
# Check if already initialized
|
||||
if _initialized:
|
||||
|
|
@ -316,6 +793,10 @@ def initialize_share_data(workers: int = 1):
|
|||
if workers > 1:
|
||||
_is_multiprocess = True
|
||||
_manager = Manager()
|
||||
_lock_registry = _manager.dict()
|
||||
_lock_registry_count = _manager.dict()
|
||||
_lock_cleanup_data = _manager.dict()
|
||||
_registry_guard = _manager.RLock()
|
||||
_internal_lock = _manager.Lock()
|
||||
_storage_lock = _manager.Lock()
|
||||
_pipeline_status_lock = _manager.Lock()
|
||||
|
|
@ -325,6 +806,10 @@ def initialize_share_data(workers: int = 1):
|
|||
_init_flags = _manager.dict()
|
||||
_update_flags = _manager.dict()
|
||||
|
||||
_graph_db_lock_keyed = KeyedUnifiedLock(
|
||||
factory_name="GraphDB",
|
||||
)
|
||||
|
||||
# Initialize async locks for multiprocess mode
|
||||
_async_locks = {
|
||||
"internal_lock": asyncio.Lock(),
|
||||
|
|
@ -348,8 +833,16 @@ def initialize_share_data(workers: int = 1):
|
|||
_init_flags = {}
|
||||
_update_flags = {}
|
||||
_async_locks = None # No need for async locks in single process mode
|
||||
|
||||
_graph_db_lock_keyed = KeyedUnifiedLock(
|
||||
factory_name="GraphDB",
|
||||
)
|
||||
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
||||
|
||||
# Initialize multiprocess cleanup times
|
||||
_earliest_mp_cleanup_time = None
|
||||
_last_mp_cleanup_time = None
|
||||
|
||||
# Mark as initialized
|
||||
_initialized = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1094,86 +1094,89 @@ class LightRAG:
|
|||
}
|
||||
)
|
||||
|
||||
# Semphore released, concurrency controlled by graph_db_lock in merge_nodes_and_edges instead
|
||||
|
||||
if file_extraction_stage_ok:
|
||||
try:
|
||||
# Get chunk_results from entity_relation_task
|
||||
chunk_results = await entity_relation_task
|
||||
await merge_nodes_and_edges(
|
||||
chunk_results=chunk_results, # result collected from entity_relation_task
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
pipeline_status=pipeline_status,
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
current_file_number=current_file_number,
|
||||
total_files=total_files,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"chunks_list": list(
|
||||
chunks.keys()
|
||||
), # 保留 chunks_list
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now(
|
||||
timezone.utc
|
||||
).isoformat(),
|
||||
"file_path": file_path,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Call _insert_done after processing each file
|
||||
await self._insert_done()
|
||||
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
except Exception as e:
|
||||
# Log error and update pipeline status
|
||||
logger.error(traceback.format_exc())
|
||||
error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}"
|
||||
logger.error(error_msg)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = error_msg
|
||||
pipeline_status["history_messages"].append(
|
||||
traceback.format_exc()
|
||||
# Concurrency is controlled by graph db lock for individual entities and relationships
|
||||
if file_extraction_stage_ok:
|
||||
try:
|
||||
# Get chunk_results from entity_relation_task
|
||||
chunk_results = await entity_relation_task
|
||||
await merge_nodes_and_edges(
|
||||
chunk_results=chunk_results, # result collected from entity_relation_task
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
pipeline_status=pipeline_status,
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
llm_response_cache=self.llm_response_cache,
|
||||
current_file_number=current_file_number,
|
||||
total_files=total_files,
|
||||
file_path=file_path,
|
||||
)
|
||||
pipeline_status["history_messages"].append(error_msg)
|
||||
|
||||
# Persistent llm cache
|
||||
if self.llm_response_cache:
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
|
||||
# Update document status to failed
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"file_path": file_path,
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"chunks_list": list(
|
||||
chunks.keys()
|
||||
), # 保留 chunks_list
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now(
|
||||
timezone.utc
|
||||
).isoformat(),
|
||||
"file_path": file_path,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Call _insert_done after processing each file
|
||||
await self._insert_done()
|
||||
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(
|
||||
log_message
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log error and update pipeline status
|
||||
logger.error(traceback.format_exc())
|
||||
error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}"
|
||||
logger.error(error_msg)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = error_msg
|
||||
pipeline_status["history_messages"].append(
|
||||
traceback.format_exc()
|
||||
)
|
||||
pipeline_status["history_messages"].append(
|
||||
error_msg
|
||||
)
|
||||
|
||||
# Persistent llm cache
|
||||
if self.llm_response_cache:
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
|
||||
# Update document status to failed
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"file_path": file_path,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Create processing tasks for all documents
|
||||
doc_tasks = []
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from .base import (
|
|||
)
|
||||
from .prompt import PROMPTS
|
||||
from .constants import GRAPH_FIELD_SEP
|
||||
from .kg.shared_storage import get_graph_db_lock_keyed
|
||||
import time
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
|
@ -1015,28 +1016,32 @@ async def _merge_edges_then_upsert(
|
|||
)
|
||||
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
# # Discard this edge if the node does not exist
|
||||
# if need_insert_id == src_id:
|
||||
# logger.warning(
|
||||
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
|
||||
# )
|
||||
# else:
|
||||
# logger.warning(
|
||||
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
|
||||
# )
|
||||
# return None
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
need_insert_id,
|
||||
node_data={
|
||||
"entity_id": need_insert_id,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": "UNKNOWN",
|
||||
"file_path": file_path,
|
||||
"created_at": int(time.time()),
|
||||
},
|
||||
)
|
||||
if await knowledge_graph_inst.has_node(need_insert_id):
|
||||
# This is so that the initial check for the existence of the node need not be locked
|
||||
continue
|
||||
async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False):
|
||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
# # Discard this edge if the node does not exist
|
||||
# if need_insert_id == src_id:
|
||||
# logger.warning(
|
||||
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
|
||||
# )
|
||||
# else:
|
||||
# logger.warning(
|
||||
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
|
||||
# )
|
||||
# return None
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
need_insert_id,
|
||||
node_data={
|
||||
"entity_id": need_insert_id,
|
||||
"source_id": source_id,
|
||||
"description": description,
|
||||
"entity_type": "UNKNOWN",
|
||||
"file_path": file_path,
|
||||
"created_at": int(time.time()),
|
||||
},
|
||||
)
|
||||
|
||||
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
||||
|
||||
|
|
@ -1118,8 +1123,6 @@ async def merge_nodes_and_edges(
|
|||
pipeline_status_lock: Lock for pipeline status
|
||||
llm_response_cache: LLM response cache
|
||||
"""
|
||||
# Get lock manager from shared storage
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
|
||||
# Collect all nodes and edges from all chunks
|
||||
all_nodes = defaultdict(list)
|
||||
|
|
@ -1136,94 +1139,101 @@ async def merge_nodes_and_edges(
|
|||
all_edges[sorted_edge_key].extend(edges)
|
||||
|
||||
# Centralized processing of all nodes and edges
|
||||
entities_data = []
|
||||
relationships_data = []
|
||||
total_entities_count = len(all_nodes)
|
||||
total_relations_count = len(all_edges)
|
||||
|
||||
# Merge nodes and edges
|
||||
# Use graph database lock to ensure atomic merges and updates
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
async with graph_db_lock:
|
||||
async with pipeline_status_lock:
|
||||
log_message = (
|
||||
f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
||||
)
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Process and update all entities at once
|
||||
for entity_name, entities in all_nodes.items():
|
||||
entity_data = await _merge_nodes_then_upsert(
|
||||
entity_name,
|
||||
entities,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
entities_data.append(entity_data)
|
||||
# Process and update all entities and relationships in parallel
|
||||
log_message = f"Processing: {total_entities_count} entities and {total_relations_count} relations"
|
||||
logger.info(log_message)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# Process and update all relationships at once
|
||||
for edge_key, edges in all_edges.items():
|
||||
edge_data = await _merge_edges_then_upsert(
|
||||
edge_key[0],
|
||||
edge_key[1],
|
||||
edges,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
if edge_data is not None:
|
||||
relationships_data.append(edge_data)
|
||||
# Get max async tasks limit from global_config for semaphore control
|
||||
llm_model_max_async = global_config.get("llm_model_max_async", 4)
|
||||
semaphore = asyncio.Semaphore(llm_model_max_async)
|
||||
|
||||
# Update total counts
|
||||
total_entities_count = len(entities_data)
|
||||
total_relations_count = len(relationships_data)
|
||||
async def _locked_process_entity_name(entity_name, entities):
|
||||
async with semaphore:
|
||||
async with get_graph_db_lock_keyed([entity_name], enable_logging=False):
|
||||
entity_data = await _merge_nodes_then_upsert(
|
||||
entity_name,
|
||||
entities,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
if entity_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): {
|
||||
"entity_name": entity_data["entity_name"],
|
||||
"entity_type": entity_data["entity_type"],
|
||||
"content": f"{entity_data['entity_name']}\n{entity_data['description']}",
|
||||
"source_id": entity_data["source_id"],
|
||||
"file_path": entity_data.get("file_path", "unknown_source"),
|
||||
}
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
return entity_data
|
||||
|
||||
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
async def _locked_process_edges(edge_key, edges):
|
||||
async with semaphore:
|
||||
async with get_graph_db_lock_keyed(
|
||||
f"{edge_key[0]}-{edge_key[1]}", enable_logging=False
|
||||
):
|
||||
edge_data = await _merge_edges_then_upsert(
|
||||
edge_key[0],
|
||||
edge_key[1],
|
||||
edges,
|
||||
knowledge_graph_inst,
|
||||
global_config,
|
||||
pipeline_status,
|
||||
pipeline_status_lock,
|
||||
llm_response_cache,
|
||||
)
|
||||
if edge_data is None:
|
||||
return None
|
||||
|
||||
# Update vector databases with all collected data
|
||||
if entity_vdb is not None and entities_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
||||
"entity_name": dp["entity_name"],
|
||||
"entity_type": dp["entity_type"],
|
||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in entities_data
|
||||
}
|
||||
await entity_vdb.upsert(data_for_vdb)
|
||||
if relationships_vdb is not None:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(
|
||||
edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"
|
||||
): {
|
||||
"src_id": edge_data["src_id"],
|
||||
"tgt_id": edge_data["tgt_id"],
|
||||
"keywords": edge_data["keywords"],
|
||||
"content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}",
|
||||
"source_id": edge_data["source_id"],
|
||||
"file_path": edge_data.get("file_path", "unknown_source"),
|
||||
}
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
return edge_data
|
||||
|
||||
log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
|
||||
logger.info(log_message)
|
||||
if pipeline_status is not None:
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
# Create a single task queue for both entities and edges
|
||||
tasks = []
|
||||
|
||||
if relationships_vdb is not None and relationships_data:
|
||||
data_for_vdb = {
|
||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||
"src_id": dp["src_id"],
|
||||
"tgt_id": dp["tgt_id"],
|
||||
"keywords": dp["keywords"],
|
||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
||||
"source_id": dp["source_id"],
|
||||
"file_path": dp.get("file_path", "unknown_source"),
|
||||
}
|
||||
for dp in relationships_data
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
# Add entity processing tasks
|
||||
for entity_name, entities in all_nodes.items():
|
||||
tasks.append(
|
||||
asyncio.create_task(_locked_process_entity_name(entity_name, entities))
|
||||
)
|
||||
|
||||
# Add edge processing tasks
|
||||
for edge_key, edges in all_edges.items():
|
||||
tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges)))
|
||||
|
||||
# Execute all tasks in parallel with semaphore control
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
async def extract_entities(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue