diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 6f36f2c4..d5780f2e 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,9 +1,12 @@ +from collections import defaultdict import os import sys import asyncio +import multiprocessing as mp from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing import Manager -from typing import Any, Dict, Optional, Union, TypeVar, Generic +import time +from typing import Any, Callable, Dict, List, Optional, Union, TypeVar, Generic # Define a direct print function for critical logs that must be visible in all processes @@ -27,8 +30,14 @@ LockType = Union[ProcessLock, asyncio.Lock] _is_multiprocess = None _workers = None _manager = None +_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None +_lock_registry_count: Optional[Dict[str, int]] = None +_lock_cleanup_data: Optional[Dict[str, time.time]] = None +_registry_guard = None _initialized = None +CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300 + # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized @@ -40,10 +49,31 @@ _internal_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None _graph_db_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None +_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None # async locks for coroutine synchronization in multiprocess mode _async_locks: Optional[Dict[str, asyncio.Lock]] = None +DEBUG_LOCKS = False +_debug_n_locks_acquired: int = 0 +def inc_debug_n_locks_acquired(): + global _debug_n_locks_acquired + if DEBUG_LOCKS: + _debug_n_locks_acquired += 1 + print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True) + +def dec_debug_n_locks_acquired(): + global _debug_n_locks_acquired + if DEBUG_LOCKS: + if _debug_n_locks_acquired > 0: + _debug_n_locks_acquired -= 1 + print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True) + else: + raise RuntimeError("Attempting to release lock when no locks are acquired") + +def get_debug_n_locks_acquired(): + global _debug_n_locks_acquired + return _debug_n_locks_acquired class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" @@ -210,6 +240,207 @@ class UnifiedLock(Generic[T]): ) raise + def locked(self) -> bool: + if self._is_async: + return self._lock.locked() + else: + return self._lock.locked() + +# ───────────────────────────────────────────────────────────────────────────── +# 2. CROSS‑PROCESS FACTORY (one manager.Lock shared by *all* processes) +# ───────────────────────────────────────────────────────────────────────────── +def _get_combined_key(factory_name: str, key: str) -> str: + """Return the combined key for the factory and key.""" + return f"{factory_name}:{key}" + +def _get_or_create_shared_raw_mp_lock(factory_name: str, key: str) -> Optional[mp.synchronize.Lock]: + """Return the *singleton* manager.Lock() proxy for *key*, creating if needed.""" + if not _is_multiprocess: + return None + + with _registry_guard: + combined_key = _get_combined_key(factory_name, key) + raw = _lock_registry.get(combined_key) + count = _lock_registry_count.get(combined_key) + if raw is None: + raw = _manager.Lock() + _lock_registry[combined_key] = raw + _lock_registry_count[combined_key] = 0 + else: + if count is None: + raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}") + count += 1 + _lock_registry_count[combined_key] = count + if count == 1 and combined_key in _lock_cleanup_data: + _lock_cleanup_data.pop(combined_key) + return raw + + +def _release_shared_raw_mp_lock(factory_name: str, key: str): + """Release the *singleton* manager.Lock() proxy for *key*.""" + if not _is_multiprocess: + return + + with _registry_guard: + combined_key = _get_combined_key(factory_name, key) + raw = _lock_registry.get(combined_key) + count = _lock_registry_count.get(combined_key) + if raw is None and count is None: + return + elif raw is None or count is None: + raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}") + + count -= 1 + if count < 0: + raise RuntimeError(f"Attempting to remove lock for {key} but it is not in the registry") + else: + _lock_registry_count[combined_key] = count + + if count == 0: + _lock_cleanup_data[combined_key] = time.time() + + for combined_key, value in list(_lock_cleanup_data.items()): + if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS: + _lock_registry.pop(combined_key) + _lock_registry_count.pop(combined_key) + _lock_cleanup_data.pop(combined_key) + + +# ───────────────────────────────────────────────────────────────────────────── +# 3. PARAMETER‑KEYED WRAPPER (unchanged except it *accepts a factory*) +# ───────────────────────────────────────────────────────────────────────────── +class KeyedUnifiedLock: + """ + Parameter‑keyed wrapper around `UnifiedLock`. + + • Keeps only a table of per‑key *asyncio* gates locally + • Fetches the shared process‑wide mutex on *every* acquire + • Builds a fresh `UnifiedLock` each time, so `enable_logging` + (or future options) can vary per call. + """ + + # ---------------- construction ---------------- + def __init__(self, factory_name: str, *, default_enable_logging: bool = True) -> None: + self._factory_name = factory_name + self._default_enable_logging = default_enable_logging + self._async_lock: Dict[str, asyncio.Lock] = {} # key → asyncio.Lock + self._async_lock_count: Dict[str, int] = {} # key → asyncio.Lock count + self._async_lock_cleanup_data: Dict[str, time.time] = {} # key → time.time + self._mp_locks: Dict[str, mp.synchronize.Lock] = {} # key → mp.synchronize.Lock + + # ---------------- public API ------------------ + def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None): + """ + Ergonomic helper so you can write: + + async with keyed_locks("alpha"): + ... + """ + if enable_logging is None: + enable_logging = self._default_enable_logging + return _KeyedLockContext(self, factory_name=self._factory_name, keys=keys, enable_logging=enable_logging) + + def _get_or_create_async_lock(self, key: str) -> asyncio.Lock: + async_lock = self._async_lock.get(key) + count = self._async_lock_count.get(key, 0) + if async_lock is None: + async_lock = asyncio.Lock() + self._async_lock[key] = async_lock + elif count == 0 and key in self._async_lock_cleanup_data: + self._async_lock_cleanup_data.pop(key) + count += 1 + self._async_lock_count[key] = count + return async_lock + + def _release_async_lock(self, key: str): + count = self._async_lock_count.get(key, 0) + count -= 1 + if count == 0: + self._async_lock_cleanup_data[key] = time.time() + self._async_lock_count[key] = count + + for key, value in list(self._async_lock_cleanup_data.items()): + if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS: + self._async_lock.pop(key) + self._async_lock_count.pop(key) + self._async_lock_cleanup_data.pop(key) + + + def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock: + # 1. get (or create) the per‑process async gate for this key + # Is synchronous, so no need to acquire a lock + async_lock = self._get_or_create_async_lock(key) + + # 2. fetch the shared raw lock + raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key) + is_multiprocess = raw_lock is not None + if not is_multiprocess: + raw_lock = async_lock + + # 3. build a *fresh* UnifiedLock with the chosen logging flag + if is_multiprocess: + return UnifiedLock( + lock=raw_lock, + is_async=False, # manager.Lock is synchronous + name=f"key:{self._factory_name}:{key}", + enable_logging=enable_logging, + async_lock=async_lock, # prevents event‑loop blocking + ) + else: + return UnifiedLock( + lock=raw_lock, + is_async=True, + name=f"key:{self._factory_name}:{key}", + enable_logging=enable_logging, + async_lock=None, # No need for async lock in single process mode + ) + + def _release_lock_for_key(self, key: str): + self._release_async_lock(key) + _release_shared_raw_mp_lock(self._factory_name, key) + +class _KeyedLockContext: + def __init__( + self, + parent: KeyedUnifiedLock, + factory_name: str, + keys: list[str], + enable_logging: bool, + ) -> None: + self._parent = parent + self._factory_name = factory_name + + # The sorting is critical to ensure proper lock and release order + # to avoid deadlocks + self._keys = sorted(keys) + self._enable_logging = ( + enable_logging if enable_logging is not None + else parent._default_enable_logging + ) + self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__ + + # ----- enter ----- + async def __aenter__(self): + if self._ul is not None: + raise RuntimeError("KeyedUnifiedLock already acquired in current context") + + # 4. acquire it + self._ul = [] + for key in self._keys: + lock = self._parent._get_lock_for_key(key, enable_logging=self._enable_logging) + await lock.__aenter__() + inc_debug_n_locks_acquired() + self._ul.append(lock) + return self # or return self._key if you prefer + + # ----- exit ----- + async def __aexit__(self, exc_type, exc, tb): + # The UnifiedLock takes care of proper release order + for ul, key in zip(reversed(self._ul), reversed(self._keys)): + await ul.__aexit__(exc_type, exc, tb) + self._parent._release_lock_for_key(key) + dec_debug_n_locks_acquired() + self._ul = None def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" @@ -258,6 +489,14 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: async_lock=async_lock, ) +def get_graph_db_lock_keyed(keys: str | list[str], enable_logging: bool = False) -> KeyedUnifiedLock: + """return unified graph database lock for ensuring atomic operations""" + global _graph_db_lock_keyed + if _graph_db_lock_keyed is None: + raise RuntimeError("Shared-Data is not initialized") + if isinstance(keys, str): + keys = [keys] + return _graph_db_lock_keyed(keys, enable_logging=enable_logging) def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: """return unified data initialization lock for ensuring atomic data initialization""" @@ -294,6 +533,10 @@ def initialize_share_data(workers: int = 1): _workers, \ _is_multiprocess, \ _storage_lock, \ + _lock_registry, \ + _lock_registry_count, \ + _lock_cleanup_data, \ + _registry_guard, \ _internal_lock, \ _pipeline_status_lock, \ _graph_db_lock, \ @@ -302,7 +545,8 @@ def initialize_share_data(workers: int = 1): _init_flags, \ _initialized, \ _update_flags, \ - _async_locks + _async_locks, \ + _graph_db_lock_keyed # Check if already initialized if _initialized: @@ -316,6 +560,10 @@ def initialize_share_data(workers: int = 1): if workers > 1: _is_multiprocess = True _manager = Manager() + _lock_registry = _manager.dict() + _lock_registry_count = _manager.dict() + _lock_cleanup_data = _manager.dict() + _registry_guard = _manager.RLock() _internal_lock = _manager.Lock() _storage_lock = _manager.Lock() _pipeline_status_lock = _manager.Lock() @@ -324,6 +572,10 @@ def initialize_share_data(workers: int = 1): _shared_dicts = _manager.dict() _init_flags = _manager.dict() _update_flags = _manager.dict() + + _graph_db_lock_keyed = KeyedUnifiedLock( + factory_name="graph_db_lock", + ) # Initialize async locks for multiprocess mode _async_locks = { @@ -348,6 +600,10 @@ def initialize_share_data(workers: int = 1): _init_flags = {} _update_flags = {} _async_locks = None # No need for async locks in single process mode + + _graph_db_lock_keyed = KeyedUnifiedLock( + factory_name="graph_db_lock", + ) direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") # Mark as initialized diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7a79da31..d4c28b17 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1024,73 +1024,73 @@ class LightRAG: } ) - # Semphore was released here + # Semphore is NOT released here, however, the profile context is - if file_extraction_stage_ok: - try: - # Get chunk_results from entity_relation_task - chunk_results = await entity_relation_task - await merge_nodes_and_edges( - chunk_results=chunk_results, # result collected from entity_relation_task - knowledge_graph_inst=self.chunk_entity_relation_graph, - entity_vdb=self.entities_vdb, - relationships_vdb=self.relationships_vdb, - global_config=asdict(self), - pipeline_status=pipeline_status, - pipeline_status_lock=pipeline_status_lock, - llm_response_cache=self.llm_response_cache, - current_file_number=current_file_number, - total_files=total_files, - file_path=file_path, - ) + if file_extraction_stage_ok: + try: + # Get chunk_results from entity_relation_task + chunk_results = await entity_relation_task + await merge_nodes_and_edges( + chunk_results=chunk_results, # result collected from entity_relation_task + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + global_config=asdict(self), + pipeline_status=pipeline_status, + pipeline_status_lock=pipeline_status_lock, + llm_response_cache=self.llm_response_cache, + current_file_number=current_file_number, + total_files=total_files, + file_path=file_path, + ) - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - "file_path": file_path, + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + "file_path": file_path, + } } - } - ) + ) - # Call _insert_done after processing each file - await self._insert_done() + # Call _insert_done after processing each file + await self._insert_done() - async with pipeline_status_lock: - log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + async with pipeline_status_lock: + log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - except Exception as e: - # Log error and update pipeline status - error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}" - logger.error(error_msg) - async with pipeline_status_lock: - pipeline_status["latest_message"] = error_msg - pipeline_status["history_messages"].append(error_msg) + except Exception as e: + # Log error and update pipeline status + error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}" + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append(error_msg) - # Update document status to failed - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.FAILED, - "error": str(e), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - "file_path": file_path, + # Update document status to failed + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + "file_path": file_path, + } } - } - ) + ) # Create processing tasks for all documents doc_tasks = [] diff --git a/lightrag/operate.py b/lightrag/operate.py index d82965e2..a0a74c52 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -9,6 +9,8 @@ import os from typing import Any, AsyncIterator from collections import Counter, defaultdict +from .kg.shared_storage import get_graph_db_lock_keyed + from .utils import ( logger, clean_str, @@ -403,27 +405,31 @@ async def _merge_edges_then_upsert( ) for need_insert_id in [src_id, tgt_id]: - if not (await knowledge_graph_inst.has_node(need_insert_id)): - # # Discard this edge if the node does not exist - # if need_insert_id == src_id: - # logger.warning( - # f"Discard edge: {src_id} - {tgt_id} | Source node missing" - # ) - # else: - # logger.warning( - # f"Discard edge: {src_id} - {tgt_id} | Target node missing" - # ) - # return None - await knowledge_graph_inst.upsert_node( - need_insert_id, - node_data={ - "entity_id": need_insert_id, - "source_id": source_id, - "description": description, - "entity_type": "UNKNOWN", - "file_path": file_path, - }, - ) + if (await knowledge_graph_inst.has_node(need_insert_id)): + # This is so that the initial check for the existence of the node need not be locked + continue + async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False): + if not (await knowledge_graph_inst.has_node(need_insert_id)): + # # Discard this edge if the node does not exist + # if need_insert_id == src_id: + # logger.warning( + # f"Discard edge: {src_id} - {tgt_id} | Source node missing" + # ) + # else: + # logger.warning( + # f"Discard edge: {src_id} - {tgt_id} | Target node missing" + # ) + # return None + await knowledge_graph_inst.upsert_node( + need_insert_id, + node_data={ + "entity_id": need_insert_id, + "source_id": source_id, + "description": description, + "entity_type": "UNKNOWN", + "file_path": file_path, + }, + ) force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] @@ -523,23 +529,30 @@ async def merge_nodes_and_edges( all_edges[sorted_edge_key].extend(edges) # Centralized processing of all nodes and edges - entities_data = [] - relationships_data = [] + total_entities_count = len(all_nodes) + total_relations_count = len(all_edges) # Merge nodes and edges # Use graph database lock to ensure atomic merges and updates graph_db_lock = get_graph_db_lock(enable_logging=False) - async with graph_db_lock: + async with pipeline_status_lock: + log_message = ( + f"Merging stage {current_file_number}/{total_files}: {file_path}" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Process and update all entities at once + log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + if pipeline_status is not None: async with pipeline_status_lock: - log_message = ( - f"Merging stage {current_file_number}/{total_files}: {file_path}" - ) - logger.info(log_message) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - # Process and update all entities at once - for entity_name, entities in all_nodes.items(): + async def _locked_process_entity_name(entity_name, entities): + async with get_graph_db_lock_keyed([entity_name], enable_logging=False): entity_data = await _merge_nodes_then_upsert( entity_name, entities, @@ -549,10 +562,34 @@ async def merge_nodes_and_edges( pipeline_status_lock, llm_response_cache, ) - entities_data.append(entity_data) + if entity_vdb is not None: + data_for_vdb = { + compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): { + "entity_name": entity_data["entity_name"], + "entity_type": entity_data["entity_type"], + "content": f"{entity_data['entity_name']}\n{entity_data['description']}", + "source_id": entity_data["source_id"], + "file_path": entity_data.get("file_path", "unknown_source"), + } + } + await entity_vdb.upsert(data_for_vdb) + return entity_data - # Process and update all relationships at once - for edge_key, edges in all_edges.items(): + tasks = [] + for entity_name, entities in all_nodes.items(): + tasks.append(asyncio.create_task(_locked_process_entity_name(entity_name, entities))) + await asyncio.gather(*tasks) + + # Process and update all relationships at once + log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + async def _locked_process_edges(edge_key, edges): + async with get_graph_db_lock_keyed(f"{edge_key[0]}-{edge_key[1]}", enable_logging=False): edge_data = await _merge_edges_then_upsert( edge_key[0], edge_key[1], @@ -563,55 +600,27 @@ async def merge_nodes_and_edges( pipeline_status_lock, llm_response_cache, ) - if edge_data is not None: - relationships_data.append(edge_data) + if edge_data is None: + return None - # Update total counts - total_entities_count = len(entities_data) - total_relations_count = len(relationships_data) - - log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - # Update vector databases with all collected data - if entity_vdb is not None and entities_data: - data_for_vdb = { - compute_mdhash_id(dp["entity_name"], prefix="ent-"): { - "entity_name": dp["entity_name"], - "entity_type": dp["entity_type"], - "content": f"{dp['entity_name']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), + if relationships_vdb is not None: + data_for_vdb = { + compute_mdhash_id(edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"): { + "src_id": edge_data["src_id"], + "tgt_id": edge_data["tgt_id"], + "keywords": edge_data["keywords"], + "content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}", + "source_id": edge_data["source_id"], + "file_path": edge_data.get("file_path", "unknown_source"), + } } - for dp in entities_data - } - await entity_vdb.upsert(data_for_vdb) - - log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - if relationships_vdb is not None and relationships_data: - data_for_vdb = { - compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { - "src_id": dp["src_id"], - "tgt_id": dp["tgt_id"], - "keywords": dp["keywords"], - "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), - } - for dp in relationships_data - } - await relationships_vdb.upsert(data_for_vdb) + await relationships_vdb.upsert(data_for_vdb) + return edge_data + tasks = [] + for edge_key, edges in all_edges.items(): + tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges))) + await asyncio.gather(*tasks) async def extract_entities( chunks: dict[str, TextChunkSchema],