refac: Add robust time out handling for LLM request

This commit is contained in:
yangdx 2025-08-29 13:50:35 +08:00
parent ac2db35160
commit 925e631a9a
5 changed files with 331 additions and 112 deletions

View file

@ -156,8 +156,8 @@ MAX_PARALLEL_INSERT=2
### LLM Configuration ### LLM Configuration
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock ### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
########################################################### ###########################################################
### LLM request timeout setting for all llm (set to TIMEOUT if not specified, 0 means no timeout for Ollma) ### LLM request timeout setting for all llm (0 means no timeout for Ollma)
# LLM_TIMEOUT=150 # LLM_TIMEOUT=180
LLM_BINDING=openai LLM_BINDING=openai
LLM_MODEL=gpt-4o LLM_MODEL=gpt-4o
@ -206,7 +206,7 @@ OLLAMA_LLM_NUM_CTX=32768
### Embedding Configuration (Should not be changed after the first file processed) ### Embedding Configuration (Should not be changed after the first file processed)
### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock ### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock
#################################################################################### ####################################################################################
### see also env.ollama-binding-options.example for fine tuning ollama # EMBEDDING_TIMEOUT=30
EMBEDDING_BINDING=ollama EMBEDDING_BINDING=ollama
EMBEDDING_MODEL=bge-m3:latest EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024

View file

@ -39,6 +39,8 @@ from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES, DEFAULT_LOG_MAX_BYTES,
DEFAULT_LOG_BACKUP_COUNT, DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME, DEFAULT_LOG_FILENAME,
DEFAULT_LLM_TIMEOUT,
DEFAULT_EMBEDDING_TIMEOUT,
) )
from lightrag.api.routers.document_routes import ( from lightrag.api.routers.document_routes import (
DocumentManager, DocumentManager,
@ -256,7 +258,10 @@ def create_app(args):
if args.embedding_binding == "jina": if args.embedding_binding == "jina":
from lightrag.llm.jina import jina_embed from lightrag.llm.jina import jina_embed
llm_timeout = get_env_value("LLM_TIMEOUT", args.timeout, int) llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value(
"EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int
)
async def openai_alike_model_complete( async def openai_alike_model_complete(
prompt, prompt,
@ -487,6 +492,8 @@ def create_app(args):
else {} else {}
), ),
embedding_func=embedding_func, embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage, kv_storage=args.kv_storage,
graph_storage=args.graph_storage, graph_storage=args.graph_storage,
vector_storage=args.vector_storage, vector_storage=args.vector_storage,
@ -517,6 +524,8 @@ def create_app(args):
summary_max_tokens=args.summary_max_tokens, summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size, summary_context_size=args.summary_context_size,
embedding_func=embedding_func, embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage, kv_storage=args.kv_storage,
graph_storage=args.graph_storage, graph_storage=args.graph_storage,
vector_storage=args.vector_storage, vector_storage=args.vector_storage,

View file

@ -64,8 +64,12 @@ DEFAULT_MAX_PARALLEL_INSERT = 2 # Default maximum parallel insert operations
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions
DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations
# gunicorn worker timeout(as default LLM request timeout if LLM_TIMEOUT is not set) # Gunicorn worker timeout
DEFAULT_TIMEOUT = 150 DEFAULT_TIMEOUT = 210
# Default llm and embedding timeout
DEFAULT_LLM_TIMEOUT = 180
DEFAULT_EMBEDDING_TIMEOUT = 30
# Logging configuration defaults # Logging configuration defaults
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB

View file

@ -41,6 +41,8 @@ from lightrag.constants import (
DEFAULT_MAX_GRAPH_NODES, DEFAULT_MAX_GRAPH_NODES,
DEFAULT_ENTITY_TYPES, DEFAULT_ENTITY_TYPES,
DEFAULT_SUMMARY_LANGUAGE, DEFAULT_SUMMARY_LANGUAGE,
DEFAULT_LLM_TIMEOUT,
DEFAULT_EMBEDDING_TIMEOUT,
) )
from lightrag.utils import get_env_value from lightrag.utils import get_env_value
@ -277,6 +279,10 @@ class LightRAG:
- use_llm_check: If True, validates cached embeddings using an LLM. - use_llm_check: If True, validates cached embeddings using an LLM.
""" """
default_embedding_timeout: int = field(
default=int(os.getenv("EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT))
)
# LLM Configuration # LLM Configuration
# --- # ---
@ -311,6 +317,10 @@ class LightRAG:
llm_model_kwargs: dict[str, Any] = field(default_factory=dict) llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function.""" """Additional keyword arguments passed to the LLM model function."""
default_llm_timeout: int = field(
default=int(os.getenv("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT))
)
# Rerank Configuration # Rerank Configuration
# --- # ---
@ -457,7 +467,8 @@ class LightRAG:
# Init Embedding # Init Embedding
self.embedding_func = priority_limit_async_func_call( self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async self.embedding_func_max_async,
llm_timeout=self.default_embedding_timeout,
)(self.embedding_func) )(self.embedding_func)
# Initialize all storages # Initialize all storages
@ -550,7 +561,11 @@ class LightRAG:
# Directly use llm_response_cache, don't create a new object # Directly use llm_response_cache, don't create a new object
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)( # Get timeout from LLM model kwargs for dynamic timeout calculation
self.llm_model_func = priority_limit_async_func_call(
self.llm_model_max_async,
llm_timeout=self.default_llm_timeout,
)(
partial( partial(
self.llm_model_func, # type: ignore self.llm_model_func, # type: ignore
hashing_kv=hashing_kv, hashing_kv=hashing_kv,

View file

@ -254,6 +254,18 @@ class UnlimitedSemaphore:
pass pass
@dataclass
class TaskState:
"""Task state tracking for priority queue management"""
future: asyncio.Future
start_time: float
execution_start_time: float = None
worker_started: bool = False
cancellation_requested: bool = False
cleanup_done: bool = False
@dataclass @dataclass
class EmbeddingFunc: class EmbeddingFunc:
embedding_dim: int embedding_dim: int
@ -323,20 +335,58 @@ def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
return None return None
# Custom exception class # Custom exception classes
class QueueFullError(Exception): class QueueFullError(Exception):
"""Raised when the queue is full and the wait times out""" """Raised when the queue is full and the wait times out"""
pass pass
def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): class WorkerTimeoutError(Exception):
"""Worker-level timeout exception with specific timeout information"""
def __init__(self, timeout_value: float, timeout_type: str = "execution"):
self.timeout_value = timeout_value
self.timeout_type = timeout_type
super().__init__(f"Worker {timeout_type} timeout after {timeout_value}s")
class HealthCheckTimeoutError(Exception):
"""Health Check-level timeout exception"""
def __init__(self, timeout_value: float, execution_duration: float):
self.timeout_value = timeout_value
self.execution_duration = execution_duration
super().__init__(
f"Task forcefully terminated due to execution timeout (>{timeout_value}s, actual: {execution_duration:.1f}s)"
)
def priority_limit_async_func_call(
max_size: int,
llm_timeout: float = None,
max_execution_timeout: float = None,
max_task_duration: float = None,
max_queue_size: int = 1000,
cleanup_timeout: float = 2.0,
):
""" """
Enhanced priority-limited asynchronous function call decorator Enhanced priority-limited asynchronous function call decorator with robust timeout handling
This decorator provides a comprehensive solution for managing concurrent LLM requests with:
- Multi-layer timeout protection (LLM -> Worker -> Health Check -> User)
- Task state tracking to prevent race conditions
- Enhanced health check system with stuck task detection
- Proper resource cleanup and error recovery
Args: Args:
max_size: Maximum number of concurrent calls max_size: Maximum number of concurrent calls
max_queue_size: Maximum queue capacity to prevent memory overflow max_queue_size: Maximum queue capacity to prevent memory overflow
llm_timeout: LLM provider timeout (from global config), used to calculate other timeouts
max_execution_timeout: Maximum time for worker to execute function (defaults to llm_timeout + 30s)
max_task_duration: Maximum time before health check intervenes (defaults to llm_timeout + 60s)
cleanup_timeout: Maximum time to wait for cleanup operations (defaults to 2.0s)
Returns: Returns:
Decorator function Decorator function
""" """
@ -345,81 +395,173 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
# Ensure func is callable # Ensure func is callable
if not callable(func): if not callable(func):
raise TypeError(f"Expected a callable object, got {type(func)}") raise TypeError(f"Expected a callable object, got {type(func)}")
# Calculate timeout hierarchy if llm_timeout is provided (Dynamic Timeout Calculation)
if llm_timeout is not None:
nonlocal max_execution_timeout, max_task_duration
if max_execution_timeout is None:
max_execution_timeout = (
llm_timeout + 30
) # LLM timeout + 30s buffer for network delays
if max_task_duration is None:
max_task_duration = (
llm_timeout + 60
) # LLM timeout + 1min buffer for execution phase
queue = asyncio.PriorityQueue(maxsize=max_queue_size) queue = asyncio.PriorityQueue(maxsize=max_queue_size)
tasks = set() tasks = set()
initialization_lock = asyncio.Lock() initialization_lock = asyncio.Lock()
counter = 0 counter = 0
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
initialized = False # Global initialization flag initialized = False
worker_health_check_task = None worker_health_check_task = None
# Track active future objects for cleanup # Enhanced task state management
task_states = {} # task_id -> TaskState
task_states_lock = asyncio.Lock()
active_futures = weakref.WeakSet() active_futures = weakref.WeakSet()
reinit_count = 0 # Reinitialization counter to track system health reinit_count = 0
# Worker function to process tasks in the queue
async def worker(): async def worker():
"""Worker that processes tasks in the priority queue""" """Enhanced worker that processes tasks with proper timeout and state management"""
try: try:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
# Use timeout to get tasks, allowing periodic checking of shutdown signal # Get task from queue with timeout for shutdown checking
try: try:
( (
priority, priority,
count, count,
future, task_id,
args, args,
kwargs, kwargs,
) = await asyncio.wait_for(queue.get(), timeout=1.0) ) = await asyncio.wait_for(queue.get(), timeout=1.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Timeout is just to check shutdown signal, continue to next iteration
continue continue
# If future is cancelled, skip execution # Get task state and mark worker as started
if future.cancelled(): async with task_states_lock:
if task_id not in task_states:
queue.task_done()
continue
task_state = task_states[task_id]
task_state.worker_started = True
# Record execution start time when worker actually begins processing
task_state.execution_start_time = (
asyncio.get_event_loop().time()
)
# Check if task was cancelled before worker started
if (
task_state.cancellation_requested
or task_state.future.cancelled()
):
async with task_states_lock:
task_states.pop(task_id, None)
queue.task_done() queue.task_done()
continue continue
try: try:
# Execute function # Execute function with timeout protection
result = await func(*args, **kwargs) if max_execution_timeout is not None:
# If future is not done, set the result result = await asyncio.wait_for(
if not future.done(): func(*args, **kwargs), timeout=max_execution_timeout
future.set_result(result) )
except asyncio.CancelledError: else:
if not future.done(): result = await func(*args, **kwargs)
future.cancel()
logger.debug("limit_async: Task cancelled during execution") # Set result if future is still valid
except Exception as e: if not task_state.future.done():
logger.error( task_state.future.set_result(result)
f"limit_async: Error in decorated function: {str(e)}"
except asyncio.TimeoutError:
# Worker-level timeout (max_execution_timeout exceeded)
logger.warning(
f"limit_async: Worker timeout for task {task_id} after {max_execution_timeout}s"
) )
if not future.done(): if not task_state.future.done():
future.set_exception(e) task_state.future.set_exception(
WorkerTimeoutError(
max_execution_timeout, "execution"
)
)
except asyncio.CancelledError:
# Task was cancelled during execution
if not task_state.future.done():
task_state.future.cancel()
logger.debug(
f"limit_async: Task {task_id} cancelled during execution"
)
except Exception as e:
# Function execution error
logger.error(
f"limit_async: Error in decorated function for task {task_id}: {str(e)}"
)
if not task_state.future.done():
task_state.future.set_exception(e)
finally: finally:
# Clean up task state
async with task_states_lock:
task_states.pop(task_id, None)
queue.task_done() queue.task_done()
except Exception as e: except Exception as e:
# Catch all exceptions in worker loop to prevent worker termination # Critical error in worker loop
logger.error(f"limit_async: Critical error in worker: {str(e)}") logger.error(f"limit_async: Critical error in worker: {str(e)}")
await asyncio.sleep(0.1) # Prevent high CPU usage await asyncio.sleep(0.1)
finally: finally:
logger.debug("limit_async: Worker exiting") logger.debug("limit_async: Worker exiting")
async def health_check(): async def enhanced_health_check():
"""Periodically check worker health status and recover""" """Enhanced health check with stuck task detection and recovery"""
nonlocal initialized nonlocal initialized
try: try:
while not shutdown_event.is_set(): while not shutdown_event.is_set():
await asyncio.sleep(5) # Check every 5 seconds await asyncio.sleep(5) # Check every 5 seconds
# No longer acquire lock, directly operate on task set current_time = asyncio.get_event_loop().time()
# Use a copy of the task set to avoid concurrent modification
# Detect and handle stuck tasks based on execution start time
if max_task_duration is not None:
stuck_tasks = []
async with task_states_lock:
for task_id, task_state in list(task_states.items()):
# Only check tasks that have started execution
if (
task_state.worker_started
and task_state.execution_start_time is not None
and current_time - task_state.execution_start_time
> max_task_duration
):
stuck_tasks.append(
(
task_id,
current_time
- task_state.execution_start_time,
)
)
# Force cleanup of stuck tasks
for task_id, execution_duration in stuck_tasks:
logger.warning(
f"limit_async: Detected stuck task {task_id} (execution time: {execution_duration:.1f}s), forcing cleanup"
)
async with task_states_lock:
if task_id in task_states:
task_state = task_states[task_id]
if not task_state.future.done():
task_state.future.set_exception(
HealthCheckTimeoutError(
max_task_duration, execution_duration
)
)
task_states.pop(task_id, None)
# Worker recovery logic
current_tasks = set(tasks) current_tasks = set(tasks)
done_tasks = {t for t in current_tasks if t.done()} done_tasks = {t for t in current_tasks if t.done()}
tasks.difference_update(done_tasks) tasks.difference_update(done_tasks)
# Calculate active tasks count
active_tasks_count = len(tasks) active_tasks_count = len(tasks)
workers_needed = max_size - active_tasks_count workers_needed = max_size - active_tasks_count
@ -432,21 +574,16 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
task = asyncio.create_task(worker()) task = asyncio.create_task(worker())
new_tasks.add(task) new_tasks.add(task)
task.add_done_callback(tasks.discard) task.add_done_callback(tasks.discard)
# Update task set in one operation
tasks.update(new_tasks) tasks.update(new_tasks)
except Exception as e: except Exception as e:
logger.error(f"limit_async: Error in health check: {str(e)}") logger.error(f"limit_async: Error in enhanced health check: {str(e)}")
finally: finally:
logger.debug("limit_async: Health check task exiting") logger.debug("limit_async: Enhanced health check task exiting")
initialized = False initialized = False
async def ensure_workers(): async def ensure_workers():
"""Ensure worker threads and health check system are available """Ensure worker system is initialized with enhanced error handling"""
This function checks if the worker system is already initialized.
If not, it performs a one-time initialization of all worker threads
and starts the health check system.
"""
nonlocal initialized, worker_health_check_task, tasks, reinit_count nonlocal initialized, worker_health_check_task, tasks, reinit_count
if initialized: if initialized:
@ -456,45 +593,56 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
if initialized: if initialized:
return return
# Increment reinitialization counter if this is not the first initialization
if reinit_count > 0: if reinit_count > 0:
reinit_count += 1 reinit_count += 1
logger.warning( logger.warning(
f"limit_async: Reinitializing needed (count: {reinit_count})" f"limit_async: Reinitializing system (count: {reinit_count})"
) )
else: else:
reinit_count = 1 # First initialization reinit_count = 1
# Check for completed tasks and remove them from the task set # Clean up completed tasks
current_tasks = set(tasks) current_tasks = set(tasks)
done_tasks = {t for t in current_tasks if t.done()} done_tasks = {t for t in current_tasks if t.done()}
tasks.difference_update(done_tasks) tasks.difference_update(done_tasks)
# Log active tasks count during reinitialization
active_tasks_count = len(tasks) active_tasks_count = len(tasks)
if active_tasks_count > 0 and reinit_count > 1: if active_tasks_count > 0 and reinit_count > 1:
logger.warning( logger.warning(
f"limit_async: {active_tasks_count} tasks still running during reinitialization" f"limit_async: {active_tasks_count} tasks still running during reinitialization"
) )
# Create initial worker tasks, only adding the number needed # Create worker tasks
workers_needed = max_size - active_tasks_count workers_needed = max_size - active_tasks_count
for _ in range(workers_needed): for _ in range(workers_needed):
task = asyncio.create_task(worker()) task = asyncio.create_task(worker())
tasks.add(task) tasks.add(task)
task.add_done_callback(tasks.discard) task.add_done_callback(tasks.discard)
# Start health check # Start enhanced health check
worker_health_check_task = asyncio.create_task(health_check()) worker_health_check_task = asyncio.create_task(enhanced_health_check())
initialized = True initialized = True
logger.info(f"limit_async: {workers_needed} new workers initialized") # Log dynamic timeout configuration
timeout_info = []
if llm_timeout is not None:
timeout_info.append(f"LLM: {llm_timeout}s")
if max_execution_timeout is not None:
timeout_info.append(f"Execution: {max_execution_timeout}s")
if max_task_duration is not None:
timeout_info.append(f"Health Check: {max_task_duration}s")
timeout_str = (
f" (Timeouts: {', '.join(timeout_info)})" if timeout_info else ""
)
logger.info(
f"limit_async: {workers_needed} new workers initialized with dynamic timeout handling{timeout_str}"
)
async def shutdown(): async def shutdown():
"""Gracefully shut down all workers and the queue""" """Gracefully shut down all workers and cleanup resources"""
logger.info("limit_async: Shutting down priority queue workers") logger.info("limit_async: Shutting down priority queue workers")
# Set the shutdown event
shutdown_event.set() shutdown_event.set()
# Cancel all active futures # Cancel all active futures
@ -502,7 +650,14 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
if not future.done(): if not future.done():
future.cancel() future.cancel()
# Wait for the queue to empty # Cancel all pending tasks
async with task_states_lock:
for task_id, task_state in list(task_states.items()):
if not task_state.future.done():
task_state.future.cancel()
task_states.clear()
# Wait for queue to empty with timeout
try: try:
await asyncio.wait_for(queue.join(), timeout=5.0) await asyncio.wait_for(queue.join(), timeout=5.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -510,7 +665,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
"limit_async: Timeout waiting for queue to empty during shutdown" "limit_async: Timeout waiting for queue to empty during shutdown"
) )
# Cancel all worker tasks # Cancel worker tasks
for task in list(tasks): for task in list(tasks):
if not task.done(): if not task.done():
task.cancel() task.cancel()
@ -519,7 +674,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
if tasks: if tasks:
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
# Cancel the health check task # Cancel health check task
if worker_health_check_task and not worker_health_check_task.done(): if worker_health_check_task and not worker_health_check_task.done():
worker_health_check_task.cancel() worker_health_check_task.cancel()
try: try:
@ -534,77 +689,113 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs *args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs
): ):
""" """
Execute the function with priority-based concurrency control Execute function with enhanced priority-based concurrency control and timeout handling
Args: Args:
*args: Positional arguments passed to the function *args: Positional arguments passed to the function
_priority: Call priority (lower values have higher priority) _priority: Call priority (lower values have higher priority)
_timeout: Maximum time to wait for function completion (in seconds) _timeout: Maximum time to wait for completion (in seconds, none means determinded by max_execution_timeout of the queue)
_queue_timeout: Maximum time to wait for entering the queue (in seconds) _queue_timeout: Maximum time to wait for entering the queue (in seconds)
**kwargs: Keyword arguments passed to the function **kwargs: Keyword arguments passed to the function
Returns: Returns:
The result of the function call The result of the function call
Raises: Raises:
TimeoutError: If the function call times out TimeoutError: If the function call times out at any level
QueueFullError: If the queue is full and waiting times out QueueFullError: If the queue is full and waiting times out
Any exception raised by the decorated function Any exception raised by the decorated function
""" """
# Ensure worker system is initialized
await ensure_workers() await ensure_workers()
# Create a future for the result # Generate unique task ID
task_id = f"{id(asyncio.current_task())}_{asyncio.get_event_loop().time()}"
future = asyncio.Future() future = asyncio.Future()
active_futures.add(future)
nonlocal counter # Create task state
async with initialization_lock: task_state = TaskState(
current_count = counter # Use local variable to avoid race conditions future=future, start_time=asyncio.get_event_loop().time()
counter += 1 )
# Try to put the task into the queue, supporting timeout
try: try:
if _queue_timeout is not None: # Register task state
# Use timeout to wait for queue space async with task_states_lock:
try: task_states[task_id] = task_state
active_futures.add(future)
# Get counter for FIFO ordering
nonlocal counter
async with initialization_lock:
current_count = counter
counter += 1
# Queue the task with timeout handling
try:
if _queue_timeout is not None:
await asyncio.wait_for( await asyncio.wait_for(
# current_count is used to ensure FIFO order queue.put(
queue.put((_priority, current_count, future, args, kwargs)), (_priority, current_count, task_id, args, kwargs)
),
timeout=_queue_timeout, timeout=_queue_timeout,
) )
except asyncio.TimeoutError: else:
raise QueueFullError( await queue.put(
f"Queue full, timeout after {_queue_timeout} seconds" (_priority, current_count, task_id, args, kwargs)
) )
else: except asyncio.TimeoutError:
# No timeout, may wait indefinitely raise QueueFullError(
# current_count is used to ensure FIFO order f"Queue full, timeout after {_queue_timeout} seconds"
await queue.put((_priority, current_count, future, args, kwargs)) )
except Exception as e: except Exception as e:
# Clean up the future # Clean up on queue error
if not future.done(): if not future.done():
future.set_exception(e) future.set_exception(e)
active_futures.discard(future) raise
raise
try: # Wait for result with timeout handling
# Wait for the result, optional timeout try:
if _timeout is not None: if _timeout is not None:
try:
return await asyncio.wait_for(future, _timeout) return await asyncio.wait_for(future, _timeout)
except asyncio.TimeoutError: else:
# Cancel the future return await future
if not future.done(): except asyncio.TimeoutError:
future.cancel() # This is user-level timeout (asyncio.wait_for caused)
raise TimeoutError( # Mark cancellation request
f"limit_async: Task timed out after {_timeout} seconds" async with task_states_lock:
) if task_id in task_states:
else: task_states[task_id].cancellation_requested = True
# Wait for the result without timeout
return await future
finally:
# Clean up the future reference
active_futures.discard(future)
# Add the shutdown method to the decorated function # Cancel future
if not future.done():
future.cancel()
# Wait for worker cleanup with timeout
cleanup_start = asyncio.get_event_loop().time()
while (
task_id in task_states
and asyncio.get_event_loop().time() - cleanup_start
< cleanup_timeout
):
await asyncio.sleep(0.1)
raise TimeoutError(
f"limit_async: User timeout after {_timeout} seconds"
)
except WorkerTimeoutError as e:
# This is Worker-level timeout, directly propagate exception information
raise TimeoutError(f"limit_async: {str(e)}")
except HealthCheckTimeoutError as e:
# This is Health Check-level timeout, directly propagate exception information
raise TimeoutError(f"limit_async: {str(e)}")
finally:
# Ensure cleanup
active_futures.discard(future)
async with task_states_lock:
task_states.pop(task_id, None)
# Add shutdown method to decorated function
wait_func.shutdown = shutdown wait_func.shutdown = shutdown
return wait_func return wait_func