From 925e631a9a04f955643b06ee84be4231c2143733 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 29 Aug 2025 13:50:35 +0800 Subject: [PATCH] refac: Add robust time out handling for LLM request --- env.example | 6 +- lightrag/api/lightrag_server.py | 11 +- lightrag/constants.py | 8 +- lightrag/lightrag.py | 19 +- lightrag/utils.py | 399 +++++++++++++++++++++++--------- 5 files changed, 331 insertions(+), 112 deletions(-) diff --git a/env.example b/env.example index c5cb4d28..35a1e5fa 100644 --- a/env.example +++ b/env.example @@ -156,8 +156,8 @@ MAX_PARALLEL_INSERT=2 ### LLM Configuration ### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock ########################################################### -### LLM request timeout setting for all llm (set to TIMEOUT if not specified, 0 means no timeout for Ollma) -# LLM_TIMEOUT=150 +### LLM request timeout setting for all llm (0 means no timeout for Ollma) +# LLM_TIMEOUT=180 LLM_BINDING=openai LLM_MODEL=gpt-4o @@ -206,7 +206,7 @@ OLLAMA_LLM_NUM_CTX=32768 ### Embedding Configuration (Should not be changed after the first file processed) ### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock #################################################################################### -### see also env.ollama-binding-options.example for fine tuning ollama +# EMBEDDING_TIMEOUT=30 EMBEDDING_BINDING=ollama EMBEDDING_MODEL=bge-m3:latest EMBEDDING_DIM=1024 diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index a2a4d848..41ede089 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -39,6 +39,8 @@ from lightrag.constants import ( DEFAULT_LOG_MAX_BYTES, DEFAULT_LOG_BACKUP_COUNT, DEFAULT_LOG_FILENAME, + DEFAULT_LLM_TIMEOUT, + DEFAULT_EMBEDDING_TIMEOUT, ) from lightrag.api.routers.document_routes import ( DocumentManager, @@ -256,7 +258,10 @@ def create_app(args): if args.embedding_binding == "jina": from lightrag.llm.jina import jina_embed - llm_timeout = get_env_value("LLM_TIMEOUT", args.timeout, int) + llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) + embedding_timeout = get_env_value( + "EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int + ) async def openai_alike_model_complete( prompt, @@ -487,6 +492,8 @@ def create_app(args): else {} ), embedding_func=embedding_func, + default_llm_timeout=llm_timeout, + default_embedding_timeout=embedding_timeout, kv_storage=args.kv_storage, graph_storage=args.graph_storage, vector_storage=args.vector_storage, @@ -517,6 +524,8 @@ def create_app(args): summary_max_tokens=args.summary_max_tokens, summary_context_size=args.summary_context_size, embedding_func=embedding_func, + default_llm_timeout=llm_timeout, + default_embedding_timeout=embedding_timeout, kv_storage=args.kv_storage, graph_storage=args.graph_storage, vector_storage=args.vector_storage, diff --git a/lightrag/constants.py b/lightrag/constants.py index a7cf5640..d78d869c 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -64,8 +64,12 @@ DEFAULT_MAX_PARALLEL_INSERT = 2 # Default maximum parallel insert operations DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations -# gunicorn worker timeout(as default LLM request timeout if LLM_TIMEOUT is not set) -DEFAULT_TIMEOUT = 150 +# Gunicorn worker timeout +DEFAULT_TIMEOUT = 210 + +# Default llm and embedding timeout +DEFAULT_LLM_TIMEOUT = 180 +DEFAULT_EMBEDDING_TIMEOUT = 30 # Logging configuration defaults DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 34ff87e6..23e6f575 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -41,6 +41,8 @@ from lightrag.constants import ( DEFAULT_MAX_GRAPH_NODES, DEFAULT_ENTITY_TYPES, DEFAULT_SUMMARY_LANGUAGE, + DEFAULT_LLM_TIMEOUT, + DEFAULT_EMBEDDING_TIMEOUT, ) from lightrag.utils import get_env_value @@ -277,6 +279,10 @@ class LightRAG: - use_llm_check: If True, validates cached embeddings using an LLM. """ + default_embedding_timeout: int = field( + default=int(os.getenv("EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT)) + ) + # LLM Configuration # --- @@ -311,6 +317,10 @@ class LightRAG: llm_model_kwargs: dict[str, Any] = field(default_factory=dict) """Additional keyword arguments passed to the LLM model function.""" + default_llm_timeout: int = field( + default=int(os.getenv("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT)) + ) + # Rerank Configuration # --- @@ -457,7 +467,8 @@ class LightRAG: # Init Embedding self.embedding_func = priority_limit_async_func_call( - self.embedding_func_max_async + self.embedding_func_max_async, + llm_timeout=self.default_embedding_timeout, )(self.embedding_func) # Initialize all storages @@ -550,7 +561,11 @@ class LightRAG: # Directly use llm_response_cache, don't create a new object hashing_kv = self.llm_response_cache - self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)( + # Get timeout from LLM model kwargs for dynamic timeout calculation + self.llm_model_func = priority_limit_async_func_call( + self.llm_model_max_async, + llm_timeout=self.default_llm_timeout, + )( partial( self.llm_model_func, # type: ignore hashing_kv=hashing_kv, diff --git a/lightrag/utils.py b/lightrag/utils.py index cb03c537..cd67e9f3 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -254,6 +254,18 @@ class UnlimitedSemaphore: pass +@dataclass +class TaskState: + """Task state tracking for priority queue management""" + + future: asyncio.Future + start_time: float + execution_start_time: float = None + worker_started: bool = False + cancellation_requested: bool = False + cleanup_done: bool = False + + @dataclass class EmbeddingFunc: embedding_dim: int @@ -323,20 +335,58 @@ def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None: return None -# Custom exception class +# Custom exception classes class QueueFullError(Exception): """Raised when the queue is full and the wait times out""" pass -def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): +class WorkerTimeoutError(Exception): + """Worker-level timeout exception with specific timeout information""" + + def __init__(self, timeout_value: float, timeout_type: str = "execution"): + self.timeout_value = timeout_value + self.timeout_type = timeout_type + super().__init__(f"Worker {timeout_type} timeout after {timeout_value}s") + + +class HealthCheckTimeoutError(Exception): + """Health Check-level timeout exception""" + + def __init__(self, timeout_value: float, execution_duration: float): + self.timeout_value = timeout_value + self.execution_duration = execution_duration + super().__init__( + f"Task forcefully terminated due to execution timeout (>{timeout_value}s, actual: {execution_duration:.1f}s)" + ) + + +def priority_limit_async_func_call( + max_size: int, + llm_timeout: float = None, + max_execution_timeout: float = None, + max_task_duration: float = None, + max_queue_size: int = 1000, + cleanup_timeout: float = 2.0, +): """ - Enhanced priority-limited asynchronous function call decorator + Enhanced priority-limited asynchronous function call decorator with robust timeout handling + + This decorator provides a comprehensive solution for managing concurrent LLM requests with: + - Multi-layer timeout protection (LLM -> Worker -> Health Check -> User) + - Task state tracking to prevent race conditions + - Enhanced health check system with stuck task detection + - Proper resource cleanup and error recovery Args: max_size: Maximum number of concurrent calls max_queue_size: Maximum queue capacity to prevent memory overflow + llm_timeout: LLM provider timeout (from global config), used to calculate other timeouts + max_execution_timeout: Maximum time for worker to execute function (defaults to llm_timeout + 30s) + max_task_duration: Maximum time before health check intervenes (defaults to llm_timeout + 60s) + cleanup_timeout: Maximum time to wait for cleanup operations (defaults to 2.0s) + Returns: Decorator function """ @@ -345,81 +395,173 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): # Ensure func is callable if not callable(func): raise TypeError(f"Expected a callable object, got {type(func)}") + + # Calculate timeout hierarchy if llm_timeout is provided (Dynamic Timeout Calculation) + if llm_timeout is not None: + nonlocal max_execution_timeout, max_task_duration + if max_execution_timeout is None: + max_execution_timeout = ( + llm_timeout + 30 + ) # LLM timeout + 30s buffer for network delays + if max_task_duration is None: + max_task_duration = ( + llm_timeout + 60 + ) # LLM timeout + 1min buffer for execution phase + queue = asyncio.PriorityQueue(maxsize=max_queue_size) tasks = set() initialization_lock = asyncio.Lock() counter = 0 shutdown_event = asyncio.Event() - initialized = False # Global initialization flag + initialized = False worker_health_check_task = None - # Track active future objects for cleanup + # Enhanced task state management + task_states = {} # task_id -> TaskState + task_states_lock = asyncio.Lock() active_futures = weakref.WeakSet() - reinit_count = 0 # Reinitialization counter to track system health + reinit_count = 0 - # Worker function to process tasks in the queue async def worker(): - """Worker that processes tasks in the priority queue""" + """Enhanced worker that processes tasks with proper timeout and state management""" try: while not shutdown_event.is_set(): try: - # Use timeout to get tasks, allowing periodic checking of shutdown signal + # Get task from queue with timeout for shutdown checking try: ( priority, count, - future, + task_id, args, kwargs, ) = await asyncio.wait_for(queue.get(), timeout=1.0) except asyncio.TimeoutError: - # Timeout is just to check shutdown signal, continue to next iteration continue - # If future is cancelled, skip execution - if future.cancelled(): + # Get task state and mark worker as started + async with task_states_lock: + if task_id not in task_states: + queue.task_done() + continue + task_state = task_states[task_id] + task_state.worker_started = True + # Record execution start time when worker actually begins processing + task_state.execution_start_time = ( + asyncio.get_event_loop().time() + ) + + # Check if task was cancelled before worker started + if ( + task_state.cancellation_requested + or task_state.future.cancelled() + ): + async with task_states_lock: + task_states.pop(task_id, None) queue.task_done() continue try: - # Execute function - result = await func(*args, **kwargs) - # If future is not done, set the result - if not future.done(): - future.set_result(result) - except asyncio.CancelledError: - if not future.done(): - future.cancel() - logger.debug("limit_async: Task cancelled during execution") - except Exception as e: - logger.error( - f"limit_async: Error in decorated function: {str(e)}" + # Execute function with timeout protection + if max_execution_timeout is not None: + result = await asyncio.wait_for( + func(*args, **kwargs), timeout=max_execution_timeout + ) + else: + result = await func(*args, **kwargs) + + # Set result if future is still valid + if not task_state.future.done(): + task_state.future.set_result(result) + + except asyncio.TimeoutError: + # Worker-level timeout (max_execution_timeout exceeded) + logger.warning( + f"limit_async: Worker timeout for task {task_id} after {max_execution_timeout}s" ) - if not future.done(): - future.set_exception(e) + if not task_state.future.done(): + task_state.future.set_exception( + WorkerTimeoutError( + max_execution_timeout, "execution" + ) + ) + except asyncio.CancelledError: + # Task was cancelled during execution + if not task_state.future.done(): + task_state.future.cancel() + logger.debug( + f"limit_async: Task {task_id} cancelled during execution" + ) + except Exception as e: + # Function execution error + logger.error( + f"limit_async: Error in decorated function for task {task_id}: {str(e)}" + ) + if not task_state.future.done(): + task_state.future.set_exception(e) finally: + # Clean up task state + async with task_states_lock: + task_states.pop(task_id, None) queue.task_done() + except Exception as e: - # Catch all exceptions in worker loop to prevent worker termination + # Critical error in worker loop logger.error(f"limit_async: Critical error in worker: {str(e)}") - await asyncio.sleep(0.1) # Prevent high CPU usage + await asyncio.sleep(0.1) finally: logger.debug("limit_async: Worker exiting") - async def health_check(): - """Periodically check worker health status and recover""" + async def enhanced_health_check(): + """Enhanced health check with stuck task detection and recovery""" nonlocal initialized try: while not shutdown_event.is_set(): await asyncio.sleep(5) # Check every 5 seconds - # No longer acquire lock, directly operate on task set - # Use a copy of the task set to avoid concurrent modification + current_time = asyncio.get_event_loop().time() + + # Detect and handle stuck tasks based on execution start time + if max_task_duration is not None: + stuck_tasks = [] + async with task_states_lock: + for task_id, task_state in list(task_states.items()): + # Only check tasks that have started execution + if ( + task_state.worker_started + and task_state.execution_start_time is not None + and current_time - task_state.execution_start_time + > max_task_duration + ): + stuck_tasks.append( + ( + task_id, + current_time + - task_state.execution_start_time, + ) + ) + + # Force cleanup of stuck tasks + for task_id, execution_duration in stuck_tasks: + logger.warning( + f"limit_async: Detected stuck task {task_id} (execution time: {execution_duration:.1f}s), forcing cleanup" + ) + async with task_states_lock: + if task_id in task_states: + task_state = task_states[task_id] + if not task_state.future.done(): + task_state.future.set_exception( + HealthCheckTimeoutError( + max_task_duration, execution_duration + ) + ) + task_states.pop(task_id, None) + + # Worker recovery logic current_tasks = set(tasks) done_tasks = {t for t in current_tasks if t.done()} tasks.difference_update(done_tasks) - # Calculate active tasks count active_tasks_count = len(tasks) workers_needed = max_size - active_tasks_count @@ -432,21 +574,16 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): task = asyncio.create_task(worker()) new_tasks.add(task) task.add_done_callback(tasks.discard) - # Update task set in one operation tasks.update(new_tasks) + except Exception as e: - logger.error(f"limit_async: Error in health check: {str(e)}") + logger.error(f"limit_async: Error in enhanced health check: {str(e)}") finally: - logger.debug("limit_async: Health check task exiting") + logger.debug("limit_async: Enhanced health check task exiting") initialized = False async def ensure_workers(): - """Ensure worker threads and health check system are available - - This function checks if the worker system is already initialized. - If not, it performs a one-time initialization of all worker threads - and starts the health check system. - """ + """Ensure worker system is initialized with enhanced error handling""" nonlocal initialized, worker_health_check_task, tasks, reinit_count if initialized: @@ -456,45 +593,56 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): if initialized: return - # Increment reinitialization counter if this is not the first initialization if reinit_count > 0: reinit_count += 1 logger.warning( - f"limit_async: Reinitializing needed (count: {reinit_count})" + f"limit_async: Reinitializing system (count: {reinit_count})" ) else: - reinit_count = 1 # First initialization + reinit_count = 1 - # Check for completed tasks and remove them from the task set + # Clean up completed tasks current_tasks = set(tasks) done_tasks = {t for t in current_tasks if t.done()} tasks.difference_update(done_tasks) - # Log active tasks count during reinitialization active_tasks_count = len(tasks) if active_tasks_count > 0 and reinit_count > 1: logger.warning( f"limit_async: {active_tasks_count} tasks still running during reinitialization" ) - # Create initial worker tasks, only adding the number needed + # Create worker tasks workers_needed = max_size - active_tasks_count for _ in range(workers_needed): task = asyncio.create_task(worker()) tasks.add(task) task.add_done_callback(tasks.discard) - # Start health check - worker_health_check_task = asyncio.create_task(health_check()) + # Start enhanced health check + worker_health_check_task = asyncio.create_task(enhanced_health_check()) initialized = True - logger.info(f"limit_async: {workers_needed} new workers initialized") + # Log dynamic timeout configuration + timeout_info = [] + if llm_timeout is not None: + timeout_info.append(f"LLM: {llm_timeout}s") + if max_execution_timeout is not None: + timeout_info.append(f"Execution: {max_execution_timeout}s") + if max_task_duration is not None: + timeout_info.append(f"Health Check: {max_task_duration}s") + + timeout_str = ( + f" (Timeouts: {', '.join(timeout_info)})" if timeout_info else "" + ) + logger.info( + f"limit_async: {workers_needed} new workers initialized with dynamic timeout handling{timeout_str}" + ) async def shutdown(): - """Gracefully shut down all workers and the queue""" + """Gracefully shut down all workers and cleanup resources""" logger.info("limit_async: Shutting down priority queue workers") - # Set the shutdown event shutdown_event.set() # Cancel all active futures @@ -502,7 +650,14 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): if not future.done(): future.cancel() - # Wait for the queue to empty + # Cancel all pending tasks + async with task_states_lock: + for task_id, task_state in list(task_states.items()): + if not task_state.future.done(): + task_state.future.cancel() + task_states.clear() + + # Wait for queue to empty with timeout try: await asyncio.wait_for(queue.join(), timeout=5.0) except asyncio.TimeoutError: @@ -510,7 +665,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): "limit_async: Timeout waiting for queue to empty during shutdown" ) - # Cancel all worker tasks + # Cancel worker tasks for task in list(tasks): if not task.done(): task.cancel() @@ -519,7 +674,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): if tasks: await asyncio.gather(*tasks, return_exceptions=True) - # Cancel the health check task + # Cancel health check task if worker_health_check_task and not worker_health_check_task.done(): worker_health_check_task.cancel() try: @@ -534,77 +689,113 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): *args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs ): """ - Execute the function with priority-based concurrency control + Execute function with enhanced priority-based concurrency control and timeout handling + Args: *args: Positional arguments passed to the function _priority: Call priority (lower values have higher priority) - _timeout: Maximum time to wait for function completion (in seconds) + _timeout: Maximum time to wait for completion (in seconds, none means determinded by max_execution_timeout of the queue) _queue_timeout: Maximum time to wait for entering the queue (in seconds) **kwargs: Keyword arguments passed to the function + Returns: The result of the function call + Raises: - TimeoutError: If the function call times out + TimeoutError: If the function call times out at any level QueueFullError: If the queue is full and waiting times out Any exception raised by the decorated function """ - # Ensure worker system is initialized await ensure_workers() - # Create a future for the result + # Generate unique task ID + task_id = f"{id(asyncio.current_task())}_{asyncio.get_event_loop().time()}" future = asyncio.Future() - active_futures.add(future) - nonlocal counter - async with initialization_lock: - current_count = counter # Use local variable to avoid race conditions - counter += 1 + # Create task state + task_state = TaskState( + future=future, start_time=asyncio.get_event_loop().time() + ) - # Try to put the task into the queue, supporting timeout try: - if _queue_timeout is not None: - # Use timeout to wait for queue space - try: + # Register task state + async with task_states_lock: + task_states[task_id] = task_state + + active_futures.add(future) + + # Get counter for FIFO ordering + nonlocal counter + async with initialization_lock: + current_count = counter + counter += 1 + + # Queue the task with timeout handling + try: + if _queue_timeout is not None: await asyncio.wait_for( - # current_count is used to ensure FIFO order - queue.put((_priority, current_count, future, args, kwargs)), + queue.put( + (_priority, current_count, task_id, args, kwargs) + ), timeout=_queue_timeout, ) - except asyncio.TimeoutError: - raise QueueFullError( - f"Queue full, timeout after {_queue_timeout} seconds" + else: + await queue.put( + (_priority, current_count, task_id, args, kwargs) ) - else: - # No timeout, may wait indefinitely - # current_count is used to ensure FIFO order - await queue.put((_priority, current_count, future, args, kwargs)) - except Exception as e: - # Clean up the future - if not future.done(): - future.set_exception(e) - active_futures.discard(future) - raise + except asyncio.TimeoutError: + raise QueueFullError( + f"Queue full, timeout after {_queue_timeout} seconds" + ) + except Exception as e: + # Clean up on queue error + if not future.done(): + future.set_exception(e) + raise - try: - # Wait for the result, optional timeout - if _timeout is not None: - try: + # Wait for result with timeout handling + try: + if _timeout is not None: return await asyncio.wait_for(future, _timeout) - except asyncio.TimeoutError: - # Cancel the future - if not future.done(): - future.cancel() - raise TimeoutError( - f"limit_async: Task timed out after {_timeout} seconds" - ) - else: - # Wait for the result without timeout - return await future - finally: - # Clean up the future reference - active_futures.discard(future) + else: + return await future + except asyncio.TimeoutError: + # This is user-level timeout (asyncio.wait_for caused) + # Mark cancellation request + async with task_states_lock: + if task_id in task_states: + task_states[task_id].cancellation_requested = True - # Add the shutdown method to the decorated function + # Cancel future + if not future.done(): + future.cancel() + + # Wait for worker cleanup with timeout + cleanup_start = asyncio.get_event_loop().time() + while ( + task_id in task_states + and asyncio.get_event_loop().time() - cleanup_start + < cleanup_timeout + ): + await asyncio.sleep(0.1) + + raise TimeoutError( + f"limit_async: User timeout after {_timeout} seconds" + ) + except WorkerTimeoutError as e: + # This is Worker-level timeout, directly propagate exception information + raise TimeoutError(f"limit_async: {str(e)}") + except HealthCheckTimeoutError as e: + # This is Health Check-level timeout, directly propagate exception information + raise TimeoutError(f"limit_async: {str(e)}") + + finally: + # Ensure cleanup + active_futures.discard(future) + async with task_states_lock: + task_states.pop(task_id, None) + + # Add shutdown method to decorated function wait_func.shutdown = shutdown return wait_func