Merge pull request #1770 from danielaskdd/merge_lock_with_key

Refac: Optimize keyed lock cleanup logic with time and size tracking
This commit is contained in:
Daniel.y 2025-07-12 05:24:36 +08:00 committed by GitHub
commit ad7d7d0854
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 700 additions and 194 deletions

View file

@ -1,23 +1,46 @@
import os
import sys
import asyncio
import multiprocessing as mp
from multiprocessing.synchronize import Lock as ProcessLock
from multiprocessing import Manager
from typing import Any, Dict, Optional, Union, TypeVar, Generic
import time
import logging
from typing import Any, Dict, List, Optional, Union, TypeVar, Generic
# Define a direct print function for critical logs that must be visible in all processes
def direct_log(message, level="INFO", enable_output: bool = True):
def direct_log(message, enable_output: bool = True, level: str = "DEBUG"):
"""
Log a message directly to stderr to ensure visibility in all processes,
including the Gunicorn master process.
Args:
message: The message to log
level: Log level (default: "INFO")
level: Log level (default: "DEBUG")
enable_output: Whether to actually output the log (default: True)
"""
if enable_output:
# Get the current logger level from the lightrag logger
try:
from lightrag.utils import logger
current_level = logger.getEffectiveLevel()
except ImportError:
# Fallback if lightrag.utils is not available
current_level = logging.INFO
# Convert string level to numeric level for comparison
level_mapping = {
"DEBUG": logging.DEBUG, # 10
"INFO": logging.INFO, # 20
"WARNING": logging.WARNING, # 30
"ERROR": logging.ERROR, # 40
"CRITICAL": logging.CRITICAL, # 50
}
message_level = level_mapping.get(level.upper(), logging.DEBUG)
# print(f"Diret_log: {level.upper()} {message_level} ? {current_level}", file=sys.stderr, flush=True)
if enable_output or (message_level >= current_level):
print(f"{level}: {message}", file=sys.stderr, flush=True)
@ -27,6 +50,23 @@ LockType = Union[ProcessLock, asyncio.Lock]
_is_multiprocess = None
_workers = None
_manager = None
# Global singleton data for multi-process keyed locks
_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None
_lock_registry_count: Optional[Dict[str, int]] = None
_lock_cleanup_data: Optional[Dict[str, time.time]] = None
_registry_guard = None
# Timeout for keyed locks in seconds
CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300
# Threshold for triggering cleanup - only clean when pending list exceeds this size
CLEANUP_THRESHOLD = 500
# Minimum interval between cleanup operations in seconds
MIN_CLEANUP_INTERVAL_SECONDS = 30
# Track the earliest cleanup time for efficient cleanup triggering (multiprocess locks only)
_earliest_mp_cleanup_time: Optional[float] = None
# Track the last cleanup time to enforce minimum interval (multiprocess locks only)
_last_mp_cleanup_time: Optional[float] = None
_initialized = None
# shared data for storage across processes
@ -40,10 +80,37 @@ _internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None
_graph_db_lock: Optional[LockType] = None
_data_init_lock: Optional[LockType] = None
# Manager for all keyed locks
_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None
# async locks for coroutine synchronization in multiprocess mode
_async_locks: Optional[Dict[str, asyncio.Lock]] = None
DEBUG_LOCKS = False
_debug_n_locks_acquired: int = 0
def inc_debug_n_locks_acquired():
global _debug_n_locks_acquired
if DEBUG_LOCKS:
_debug_n_locks_acquired += 1
print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}")
def dec_debug_n_locks_acquired():
global _debug_n_locks_acquired
if DEBUG_LOCKS:
if _debug_n_locks_acquired > 0:
_debug_n_locks_acquired -= 1
print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}")
else:
raise RuntimeError("Attempting to release lock when no locks are acquired")
def get_debug_n_locks_acquired():
global _debug_n_locks_acquired
return _debug_n_locks_acquired
class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
@ -65,17 +132,8 @@ class UnifiedLock(Generic[T]):
async def __aenter__(self) -> "UnifiedLock[T]":
try:
# direct_log(
# f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
# enable_output=self._enable_logging,
# )
# If in multiprocess mode and async lock exists, acquire it first
if not self._is_async and self._async_lock is not None:
# direct_log(
# f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'",
# enable_output=self._enable_logging,
# )
await self._async_lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired",
@ -210,6 +268,406 @@ class UnifiedLock(Generic[T]):
)
raise
def locked(self) -> bool:
if self._is_async:
return self._lock.locked()
else:
return self._lock.locked()
def _get_combined_key(factory_name: str, key: str) -> str:
"""Return the combined key for the factory and key."""
return f"{factory_name}:{key}"
def _get_or_create_shared_raw_mp_lock(
factory_name: str, key: str
) -> Optional[mp.synchronize.Lock]:
"""Return the *singleton* manager.Lock() proxy for keyed lock, creating if needed."""
if not _is_multiprocess:
return None
with _registry_guard:
combined_key = _get_combined_key(factory_name, key)
raw = _lock_registry.get(combined_key)
count = _lock_registry_count.get(combined_key)
if raw is None:
raw = _manager.Lock()
_lock_registry[combined_key] = raw
count = 0
else:
if count is None:
raise RuntimeError(
f"Shared-Data lock registry for {factory_name} is corrupted for key {key}"
)
if (
count == 1 and combined_key in _lock_cleanup_data
): # Reusing an key waiting for cleanup, remove it from cleanup list
_lock_cleanup_data.pop(combined_key)
count += 1
_lock_registry_count[combined_key] = count
return raw
def _release_shared_raw_mp_lock(factory_name: str, key: str):
"""Release the *singleton* manager.Lock() proxy for *key*."""
if not _is_multiprocess:
return
global _earliest_mp_cleanup_time, _last_mp_cleanup_time
with _registry_guard:
combined_key = _get_combined_key(factory_name, key)
raw = _lock_registry.get(combined_key)
count = _lock_registry_count.get(combined_key)
if raw is None and count is None:
return
elif raw is None or count is None:
raise RuntimeError(
f"Shared-Data lock registry for {factory_name} is corrupted for key {key}"
)
count -= 1
if count < 0:
raise RuntimeError(
f"Attempting to release lock for {key} more times than it was acquired"
)
_lock_registry_count[combined_key] = count
current_time = time.time()
if count == 0:
_lock_cleanup_data[combined_key] = current_time
# Update earliest multiprocess cleanup time (only when earlier)
if (
_earliest_mp_cleanup_time is None
or current_time < _earliest_mp_cleanup_time
):
_earliest_mp_cleanup_time = current_time
# Efficient cleanup triggering with minimum interval control
total_cleanup_len = len(_lock_cleanup_data)
if total_cleanup_len >= CLEANUP_THRESHOLD:
# Time rollback detection
if (
_last_mp_cleanup_time is not None
and current_time < _last_mp_cleanup_time
):
direct_log(
"== mp Lock == Time rollback detected, resetting cleanup time",
level="WARNING",
enable_output=False,
)
_last_mp_cleanup_time = None
# Check cleanup conditions
has_expired_locks = (
_earliest_mp_cleanup_time is not None
and current_time - _earliest_mp_cleanup_time
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS
)
interval_satisfied = (
_last_mp_cleanup_time is None
or current_time - _last_mp_cleanup_time > MIN_CLEANUP_INTERVAL_SECONDS
)
if has_expired_locks and interval_satisfied:
try:
cleaned_count = 0
new_earliest_time = None
# Perform cleanup while maintaining the new earliest time
for cleanup_key, cleanup_time in list(_lock_cleanup_data.items()):
if (
current_time - cleanup_time
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS
):
# Clean expired locks
_lock_registry.pop(cleanup_key, None)
_lock_registry_count.pop(cleanup_key, None)
_lock_cleanup_data.pop(cleanup_key, None)
cleaned_count += 1
else:
# Track the earliest time among remaining locks
if (
new_earliest_time is None
or cleanup_time < new_earliest_time
):
new_earliest_time = cleanup_time
# Update state only after successful cleanup
_earliest_mp_cleanup_time = new_earliest_time
_last_mp_cleanup_time = current_time
if cleaned_count > 0:
next_cleanup_in = max(
(
new_earliest_time
+ CLEANUP_KEYED_LOCKS_AFTER_SECONDS
- current_time
)
if new_earliest_time
else float("inf"),
MIN_CLEANUP_INTERVAL_SECONDS,
)
direct_log(
f"== mp Lock == Cleaned up {cleaned_count}/{total_cleanup_len} expired locks, "
f"next cleanup in {next_cleanup_in:.1f}s",
enable_output=False,
level="INFO",
)
except Exception as e:
direct_log(
f"== mp Lock == Cleanup failed: {e}",
level="ERROR",
enable_output=False,
)
# Don't update _last_mp_cleanup_time to allow retry
class KeyedUnifiedLock:
"""
Manager for unified keyed locks, supporting both single and multi-process
Keeps only a table of async keyed locks locally
Fetches the multi-process keyed lockon every acquire
Builds a fresh `UnifiedLock` each time, so `enable_logging`
(or future options) can vary per call.
"""
def __init__(
self, factory_name: str, *, default_enable_logging: bool = True
) -> None:
self._factory_name = factory_name
self._default_enable_logging = default_enable_logging
self._async_lock: Dict[str, asyncio.Lock] = {} # local keyed locks
self._async_lock_count: Dict[
str, int
] = {} # local keyed locks referenced count
self._async_lock_cleanup_data: Dict[
str, time.time
] = {} # local keyed locks timeout
self._mp_locks: Dict[
str, mp.synchronize.Lock
] = {} # multi-process lock proxies
self._earliest_async_cleanup_time: Optional[float] = (
None # track earliest async cleanup time
)
self._last_async_cleanup_time: Optional[float] = (
None # track last async cleanup time for minimum interval
)
def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None):
"""
Ergonomic helper so you can write:
async with keyed_locks("alpha"):
...
"""
if enable_logging is None:
enable_logging = self._default_enable_logging
return _KeyedLockContext(
self,
factory_name=self._factory_name,
keys=keys,
enable_logging=enable_logging,
)
def _get_or_create_async_lock(self, key: str) -> asyncio.Lock:
async_lock = self._async_lock.get(key)
count = self._async_lock_count.get(key, 0)
if async_lock is None:
async_lock = asyncio.Lock()
self._async_lock[key] = async_lock
elif count == 0 and key in self._async_lock_cleanup_data:
self._async_lock_cleanup_data.pop(key)
count += 1
self._async_lock_count[key] = count
return async_lock
def _release_async_lock(self, key: str):
count = self._async_lock_count.get(key, 0)
count -= 1
current_time = time.time()
if count == 0:
self._async_lock_cleanup_data[key] = current_time
# Update earliest async cleanup time (only when earlier)
if (
self._earliest_async_cleanup_time is None
or current_time < self._earliest_async_cleanup_time
):
self._earliest_async_cleanup_time = current_time
self._async_lock_count[key] = count
# Efficient cleanup triggering with minimum interval control
total_cleanup_len = len(self._async_lock_cleanup_data)
if total_cleanup_len >= CLEANUP_THRESHOLD:
# Time rollback detection
if (
self._last_async_cleanup_time is not None
and current_time < self._last_async_cleanup_time
):
direct_log(
"== async Lock == Time rollback detected, resetting cleanup time",
level="WARNING",
enable_output=False,
)
self._last_async_cleanup_time = None
# Check cleanup conditions
has_expired_locks = (
self._earliest_async_cleanup_time is not None
and current_time - self._earliest_async_cleanup_time
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS
)
interval_satisfied = (
self._last_async_cleanup_time is None
or current_time - self._last_async_cleanup_time
> MIN_CLEANUP_INTERVAL_SECONDS
)
if has_expired_locks and interval_satisfied:
try:
cleaned_count = 0
new_earliest_time = None
# Perform cleanup while maintaining the new earliest time
for cleanup_key, cleanup_time in list(
self._async_lock_cleanup_data.items()
):
if (
current_time - cleanup_time
> CLEANUP_KEYED_LOCKS_AFTER_SECONDS
):
# Clean expired async locks
self._async_lock.pop(cleanup_key)
self._async_lock_count.pop(cleanup_key)
self._async_lock_cleanup_data.pop(cleanup_key)
cleaned_count += 1
else:
# Track the earliest time among remaining locks
if (
new_earliest_time is None
or cleanup_time < new_earliest_time
):
new_earliest_time = cleanup_time
# Update state only after successful cleanup
self._earliest_async_cleanup_time = new_earliest_time
self._last_async_cleanup_time = current_time
if cleaned_count > 0:
next_cleanup_in = max(
(
new_earliest_time
+ CLEANUP_KEYED_LOCKS_AFTER_SECONDS
- current_time
)
if new_earliest_time
else float("inf"),
MIN_CLEANUP_INTERVAL_SECONDS,
)
direct_log(
f"== async Lock == Cleaned up {cleaned_count}/{total_cleanup_len} expired async locks, "
f"next cleanup in {next_cleanup_in:.1f}s",
enable_output=False,
level="INFO",
)
except Exception as e:
direct_log(
f"== async Lock == Cleanup failed: {e}",
level="ERROR",
enable_output=False,
)
# Don't update _last_async_cleanup_time to allow retry
def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock:
# 1. get (or create) the perprocess async gate for this key
# Is synchronous, so no need to acquire a lock
async_lock = self._get_or_create_async_lock(key)
# 2. fetch the shared raw lock
raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key)
is_multiprocess = raw_lock is not None
if not is_multiprocess:
raw_lock = async_lock
# 3. build a *fresh* UnifiedLock with the chosen logging flag
if is_multiprocess:
return UnifiedLock(
lock=raw_lock,
is_async=False, # manager.Lock is synchronous
name=_get_combined_key(self._factory_name, key),
enable_logging=enable_logging,
async_lock=async_lock, # prevents eventloop blocking
)
else:
return UnifiedLock(
lock=raw_lock,
is_async=True,
name=_get_combined_key(self._factory_name, key),
enable_logging=enable_logging,
async_lock=None, # No need for async lock in single process mode
)
def _release_lock_for_key(self, key: str):
self._release_async_lock(key)
_release_shared_raw_mp_lock(self._factory_name, key)
class _KeyedLockContext:
def __init__(
self,
parent: KeyedUnifiedLock,
factory_name: str,
keys: list[str],
enable_logging: bool,
) -> None:
self._parent = parent
self._factory_name = factory_name
# The sorting is critical to ensure proper lock and release order
# to avoid deadlocks
self._keys = sorted(keys)
self._enable_logging = (
enable_logging
if enable_logging is not None
else parent._default_enable_logging
)
self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__
# ----- enter -----
async def __aenter__(self):
if self._ul is not None:
raise RuntimeError("KeyedUnifiedLock already acquired in current context")
# 4. acquire it
self._ul = []
for key in self._keys:
lock = self._parent._get_lock_for_key(
key, enable_logging=self._enable_logging
)
await lock.__aenter__()
inc_debug_n_locks_acquired()
self._ul.append(lock)
return self # or return self._key if you prefer
# ----- exit -----
async def __aexit__(self, exc_type, exc, tb):
# The UnifiedLock takes care of proper release order
for ul, key in zip(reversed(self._ul), reversed(self._keys)):
await ul.__aexit__(exc_type, exc, tb)
self._parent._release_lock_for_key(key)
dec_debug_n_locks_acquired()
self._ul = None
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
@ -259,6 +717,18 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
)
def get_graph_db_lock_keyed(
keys: str | list[str], enable_logging: bool = False
) -> KeyedUnifiedLock:
"""return unified graph database lock for ensuring atomic operations"""
global _graph_db_lock_keyed
if _graph_db_lock_keyed is None:
raise RuntimeError("Shared-Data is not initialized")
if isinstance(keys, str):
keys = [keys]
return _graph_db_lock_keyed(keys, enable_logging=enable_logging)
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified data initialization lock for ensuring atomic data initialization"""
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
@ -294,6 +764,10 @@ def initialize_share_data(workers: int = 1):
_workers, \
_is_multiprocess, \
_storage_lock, \
_lock_registry, \
_lock_registry_count, \
_lock_cleanup_data, \
_registry_guard, \
_internal_lock, \
_pipeline_status_lock, \
_graph_db_lock, \
@ -302,7 +776,10 @@ def initialize_share_data(workers: int = 1):
_init_flags, \
_initialized, \
_update_flags, \
_async_locks
_async_locks, \
_graph_db_lock_keyed, \
_earliest_mp_cleanup_time, \
_last_mp_cleanup_time
# Check if already initialized
if _initialized:
@ -316,6 +793,10 @@ def initialize_share_data(workers: int = 1):
if workers > 1:
_is_multiprocess = True
_manager = Manager()
_lock_registry = _manager.dict()
_lock_registry_count = _manager.dict()
_lock_cleanup_data = _manager.dict()
_registry_guard = _manager.RLock()
_internal_lock = _manager.Lock()
_storage_lock = _manager.Lock()
_pipeline_status_lock = _manager.Lock()
@ -325,6 +806,10 @@ def initialize_share_data(workers: int = 1):
_init_flags = _manager.dict()
_update_flags = _manager.dict()
_graph_db_lock_keyed = KeyedUnifiedLock(
factory_name="GraphDB",
)
# Initialize async locks for multiprocess mode
_async_locks = {
"internal_lock": asyncio.Lock(),
@ -348,8 +833,16 @@ def initialize_share_data(workers: int = 1):
_init_flags = {}
_update_flags = {}
_async_locks = None # No need for async locks in single process mode
_graph_db_lock_keyed = KeyedUnifiedLock(
factory_name="GraphDB",
)
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
# Initialize multiprocess cleanup times
_earliest_mp_cleanup_time = None
_last_mp_cleanup_time = None
# Mark as initialized
_initialized = True

View file

@ -1094,86 +1094,89 @@ class LightRAG:
}
)
# Semphore released, concurrency controlled by graph_db_lock in merge_nodes_and_edges instead
if file_extraction_stage_ok:
try:
# Get chunk_results from entity_relation_task
chunk_results = await entity_relation_task
await merge_nodes_and_edges(
chunk_results=chunk_results, # result collected from entity_relation_task
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
current_file_number=current_file_number,
total_files=total_files,
file_path=file_path,
)
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"chunks_list": list(
chunks.keys()
), # 保留 chunks_list
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now(
timezone.utc
).isoformat(),
"file_path": file_path,
}
}
)
# Call _insert_done after processing each file
await self._insert_done()
async with pipeline_status_lock:
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
# Log error and update pipeline status
logger.error(traceback.format_exc())
error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}"
logger.error(error_msg)
async with pipeline_status_lock:
pipeline_status["latest_message"] = error_msg
pipeline_status["history_messages"].append(
traceback.format_exc()
# Concurrency is controlled by graph db lock for individual entities and relationships
if file_extraction_stage_ok:
try:
# Get chunk_results from entity_relation_task
chunk_results = await entity_relation_task
await merge_nodes_and_edges(
chunk_results=chunk_results, # result collected from entity_relation_task
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
current_file_number=current_file_number,
total_files=total_files,
file_path=file_path,
)
pipeline_status["history_messages"].append(error_msg)
# Persistent llm cache
if self.llm_response_cache:
await self.llm_response_cache.index_done_callback()
# Update document status to failed
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"chunks_list": list(
chunks.keys()
), # 保留 chunks_list
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now(
timezone.utc
).isoformat(),
"file_path": file_path,
}
}
}
)
)
# Call _insert_done after processing each file
await self._insert_done()
async with pipeline_status_lock:
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(
log_message
)
except Exception as e:
# Log error and update pipeline status
logger.error(traceback.format_exc())
error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}"
logger.error(error_msg)
async with pipeline_status_lock:
pipeline_status["latest_message"] = error_msg
pipeline_status["history_messages"].append(
traceback.format_exc()
)
pipeline_status["history_messages"].append(
error_msg
)
# Persistent llm cache
if self.llm_response_cache:
await self.llm_response_cache.index_done_callback()
# Update document status to failed
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
"file_path": file_path,
}
}
)
# Create processing tasks for all documents
doc_tasks = []

View file

@ -37,6 +37,7 @@ from .base import (
)
from .prompt import PROMPTS
from .constants import GRAPH_FIELD_SEP
from .kg.shared_storage import get_graph_db_lock_keyed
import time
from dotenv import load_dotenv
@ -1015,28 +1016,32 @@ async def _merge_edges_then_upsert(
)
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
# # Discard this edge if the node does not exist
# if need_insert_id == src_id:
# logger.warning(
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
# )
# else:
# logger.warning(
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
# )
# return None
await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"entity_id": need_insert_id,
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
"file_path": file_path,
"created_at": int(time.time()),
},
)
if await knowledge_graph_inst.has_node(need_insert_id):
# This is so that the initial check for the existence of the node need not be locked
continue
async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False):
if not (await knowledge_graph_inst.has_node(need_insert_id)):
# # Discard this edge if the node does not exist
# if need_insert_id == src_id:
# logger.warning(
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
# )
# else:
# logger.warning(
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
# )
# return None
await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"entity_id": need_insert_id,
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
"file_path": file_path,
"created_at": int(time.time()),
},
)
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
@ -1118,8 +1123,6 @@ async def merge_nodes_and_edges(
pipeline_status_lock: Lock for pipeline status
llm_response_cache: LLM response cache
"""
# Get lock manager from shared storage
from .kg.shared_storage import get_graph_db_lock
# Collect all nodes and edges from all chunks
all_nodes = defaultdict(list)
@ -1136,94 +1139,101 @@ async def merge_nodes_and_edges(
all_edges[sorted_edge_key].extend(edges)
# Centralized processing of all nodes and edges
entities_data = []
relationships_data = []
total_entities_count = len(all_nodes)
total_relations_count = len(all_edges)
# Merge nodes and edges
# Use graph database lock to ensure atomic merges and updates
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
async with pipeline_status_lock:
log_message = (
f"Merging stage {current_file_number}/{total_files}: {file_path}"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process and update all entities at once
for entity_name, entities in all_nodes.items():
entity_data = await _merge_nodes_then_upsert(
entity_name,
entities,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
entities_data.append(entity_data)
# Process and update all entities and relationships in parallel
log_message = f"Processing: {total_entities_count} entities and {total_relations_count} relations"
logger.info(log_message)
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Process and update all relationships at once
for edge_key, edges in all_edges.items():
edge_data = await _merge_edges_then_upsert(
edge_key[0],
edge_key[1],
edges,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
if edge_data is not None:
relationships_data.append(edge_data)
# Get max async tasks limit from global_config for semaphore control
llm_model_max_async = global_config.get("llm_model_max_async", 4)
semaphore = asyncio.Semaphore(llm_model_max_async)
# Update total counts
total_entities_count = len(entities_data)
total_relations_count = len(relationships_data)
async def _locked_process_entity_name(entity_name, entities):
async with semaphore:
async with get_graph_db_lock_keyed([entity_name], enable_logging=False):
entity_data = await _merge_nodes_then_upsert(
entity_name,
entities,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): {
"entity_name": entity_data["entity_name"],
"entity_type": entity_data["entity_type"],
"content": f"{entity_data['entity_name']}\n{entity_data['description']}",
"source_id": entity_data["source_id"],
"file_path": entity_data.get("file_path", "unknown_source"),
}
}
await entity_vdb.upsert(data_for_vdb)
return entity_data
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async def _locked_process_edges(edge_key, edges):
async with semaphore:
async with get_graph_db_lock_keyed(
f"{edge_key[0]}-{edge_key[1]}", enable_logging=False
):
edge_data = await _merge_edges_then_upsert(
edge_key[0],
edge_key[1],
edges,
knowledge_graph_inst,
global_config,
pipeline_status,
pipeline_status_lock,
llm_response_cache,
)
if edge_data is None:
return None
# Update vector databases with all collected data
if entity_vdb is not None and entities_data:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"entity_name": dp["entity_name"],
"entity_type": dp["entity_type"],
"content": f"{dp['entity_name']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in entities_data
}
await entity_vdb.upsert(data_for_vdb)
if relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(
edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"
): {
"src_id": edge_data["src_id"],
"tgt_id": edge_data["tgt_id"],
"keywords": edge_data["keywords"],
"content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}",
"source_id": edge_data["source_id"],
"file_path": edge_data.get("file_path", "unknown_source"),
}
}
await relationships_vdb.upsert(data_for_vdb)
return edge_data
log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
logger.info(log_message)
if pipeline_status is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Create a single task queue for both entities and edges
tasks = []
if relationships_vdb is not None and relationships_data:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"keywords": dp["keywords"],
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
"source_id": dp["source_id"],
"file_path": dp.get("file_path", "unknown_source"),
}
for dp in relationships_data
}
await relationships_vdb.upsert(data_for_vdb)
# Add entity processing tasks
for entity_name, entities in all_nodes.items():
tasks.append(
asyncio.create_task(_locked_process_entity_name(entity_name, entities))
)
# Add edge processing tasks
for edge_key, edges in all_edges.items():
tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges)))
# Execute all tasks in parallel with semaphore control
await asyncio.gather(*tasks)
async def extract_entities(