Merge branch 'using_keyed_lock_for_max_concurrency' into merge_lock_with_key
This commit is contained in:
commit
35eb86942f
3 changed files with 422 additions and 158 deletions
|
|
@ -1,9 +1,12 @@
|
||||||
|
from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import multiprocessing as mp
|
||||||
from multiprocessing.synchronize import Lock as ProcessLock
|
from multiprocessing.synchronize import Lock as ProcessLock
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
import time
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union, TypeVar, Generic
|
||||||
|
|
||||||
|
|
||||||
# Define a direct print function for critical logs that must be visible in all processes
|
# Define a direct print function for critical logs that must be visible in all processes
|
||||||
|
|
@ -27,8 +30,14 @@ LockType = Union[ProcessLock, asyncio.Lock]
|
||||||
_is_multiprocess = None
|
_is_multiprocess = None
|
||||||
_workers = None
|
_workers = None
|
||||||
_manager = None
|
_manager = None
|
||||||
|
_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None
|
||||||
|
_lock_registry_count: Optional[Dict[str, int]] = None
|
||||||
|
_lock_cleanup_data: Optional[Dict[str, time.time]] = None
|
||||||
|
_registry_guard = None
|
||||||
_initialized = None
|
_initialized = None
|
||||||
|
|
||||||
|
CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300
|
||||||
|
|
||||||
# shared data for storage across processes
|
# shared data for storage across processes
|
||||||
_shared_dicts: Optional[Dict[str, Any]] = None
|
_shared_dicts: Optional[Dict[str, Any]] = None
|
||||||
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
||||||
|
|
@ -40,10 +49,31 @@ _internal_lock: Optional[LockType] = None
|
||||||
_pipeline_status_lock: Optional[LockType] = None
|
_pipeline_status_lock: Optional[LockType] = None
|
||||||
_graph_db_lock: Optional[LockType] = None
|
_graph_db_lock: Optional[LockType] = None
|
||||||
_data_init_lock: Optional[LockType] = None
|
_data_init_lock: Optional[LockType] = None
|
||||||
|
_graph_db_lock_keyed: Optional["KeyedUnifiedLock"] = None
|
||||||
|
|
||||||
# async locks for coroutine synchronization in multiprocess mode
|
# async locks for coroutine synchronization in multiprocess mode
|
||||||
_async_locks: Optional[Dict[str, asyncio.Lock]] = None
|
_async_locks: Optional[Dict[str, asyncio.Lock]] = None
|
||||||
|
|
||||||
|
DEBUG_LOCKS = False
|
||||||
|
_debug_n_locks_acquired: int = 0
|
||||||
|
def inc_debug_n_locks_acquired():
|
||||||
|
global _debug_n_locks_acquired
|
||||||
|
if DEBUG_LOCKS:
|
||||||
|
_debug_n_locks_acquired += 1
|
||||||
|
print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True)
|
||||||
|
|
||||||
|
def dec_debug_n_locks_acquired():
|
||||||
|
global _debug_n_locks_acquired
|
||||||
|
if DEBUG_LOCKS:
|
||||||
|
if _debug_n_locks_acquired > 0:
|
||||||
|
_debug_n_locks_acquired -= 1
|
||||||
|
print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}", end="\r", flush=True)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Attempting to release lock when no locks are acquired")
|
||||||
|
|
||||||
|
def get_debug_n_locks_acquired():
|
||||||
|
global _debug_n_locks_acquired
|
||||||
|
return _debug_n_locks_acquired
|
||||||
|
|
||||||
class UnifiedLock(Generic[T]):
|
class UnifiedLock(Generic[T]):
|
||||||
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||||
|
|
@ -210,6 +240,207 @@ class UnifiedLock(Generic[T]):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def locked(self) -> bool:
|
||||||
|
if self._is_async:
|
||||||
|
return self._lock.locked()
|
||||||
|
else:
|
||||||
|
return self._lock.locked()
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# 2. CROSS‑PROCESS FACTORY (one manager.Lock shared by *all* processes)
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
def _get_combined_key(factory_name: str, key: str) -> str:
|
||||||
|
"""Return the combined key for the factory and key."""
|
||||||
|
return f"{factory_name}:{key}"
|
||||||
|
|
||||||
|
def _get_or_create_shared_raw_mp_lock(factory_name: str, key: str) -> Optional[mp.synchronize.Lock]:
|
||||||
|
"""Return the *singleton* manager.Lock() proxy for *key*, creating if needed."""
|
||||||
|
if not _is_multiprocess:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with _registry_guard:
|
||||||
|
combined_key = _get_combined_key(factory_name, key)
|
||||||
|
raw = _lock_registry.get(combined_key)
|
||||||
|
count = _lock_registry_count.get(combined_key)
|
||||||
|
if raw is None:
|
||||||
|
raw = _manager.Lock()
|
||||||
|
_lock_registry[combined_key] = raw
|
||||||
|
_lock_registry_count[combined_key] = 0
|
||||||
|
else:
|
||||||
|
if count is None:
|
||||||
|
raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}")
|
||||||
|
count += 1
|
||||||
|
_lock_registry_count[combined_key] = count
|
||||||
|
if count == 1 and combined_key in _lock_cleanup_data:
|
||||||
|
_lock_cleanup_data.pop(combined_key)
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def _release_shared_raw_mp_lock(factory_name: str, key: str):
|
||||||
|
"""Release the *singleton* manager.Lock() proxy for *key*."""
|
||||||
|
if not _is_multiprocess:
|
||||||
|
return
|
||||||
|
|
||||||
|
with _registry_guard:
|
||||||
|
combined_key = _get_combined_key(factory_name, key)
|
||||||
|
raw = _lock_registry.get(combined_key)
|
||||||
|
count = _lock_registry_count.get(combined_key)
|
||||||
|
if raw is None and count is None:
|
||||||
|
return
|
||||||
|
elif raw is None or count is None:
|
||||||
|
raise RuntimeError(f"Shared-Data lock registry for {factory_name} is corrupted for key {key}")
|
||||||
|
|
||||||
|
count -= 1
|
||||||
|
if count < 0:
|
||||||
|
raise RuntimeError(f"Attempting to remove lock for {key} but it is not in the registry")
|
||||||
|
else:
|
||||||
|
_lock_registry_count[combined_key] = count
|
||||||
|
|
||||||
|
if count == 0:
|
||||||
|
_lock_cleanup_data[combined_key] = time.time()
|
||||||
|
|
||||||
|
for combined_key, value in list(_lock_cleanup_data.items()):
|
||||||
|
if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS:
|
||||||
|
_lock_registry.pop(combined_key)
|
||||||
|
_lock_registry_count.pop(combined_key)
|
||||||
|
_lock_cleanup_data.pop(combined_key)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# 3. PARAMETER‑KEYED WRAPPER (unchanged except it *accepts a factory*)
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
class KeyedUnifiedLock:
|
||||||
|
"""
|
||||||
|
Parameter‑keyed wrapper around `UnifiedLock`.
|
||||||
|
|
||||||
|
• Keeps only a table of per‑key *asyncio* gates locally
|
||||||
|
• Fetches the shared process‑wide mutex on *every* acquire
|
||||||
|
• Builds a fresh `UnifiedLock` each time, so `enable_logging`
|
||||||
|
(or future options) can vary per call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ---------------- construction ----------------
|
||||||
|
def __init__(self, factory_name: str, *, default_enable_logging: bool = True) -> None:
|
||||||
|
self._factory_name = factory_name
|
||||||
|
self._default_enable_logging = default_enable_logging
|
||||||
|
self._async_lock: Dict[str, asyncio.Lock] = {} # key → asyncio.Lock
|
||||||
|
self._async_lock_count: Dict[str, int] = {} # key → asyncio.Lock count
|
||||||
|
self._async_lock_cleanup_data: Dict[str, time.time] = {} # key → time.time
|
||||||
|
self._mp_locks: Dict[str, mp.synchronize.Lock] = {} # key → mp.synchronize.Lock
|
||||||
|
|
||||||
|
# ---------------- public API ------------------
|
||||||
|
def __call__(self, keys: list[str], *, enable_logging: Optional[bool] = None):
|
||||||
|
"""
|
||||||
|
Ergonomic helper so you can write:
|
||||||
|
|
||||||
|
async with keyed_locks("alpha"):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
if enable_logging is None:
|
||||||
|
enable_logging = self._default_enable_logging
|
||||||
|
return _KeyedLockContext(self, factory_name=self._factory_name, keys=keys, enable_logging=enable_logging)
|
||||||
|
|
||||||
|
def _get_or_create_async_lock(self, key: str) -> asyncio.Lock:
|
||||||
|
async_lock = self._async_lock.get(key)
|
||||||
|
count = self._async_lock_count.get(key, 0)
|
||||||
|
if async_lock is None:
|
||||||
|
async_lock = asyncio.Lock()
|
||||||
|
self._async_lock[key] = async_lock
|
||||||
|
elif count == 0 and key in self._async_lock_cleanup_data:
|
||||||
|
self._async_lock_cleanup_data.pop(key)
|
||||||
|
count += 1
|
||||||
|
self._async_lock_count[key] = count
|
||||||
|
return async_lock
|
||||||
|
|
||||||
|
def _release_async_lock(self, key: str):
|
||||||
|
count = self._async_lock_count.get(key, 0)
|
||||||
|
count -= 1
|
||||||
|
if count == 0:
|
||||||
|
self._async_lock_cleanup_data[key] = time.time()
|
||||||
|
self._async_lock_count[key] = count
|
||||||
|
|
||||||
|
for key, value in list(self._async_lock_cleanup_data.items()):
|
||||||
|
if time.time() - value > CLEANUP_KEYED_LOCKS_AFTER_SECONDS:
|
||||||
|
self._async_lock.pop(key)
|
||||||
|
self._async_lock_count.pop(key)
|
||||||
|
self._async_lock_cleanup_data.pop(key)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lock_for_key(self, key: str, enable_logging: bool = False) -> UnifiedLock:
|
||||||
|
# 1. get (or create) the per‑process async gate for this key
|
||||||
|
# Is synchronous, so no need to acquire a lock
|
||||||
|
async_lock = self._get_or_create_async_lock(key)
|
||||||
|
|
||||||
|
# 2. fetch the shared raw lock
|
||||||
|
raw_lock = _get_or_create_shared_raw_mp_lock(self._factory_name, key)
|
||||||
|
is_multiprocess = raw_lock is not None
|
||||||
|
if not is_multiprocess:
|
||||||
|
raw_lock = async_lock
|
||||||
|
|
||||||
|
# 3. build a *fresh* UnifiedLock with the chosen logging flag
|
||||||
|
if is_multiprocess:
|
||||||
|
return UnifiedLock(
|
||||||
|
lock=raw_lock,
|
||||||
|
is_async=False, # manager.Lock is synchronous
|
||||||
|
name=f"key:{self._factory_name}:{key}",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
async_lock=async_lock, # prevents event‑loop blocking
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return UnifiedLock(
|
||||||
|
lock=raw_lock,
|
||||||
|
is_async=True,
|
||||||
|
name=f"key:{self._factory_name}:{key}",
|
||||||
|
enable_logging=enable_logging,
|
||||||
|
async_lock=None, # No need for async lock in single process mode
|
||||||
|
)
|
||||||
|
|
||||||
|
def _release_lock_for_key(self, key: str):
|
||||||
|
self._release_async_lock(key)
|
||||||
|
_release_shared_raw_mp_lock(self._factory_name, key)
|
||||||
|
|
||||||
|
class _KeyedLockContext:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent: KeyedUnifiedLock,
|
||||||
|
factory_name: str,
|
||||||
|
keys: list[str],
|
||||||
|
enable_logging: bool,
|
||||||
|
) -> None:
|
||||||
|
self._parent = parent
|
||||||
|
self._factory_name = factory_name
|
||||||
|
|
||||||
|
# The sorting is critical to ensure proper lock and release order
|
||||||
|
# to avoid deadlocks
|
||||||
|
self._keys = sorted(keys)
|
||||||
|
self._enable_logging = (
|
||||||
|
enable_logging if enable_logging is not None
|
||||||
|
else parent._default_enable_logging
|
||||||
|
)
|
||||||
|
self._ul: Optional[List["UnifiedLock"]] = None # set in __aenter__
|
||||||
|
|
||||||
|
# ----- enter -----
|
||||||
|
async def __aenter__(self):
|
||||||
|
if self._ul is not None:
|
||||||
|
raise RuntimeError("KeyedUnifiedLock already acquired in current context")
|
||||||
|
|
||||||
|
# 4. acquire it
|
||||||
|
self._ul = []
|
||||||
|
for key in self._keys:
|
||||||
|
lock = self._parent._get_lock_for_key(key, enable_logging=self._enable_logging)
|
||||||
|
await lock.__aenter__()
|
||||||
|
inc_debug_n_locks_acquired()
|
||||||
|
self._ul.append(lock)
|
||||||
|
return self # or return self._key if you prefer
|
||||||
|
|
||||||
|
# ----- exit -----
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
# The UnifiedLock takes care of proper release order
|
||||||
|
for ul, key in zip(reversed(self._ul), reversed(self._keys)):
|
||||||
|
await ul.__aexit__(exc_type, exc, tb)
|
||||||
|
self._parent._release_lock_for_key(key)
|
||||||
|
dec_debug_n_locks_acquired()
|
||||||
|
self._ul = None
|
||||||
|
|
||||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified storage lock for data consistency"""
|
"""return unified storage lock for data consistency"""
|
||||||
|
|
@ -258,6 +489,14 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
async_lock=async_lock,
|
async_lock=async_lock,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_graph_db_lock_keyed(keys: str | list[str], enable_logging: bool = False) -> KeyedUnifiedLock:
|
||||||
|
"""return unified graph database lock for ensuring atomic operations"""
|
||||||
|
global _graph_db_lock_keyed
|
||||||
|
if _graph_db_lock_keyed is None:
|
||||||
|
raise RuntimeError("Shared-Data is not initialized")
|
||||||
|
if isinstance(keys, str):
|
||||||
|
keys = [keys]
|
||||||
|
return _graph_db_lock_keyed(keys, enable_logging=enable_logging)
|
||||||
|
|
||||||
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||||
"""return unified data initialization lock for ensuring atomic data initialization"""
|
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||||
|
|
@ -294,6 +533,10 @@ def initialize_share_data(workers: int = 1):
|
||||||
_workers, \
|
_workers, \
|
||||||
_is_multiprocess, \
|
_is_multiprocess, \
|
||||||
_storage_lock, \
|
_storage_lock, \
|
||||||
|
_lock_registry, \
|
||||||
|
_lock_registry_count, \
|
||||||
|
_lock_cleanup_data, \
|
||||||
|
_registry_guard, \
|
||||||
_internal_lock, \
|
_internal_lock, \
|
||||||
_pipeline_status_lock, \
|
_pipeline_status_lock, \
|
||||||
_graph_db_lock, \
|
_graph_db_lock, \
|
||||||
|
|
@ -302,7 +545,8 @@ def initialize_share_data(workers: int = 1):
|
||||||
_init_flags, \
|
_init_flags, \
|
||||||
_initialized, \
|
_initialized, \
|
||||||
_update_flags, \
|
_update_flags, \
|
||||||
_async_locks
|
_async_locks, \
|
||||||
|
_graph_db_lock_keyed
|
||||||
|
|
||||||
# Check if already initialized
|
# Check if already initialized
|
||||||
if _initialized:
|
if _initialized:
|
||||||
|
|
@ -316,6 +560,10 @@ def initialize_share_data(workers: int = 1):
|
||||||
if workers > 1:
|
if workers > 1:
|
||||||
_is_multiprocess = True
|
_is_multiprocess = True
|
||||||
_manager = Manager()
|
_manager = Manager()
|
||||||
|
_lock_registry = _manager.dict()
|
||||||
|
_lock_registry_count = _manager.dict()
|
||||||
|
_lock_cleanup_data = _manager.dict()
|
||||||
|
_registry_guard = _manager.RLock()
|
||||||
_internal_lock = _manager.Lock()
|
_internal_lock = _manager.Lock()
|
||||||
_storage_lock = _manager.Lock()
|
_storage_lock = _manager.Lock()
|
||||||
_pipeline_status_lock = _manager.Lock()
|
_pipeline_status_lock = _manager.Lock()
|
||||||
|
|
@ -324,6 +572,10 @@ def initialize_share_data(workers: int = 1):
|
||||||
_shared_dicts = _manager.dict()
|
_shared_dicts = _manager.dict()
|
||||||
_init_flags = _manager.dict()
|
_init_flags = _manager.dict()
|
||||||
_update_flags = _manager.dict()
|
_update_flags = _manager.dict()
|
||||||
|
|
||||||
|
_graph_db_lock_keyed = KeyedUnifiedLock(
|
||||||
|
factory_name="graph_db_lock",
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize async locks for multiprocess mode
|
# Initialize async locks for multiprocess mode
|
||||||
_async_locks = {
|
_async_locks = {
|
||||||
|
|
@ -348,6 +600,10 @@ def initialize_share_data(workers: int = 1):
|
||||||
_init_flags = {}
|
_init_flags = {}
|
||||||
_update_flags = {}
|
_update_flags = {}
|
||||||
_async_locks = None # No need for async locks in single process mode
|
_async_locks = None # No need for async locks in single process mode
|
||||||
|
|
||||||
|
_graph_db_lock_keyed = KeyedUnifiedLock(
|
||||||
|
factory_name="graph_db_lock",
|
||||||
|
)
|
||||||
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
||||||
|
|
||||||
# Mark as initialized
|
# Mark as initialized
|
||||||
|
|
|
||||||
|
|
@ -1046,83 +1046,85 @@ class LightRAG:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Semphore released, concurrency controlled by graph_db_lock in merge_nodes_and_edges instead
|
# Semphore is NOT released here, because the merge_nodes_and_edges function is highly concurrent
|
||||||
|
# and more importantly, it is the bottleneck and needs to be resource controlled in massively
|
||||||
|
# parallel insertions
|
||||||
|
|
||||||
if file_extraction_stage_ok:
|
if file_extraction_stage_ok:
|
||||||
try:
|
try:
|
||||||
# Get chunk_results from entity_relation_task
|
# Get chunk_results from entity_relation_task
|
||||||
chunk_results = await entity_relation_task
|
chunk_results = await entity_relation_task
|
||||||
await merge_nodes_and_edges(
|
await merge_nodes_and_edges(
|
||||||
chunk_results=chunk_results, # result collected from entity_relation_task
|
chunk_results=chunk_results, # result collected from entity_relation_task
|
||||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||||
entity_vdb=self.entities_vdb,
|
entity_vdb=self.entities_vdb,
|
||||||
relationships_vdb=self.relationships_vdb,
|
relationships_vdb=self.relationships_vdb,
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
pipeline_status=pipeline_status,
|
pipeline_status=pipeline_status,
|
||||||
pipeline_status_lock=pipeline_status_lock,
|
pipeline_status_lock=pipeline_status_lock,
|
||||||
llm_response_cache=self.llm_response_cache,
|
llm_response_cache=self.llm_response_cache,
|
||||||
current_file_number=current_file_number,
|
current_file_number=current_file_number,
|
||||||
total_files=total_files,
|
total_files=total_files,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
)
|
|
||||||
|
|
||||||
await self.doc_status.upsert(
|
|
||||||
{
|
|
||||||
doc_id: {
|
|
||||||
"status": DocStatus.PROCESSED,
|
|
||||||
"chunks_count": len(chunks),
|
|
||||||
"content": status_doc.content,
|
|
||||||
"content_summary": status_doc.content_summary,
|
|
||||||
"content_length": status_doc.content_length,
|
|
||||||
"created_at": status_doc.created_at,
|
|
||||||
"updated_at": datetime.now(
|
|
||||||
timezone.utc
|
|
||||||
).isoformat(),
|
|
||||||
"file_path": file_path,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call _insert_done after processing each file
|
|
||||||
await self._insert_done()
|
|
||||||
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
|
|
||||||
logger.info(log_message)
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Log error and update pipeline status
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = error_msg
|
|
||||||
pipeline_status["history_messages"].append(
|
|
||||||
traceback.format_exc()
|
|
||||||
)
|
)
|
||||||
pipeline_status["history_messages"].append(error_msg)
|
|
||||||
|
|
||||||
# Persistent llm cache
|
await self.doc_status.upsert(
|
||||||
if self.llm_response_cache:
|
{
|
||||||
await self.llm_response_cache.index_done_callback()
|
doc_id: {
|
||||||
|
"status": DocStatus.PROCESSED,
|
||||||
# Update document status to failed
|
"chunks_count": len(chunks),
|
||||||
await self.doc_status.upsert(
|
"content": status_doc.content,
|
||||||
{
|
"content_summary": status_doc.content_summary,
|
||||||
doc_id: {
|
"content_length": status_doc.content_length,
|
||||||
"status": DocStatus.FAILED,
|
"created_at": status_doc.created_at,
|
||||||
"error": str(e),
|
"updated_at": datetime.now(
|
||||||
"content": status_doc.content,
|
timezone.utc
|
||||||
"content_summary": status_doc.content_summary,
|
).isoformat(),
|
||||||
"content_length": status_doc.content_length,
|
"file_path": file_path,
|
||||||
"created_at": status_doc.created_at,
|
}
|
||||||
"updated_at": datetime.now().isoformat(),
|
|
||||||
"file_path": file_path,
|
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
)
|
|
||||||
|
# Call _insert_done after processing each file
|
||||||
|
await self._insert_done()
|
||||||
|
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
log_message = f"Completed processing file {current_file_number}/{total_files}: {file_path}"
|
||||||
|
logger.info(log_message)
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error and update pipeline status
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
error_msg = f"Merging stage failed in document {current_file_number}/{total_files}: {file_path}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = error_msg
|
||||||
|
pipeline_status["history_messages"].append(
|
||||||
|
traceback.format_exc()
|
||||||
|
)
|
||||||
|
pipeline_status["history_messages"].append(error_msg)
|
||||||
|
|
||||||
|
# Persistent llm cache
|
||||||
|
if self.llm_response_cache:
|
||||||
|
await self.llm_response_cache.index_done_callback()
|
||||||
|
|
||||||
|
# Update document status to failed
|
||||||
|
await self.doc_status.upsert(
|
||||||
|
{
|
||||||
|
doc_id: {
|
||||||
|
"status": DocStatus.FAILED,
|
||||||
|
"error": str(e),
|
||||||
|
"content": status_doc.content,
|
||||||
|
"content_summary": status_doc.content_summary,
|
||||||
|
"content_length": status_doc.content_length,
|
||||||
|
"created_at": status_doc.created_at,
|
||||||
|
"updated_at": datetime.now().isoformat(),
|
||||||
|
"file_path": file_path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Create processing tasks for all documents
|
# Create processing tasks for all documents
|
||||||
doc_tasks = []
|
doc_tasks = []
|
||||||
|
|
|
||||||
|
|
@ -422,28 +422,32 @@ async def _merge_edges_then_upsert(
|
||||||
)
|
)
|
||||||
|
|
||||||
for need_insert_id in [src_id, tgt_id]:
|
for need_insert_id in [src_id, tgt_id]:
|
||||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
if (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||||
# # Discard this edge if the node does not exist
|
# This is so that the initial check for the existence of the node need not be locked
|
||||||
# if need_insert_id == src_id:
|
continue
|
||||||
# logger.warning(
|
async with get_graph_db_lock_keyed([need_insert_id], enable_logging=False):
|
||||||
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
|
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||||
# )
|
# # Discard this edge if the node does not exist
|
||||||
# else:
|
# if need_insert_id == src_id:
|
||||||
# logger.warning(
|
# logger.warning(
|
||||||
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
|
# f"Discard edge: {src_id} - {tgt_id} | Source node missing"
|
||||||
# )
|
# )
|
||||||
# return None
|
# else:
|
||||||
await knowledge_graph_inst.upsert_node(
|
# logger.warning(
|
||||||
need_insert_id,
|
# f"Discard edge: {src_id} - {tgt_id} | Target node missing"
|
||||||
node_data={
|
# )
|
||||||
"entity_id": need_insert_id,
|
# return None
|
||||||
"source_id": source_id,
|
await knowledge_graph_inst.upsert_node(
|
||||||
"description": description,
|
need_insert_id,
|
||||||
"entity_type": "UNKNOWN",
|
node_data={
|
||||||
"file_path": file_path,
|
"entity_id": need_insert_id,
|
||||||
"created_at": int(time.time()),
|
"source_id": source_id,
|
||||||
},
|
"description": description,
|
||||||
)
|
"entity_type": "UNKNOWN",
|
||||||
|
"file_path": file_path,
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"]
|
||||||
|
|
||||||
|
|
@ -527,7 +531,8 @@ async def merge_nodes_and_edges(
|
||||||
llm_response_cache: LLM response cache
|
llm_response_cache: LLM response cache
|
||||||
"""
|
"""
|
||||||
# Get lock manager from shared storage
|
# Get lock manager from shared storage
|
||||||
from .kg.shared_storage import get_graph_db_lock
|
from .kg.shared_storage import get_graph_db_lock_keyed
|
||||||
|
|
||||||
|
|
||||||
# Collect all nodes and edges from all chunks
|
# Collect all nodes and edges from all chunks
|
||||||
all_nodes = defaultdict(list)
|
all_nodes = defaultdict(list)
|
||||||
|
|
@ -544,23 +549,28 @@ async def merge_nodes_and_edges(
|
||||||
all_edges[sorted_edge_key].extend(edges)
|
all_edges[sorted_edge_key].extend(edges)
|
||||||
|
|
||||||
# Centralized processing of all nodes and edges
|
# Centralized processing of all nodes and edges
|
||||||
entities_data = []
|
total_entities_count = len(all_nodes)
|
||||||
relationships_data = []
|
total_relations_count = len(all_edges)
|
||||||
|
|
||||||
# Merge nodes and edges
|
# Merge nodes and edges
|
||||||
# Use graph database lock to ensure atomic merges and updates
|
async with pipeline_status_lock:
|
||||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
log_message = (
|
||||||
async with graph_db_lock:
|
f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
||||||
|
)
|
||||||
|
logger.info(log_message)
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
# Process and update all entities at once
|
||||||
|
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
|
||||||
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
log_message = (
|
|
||||||
f"Merging stage {current_file_number}/{total_files}: {file_path}"
|
|
||||||
)
|
|
||||||
logger.info(log_message)
|
|
||||||
pipeline_status["latest_message"] = log_message
|
pipeline_status["latest_message"] = log_message
|
||||||
pipeline_status["history_messages"].append(log_message)
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
# Process and update all entities at once
|
async def _locked_process_entity_name(entity_name, entities):
|
||||||
for entity_name, entities in all_nodes.items():
|
async with get_graph_db_lock_keyed([entity_name], enable_logging=False):
|
||||||
entity_data = await _merge_nodes_then_upsert(
|
entity_data = await _merge_nodes_then_upsert(
|
||||||
entity_name,
|
entity_name,
|
||||||
entities,
|
entities,
|
||||||
|
|
@ -570,10 +580,34 @@ async def merge_nodes_and_edges(
|
||||||
pipeline_status_lock,
|
pipeline_status_lock,
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
)
|
)
|
||||||
entities_data.append(entity_data)
|
if entity_vdb is not None:
|
||||||
|
data_for_vdb = {
|
||||||
|
compute_mdhash_id(entity_data["entity_name"], prefix="ent-"): {
|
||||||
|
"entity_name": entity_data["entity_name"],
|
||||||
|
"entity_type": entity_data["entity_type"],
|
||||||
|
"content": f"{entity_data['entity_name']}\n{entity_data['description']}",
|
||||||
|
"source_id": entity_data["source_id"],
|
||||||
|
"file_path": entity_data.get("file_path", "unknown_source"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await entity_vdb.upsert(data_for_vdb)
|
||||||
|
return entity_data
|
||||||
|
|
||||||
# Process and update all relationships at once
|
tasks = []
|
||||||
for edge_key, edges in all_edges.items():
|
for entity_name, entities in all_nodes.items():
|
||||||
|
tasks.append(asyncio.create_task(_locked_process_entity_name(entity_name, entities)))
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Process and update all relationships at once
|
||||||
|
log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
|
||||||
|
logger.info(log_message)
|
||||||
|
if pipeline_status is not None:
|
||||||
|
async with pipeline_status_lock:
|
||||||
|
pipeline_status["latest_message"] = log_message
|
||||||
|
pipeline_status["history_messages"].append(log_message)
|
||||||
|
|
||||||
|
async def _locked_process_edges(edge_key, edges):
|
||||||
|
async with get_graph_db_lock_keyed(f"{edge_key[0]}-{edge_key[1]}", enable_logging=False):
|
||||||
edge_data = await _merge_edges_then_upsert(
|
edge_data = await _merge_edges_then_upsert(
|
||||||
edge_key[0],
|
edge_key[0],
|
||||||
edge_key[1],
|
edge_key[1],
|
||||||
|
|
@ -584,55 +618,27 @@ async def merge_nodes_and_edges(
|
||||||
pipeline_status_lock,
|
pipeline_status_lock,
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
)
|
)
|
||||||
if edge_data is not None:
|
if edge_data is None:
|
||||||
relationships_data.append(edge_data)
|
return None
|
||||||
|
|
||||||
# Update total counts
|
if relationships_vdb is not None:
|
||||||
total_entities_count = len(entities_data)
|
data_for_vdb = {
|
||||||
total_relations_count = len(relationships_data)
|
compute_mdhash_id(edge_data["src_id"] + edge_data["tgt_id"], prefix="rel-"): {
|
||||||
|
"src_id": edge_data["src_id"],
|
||||||
log_message = f"Updating {total_entities_count} entities {current_file_number}/{total_files}: {file_path}"
|
"tgt_id": edge_data["tgt_id"],
|
||||||
logger.info(log_message)
|
"keywords": edge_data["keywords"],
|
||||||
if pipeline_status is not None:
|
"content": f"{edge_data['src_id']}\t{edge_data['tgt_id']}\n{edge_data['keywords']}\n{edge_data['description']}",
|
||||||
async with pipeline_status_lock:
|
"source_id": edge_data["source_id"],
|
||||||
pipeline_status["latest_message"] = log_message
|
"file_path": edge_data.get("file_path", "unknown_source"),
|
||||||
pipeline_status["history_messages"].append(log_message)
|
}
|
||||||
|
|
||||||
# Update vector databases with all collected data
|
|
||||||
if entity_vdb is not None and entities_data:
|
|
||||||
data_for_vdb = {
|
|
||||||
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
|
||||||
"entity_name": dp["entity_name"],
|
|
||||||
"entity_type": dp["entity_type"],
|
|
||||||
"content": f"{dp['entity_name']}\n{dp['description']}",
|
|
||||||
"source_id": dp["source_id"],
|
|
||||||
"file_path": dp.get("file_path", "unknown_source"),
|
|
||||||
}
|
}
|
||||||
for dp in entities_data
|
await relationships_vdb.upsert(data_for_vdb)
|
||||||
}
|
return edge_data
|
||||||
await entity_vdb.upsert(data_for_vdb)
|
|
||||||
|
|
||||||
log_message = f"Updating {total_relations_count} relations {current_file_number}/{total_files}: {file_path}"
|
|
||||||
logger.info(log_message)
|
|
||||||
if pipeline_status is not None:
|
|
||||||
async with pipeline_status_lock:
|
|
||||||
pipeline_status["latest_message"] = log_message
|
|
||||||
pipeline_status["history_messages"].append(log_message)
|
|
||||||
|
|
||||||
if relationships_vdb is not None and relationships_data:
|
|
||||||
data_for_vdb = {
|
|
||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
|
||||||
"src_id": dp["src_id"],
|
|
||||||
"tgt_id": dp["tgt_id"],
|
|
||||||
"keywords": dp["keywords"],
|
|
||||||
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
|
||||||
"source_id": dp["source_id"],
|
|
||||||
"file_path": dp.get("file_path", "unknown_source"),
|
|
||||||
}
|
|
||||||
for dp in relationships_data
|
|
||||||
}
|
|
||||||
await relationships_vdb.upsert(data_for_vdb)
|
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for edge_key, edges in all_edges.items():
|
||||||
|
tasks.append(asyncio.create_task(_locked_process_edges(edge_key, edges)))
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def extract_entities(
|
async def extract_entities(
|
||||||
chunks: dict[str, TextChunkSchema],
|
chunks: dict[str, TextChunkSchema],
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue