diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 6f36f2c4..b83e058c 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,23 +1,46 @@ import os import sys import asyncio +import multiprocessing as mp from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing import Manager -from typing import Any, Dict, Optional, Union, TypeVar, Generic +import time +import logging +from typing import Any, Dict, List, Optional, Union, TypeVar, Generic # Define a direct print function for critical logs that must be visible in all processes -def direct_log(message, level="INFO", enable_output: bool = True): +def direct_log(message, enable_output: bool = True, level: str = "DEBUG"): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. Args: message: The message to log - level: Log level (default: "INFO") + level: Log level (default: "DEBUG") enable_output: Whether to actually output the log (default: True) """ - if enable_output: + # Get the current logger level from the lightrag logger + try: + from lightrag.utils import logger + + current_level = logger.getEffectiveLevel() + except ImportError: + # Fallback if lightrag.utils is not available + current_level = logging.INFO + + # Convert string level to numeric level for comparison + level_mapping = { + "DEBUG": logging.DEBUG, # 10 + "INFO": logging.INFO, # 20 + "WARNING": logging.WARNING, # 30 + "ERROR": logging.ERROR, # 40 + "CRITICAL": logging.CRITICAL, # 50 + } + message_level = level_mapping.get(level.upper(), logging.DEBUG) + + # print(f"Diret_log: {level.upper()} {message_level} ? {current_level}", file=sys.stderr, flush=True) + if enable_output or (message_level >= current_level): print(f"{level}: {message}", file=sys.stderr, flush=True) @@ -27,6 +50,23 @@ LockType = Union[ProcessLock, asyncio.Lock] _is_multiprocess = None _workers = None _manager = None + +# Global singleton data for multi-process keyed locks +_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None +_lock_registry_count: Optional[Dict[str, int]] = None +_lock_cleanup_data: Optional[Dict[str, time.time]] = None +_registry_guard = None +# Timeout for keyed locks in seconds +CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300 +# Threshold for triggering cleanup - only clean when pending list exceeds this size +CLEANUP_THRESHOLD = 500 +# Minimum interval between cleanup operations in seconds +MIN_CLEANUP_INTERVAL_SECONDS = 30 +# Track the earliest cleanup time for efficient cleanup triggering (multiprocess locks only) +_earliest_mp_cleanup_time: Optional[float] = None +# Track the last cleanup time to enforce minimum interval (multiprocess locks only) +_last_mp_cleanup_time: Optional[float] = None + _initialized = None # shared data for storage across processes @@ -40,10 +80,37 @@ _internal_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None _graph_db_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None +# Manager for all keyed locks +_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None # async locks for coroutine synchronization in multiprocess mode _async_locks: Optional[Dict[str, asyncio.Lock]] = None +DEBUG_LOCKS = False +_debug_n_locks_acquired: int = 0 + + +def inc_debug_n_locks_acquired(): + global _debug_n_locks_acquired + if DEBUG_LOCKS: + _debug_n_locks_acquired += 1 + print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}") + + +def dec_debug_n_locks_acquired(): + global _debug_n_locks_acquired + if DEBUG_LOCKS: + if _debug_n_locks_acquired > 0: + _debug_n_locks_acquired -= 1 + print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}") + else: + raise RuntimeError("Attempting to release lock when no locks are acquired") + + +def get_debug_n_locks_acquired(): + global _debug_n_locks_acquired + return _debug_n_locks_acquired + class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" @@ -65,17 +132,8 @@ class UnifiedLock(Generic[T]): async def __aenter__(self) -> "UnifiedLock[T]": try: - # direct_log( - # f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", - # enable_output=self._enable_logging, - # ) - # If in multiprocess mode and async lock exists, acquire it first if not self._is_async and self._async_lock is not None: - # direct_log( - # f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'", - # enable_output=self._enable_logging, - # ) await self._async_lock.acquire() direct_log( f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired", @@ -210,6 +268,406 @@ class UnifiedLock(Generic[T]): ) raise + def locked(self) -> bool: + if self._is_async: + return self._lock.locked() + else: + return self._lock.locked() + + +def _get_combined_key(factory_name: str, key: str) -> str: + """Return the combined key for the factory and key.""" + return f"{factory_name}:{key}" + + +def _get_or_create_shared_raw_mp_lock( + factory_name: str, key: str +) -> Optional[mp.synchronize.Lock]: + """Return the *singleton* manager.Lock() proxy for keyed lock, creating if needed.""" + if not _is_multiprocess: + return None + + with _registry_guard: + combined_key = _get_combined_key(factory_name, key) + raw = _lock_registry.get(combined_key) + count = _lock_registry_count.get(combined_key) + if raw is None: + raw = _manager.Lock() + _lock_registry[combined_key] = raw + count = 0 + else: + if count is None: + raise RuntimeError( + f"Shared-Data lock registry for {factory_name} is corrupted for key {key}" + ) + if ( + count == 1 and combined_key in _lock_cleanup_data + ): # Reusing an key waiting for cleanup, remove it from cleanup list + _lock_cleanup_data.pop(combined_key) + count += 1 + _lock_registry_count[combined_key] = count + return raw + + +def _release_shared_raw_mp_lock(factory_name: str, key: str): + """Release the *singleton* manager.Lock() proxy for *key*.""" + if not _is_multiprocess: + return + + global _earliest_mp_cleanup_time, _last_mp_cleanup_time + + with _registry_guard: + combined_key = _get_combined_key(factory_name, key) + raw = _lock_registry.get(combined_key) + count = _lock_registry_count.get(combined_key) + if raw is None and count is None: + return + elif raw is None or count is None: + raise RuntimeError( + f"Shared-Data lock registry for {factory_name} is corrupted for key {key}" + ) + + count -= 1 + if count < 0: + raise RuntimeError( + f"Attempting to release lock for {key} more times than it was acquired" + ) + + _lock_registry_count[combined_key] = count + + current_time = time.time() + if count == 0: + _lock_cleanup_data[combined_key] = current_time + + # Update earliest multiprocess cleanup time (only when earlier) + if ( + _earliest_mp_cleanup_time is None + or current_time < _earliest_mp_cleanup_time + ): + _earliest_mp_cleanup_time = current_time + + # Efficient cleanup triggering with minimum interval control + total_cleanup_len = len(_lock_cleanup_data) + if total_cleanup_len >= CLEANUP_THRESHOLD: + # Time rollback detection + if ( + _last_mp_cleanup_time is not None + and current_time < _last_mp_cleanup_time + ): + direct_log( + "== mp Lock == Time rollback detected, resetting cleanup time", + level="WARNING", + enable_output=False, + ) + _last_mp_cleanup_time = None + + # Check cleanup conditions + has_expired_locks = ( + _earliest_mp_cleanup_time is not None + and current_time - _earliest_mp_cleanup_time + > CLEANUP_KEYED_LOCKS_AFTER_SECONDS + ) + + interval_satisfied = ( + _last_mp_cleanup_time is None + or current_time - _last_mp_cleanup_time > MIN_CLEANUP_INTERVAL_SECONDS + ) + + if has_expired_locks and interval_satisfied: + try: + cleaned_count = 0 + new_earliest_time = None + + # Perform cleanup while maintaining the new earliest time + for cleanup_key, cleanup_time in list(_lock_cleanup_data.items()): + if ( + current_time - cleanup_time + > CLEANUP_KEYED_LOCKS_AFTER_SECONDS + ): + # Clean expired locks + _lock_registry.pop(cleanup_key, None) + _lock_registry_count.pop(cleanup_key, None) + _lock_cleanup_data.pop(cleanup_key, None) + cleaned_count += 1 + else: + # Track the earliest time among remaining locks + if ( + new_earliest_time is None + or cleanup_time < new_earliest_time + ): + new_earliest_time = cleanup_time + + # Update state only after successful cleanup + _earliest_mp_cleanup_time = new_earliest_time + _last_mp_cleanup_time = current_time + + if cleaned_count > 0: + next_cleanup_in = max( + ( + new_earliest_time + + CLEANUP_KEYED_LOCKS_AFTER_SECONDS + - current_time + ) + if new_earliest_time + else float("inf"), + MIN_CLEANUP_INTERVAL_SECONDS, + ) + direct_log( + f"== mp Lock == Cleaned up {cleaned_count}/{total_cleanup_len} expired locks, " + f"next cleanup in {next_cleanup_in:.1f}s", + enable_output=False, + level="INFO", + ) + + except Exception as e: + direct_log( + f"== mp Lock == Cleanup failed: {e}", + level="ERROR", + enable_output=False, + ) + # Don't update _last_mp_cleanup_time to allow retry + + +class KeyedUnifiedLock: + """ + Manager for unified keyed locks, supporting both single and multi-process + + • Keeps only a table of async keyed locks locally + • Fetches the multi-process keyed lockon every acquire + • Builds a fresh `UnifiedLock` each time, so `enable_logging` + (or future options) can vary per call. + """ + + def __init__( + self, factory_name: str, *, default_enable_logging: bool = True + ) -> None: + self._factory_name = factory_name + self._default_enable_logging = default_enable_logging + self._async_lock: Dict[str, asyncio.Lock] = {} # local keyed locks + self._async_lock_count: Dict[ + str, int + ] = {} # local keyed locks referenced count + self._async_lock_cleanup_data: Dict[ + str, time.time + ] = {} # local keyed locks timeout + self._mp_locks: Dict[ + str, mp.synchronize.Lock + ] = {} # multi-process lock proxies + self._earliest_async_cleanup_time: Optional[float] = ( + None # track earliest async cleanup time + ) + self._last_async_cleanup_time: Optional[float] = ( + None # track last async cleanup time for minimum interval + ) + + def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None): + """ + Ergonomic helper so you can write: + + async with keyed_locks("alpha"): + ... + """ + if enable_logging is None: + enable_logging = self._default_enable_logging + return _KeyedLockContext( + self, + factory_name=self._factory_name, + keys=keys, + enable_logging=enable_logging, + ) + + def _get_or_create_async_lock(self, key: str) -> asyncio.Lock: + async_lock = self._async_lock.get(key) + count = self._async_lock_count.get(key, 0) + if async_lock is None: + async_lock = asyncio.Lock() + self._async_lock[key] = async_lock + elif count == 0 and key in self._async_lock_cleanup_data: + self._async_lock_cleanup_data.pop(key) + count += 1 + self._async_lock_count[key] = count + return async_lock + + def _release_async_lock(self, key: str): + count = self._async_lock_count.get(key, 0) + count -= 1 + + current_time = time.time() + if count == 0: + self._async_lock_cleanup_data[key] = current_time + + # Update earliest async cleanup time (only when earlier) + if ( + self._earliest_async_cleanup_time is None + or current_time < self._earliest_async_cleanup_time + ): + self._earliest_async_cleanup_time = current_time + self._async_lock_count[key] = count + + # Efficient cleanup triggering with minimum interval control + total_cleanup_len = len(self._async_lock_cleanup_data) + if total_cleanup_len >= CLEANUP_THRESHOLD: + # Time rollback detection + if ( + self._last_async_cleanup_time is not None + and current_time < self._last_async_cleanup_time + ): + direct_log( + "== async Lock == Time rollback detected, resetting cleanup time", + level="WARNING", + enable_output=False, + ) + self._last_async_cleanup_time = None + + # Check cleanup conditions + has_expired_locks = ( + self._earliest_async_cleanup_time is not None + and current_time - self._earliest_async_cleanup_time + > CLEANUP_KEYED_LOCKS_AFTER_SECONDS + ) + + interval_satisfied = ( + self._last_async_cleanup_time is None + or current_time - self._last_async_cleanup_time + > MIN_CLEANUP_INTERVAL_SECONDS + ) + + if has_expired_locks and interval_satisfied: + try: + cleaned_count = 0 + new_earliest_time = None + + # Perform cleanup while maintaining the new earliest time + for cleanup_key, cleanup_time in list( + self._async_lock_cleanup_data.items() + ): + if ( + current_time - cleanup_time + > CLEANUP_KEYED_LOCKS_AFTER_SECONDS + ): + # Clean expired async locks + self._async_lock.pop(cleanup_key) + self._async_lock_count.pop(cleanup_key) + self._async_lock_cleanup_data.pop(cleanup_key) + cleaned_count += 1 + else: + # Track the earliest time among remaining locks + if ( + new_earliest_time is None + or cleanup_time < new_earliest_time + ): + new_earliest_time = cleanup_time + + # Update state only after successful cleanup + self._earliest_async_cleanup_time = new_earliest_time + self._last_async_cleanup_time = current_time + + if cleaned_count > 0: + next_cleanup_in = max( + ( + new_earliest_time + + CLEANUP_KEYED_LOCKS_AFTER_SECONDS + - current_time + ) + if new_earliest_time + else float("inf"), + MIN_CLEANUP_INTERVAL_SECONDS, + ) + direct_log( + f"== async Lock == Cleaned up {cleaned_count}/{total_cleanup_len} expired async locks, " + f"next cleanup in {next_cleanup_in:.1f}s", + enable_output=False, + level="INFO", + ) + + except Exception as e: + direct_log( + f"== async Lock == Cleanup failed: {e}", + level="ERROR", + enable_output=False, + ) + # Don't update _last_async_cleanup_time to allow retry + + def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock: + # 1. get (or create) the per‑process async gate for this key + # Is synchronous, so no need to acquire a lock + async_lock = self._get_or_create_async_lock(key) + + # 2. fetch the shared raw lock + raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key) + is_multiprocess = raw_lock is not None + if not is_multiprocess: + raw_lock = async_lock + + # 3. build a *fresh* UnifiedLock with the chosen logging flag + if is_multiprocess: + return UnifiedLock( + lock=raw_lock, + is_async=False, # manager.Lock is synchronous + name=_get_combined_key(self._factory_name, key), + enable_logging=enable_logging, + async_lock=async_lock, # prevents event‑loop blocking + ) + else: + return UnifiedLock( + lock=raw_lock, + is_async=True, + name=_get_combined_key(self._factory_name, key), + enable_logging=enable_logging, + async_lock=None, # No need for async lock in single process mode + ) + + def _release_lock_for_key(self, key: str): + self._release_async_lock(key) + _release_shared_raw_mp_lock(self._factory_name, key) + + +class _KeyedLockContext: + def __init__( + self, + parent: KeyedUnifiedLock, + factory_name: str, + keys: list[str], + enable_logging: bool, + ) -> None: + self._parent = parent + self._factory_name = factory_name + + # The sorting is critical to ensure proper lock and release order + # to avoid deadlocks + self._keys = sorted(keys) + self._enable_logging = ( + enable_logging + if enable_logging is not None + else parent._default_enable_logging + ) + self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__ + + # ----- enter ----- + async def __aenter__(self): + if self._ul is not None: + raise RuntimeError("KeyedUnifiedLock already acquired in current context") + + # 4. acquire it + self._ul = [] + for key in self._keys: + lock = self._parent._get_lock_for_key( + key, enable_logging=self._enable_logging + ) + await lock.__aenter__() + inc_debug_n_locks_acquired() + self._ul.append(lock) + return self # or return self._key if you prefer + + # ----- exit ----- + async def __aexit__(self, exc_type, exc, tb): + # The UnifiedLock takes care of proper release order + for ul, key in zip(reversed(self._ul), reversed(self._keys)): + await ul.__aexit__(exc_type, exc, tb) + self._parent._release_lock_for_key(key) + dec_debug_n_locks_acquired() + self._ul = None + def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" @@ -259,6 +717,18 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: ) +def get_graph_db_lock_keyed( + keys: str | list[str], enable_logging: bool = False +) -> KeyedUnifiedLock: + """return unified graph database lock for ensuring atomic operations""" + global _graph_db_lock_keyed + if _graph_db_lock_keyed is None: + raise RuntimeError("Shared-Data is not initialized") + if isinstance(keys, str): + keys = [keys] + return _graph_db_lock_keyed(keys, enable_logging=enable_logging) + + def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: """return unified data initialization lock for ensuring atomic data initialization""" async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None @@ -294,6 +764,10 @@ def initialize_share_data(workers: int = 1): _workers, \ _is_multiprocess, \ _storage_lock, \ + _lock_registry, \ + _lock_registry_count, \ + _lock_cleanup_data, \ + _registry_guard, \ _internal_lock, \ _pipeline_status_lock, \ _graph_db_lock, \ @@ -302,7 +776,10 @@ def initialize_share_data(workers: int = 1): _init_flags, \ _initialized, \ _update_flags, \ - _async_locks + _async_locks, \ + _graph_db_lock_keyed, \ + _earliest_mp_cleanup_time, \ + _last_mp_cleanup_time # Check if already initialized if _initialized: @@ -316,6 +793,10 @@ def initialize_share_data(workers: int = 1): if workers > 1: _is_multiprocess = True _manager = Manager() + _lock_registry = _manager.dict() + _lock_registry_count = _manager.dict() + _lock_cleanup_data = _manager.dict() + _registry_guard = _manager.RLock() _internal_lock = _manager.Lock() _storage_lock = _manager.Lock() _pipeline_status_lock = _manager.Lock() @@ -325,6 +806,10 @@ def initialize_share_data(workers: int = 1): _init_flags = _manager.dict() _update_flags = _manager.dict() + _graph_db_lock_keyed = KeyedUnifiedLock( + factory_name="GraphDB", + ) + # Initialize async locks for multiprocess mode _async_locks = { "internal_lock": asyncio.Lock(), @@ -348,8 +833,16 @@ def initialize_share_data(workers: int = 1): _init_flags = {} _update_flags = {} _async_locks = None # No need for async locks in single process mode + + _graph_db_lock_keyed = KeyedUnifiedLock( + factory_name="GraphDB", + ) direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") + # Initialize multiprocess cleanup times + _earliest_mp_cleanup_time = None + _last_mp_cleanup_time = None + # Mark as initialized _initialized = True diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index bc3c289a..feb7ab16 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1094,86 +1094,89 @@ class LightRAG: } ) - # Semphore released, concurrency controlled by graph_db_lock in merge_nodes_and_edges instead - - if file_extraction_stage_ok: - try: - # Get chunk_results from entity_relation_task - chunk_results = await entity_relation_task - await merge_nodes_and_edges( - chunk_results=chunk_results, # result collected from entity_relation_task - knowledge_graph_inst=self.chunk_entity_relation_graph, - entity_vdb=self.entities_vdb, - relationships_vdb=self.relationships_vdb, - global_config=asdict(self), - pipeline_status=pipeline_status, - pipeline_status_lock=pipeline_status_lock, - llm_response_cache=self.llm_response_cache, - current_file_number=current_file_number, - total_files=total_files, - file_path=file_path, - ) - - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "chunks_list": list( - chunks.keys() - ), # 保留 chunks_list - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now( - timezone.utc - ).isoformat(), - "file_path": file_path, - } - } - ) - - # Call _insert_done after processing each file - await self._insert_done() - - async with pipeline_status_lock: - log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - except Exception as e: - # Log error and update pipeline status - logger.error(traceback.format_exc()) - error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}" - logger.error(error_msg) - async with pipeline_status_lock: - pipeline_status["latest_message"] = error_msg - pipeline_status["history_messages"].append( - traceback.format_exc() + # Concurrency is controlled by graph db lock for individual entities and relationships + if file_extraction_stage_ok: + try: + # Get chunk_results from entity_relation_task + chunk_results = await entity_relation_task + await merge_nodes_and_edges( + chunk_results=chunk_results, # result collected from entity_relation_task + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + global_config=asdict(self), + pipeline_status=pipeline_status, + pipeline_status_lock=pipeline_status_lock, + llm_response_cache=self.llm_response_cache, + current_file_number=current_file_number, + total_files=total_files, + file_path=file_path, ) - pipeline_status["history_messages"].append(error_msg) - # Persistent llm cache - if self.llm_response_cache: - await self.llm_response_cache.index_done_callback() - - # Update document status to failed - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.FAILED, - "error": str(e), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - "file_path": file_path, + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "chunks_list": list( + chunks.keys() + ), # 保留 chunks_list + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now( + timezone.utc + ).isoformat(), + "file_path": file_path, + } } - } - ) + ) + + # Call _insert_done after processing each file + await self._insert_done() + + async with pipeline_status_lock: + log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append( + log_message + ) + + except Exception as e: + # Log error and update pipeline status + logger.error(traceback.format_exc()) + error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}" + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append( + traceback.format_exc() + ) + pipeline_status["history_messages"].append( + error_msg + ) + + # Persistent llm cache + if self.llm_response_cache: + await self.llm_response_cache.index_done_callback() + + # Update document status to failed + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + "file_path": file_path, + } + } + ) # Create processing tasks for all documents doc_tasks = [] diff --git a/lightrag/operate.py b/lightrag/operate.py index be4499ab..2008cb51 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -37,6 +37,7 @@ from .base import ( ) from .prompt import PROMPTS from .constants import GRAPH_FIELD_SEP +from .kg.shared_storage import get_graph_db_lock_keyed import time from dotenv import load_dotenv @@ -1015,28 +1016,32 @@ async def _merge_edges_then_upsert( ) for need_insert_id in [src_id, tgt_id]: - if not (await knowledge_graph_inst.has_node(need_insert_id)): - # # Discard this edge if the node does not exist - # if need_insert_id == src_id: - # logger.warning( - # f"Discard edge: {src_id} - {tgt_id} | Source node missing" - # ) - # else: - # logger.warning( - # f"Discard edge: {src_id} - {tgt_id} | Target node missing" - # ) - # return None - await knowledge_graph_inst.upsert_node( - need_insert_id, - node_data={ - "entity_id": need_insert_id, - "source_id": source_id, - "description": description, - "entity_type": "UNKNOWN", - "file_path": file_path, - "created_at": int(time.time()), - }, - ) + if await knowledge_graph_inst.has_node(need_insert_id): + # This is so that the initial check for the existence of the node need not be locked + continue + async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False): + if not (await knowledge_graph_inst.has_node(need_insert_id)): + # # Discard this edge if the node does not exist + # if need_insert_id == src_id: + # logger.warning( + # f"Discard edge: {src_id} - {tgt_id} | Source node missing" + # ) + # else: + # logger.warning( + # f"Discard edge: {src_id} - {tgt_id} | Target node missing" + # ) + # return None + await knowledge_graph_inst.upsert_node( + need_insert_id, + node_data={ + "entity_id": need_insert_id, + "source_id": source_id, + "description": description, + "entity_type": "UNKNOWN", + "file_path": file_path, + "created_at": int(time.time()), + }, + ) force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] @@ -1118,8 +1123,6 @@ async def merge_nodes_and_edges( pipeline_status_lock: Lock for pipeline status llm_response_cache: LLM response cache """ - # Get lock manager from shared storage - from .kg.shared_storage import get_graph_db_lock # Collect all nodes and edges from all chunks all_nodes = defaultdict(list) @@ -1136,94 +1139,101 @@ async def merge_nodes_and_edges( all_edges[sorted_edge_key].extend(edges) # Centralized processing of all nodes and edges - entities_data = [] - relationships_data = [] + total_entities_count = len(all_nodes) + total_relations_count = len(all_edges) # Merge nodes and edges - # Use graph database lock to ensure atomic merges and updates - graph_db_lock = get_graph_db_lock(enable_logging=False) - async with graph_db_lock: - async with pipeline_status_lock: - log_message = ( - f"Merging stage {current_file_number}/{total_files}: {file_path}" - ) - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - # Process and update all entities at once - for entity_name, entities in all_nodes.items(): - entity_data = await _merge_nodes_then_upsert( - entity_name, - entities, - knowledge_graph_inst, - global_config, - pipeline_status, - pipeline_status_lock, - llm_response_cache, - ) - entities_data.append(entity_data) + # Process and update all entities and relationships in parallel + log_message = f"Processing: {total_entities_count} entities and {total_relations_count} relations" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - # Process and update all relationships at once - for edge_key, edges in all_edges.items(): - edge_data = await _merge_edges_then_upsert( - edge_key[0], - edge_key[1], - edges, - knowledge_graph_inst, - global_config, - pipeline_status, - pipeline_status_lock, - llm_response_cache, - ) - if edge_data is not None: - relationships_data.append(edge_data) + # Get max async tasks limit from global_config for semaphore control + llm_model_max_async = global_config.get("llm_model_max_async", 4) + semaphore = asyncio.Semaphore(llm_model_max_async) - # Update total counts - total_entities_count = len(entities_data) - total_relations_count = len(relationships_data) + async def _locked_process_entity_name(entity_name, entities): + async with semaphore: + async with get_graph_db_lock_keyed([entity_name], enable_logging=False): + entity_data = await _merge_nodes_then_upsert( + entity_name, + entities, + knowledge_graph_inst, + global_config, + pipeline_status, + pipeline_status_lock, + llm_response_cache, + ) + if entity_vdb is not None: + data_for_vdb = { + compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): { + "entity_name": entity_data["entity_name"], + "entity_type": entity_data["entity_type"], + "content": f"{entity_data['entity_name']}\n{entity_data['description']}", + "source_id": entity_data["source_id"], + "file_path": entity_data.get("file_path", "unknown_source"), + } + } + await entity_vdb.upsert(data_for_vdb) + return entity_data - log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + async def _locked_process_edges(edge_key, edges): + async with semaphore: + async with get_graph_db_lock_keyed( + f"{edge_key[0]}-{edge_key[1]}", enable_logging=False + ): + edge_data = await _merge_edges_then_upsert( + edge_key[0], + edge_key[1], + edges, + knowledge_graph_inst, + global_config, + pipeline_status, + pipeline_status_lock, + llm_response_cache, + ) + if edge_data is None: + return None - # Update vector databases with all collected data - if entity_vdb is not None and entities_data: - data_for_vdb = { - compute_mdhash_id(dp["entity_name"], prefix="ent-"): { - "entity_name": dp["entity_name"], - "entity_type": dp["entity_type"], - "content": f"{dp['entity_name']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), - } - for dp in entities_data - } - await entity_vdb.upsert(data_for_vdb) + if relationships_vdb is not None: + data_for_vdb = { + compute_mdhash_id( + edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-" + ): { + "src_id": edge_data["src_id"], + "tgt_id": edge_data["tgt_id"], + "keywords": edge_data["keywords"], + "content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}", + "source_id": edge_data["source_id"], + "file_path": edge_data.get("file_path", "unknown_source"), + } + } + await relationships_vdb.upsert(data_for_vdb) + return edge_data - log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # Create a single task queue for both entities and edges + tasks = [] - if relationships_vdb is not None and relationships_data: - data_for_vdb = { - compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { - "src_id": dp["src_id"], - "tgt_id": dp["tgt_id"], - "keywords": dp["keywords"], - "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), - } - for dp in relationships_data - } - await relationships_vdb.upsert(data_for_vdb) + # Add entity processing tasks + for entity_name, entities in all_nodes.items(): + tasks.append( + asyncio.create_task(_locked_process_entity_name(entity_name, entities)) + ) + + # Add edge processing tasks + for edge_key, edges in all_edges.items(): + tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges))) + + # Execute all tasks in parallel with semaphore control + await asyncio.gather(*tasks) async def extract_entities(