cancel task endpoint
This commit is contained in:
parent
c1eb61a973
commit
a282f2a9f8
4 changed files with 99 additions and 21 deletions
|
|
@ -2,38 +2,40 @@ from starlette.requests import Request
|
|||
from starlette.responses import JSONResponse
|
||||
|
||||
async def connector_sync(request: Request, connector_service, session_manager):
|
||||
"""Sync files from a connector connection"""
|
||||
"""Sync files from all active connections of a connector type"""
|
||||
connector_type = request.path_params.get("connector_type", "google_drive")
|
||||
data = await request.json()
|
||||
connection_id = data.get("connection_id")
|
||||
max_files = data.get("max_files")
|
||||
|
||||
if not connection_id:
|
||||
return JSONResponse({"error": "connection_id is required"}, status_code=400)
|
||||
|
||||
try:
|
||||
print(f"[DEBUG] Starting connector sync for connection_id={connection_id}, max_files={max_files}")
|
||||
print(f"[DEBUG] Starting connector sync for connector_type={connector_type}, max_files={max_files}")
|
||||
|
||||
# Verify user owns this connection
|
||||
user = request.state.user
|
||||
print(f"[DEBUG] User: {user.user_id}")
|
||||
|
||||
connection_config = await connector_service.connection_manager.get_connection(connection_id)
|
||||
print(f"[DEBUG] Got connection config: {connection_config is not None}")
|
||||
# Get all active connections for this connector type and user
|
||||
connections = await connector_service.connection_manager.list_connections(
|
||||
user_id=user.user_id,
|
||||
connector_type=connector_type
|
||||
)
|
||||
|
||||
if not connection_config:
|
||||
return JSONResponse({"error": "Connection not found"}, status_code=404)
|
||||
active_connections = [conn for conn in connections if conn.is_active]
|
||||
if not active_connections:
|
||||
return JSONResponse({"error": f"No active {connector_type} connections found"}, status_code=404)
|
||||
|
||||
if connection_config.user_id != user.user_id:
|
||||
return JSONResponse({"error": "Access denied"}, status_code=403)
|
||||
|
||||
print(f"[DEBUG] About to call sync_connector_files")
|
||||
task_id = await connector_service.sync_connector_files(connection_id, user.user_id, max_files)
|
||||
print(f"[DEBUG] Got task_id: {task_id}")
|
||||
# Start sync tasks for all active connections
|
||||
task_ids = []
|
||||
for connection in active_connections:
|
||||
print(f"[DEBUG] About to call sync_connector_files for connection {connection.connection_id}")
|
||||
task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files)
|
||||
task_ids.append(task_id)
|
||||
print(f"[DEBUG] Got task_id: {task_id}")
|
||||
|
||||
return JSONResponse({
|
||||
"task_id": task_id,
|
||||
"task_ids": task_ids,
|
||||
"status": "sync_started",
|
||||
"message": f"Started syncing files from connection {connection_id}"
|
||||
"message": f"Started syncing files from {len(active_connections)} {connector_type} connection(s)",
|
||||
"connections_synced": len(active_connections)
|
||||
},
|
||||
status_code=201
|
||||
)
|
||||
|
|
|
|||
30
src/api/tasks.py
Normal file
30
src/api/tasks.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
async def task_status(request: Request, task_service, session_manager):
|
||||
"""Get the status of a specific task"""
|
||||
task_id = request.path_params.get("task_id")
|
||||
user = request.state.user
|
||||
|
||||
task_status_result = task_service.get_task_status(user.user_id, task_id)
|
||||
if not task_status_result:
|
||||
return JSONResponse({"error": "Task not found"}, status_code=404)
|
||||
|
||||
return JSONResponse(task_status_result)
|
||||
|
||||
async def all_tasks(request: Request, task_service, session_manager):
|
||||
"""Get all tasks for the authenticated user"""
|
||||
user = request.state.user
|
||||
tasks = task_service.get_all_tasks(user.user_id)
|
||||
return JSONResponse({"tasks": tasks})
|
||||
|
||||
async def cancel_task(request: Request, task_service, session_manager):
|
||||
"""Cancel a task"""
|
||||
task_id = request.path_params.get("task_id")
|
||||
user = request.state.user
|
||||
|
||||
success = task_service.cancel_task(user.user_id, task_id)
|
||||
if not success:
|
||||
return JSONResponse({"error": "Task not found or cannot be cancelled"}, status_code=400)
|
||||
|
||||
return JSONResponse({"status": "cancelled", "task_id": task_id})
|
||||
11
src/main.py
11
src/main.py
|
|
@ -145,6 +145,13 @@ def create_app():
|
|||
session_manager=services['session_manager'])
|
||||
), methods=["GET"]),
|
||||
|
||||
Route("/tasks/{task_id}/cancel",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(tasks.cancel_task,
|
||||
task_service=services['task_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
# Search endpoint
|
||||
Route("/search",
|
||||
require_auth(services['session_manager'])(
|
||||
|
|
@ -197,14 +204,14 @@ def create_app():
|
|||
), methods=["POST"]),
|
||||
|
||||
# Connector endpoints
|
||||
Route("/connectors/sync",
|
||||
Route("/connectors/{connector_type}/sync",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(connectors.connector_sync,
|
||||
connector_service=services['connector_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/connectors/status/{connector_type}",
|
||||
Route("/connectors/{connector_type}/status",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(connectors.connector_status,
|
||||
connector_service=services['connector_service'],
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ class TaskService:
|
|||
self.background_tasks.add(background_task)
|
||||
background_task.add_done_callback(self.background_tasks.discard)
|
||||
|
||||
# Store reference to background task for cancellation
|
||||
upload_task.background_task = background_task
|
||||
|
||||
return task_id
|
||||
|
||||
async def background_upload_processor(self, user_id: str, task_id: str) -> None:
|
||||
|
|
@ -128,6 +131,12 @@ class TaskService:
|
|||
upload_task.status = TaskStatus.COMPLETED
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print(f"[INFO] Background processor for task {task_id} was cancelled")
|
||||
if user_id in self.task_store and task_id in self.task_store[user_id]:
|
||||
# Task status and pending files already handled by cancel_task()
|
||||
pass
|
||||
raise # Re-raise to properly handle cancellation
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}")
|
||||
import traceback
|
||||
|
|
@ -190,6 +199,36 @@ class TaskService:
|
|||
tasks.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
return tasks
|
||||
|
||||
def cancel_task(self, user_id: str, task_id: str) -> bool:
|
||||
"""Cancel a task if it exists and is not already completed"""
|
||||
if (user_id not in self.task_store or
|
||||
task_id not in self.task_store[user_id]):
|
||||
return False
|
||||
|
||||
upload_task = self.task_store[user_id][task_id]
|
||||
|
||||
# Can only cancel pending or running tasks
|
||||
if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
||||
return False
|
||||
|
||||
# Cancel the background task to stop scheduling new work
|
||||
if hasattr(upload_task, 'background_task') and not upload_task.background_task.done():
|
||||
upload_task.background_task.cancel()
|
||||
|
||||
# Mark task as failed (cancelled)
|
||||
upload_task.status = TaskStatus.FAILED
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
# Mark all pending file tasks as failed
|
||||
for file_task in upload_task.file_tasks.values():
|
||||
if file_task.status == TaskStatus.PENDING:
|
||||
file_task.status = TaskStatus.FAILED
|
||||
file_task.error = "Task cancelled by user"
|
||||
file_task.updated_at = time.time()
|
||||
upload_task.failed_files += 1
|
||||
|
||||
return True
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup process pool"""
|
||||
if hasattr(self, 'process_pool'):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue