cancel task endpoint

This commit is contained in:
phact 2025-07-30 22:42:01 -04:00
parent c1eb61a973
commit a282f2a9f8
4 changed files with 99 additions and 21 deletions

View file

@ -2,38 +2,40 @@ from starlette.requests import Request
from starlette.responses import JSONResponse
async def connector_sync(request: Request, connector_service, session_manager):
"""Sync files from a connector connection"""
"""Sync files from all active connections of a connector type"""
connector_type = request.path_params.get("connector_type", "google_drive")
data = await request.json()
connection_id = data.get("connection_id")
max_files = data.get("max_files")
if not connection_id:
return JSONResponse({"error": "connection_id is required"}, status_code=400)
try:
print(f"[DEBUG] Starting connector sync for connection_id={connection_id}, max_files={max_files}")
print(f"[DEBUG] Starting connector sync for connector_type={connector_type}, max_files={max_files}")
# Verify user owns this connection
user = request.state.user
print(f"[DEBUG] User: {user.user_id}")
connection_config = await connector_service.connection_manager.get_connection(connection_id)
print(f"[DEBUG] Got connection config: {connection_config is not None}")
# Get all active connections for this connector type and user
connections = await connector_service.connection_manager.list_connections(
user_id=user.user_id,
connector_type=connector_type
)
if not connection_config:
return JSONResponse({"error": "Connection not found"}, status_code=404)
active_connections = [conn for conn in connections if conn.is_active]
if not active_connections:
return JSONResponse({"error": f"No active {connector_type} connections found"}, status_code=404)
if connection_config.user_id != user.user_id:
return JSONResponse({"error": "Access denied"}, status_code=403)
print(f"[DEBUG] About to call sync_connector_files")
task_id = await connector_service.sync_connector_files(connection_id, user.user_id, max_files)
print(f"[DEBUG] Got task_id: {task_id}")
# Start sync tasks for all active connections
task_ids = []
for connection in active_connections:
print(f"[DEBUG] About to call sync_connector_files for connection {connection.connection_id}")
task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files)
task_ids.append(task_id)
print(f"[DEBUG] Got task_id: {task_id}")
return JSONResponse({
"task_id": task_id,
"task_ids": task_ids,
"status": "sync_started",
"message": f"Started syncing files from connection {connection_id}"
"message": f"Started syncing files from {len(active_connections)} {connector_type} connection(s)",
"connections_synced": len(active_connections)
},
status_code=201
)

30
src/api/tasks.py Normal file
View file

@ -0,0 +1,30 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
async def task_status(request: Request, task_service, session_manager):
"""Get the status of a specific task"""
task_id = request.path_params.get("task_id")
user = request.state.user
task_status_result = task_service.get_task_status(user.user_id, task_id)
if not task_status_result:
return JSONResponse({"error": "Task not found"}, status_code=404)
return JSONResponse(task_status_result)
async def all_tasks(request: Request, task_service, session_manager):
"""Get all tasks for the authenticated user"""
user = request.state.user
tasks = task_service.get_all_tasks(user.user_id)
return JSONResponse({"tasks": tasks})
async def cancel_task(request: Request, task_service, session_manager):
"""Cancel a task"""
task_id = request.path_params.get("task_id")
user = request.state.user
success = task_service.cancel_task(user.user_id, task_id)
if not success:
return JSONResponse({"error": "Task not found or cannot be cancelled"}, status_code=400)
return JSONResponse({"status": "cancelled", "task_id": task_id})

View file

@ -145,6 +145,13 @@ def create_app():
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/tasks/{task_id}/cancel",
require_auth(services['session_manager'])(
partial(tasks.cancel_task,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Search endpoint
Route("/search",
require_auth(services['session_manager'])(
@ -197,14 +204,14 @@ def create_app():
), methods=["POST"]),
# Connector endpoints
Route("/connectors/sync",
Route("/connectors/{connector_type}/sync",
require_auth(services['session_manager'])(
partial(connectors.connector_sync,
connector_service=services['connector_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/connectors/status/{connector_type}",
Route("/connectors/{connector_type}/status",
require_auth(services['session_manager'])(
partial(connectors.connector_status,
connector_service=services['connector_service'],

View file

@ -52,6 +52,9 @@ class TaskService:
self.background_tasks.add(background_task)
background_task.add_done_callback(self.background_tasks.discard)
# Store reference to background task for cancellation
upload_task.background_task = background_task
return task_id
async def background_upload_processor(self, user_id: str, task_id: str) -> None:
@ -128,6 +131,12 @@ class TaskService:
upload_task.status = TaskStatus.COMPLETED
upload_task.updated_at = time.time()
except asyncio.CancelledError:
print(f"[INFO] Background processor for task {task_id} was cancelled")
if user_id in self.task_store and task_id in self.task_store[user_id]:
# Task status and pending files already handled by cancel_task()
pass
raise # Re-raise to properly handle cancellation
except Exception as e:
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}")
import traceback
@ -190,6 +199,36 @@ class TaskService:
tasks.sort(key=lambda x: x["created_at"], reverse=True)
return tasks
def cancel_task(self, user_id: str, task_id: str) -> bool:
"""Cancel a task if it exists and is not already completed"""
if (user_id not in self.task_store or
task_id not in self.task_store[user_id]):
return False
upload_task = self.task_store[user_id][task_id]
# Can only cancel pending or running tasks
if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
return False
# Cancel the background task to stop scheduling new work
if hasattr(upload_task, 'background_task') and not upload_task.background_task.done():
upload_task.background_task.cancel()
# Mark task as failed (cancelled)
upload_task.status = TaskStatus.FAILED
upload_task.updated_at = time.time()
# Mark all pending file tasks as failed
for file_task in upload_task.file_tasks.values():
if file_task.status == TaskStatus.PENDING:
file_task.status = TaskStatus.FAILED
file_task.error = "Task cancelled by user"
file_task.updated_at = time.time()
upload_task.failed_files += 1
return True
def shutdown(self):
"""Cleanup process pool"""
if hasattr(self, 'process_pool'):