diff --git a/src/api/tasks.py b/src/api/tasks.py index de4bf505..92779d09 100644 --- a/src/api/tasks.py +++ b/src/api/tasks.py @@ -26,7 +26,7 @@ async def cancel_task(request: Request, task_service, session_manager): task_id = request.path_params.get("task_id") user = request.state.user - success = task_service.cancel_task(user.user_id, task_id) + success = await task_service.cancel_task(user.user_id, task_id) if not success: return JSONResponse( {"error": "Task not found or cannot be cancelled"}, status_code=400 diff --git a/src/services/task_service.py b/src/services/task_service.py index 3052ba4f..be5312a0 100644 --- a/src/services/task_service.py +++ b/src/services/task_service.py @@ -352,7 +352,7 @@ class TaskService: tasks.sort(key=lambda x: x["created_at"], reverse=True) return tasks - def cancel_task(self, user_id: str, task_id: str) -> bool: + async def cancel_task(self, user_id: str, task_id: str) -> bool: """Cancel a task if it exists and is not already completed. Supports cancellation of shared default tasks stored under the anonymous user. @@ -384,18 +384,28 @@ class TaskService: and not upload_task.background_task.done() ): upload_task.background_task.cancel() + # Wait for the background task to actually stop to avoid race conditions + try: + await upload_task.background_task + except asyncio.CancelledError: + pass # Expected when we cancel the task + except Exception: + pass # Ignore other errors during cancellation # Mark task as failed (cancelled) upload_task.status = TaskStatus.FAILED upload_task.updated_at = time.time() - # Mark all pending file tasks as failed + # Mark all pending and running file tasks as failed for file_task in upload_task.file_tasks.values(): - if file_task.status == TaskStatus.PENDING: + if file_task.status in [TaskStatus.PENDING, TaskStatus.RUNNING]: + # Increment failed_files counter for both pending and running + # (running files haven't been counted yet in either counter) + upload_task.failed_files += 1 + file_task.status = TaskStatus.FAILED file_task.error = "Task cancelled by user" file_task.updated_at = time.time() - upload_task.failed_files += 1 return True