refac: Add robust time out handling for LLM request
This commit is contained in:
parent
ac2db35160
commit
925e631a9a
5 changed files with 331 additions and 112 deletions
|
|
@ -156,8 +156,8 @@ MAX_PARALLEL_INSERT=2
|
||||||
### LLM Configuration
|
### LLM Configuration
|
||||||
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
||||||
###########################################################
|
###########################################################
|
||||||
### LLM request timeout setting for all llm (set to TIMEOUT if not specified, 0 means no timeout for Ollma)
|
### LLM request timeout setting for all llm (0 means no timeout for Ollma)
|
||||||
# LLM_TIMEOUT=150
|
# LLM_TIMEOUT=180
|
||||||
|
|
||||||
LLM_BINDING=openai
|
LLM_BINDING=openai
|
||||||
LLM_MODEL=gpt-4o
|
LLM_MODEL=gpt-4o
|
||||||
|
|
@ -206,7 +206,7 @@ OLLAMA_LLM_NUM_CTX=32768
|
||||||
### Embedding Configuration (Should not be changed after the first file processed)
|
### Embedding Configuration (Should not be changed after the first file processed)
|
||||||
### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock
|
### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock
|
||||||
####################################################################################
|
####################################################################################
|
||||||
### see also env.ollama-binding-options.example for fine tuning ollama
|
# EMBEDDING_TIMEOUT=30
|
||||||
EMBEDDING_BINDING=ollama
|
EMBEDDING_BINDING=ollama
|
||||||
EMBEDDING_MODEL=bge-m3:latest
|
EMBEDDING_MODEL=bge-m3:latest
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@ from lightrag.constants import (
|
||||||
DEFAULT_LOG_MAX_BYTES,
|
DEFAULT_LOG_MAX_BYTES,
|
||||||
DEFAULT_LOG_BACKUP_COUNT,
|
DEFAULT_LOG_BACKUP_COUNT,
|
||||||
DEFAULT_LOG_FILENAME,
|
DEFAULT_LOG_FILENAME,
|
||||||
|
DEFAULT_LLM_TIMEOUT,
|
||||||
|
DEFAULT_EMBEDDING_TIMEOUT,
|
||||||
)
|
)
|
||||||
from lightrag.api.routers.document_routes import (
|
from lightrag.api.routers.document_routes import (
|
||||||
DocumentManager,
|
DocumentManager,
|
||||||
|
|
@ -256,7 +258,10 @@ def create_app(args):
|
||||||
if args.embedding_binding == "jina":
|
if args.embedding_binding == "jina":
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
llm_timeout = get_env_value("LLM_TIMEOUT", args.timeout, int)
|
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||||
|
embedding_timeout = get_env_value(
|
||||||
|
"EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int
|
||||||
|
)
|
||||||
|
|
||||||
async def openai_alike_model_complete(
|
async def openai_alike_model_complete(
|
||||||
prompt,
|
prompt,
|
||||||
|
|
@ -487,6 +492,8 @@ def create_app(args):
|
||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
|
default_llm_timeout=llm_timeout,
|
||||||
|
default_embedding_timeout=embedding_timeout,
|
||||||
kv_storage=args.kv_storage,
|
kv_storage=args.kv_storage,
|
||||||
graph_storage=args.graph_storage,
|
graph_storage=args.graph_storage,
|
||||||
vector_storage=args.vector_storage,
|
vector_storage=args.vector_storage,
|
||||||
|
|
@ -517,6 +524,8 @@ def create_app(args):
|
||||||
summary_max_tokens=args.summary_max_tokens,
|
summary_max_tokens=args.summary_max_tokens,
|
||||||
summary_context_size=args.summary_context_size,
|
summary_context_size=args.summary_context_size,
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
|
default_llm_timeout=llm_timeout,
|
||||||
|
default_embedding_timeout=embedding_timeout,
|
||||||
kv_storage=args.kv_storage,
|
kv_storage=args.kv_storage,
|
||||||
graph_storage=args.graph_storage,
|
graph_storage=args.graph_storage,
|
||||||
vector_storage=args.vector_storage,
|
vector_storage=args.vector_storage,
|
||||||
|
|
|
||||||
|
|
@ -64,8 +64,12 @@ DEFAULT_MAX_PARALLEL_INSERT = 2 # Default maximum parallel insert operations
|
||||||
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions
|
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions
|
||||||
DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations
|
DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations
|
||||||
|
|
||||||
# gunicorn worker timeout(as default LLM request timeout if LLM_TIMEOUT is not set)
|
# Gunicorn worker timeout
|
||||||
DEFAULT_TIMEOUT = 150
|
DEFAULT_TIMEOUT = 210
|
||||||
|
|
||||||
|
# Default llm and embedding timeout
|
||||||
|
DEFAULT_LLM_TIMEOUT = 180
|
||||||
|
DEFAULT_EMBEDDING_TIMEOUT = 30
|
||||||
|
|
||||||
# Logging configuration defaults
|
# Logging configuration defaults
|
||||||
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
|
DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,8 @@ from lightrag.constants import (
|
||||||
DEFAULT_MAX_GRAPH_NODES,
|
DEFAULT_MAX_GRAPH_NODES,
|
||||||
DEFAULT_ENTITY_TYPES,
|
DEFAULT_ENTITY_TYPES,
|
||||||
DEFAULT_SUMMARY_LANGUAGE,
|
DEFAULT_SUMMARY_LANGUAGE,
|
||||||
|
DEFAULT_LLM_TIMEOUT,
|
||||||
|
DEFAULT_EMBEDDING_TIMEOUT,
|
||||||
)
|
)
|
||||||
from lightrag.utils import get_env_value
|
from lightrag.utils import get_env_value
|
||||||
|
|
||||||
|
|
@ -277,6 +279,10 @@ class LightRAG:
|
||||||
- use_llm_check: If True, validates cached embeddings using an LLM.
|
- use_llm_check: If True, validates cached embeddings using an LLM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
default_embedding_timeout: int = field(
|
||||||
|
default=int(os.getenv("EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT))
|
||||||
|
)
|
||||||
|
|
||||||
# LLM Configuration
|
# LLM Configuration
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
|
|
@ -311,6 +317,10 @@ class LightRAG:
|
||||||
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Additional keyword arguments passed to the LLM model function."""
|
"""Additional keyword arguments passed to the LLM model function."""
|
||||||
|
|
||||||
|
default_llm_timeout: int = field(
|
||||||
|
default=int(os.getenv("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT))
|
||||||
|
)
|
||||||
|
|
||||||
# Rerank Configuration
|
# Rerank Configuration
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
|
|
@ -457,7 +467,8 @@ class LightRAG:
|
||||||
|
|
||||||
# Init Embedding
|
# Init Embedding
|
||||||
self.embedding_func = priority_limit_async_func_call(
|
self.embedding_func = priority_limit_async_func_call(
|
||||||
self.embedding_func_max_async
|
self.embedding_func_max_async,
|
||||||
|
llm_timeout=self.default_embedding_timeout,
|
||||||
)(self.embedding_func)
|
)(self.embedding_func)
|
||||||
|
|
||||||
# Initialize all storages
|
# Initialize all storages
|
||||||
|
|
@ -550,7 +561,11 @@ class LightRAG:
|
||||||
# Directly use llm_response_cache, don't create a new object
|
# Directly use llm_response_cache, don't create a new object
|
||||||
hashing_kv = self.llm_response_cache
|
hashing_kv = self.llm_response_cache
|
||||||
|
|
||||||
self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)(
|
# Get timeout from LLM model kwargs for dynamic timeout calculation
|
||||||
|
self.llm_model_func = priority_limit_async_func_call(
|
||||||
|
self.llm_model_max_async,
|
||||||
|
llm_timeout=self.default_llm_timeout,
|
||||||
|
)(
|
||||||
partial(
|
partial(
|
||||||
self.llm_model_func, # type: ignore
|
self.llm_model_func, # type: ignore
|
||||||
hashing_kv=hashing_kv,
|
hashing_kv=hashing_kv,
|
||||||
|
|
|
||||||
|
|
@ -254,6 +254,18 @@ class UnlimitedSemaphore:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskState:
|
||||||
|
"""Task state tracking for priority queue management"""
|
||||||
|
|
||||||
|
future: asyncio.Future
|
||||||
|
start_time: float
|
||||||
|
execution_start_time: float = None
|
||||||
|
worker_started: bool = False
|
||||||
|
cancellation_requested: bool = False
|
||||||
|
cleanup_done: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
|
|
@ -323,20 +335,58 @@ def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Custom exception class
|
# Custom exception classes
|
||||||
class QueueFullError(Exception):
|
class QueueFullError(Exception):
|
||||||
"""Raised when the queue is full and the wait times out"""
|
"""Raised when the queue is full and the wait times out"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
class WorkerTimeoutError(Exception):
|
||||||
|
"""Worker-level timeout exception with specific timeout information"""
|
||||||
|
|
||||||
|
def __init__(self, timeout_value: float, timeout_type: str = "execution"):
|
||||||
|
self.timeout_value = timeout_value
|
||||||
|
self.timeout_type = timeout_type
|
||||||
|
super().__init__(f"Worker {timeout_type} timeout after {timeout_value}s")
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckTimeoutError(Exception):
|
||||||
|
"""Health Check-level timeout exception"""
|
||||||
|
|
||||||
|
def __init__(self, timeout_value: float, execution_duration: float):
|
||||||
|
self.timeout_value = timeout_value
|
||||||
|
self.execution_duration = execution_duration
|
||||||
|
super().__init__(
|
||||||
|
f"Task forcefully terminated due to execution timeout (>{timeout_value}s, actual: {execution_duration:.1f}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def priority_limit_async_func_call(
|
||||||
|
max_size: int,
|
||||||
|
llm_timeout: float = None,
|
||||||
|
max_execution_timeout: float = None,
|
||||||
|
max_task_duration: float = None,
|
||||||
|
max_queue_size: int = 1000,
|
||||||
|
cleanup_timeout: float = 2.0,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Enhanced priority-limited asynchronous function call decorator
|
Enhanced priority-limited asynchronous function call decorator with robust timeout handling
|
||||||
|
|
||||||
|
This decorator provides a comprehensive solution for managing concurrent LLM requests with:
|
||||||
|
- Multi-layer timeout protection (LLM -> Worker -> Health Check -> User)
|
||||||
|
- Task state tracking to prevent race conditions
|
||||||
|
- Enhanced health check system with stuck task detection
|
||||||
|
- Proper resource cleanup and error recovery
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_size: Maximum number of concurrent calls
|
max_size: Maximum number of concurrent calls
|
||||||
max_queue_size: Maximum queue capacity to prevent memory overflow
|
max_queue_size: Maximum queue capacity to prevent memory overflow
|
||||||
|
llm_timeout: LLM provider timeout (from global config), used to calculate other timeouts
|
||||||
|
max_execution_timeout: Maximum time for worker to execute function (defaults to llm_timeout + 30s)
|
||||||
|
max_task_duration: Maximum time before health check intervenes (defaults to llm_timeout + 60s)
|
||||||
|
cleanup_timeout: Maximum time to wait for cleanup operations (defaults to 2.0s)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decorator function
|
Decorator function
|
||||||
"""
|
"""
|
||||||
|
|
@ -345,81 +395,173 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||||
# Ensure func is callable
|
# Ensure func is callable
|
||||||
if not callable(func):
|
if not callable(func):
|
||||||
raise TypeError(f"Expected a callable object, got {type(func)}")
|
raise TypeError(f"Expected a callable object, got {type(func)}")
|
||||||
|
|
||||||
|
# Calculate timeout hierarchy if llm_timeout is provided (Dynamic Timeout Calculation)
|
||||||
|
if llm_timeout is not None:
|
||||||
|
nonlocal max_execution_timeout, max_task_duration
|
||||||
|
if max_execution_timeout is None:
|
||||||
|
max_execution_timeout = (
|
||||||
|
llm_timeout + 30
|
||||||
|
) # LLM timeout + 30s buffer for network delays
|
||||||
|
if max_task_duration is None:
|
||||||
|
max_task_duration = (
|
||||||
|
llm_timeout + 60
|
||||||
|
) # LLM timeout + 1min buffer for execution phase
|
||||||
|
|
||||||
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
|
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
|
||||||
tasks = set()
|
tasks = set()
|
||||||
initialization_lock = asyncio.Lock()
|
initialization_lock = asyncio.Lock()
|
||||||
counter = 0
|
counter = 0
|
||||||
shutdown_event = asyncio.Event()
|
shutdown_event = asyncio.Event()
|
||||||
initialized = False # Global initialization flag
|
initialized = False
|
||||||
worker_health_check_task = None
|
worker_health_check_task = None
|
||||||
|
|
||||||
# Track active future objects for cleanup
|
# Enhanced task state management
|
||||||
|
task_states = {} # task_id -> TaskState
|
||||||
|
task_states_lock = asyncio.Lock()
|
||||||
active_futures = weakref.WeakSet()
|
active_futures = weakref.WeakSet()
|
||||||
reinit_count = 0 # Reinitialization counter to track system health
|
reinit_count = 0
|
||||||
|
|
||||||
# Worker function to process tasks in the queue
|
|
||||||
async def worker():
|
async def worker():
|
||||||
"""Worker that processes tasks in the priority queue"""
|
"""Enhanced worker that processes tasks with proper timeout and state management"""
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
# Use timeout to get tasks, allowing periodic checking of shutdown signal
|
# Get task from queue with timeout for shutdown checking
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
priority,
|
priority,
|
||||||
count,
|
count,
|
||||||
future,
|
task_id,
|
||||||
args,
|
args,
|
||||||
kwargs,
|
kwargs,
|
||||||
) = await asyncio.wait_for(queue.get(), timeout=1.0)
|
) = await asyncio.wait_for(queue.get(), timeout=1.0)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Timeout is just to check shutdown signal, continue to next iteration
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If future is cancelled, skip execution
|
# Get task state and mark worker as started
|
||||||
if future.cancelled():
|
async with task_states_lock:
|
||||||
|
if task_id not in task_states:
|
||||||
|
queue.task_done()
|
||||||
|
continue
|
||||||
|
task_state = task_states[task_id]
|
||||||
|
task_state.worker_started = True
|
||||||
|
# Record execution start time when worker actually begins processing
|
||||||
|
task_state.execution_start_time = (
|
||||||
|
asyncio.get_event_loop().time()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if task was cancelled before worker started
|
||||||
|
if (
|
||||||
|
task_state.cancellation_requested
|
||||||
|
or task_state.future.cancelled()
|
||||||
|
):
|
||||||
|
async with task_states_lock:
|
||||||
|
task_states.pop(task_id, None)
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Execute function
|
# Execute function with timeout protection
|
||||||
result = await func(*args, **kwargs)
|
if max_execution_timeout is not None:
|
||||||
# If future is not done, set the result
|
result = await asyncio.wait_for(
|
||||||
if not future.done():
|
func(*args, **kwargs), timeout=max_execution_timeout
|
||||||
future.set_result(result)
|
)
|
||||||
except asyncio.CancelledError:
|
else:
|
||||||
if not future.done():
|
result = await func(*args, **kwargs)
|
||||||
future.cancel()
|
|
||||||
logger.debug("limit_async: Task cancelled during execution")
|
# Set result if future is still valid
|
||||||
except Exception as e:
|
if not task_state.future.done():
|
||||||
logger.error(
|
task_state.future.set_result(result)
|
||||||
f"limit_async: Error in decorated function: {str(e)}"
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Worker-level timeout (max_execution_timeout exceeded)
|
||||||
|
logger.warning(
|
||||||
|
f"limit_async: Worker timeout for task {task_id} after {max_execution_timeout}s"
|
||||||
)
|
)
|
||||||
if not future.done():
|
if not task_state.future.done():
|
||||||
future.set_exception(e)
|
task_state.future.set_exception(
|
||||||
|
WorkerTimeoutError(
|
||||||
|
max_execution_timeout, "execution"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Task was cancelled during execution
|
||||||
|
if not task_state.future.done():
|
||||||
|
task_state.future.cancel()
|
||||||
|
logger.debug(
|
||||||
|
f"limit_async: Task {task_id} cancelled during execution"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Function execution error
|
||||||
|
logger.error(
|
||||||
|
f"limit_async: Error in decorated function for task {task_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
if not task_state.future.done():
|
||||||
|
task_state.future.set_exception(e)
|
||||||
finally:
|
finally:
|
||||||
|
# Clean up task state
|
||||||
|
async with task_states_lock:
|
||||||
|
task_states.pop(task_id, None)
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch all exceptions in worker loop to prevent worker termination
|
# Critical error in worker loop
|
||||||
logger.error(f"limit_async: Critical error in worker: {str(e)}")
|
logger.error(f"limit_async: Critical error in worker: {str(e)}")
|
||||||
await asyncio.sleep(0.1) # Prevent high CPU usage
|
await asyncio.sleep(0.1)
|
||||||
finally:
|
finally:
|
||||||
logger.debug("limit_async: Worker exiting")
|
logger.debug("limit_async: Worker exiting")
|
||||||
|
|
||||||
async def health_check():
|
async def enhanced_health_check():
|
||||||
"""Periodically check worker health status and recover"""
|
"""Enhanced health check with stuck task detection and recovery"""
|
||||||
nonlocal initialized
|
nonlocal initialized
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
await asyncio.sleep(5) # Check every 5 seconds
|
await asyncio.sleep(5) # Check every 5 seconds
|
||||||
|
|
||||||
# No longer acquire lock, directly operate on task set
|
current_time = asyncio.get_event_loop().time()
|
||||||
# Use a copy of the task set to avoid concurrent modification
|
|
||||||
|
# Detect and handle stuck tasks based on execution start time
|
||||||
|
if max_task_duration is not None:
|
||||||
|
stuck_tasks = []
|
||||||
|
async with task_states_lock:
|
||||||
|
for task_id, task_state in list(task_states.items()):
|
||||||
|
# Only check tasks that have started execution
|
||||||
|
if (
|
||||||
|
task_state.worker_started
|
||||||
|
and task_state.execution_start_time is not None
|
||||||
|
and current_time - task_state.execution_start_time
|
||||||
|
> max_task_duration
|
||||||
|
):
|
||||||
|
stuck_tasks.append(
|
||||||
|
(
|
||||||
|
task_id,
|
||||||
|
current_time
|
||||||
|
- task_state.execution_start_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force cleanup of stuck tasks
|
||||||
|
for task_id, execution_duration in stuck_tasks:
|
||||||
|
logger.warning(
|
||||||
|
f"limit_async: Detected stuck task {task_id} (execution time: {execution_duration:.1f}s), forcing cleanup"
|
||||||
|
)
|
||||||
|
async with task_states_lock:
|
||||||
|
if task_id in task_states:
|
||||||
|
task_state = task_states[task_id]
|
||||||
|
if not task_state.future.done():
|
||||||
|
task_state.future.set_exception(
|
||||||
|
HealthCheckTimeoutError(
|
||||||
|
max_task_duration, execution_duration
|
||||||
|
)
|
||||||
|
)
|
||||||
|
task_states.pop(task_id, None)
|
||||||
|
|
||||||
|
# Worker recovery logic
|
||||||
current_tasks = set(tasks)
|
current_tasks = set(tasks)
|
||||||
done_tasks = {t for t in current_tasks if t.done()}
|
done_tasks = {t for t in current_tasks if t.done()}
|
||||||
tasks.difference_update(done_tasks)
|
tasks.difference_update(done_tasks)
|
||||||
|
|
||||||
# Calculate active tasks count
|
|
||||||
active_tasks_count = len(tasks)
|
active_tasks_count = len(tasks)
|
||||||
workers_needed = max_size - active_tasks_count
|
workers_needed = max_size - active_tasks_count
|
||||||
|
|
||||||
|
|
@ -432,21 +574,16 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||||
task = asyncio.create_task(worker())
|
task = asyncio.create_task(worker())
|
||||||
new_tasks.add(task)
|
new_tasks.add(task)
|
||||||
task.add_done_callback(tasks.discard)
|
task.add_done_callback(tasks.discard)
|
||||||
# Update task set in one operation
|
|
||||||
tasks.update(new_tasks)
|
tasks.update(new_tasks)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"limit_async: Error in health check: {str(e)}")
|
logger.error(f"limit_async: Error in enhanced health check: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
logger.debug("limit_async: Health check task exiting")
|
logger.debug("limit_async: Enhanced health check task exiting")
|
||||||
initialized = False
|
initialized = False
|
||||||
|
|
||||||
async def ensure_workers():
|
async def ensure_workers():
|
||||||
"""Ensure worker threads and health check system are available
|
"""Ensure worker system is initialized with enhanced error handling"""
|
||||||
|
|
||||||
This function checks if the worker system is already initialized.
|
|
||||||
If not, it performs a one-time initialization of all worker threads
|
|
||||||
and starts the health check system.
|
|
||||||
"""
|
|
||||||
nonlocal initialized, worker_health_check_task, tasks, reinit_count
|
nonlocal initialized, worker_health_check_task, tasks, reinit_count
|
||||||
|
|
||||||
if initialized:
|
if initialized:
|
||||||
|
|
@ -456,45 +593,56 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||||
if initialized:
|
if initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Increment reinitialization counter if this is not the first initialization
|
|
||||||
if reinit_count > 0:
|
if reinit_count > 0:
|
||||||
reinit_count += 1
|
reinit_count += 1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"limit_async: Reinitializing needed (count: {reinit_count})"
|
f"limit_async: Reinitializing system (count: {reinit_count})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
reinit_count = 1 # First initialization
|
reinit_count = 1
|
||||||
|
|
||||||
# Check for completed tasks and remove them from the task set
|
# Clean up completed tasks
|
||||||
current_tasks = set(tasks)
|
current_tasks = set(tasks)
|
||||||
done_tasks = {t for t in current_tasks if t.done()}
|
done_tasks = {t for t in current_tasks if t.done()}
|
||||||
tasks.difference_update(done_tasks)
|
tasks.difference_update(done_tasks)
|
||||||
|
|
||||||
# Log active tasks count during reinitialization
|
|
||||||
active_tasks_count = len(tasks)
|
active_tasks_count = len(tasks)
|
||||||
if active_tasks_count > 0 and reinit_count > 1:
|
if active_tasks_count > 0 and reinit_count > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"limit_async: {active_tasks_count} tasks still running during reinitialization"
|
f"limit_async: {active_tasks_count} tasks still running during reinitialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create initial worker tasks, only adding the number needed
|
# Create worker tasks
|
||||||
workers_needed = max_size - active_tasks_count
|
workers_needed = max_size - active_tasks_count
|
||||||
for _ in range(workers_needed):
|
for _ in range(workers_needed):
|
||||||
task = asyncio.create_task(worker())
|
task = asyncio.create_task(worker())
|
||||||
tasks.add(task)
|
tasks.add(task)
|
||||||
task.add_done_callback(tasks.discard)
|
task.add_done_callback(tasks.discard)
|
||||||
|
|
||||||
# Start health check
|
# Start enhanced health check
|
||||||
worker_health_check_task = asyncio.create_task(health_check())
|
worker_health_check_task = asyncio.create_task(enhanced_health_check())
|
||||||
|
|
||||||
initialized = True
|
initialized = True
|
||||||
logger.info(f"limit_async: {workers_needed} new workers initialized")
|
# Log dynamic timeout configuration
|
||||||
|
timeout_info = []
|
||||||
|
if llm_timeout is not None:
|
||||||
|
timeout_info.append(f"LLM: {llm_timeout}s")
|
||||||
|
if max_execution_timeout is not None:
|
||||||
|
timeout_info.append(f"Execution: {max_execution_timeout}s")
|
||||||
|
if max_task_duration is not None:
|
||||||
|
timeout_info.append(f"Health Check: {max_task_duration}s")
|
||||||
|
|
||||||
|
timeout_str = (
|
||||||
|
f" (Timeouts: {', '.join(timeout_info)})" if timeout_info else ""
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"limit_async: {workers_needed} new workers initialized with dynamic timeout handling{timeout_str}"
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown():
|
async def shutdown():
|
||||||
"""Gracefully shut down all workers and the queue"""
|
"""Gracefully shut down all workers and cleanup resources"""
|
||||||
logger.info("limit_async: Shutting down priority queue workers")
|
logger.info("limit_async: Shutting down priority queue workers")
|
||||||
|
|
||||||
# Set the shutdown event
|
|
||||||
shutdown_event.set()
|
shutdown_event.set()
|
||||||
|
|
||||||
# Cancel all active futures
|
# Cancel all active futures
|
||||||
|
|
@ -502,7 +650,14 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||||
if not future.done():
|
if not future.done():
|
||||||
future.cancel()
|
future.cancel()
|
||||||
|
|
||||||
# Wait for the queue to empty
|
# Cancel all pending tasks
|
||||||
|
async with task_states_lock:
|
||||||
|
for task_id, task_state in list(task_states.items()):
|
||||||
|
if not task_state.future.done():
|
||||||
|
task_state.future.cancel()
|
||||||
|
task_states.clear()
|
||||||
|
|
||||||
|
# Wait for queue to empty with timeout
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(queue.join(), timeout=5.0)
|
await asyncio.wait_for(queue.join(), timeout=5.0)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|
@ -510,7 +665,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||||
"limit_async: Timeout waiting for queue to empty during shutdown"
|
"limit_async: Timeout waiting for queue to empty during shutdown"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cancel all worker tasks
|
# Cancel worker tasks
|
||||||
for task in list(tasks):
|
for task in list(tasks):
|
||||||
if not task.done():
|
if not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
@ -519,7 +674,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
# Cancel the health check task
|
# Cancel health check task
|
||||||
if worker_health_check_task and not worker_health_check_task.done():
|
if worker_health_check_task and not worker_health_check_task.done():
|
||||||
worker_health_check_task.cancel()
|
worker_health_check_task.cancel()
|
||||||
try:
|
try:
|
||||||
|
|
@ -534,77 +689,113 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||||
*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs
|
*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute the function with priority-based concurrency control
|
Execute function with enhanced priority-based concurrency control and timeout handling
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: Positional arguments passed to the function
|
*args: Positional arguments passed to the function
|
||||||
_priority: Call priority (lower values have higher priority)
|
_priority: Call priority (lower values have higher priority)
|
||||||
_timeout: Maximum time to wait for function completion (in seconds)
|
_timeout: Maximum time to wait for completion (in seconds, none means determinded by max_execution_timeout of the queue)
|
||||||
_queue_timeout: Maximum time to wait for entering the queue (in seconds)
|
_queue_timeout: Maximum time to wait for entering the queue (in seconds)
|
||||||
**kwargs: Keyword arguments passed to the function
|
**kwargs: Keyword arguments passed to the function
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The result of the function call
|
The result of the function call
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TimeoutError: If the function call times out
|
TimeoutError: If the function call times out at any level
|
||||||
QueueFullError: If the queue is full and waiting times out
|
QueueFullError: If the queue is full and waiting times out
|
||||||
Any exception raised by the decorated function
|
Any exception raised by the decorated function
|
||||||
"""
|
"""
|
||||||
# Ensure worker system is initialized
|
|
||||||
await ensure_workers()
|
await ensure_workers()
|
||||||
|
|
||||||
# Create a future for the result
|
# Generate unique task ID
|
||||||
|
task_id = f"{id(asyncio.current_task())}_{asyncio.get_event_loop().time()}"
|
||||||
future = asyncio.Future()
|
future = asyncio.Future()
|
||||||
active_futures.add(future)
|
|
||||||
|
|
||||||
nonlocal counter
|
# Create task state
|
||||||
async with initialization_lock:
|
task_state = TaskState(
|
||||||
current_count = counter # Use local variable to avoid race conditions
|
future=future, start_time=asyncio.get_event_loop().time()
|
||||||
counter += 1
|
)
|
||||||
|
|
||||||
# Try to put the task into the queue, supporting timeout
|
|
||||||
try:
|
try:
|
||||||
if _queue_timeout is not None:
|
# Register task state
|
||||||
# Use timeout to wait for queue space
|
async with task_states_lock:
|
||||||
try:
|
task_states[task_id] = task_state
|
||||||
|
|
||||||
|
active_futures.add(future)
|
||||||
|
|
||||||
|
# Get counter for FIFO ordering
|
||||||
|
nonlocal counter
|
||||||
|
async with initialization_lock:
|
||||||
|
current_count = counter
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
# Queue the task with timeout handling
|
||||||
|
try:
|
||||||
|
if _queue_timeout is not None:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
# current_count is used to ensure FIFO order
|
queue.put(
|
||||||
queue.put((_priority, current_count, future, args, kwargs)),
|
(_priority, current_count, task_id, args, kwargs)
|
||||||
|
),
|
||||||
timeout=_queue_timeout,
|
timeout=_queue_timeout,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
else:
|
||||||
raise QueueFullError(
|
await queue.put(
|
||||||
f"Queue full, timeout after {_queue_timeout} seconds"
|
(_priority, current_count, task_id, args, kwargs)
|
||||||
)
|
)
|
||||||
else:
|
except asyncio.TimeoutError:
|
||||||
# No timeout, may wait indefinitely
|
raise QueueFullError(
|
||||||
# current_count is used to ensure FIFO order
|
f"Queue full, timeout after {_queue_timeout} seconds"
|
||||||
await queue.put((_priority, current_count, future, args, kwargs))
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Clean up the future
|
# Clean up on queue error
|
||||||
if not future.done():
|
if not future.done():
|
||||||
future.set_exception(e)
|
future.set_exception(e)
|
||||||
active_futures.discard(future)
|
raise
|
||||||
raise
|
|
||||||
|
|
||||||
try:
|
# Wait for result with timeout handling
|
||||||
# Wait for the result, optional timeout
|
try:
|
||||||
if _timeout is not None:
|
if _timeout is not None:
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(future, _timeout)
|
return await asyncio.wait_for(future, _timeout)
|
||||||
except asyncio.TimeoutError:
|
else:
|
||||||
# Cancel the future
|
return await future
|
||||||
if not future.done():
|
except asyncio.TimeoutError:
|
||||||
future.cancel()
|
# This is user-level timeout (asyncio.wait_for caused)
|
||||||
raise TimeoutError(
|
# Mark cancellation request
|
||||||
f"limit_async: Task timed out after {_timeout} seconds"
|
async with task_states_lock:
|
||||||
)
|
if task_id in task_states:
|
||||||
else:
|
task_states[task_id].cancellation_requested = True
|
||||||
# Wait for the result without timeout
|
|
||||||
return await future
|
|
||||||
finally:
|
|
||||||
# Clean up the future reference
|
|
||||||
active_futures.discard(future)
|
|
||||||
|
|
||||||
# Add the shutdown method to the decorated function
|
# Cancel future
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
|
||||||
|
# Wait for worker cleanup with timeout
|
||||||
|
cleanup_start = asyncio.get_event_loop().time()
|
||||||
|
while (
|
||||||
|
task_id in task_states
|
||||||
|
and asyncio.get_event_loop().time() - cleanup_start
|
||||||
|
< cleanup_timeout
|
||||||
|
):
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"limit_async: User timeout after {_timeout} seconds"
|
||||||
|
)
|
||||||
|
except WorkerTimeoutError as e:
|
||||||
|
# This is Worker-level timeout, directly propagate exception information
|
||||||
|
raise TimeoutError(f"limit_async: {str(e)}")
|
||||||
|
except HealthCheckTimeoutError as e:
|
||||||
|
# This is Health Check-level timeout, directly propagate exception information
|
||||||
|
raise TimeoutError(f"limit_async: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure cleanup
|
||||||
|
active_futures.discard(future)
|
||||||
|
async with task_states_lock:
|
||||||
|
task_states.pop(task_id, None)
|
||||||
|
|
||||||
|
# Add shutdown method to decorated function
|
||||||
wait_func.shutdown = shutdown
|
wait_func.shutdown = shutdown
|
||||||
|
|
||||||
return wait_func
|
return wait_func
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue