From f2c522ce7a04c7043d9a8394d96cef025c8d28aa Mon Sep 17 00:00:00 2001 From: Arjun Rao Date: Thu, 8 May 2025 11:00:56 +1000 Subject: [PATCH 1/3] Allow max_connections to be configured in postgres --- config.ini.example | 1 + lightrag/kg/postgres_impl.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/config.ini.example b/config.ini.example index 5ff7cfbb..c6edcb60 100644 --- a/config.ini.example +++ b/config.ini.example @@ -20,3 +20,4 @@ user = your_username password = your_password database = your_database workspace = default # 可选,默认为default +max_connections = 12 \ No newline at end of file diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index cb302e8c..c71bc867 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -54,7 +54,7 @@ class PostgreSQLDB: self.password = config.get("password", None) self.database = config.get("database", "postgres") self.workspace = config.get("workspace", "default") - self.max = 12 + self.max = config.get("max_connections", 12) self.increment = 1 self.pool: Pool | None = None @@ -250,6 +250,10 @@ class ClientManager: "POSTGRES_WORKSPACE", config.get("postgres", "workspace", fallback="default"), ), + "max_connections": os.environ.get( + "POSTGRES_MAX_CONNECTIONS", + config.get("postgres", "max_connections", fallback=12), + ), } @classmethod From b7eae4d7c0573525f46742d027d38162b0beb277 Mon Sep 17 00:00:00 2001 From: Arjun Rao Date: Thu, 8 May 2025 11:42:53 +1000 Subject: [PATCH 2/3] Use the context manager for the openai client This avoids issues of resource cleanup (too many open files) when dealing with massively parallel calls to the openai API since RAII in python is highly unreliable in such contexts. --- lightrag/llm/openai.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 2f01c969..cd44bb93 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -177,14 +177,15 @@ async def openai_complete_if_cache( logger.debug("===== Sending Query to LLM =====") try: - if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( - model=model, messages=messages, **kwargs - ) - else: - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) + async with openai_async_client: + if "response_format" in kwargs: + response = await openai_async_client.beta.chat.completions.parse( + model=model, messages=messages, **kwargs + ) + else: + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) except APIConnectionError as e: logger.error(f"OpenAI API Connection Error: {e}") raise @@ -421,7 +422,8 @@ async def openai_embed( api_key=api_key, base_url=base_url, client_configs=client_configs ) - response = await openai_async_client.embeddings.create( - model=model, input=texts, encoding_format="float" - ) - return np.array([dp.embedding for dp in response.data]) + async with openai_async_client: + response = await openai_async_client.embeddings.create( + model=model, input=texts, encoding_format="float" + ) + return np.array([dp.embedding for dp in response.data]) From f8149790e48000f4c781f3ee615171708437f692 Mon Sep 17 00:00:00 2001 From: Arjun Rao Date: Thu, 8 May 2025 11:35:10 +1000 Subject: [PATCH 3/3] Initial commit with keyed graph lock --- lightrag/kg/shared_storage.py | 260 +++++++++++++++++++++++++++++++++- lightrag/lightrag.py | 118 +++++++-------- lightrag/operate.py | 167 +++++++++++----------- 3 files changed, 405 insertions(+), 140 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 6f36f2c4..d5780f2e 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,9 +1,12 @@ +from collections import defaultdict import os import sys import asyncio +import multiprocessing as mp from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing import Manager -from typing import Any, Dict, Optional, Union, TypeVar, Generic +import time +from typing import Any, Callable, Dict, List, Optional, Union, TypeVar, Generic # Define a direct print function for critical logs that must be visible in all processes @@ -27,8 +30,14 @@ LockType = Union[ProcessLock, asyncio.Lock] _is_multiprocess = None _workers = None _manager = None +_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None +_lock_registry_count: Optional[Dict[str, int]] = None +_lock_cleanup_data: Optional[Dict[str, time.time]] = None +_registry_guard = None _initialized = None +CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300 + # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized @@ -40,10 +49,31 @@ _internal_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None _graph_db_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None +_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None # async locks for coroutine synchronization in multiprocess mode _async_locks: Optional[Dict[str, asyncio.Lock]] = None +DEBUG_LOCKS = False +_debug_n_locks_acquired: int = 0 +def inc_debug_n_locks_acquired(): + global _debug_n_locks_acquired + if DEBUG_LOCKS: + _debug_n_locks_acquired += 1 + print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True) + +def dec_debug_n_locks_acquired(): + global _debug_n_locks_acquired + if DEBUG_LOCKS: + if _debug_n_locks_acquired > 0: + _debug_n_locks_acquired -= 1 + print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True) + else: + raise RuntimeError("Attempting to release lock when no locks are acquired") + +def get_debug_n_locks_acquired(): + global _debug_n_locks_acquired + return _debug_n_locks_acquired class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" @@ -210,6 +240,207 @@ class UnifiedLock(Generic[T]): ) raise + def locked(self) -> bool: + if self._is_async: + return self._lock.locked() + else: + return self._lock.locked() + +# ───────────────────────────────────────────────────────────────────────────── +# 2. CROSS‑PROCESS FACTORY (one manager.Lock shared by *all* processes) +# ───────────────────────────────────────────────────────────────────────────── +def _get_combined_key(factory_name: str, key: str) -> str: + """Return the combined key for the factory and key.""" + return f"{factory_name}:{key}" + +def _get_or_create_shared_raw_mp_lock(factory_name: str, key: str) -> Optional[mp.synchronize.Lock]: + """Return the *singleton* manager.Lock() proxy for *key*, creating if needed.""" + if not _is_multiprocess: + return None + + with _registry_guard: + combined_key = _get_combined_key(factory_name, key) + raw = _lock_registry.get(combined_key) + count = _lock_registry_count.get(combined_key) + if raw is None: + raw = _manager.Lock() + _lock_registry[combined_key] = raw + _lock_registry_count[combined_key] = 0 + else: + if count is None: + raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}") + count += 1 + _lock_registry_count[combined_key] = count + if count == 1 and combined_key in _lock_cleanup_data: + _lock_cleanup_data.pop(combined_key) + return raw + + +def _release_shared_raw_mp_lock(factory_name: str, key: str): + """Release the *singleton* manager.Lock() proxy for *key*.""" + if not _is_multiprocess: + return + + with _registry_guard: + combined_key = _get_combined_key(factory_name, key) + raw = _lock_registry.get(combined_key) + count = _lock_registry_count.get(combined_key) + if raw is None and count is None: + return + elif raw is None or count is None: + raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}") + + count -= 1 + if count < 0: + raise RuntimeError(f"Attempting to remove lock for {key} but it is not in the registry") + else: + _lock_registry_count[combined_key] = count + + if count == 0: + _lock_cleanup_data[combined_key] = time.time() + + for combined_key, value in list(_lock_cleanup_data.items()): + if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS: + _lock_registry.pop(combined_key) + _lock_registry_count.pop(combined_key) + _lock_cleanup_data.pop(combined_key) + + +# ───────────────────────────────────────────────────────────────────────────── +# 3. PARAMETER‑KEYED WRAPPER (unchanged except it *accepts a factory*) +# ───────────────────────────────────────────────────────────────────────────── +class KeyedUnifiedLock: + """ + Parameter‑keyed wrapper around `UnifiedLock`. + + • Keeps only a table of per‑key *asyncio* gates locally + • Fetches the shared process‑wide mutex on *every* acquire + • Builds a fresh `UnifiedLock` each time, so `enable_logging` + (or future options) can vary per call. + """ + + # ---------------- construction ---------------- + def __init__(self, factory_name: str, *, default_enable_logging: bool = True) -> None: + self._factory_name = factory_name + self._default_enable_logging = default_enable_logging + self._async_lock: Dict[str, asyncio.Lock] = {} # key → asyncio.Lock + self._async_lock_count: Dict[str, int] = {} # key → asyncio.Lock count + self._async_lock_cleanup_data: Dict[str, time.time] = {} # key → time.time + self._mp_locks: Dict[str, mp.synchronize.Lock] = {} # key → mp.synchronize.Lock + + # ---------------- public API ------------------ + def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None): + """ + Ergonomic helper so you can write: + + async with keyed_locks("alpha"): + ... + """ + if enable_logging is None: + enable_logging = self._default_enable_logging + return _KeyedLockContext(self, factory_name=self._factory_name, keys=keys, enable_logging=enable_logging) + + def _get_or_create_async_lock(self, key: str) -> asyncio.Lock: + async_lock = self._async_lock.get(key) + count = self._async_lock_count.get(key, 0) + if async_lock is None: + async_lock = asyncio.Lock() + self._async_lock[key] = async_lock + elif count == 0 and key in self._async_lock_cleanup_data: + self._async_lock_cleanup_data.pop(key) + count += 1 + self._async_lock_count[key] = count + return async_lock + + def _release_async_lock(self, key: str): + count = self._async_lock_count.get(key, 0) + count -= 1 + if count == 0: + self._async_lock_cleanup_data[key] = time.time() + self._async_lock_count[key] = count + + for key, value in list(self._async_lock_cleanup_data.items()): + if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS: + self._async_lock.pop(key) + self._async_lock_count.pop(key) + self._async_lock_cleanup_data.pop(key) + + + def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock: + # 1. get (or create) the per‑process async gate for this key + # Is synchronous, so no need to acquire a lock + async_lock = self._get_or_create_async_lock(key) + + # 2. fetch the shared raw lock + raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key) + is_multiprocess = raw_lock is not None + if not is_multiprocess: + raw_lock = async_lock + + # 3. build a *fresh* UnifiedLock with the chosen logging flag + if is_multiprocess: + return UnifiedLock( + lock=raw_lock, + is_async=False, # manager.Lock is synchronous + name=f"key:{self._factory_name}:{key}", + enable_logging=enable_logging, + async_lock=async_lock, # prevents event‑loop blocking + ) + else: + return UnifiedLock( + lock=raw_lock, + is_async=True, + name=f"key:{self._factory_name}:{key}", + enable_logging=enable_logging, + async_lock=None, # No need for async lock in single process mode + ) + + def _release_lock_for_key(self, key: str): + self._release_async_lock(key) + _release_shared_raw_mp_lock(self._factory_name, key) + +class _KeyedLockContext: + def __init__( + self, + parent: KeyedUnifiedLock, + factory_name: str, + keys: list[str], + enable_logging: bool, + ) -> None: + self._parent = parent + self._factory_name = factory_name + + # The sorting is critical to ensure proper lock and release order + # to avoid deadlocks + self._keys = sorted(keys) + self._enable_logging = ( + enable_logging if enable_logging is not None + else parent._default_enable_logging + ) + self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__ + + # ----- enter ----- + async def __aenter__(self): + if self._ul is not None: + raise RuntimeError("KeyedUnifiedLock already acquired in current context") + + # 4. acquire it + self._ul = [] + for key in self._keys: + lock = self._parent._get_lock_for_key(key, enable_logging=self._enable_logging) + await lock.__aenter__() + inc_debug_n_locks_acquired() + self._ul.append(lock) + return self # or return self._key if you prefer + + # ----- exit ----- + async def __aexit__(self, exc_type, exc, tb): + # The UnifiedLock takes care of proper release order + for ul, key in zip(reversed(self._ul), reversed(self._keys)): + await ul.__aexit__(exc_type, exc, tb) + self._parent._release_lock_for_key(key) + dec_debug_n_locks_acquired() + self._ul = None def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" @@ -258,6 +489,14 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: async_lock=async_lock, ) +def get_graph_db_lock_keyed(keys: str | list[str], enable_logging: bool = False) -> KeyedUnifiedLock: + """return unified graph database lock for ensuring atomic operations""" + global _graph_db_lock_keyed + if _graph_db_lock_keyed is None: + raise RuntimeError("Shared-Data is not initialized") + if isinstance(keys, str): + keys = [keys] + return _graph_db_lock_keyed(keys, enable_logging=enable_logging) def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: """return unified data initialization lock for ensuring atomic data initialization""" @@ -294,6 +533,10 @@ def initialize_share_data(workers: int = 1): _workers, \ _is_multiprocess, \ _storage_lock, \ + _lock_registry, \ + _lock_registry_count, \ + _lock_cleanup_data, \ + _registry_guard, \ _internal_lock, \ _pipeline_status_lock, \ _graph_db_lock, \ @@ -302,7 +545,8 @@ def initialize_share_data(workers: int = 1): _init_flags, \ _initialized, \ _update_flags, \ - _async_locks + _async_locks, \ + _graph_db_lock_keyed # Check if already initialized if _initialized: @@ -316,6 +560,10 @@ def initialize_share_data(workers: int = 1): if workers > 1: _is_multiprocess = True _manager = Manager() + _lock_registry = _manager.dict() + _lock_registry_count = _manager.dict() + _lock_cleanup_data = _manager.dict() + _registry_guard = _manager.RLock() _internal_lock = _manager.Lock() _storage_lock = _manager.Lock() _pipeline_status_lock = _manager.Lock() @@ -324,6 +572,10 @@ def initialize_share_data(workers: int = 1): _shared_dicts = _manager.dict() _init_flags = _manager.dict() _update_flags = _manager.dict() + + _graph_db_lock_keyed = KeyedUnifiedLock( + factory_name="graph_db_lock", + ) # Initialize async locks for multiprocess mode _async_locks = { @@ -348,6 +600,10 @@ def initialize_share_data(workers: int = 1): _init_flags = {} _update_flags = {} _async_locks = None # No need for async locks in single process mode + + _graph_db_lock_keyed = KeyedUnifiedLock( + factory_name="graph_db_lock", + ) direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") # Mark as initialized diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7a79da31..d4c28b17 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1024,73 +1024,73 @@ class LightRAG: } ) - # Semphore was released here + # Semphore is NOT released here, however, the profile context is - if file_extraction_stage_ok: - try: - # Get chunk_results from entity_relation_task - chunk_results = await entity_relation_task - await merge_nodes_and_edges( - chunk_results=chunk_results, # result collected from entity_relation_task - knowledge_graph_inst=self.chunk_entity_relation_graph, - entity_vdb=self.entities_vdb, - relationships_vdb=self.relationships_vdb, - global_config=asdict(self), - pipeline_status=pipeline_status, - pipeline_status_lock=pipeline_status_lock, - llm_response_cache=self.llm_response_cache, - current_file_number=current_file_number, - total_files=total_files, - file_path=file_path, - ) + if file_extraction_stage_ok: + try: + # Get chunk_results from entity_relation_task + chunk_results = await entity_relation_task + await merge_nodes_and_edges( + chunk_results=chunk_results, # result collected from entity_relation_task + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + global_config=asdict(self), + pipeline_status=pipeline_status, + pipeline_status_lock=pipeline_status_lock, + llm_response_cache=self.llm_response_cache, + current_file_number=current_file_number, + total_files=total_files, + file_path=file_path, + ) - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - "file_path": file_path, + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + "file_path": file_path, + } } - } - ) + ) - # Call _insert_done after processing each file - await self._insert_done() + # Call _insert_done after processing each file + await self._insert_done() - async with pipeline_status_lock: - log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + async with pipeline_status_lock: + log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - except Exception as e: - # Log error and update pipeline status - error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}" - logger.error(error_msg) - async with pipeline_status_lock: - pipeline_status["latest_message"] = error_msg - pipeline_status["history_messages"].append(error_msg) + except Exception as e: + # Log error and update pipeline status + error_msg = f"Merging stage failed in document {doc_id}: {traceback.format_exc()}" + logger.error(error_msg) + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append(error_msg) - # Update document status to failed - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.FAILED, - "error": str(e), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - "file_path": file_path, + # Update document status to failed + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + "file_path": file_path, + } } - } - ) + ) # Create processing tasks for all documents doc_tasks = [] diff --git a/lightrag/operate.py b/lightrag/operate.py index d82965e2..a0a74c52 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -9,6 +9,8 @@ import os from typing import Any, AsyncIterator from collections import Counter, defaultdict +from .kg.shared_storage import get_graph_db_lock_keyed + from .utils import ( logger, clean_str, @@ -403,27 +405,31 @@ async def _merge_edges_then_upsert( ) for need_insert_id in [src_id, tgt_id]: - if not (await knowledge_graph_inst.has_node(need_insert_id)): - # # Discard this edge if the node does not exist - # if need_insert_id == src_id: - # logger.warning( - # f"Discard edge: {src_id} - {tgt_id} | Source node missing" - # ) - # else: - # logger.warning( - # f"Discard edge: {src_id} - {tgt_id} | Target node missing" - # ) - # return None - await knowledge_graph_inst.upsert_node( - need_insert_id, - node_data={ - "entity_id": need_insert_id, - "source_id": source_id, - "description": description, - "entity_type": "UNKNOWN", - "file_path": file_path, - }, - ) + if (await knowledge_graph_inst.has_node(need_insert_id)): + # This is so that the initial check for the existence of the node need not be locked + continue + async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False): + if not (await knowledge_graph_inst.has_node(need_insert_id)): + # # Discard this edge if the node does not exist + # if need_insert_id == src_id: + # logger.warning( + # f"Discard edge: {src_id} - {tgt_id} | Source node missing" + # ) + # else: + # logger.warning( + # f"Discard edge: {src_id} - {tgt_id} | Target node missing" + # ) + # return None + await knowledge_graph_inst.upsert_node( + need_insert_id, + node_data={ + "entity_id": need_insert_id, + "source_id": source_id, + "description": description, + "entity_type": "UNKNOWN", + "file_path": file_path, + }, + ) force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] @@ -523,23 +529,30 @@ async def merge_nodes_and_edges( all_edges[sorted_edge_key].extend(edges) # Centralized processing of all nodes and edges - entities_data = [] - relationships_data = [] + total_entities_count = len(all_nodes) + total_relations_count = len(all_edges) # Merge nodes and edges # Use graph database lock to ensure atomic merges and updates graph_db_lock = get_graph_db_lock(enable_logging=False) - async with graph_db_lock: + async with pipeline_status_lock: + log_message = ( + f"Merging stage {current_file_number}/{total_files}: {file_path}" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Process and update all entities at once + log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + if pipeline_status is not None: async with pipeline_status_lock: - log_message = ( - f"Merging stage {current_file_number}/{total_files}: {file_path}" - ) - logger.info(log_message) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - # Process and update all entities at once - for entity_name, entities in all_nodes.items(): + async def _locked_process_entity_name(entity_name, entities): + async with get_graph_db_lock_keyed([entity_name], enable_logging=False): entity_data = await _merge_nodes_then_upsert( entity_name, entities, @@ -549,10 +562,34 @@ async def merge_nodes_and_edges( pipeline_status_lock, llm_response_cache, ) - entities_data.append(entity_data) + if entity_vdb is not None: + data_for_vdb = { + compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): { + "entity_name": entity_data["entity_name"], + "entity_type": entity_data["entity_type"], + "content": f"{entity_data['entity_name']}\n{entity_data['description']}", + "source_id": entity_data["source_id"], + "file_path": entity_data.get("file_path", "unknown_source"), + } + } + await entity_vdb.upsert(data_for_vdb) + return entity_data - # Process and update all relationships at once - for edge_key, edges in all_edges.items(): + tasks = [] + for entity_name, entities in all_nodes.items(): + tasks.append(asyncio.create_task(_locked_process_entity_name(entity_name, entities))) + await asyncio.gather(*tasks) + + # Process and update all relationships at once + log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + async def _locked_process_edges(edge_key, edges): + async with get_graph_db_lock_keyed(f"{edge_key[0]}-{edge_key[1]}", enable_logging=False): edge_data = await _merge_edges_then_upsert( edge_key[0], edge_key[1], @@ -563,55 +600,27 @@ async def merge_nodes_and_edges( pipeline_status_lock, llm_response_cache, ) - if edge_data is not None: - relationships_data.append(edge_data) + if edge_data is None: + return None - # Update total counts - total_entities_count = len(entities_data) - total_relations_count = len(relationships_data) - - log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - # Update vector databases with all collected data - if entity_vdb is not None and entities_data: - data_for_vdb = { - compute_mdhash_id(dp["entity_name"], prefix="ent-"): { - "entity_name": dp["entity_name"], - "entity_type": dp["entity_type"], - "content": f"{dp['entity_name']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), + if relationships_vdb is not None: + data_for_vdb = { + compute_mdhash_id(edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"): { + "src_id": edge_data["src_id"], + "tgt_id": edge_data["tgt_id"], + "keywords": edge_data["keywords"], + "content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}", + "source_id": edge_data["source_id"], + "file_path": edge_data.get("file_path", "unknown_source"), + } } - for dp in entities_data - } - await entity_vdb.upsert(data_for_vdb) - - log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}" - logger.info(log_message) - if pipeline_status is not None: - async with pipeline_status_lock: - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - if relationships_vdb is not None and relationships_data: - data_for_vdb = { - compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { - "src_id": dp["src_id"], - "tgt_id": dp["tgt_id"], - "keywords": dp["keywords"], - "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}", - "source_id": dp["source_id"], - "file_path": dp.get("file_path", "unknown_source"), - } - for dp in relationships_data - } - await relationships_vdb.upsert(data_for_vdb) + await relationships_vdb.upsert(data_for_vdb) + return edge_data + tasks = [] + for edge_key, edges in all_edges.items(): + tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges))) + await asyncio.gather(*tasks) async def extract_entities( chunks: dict[str, TextChunkSchema],