diff --git a/src/api/connectors.py b/src/api/connectors.py index 20c71f08..96390f4c 100644 --- a/src/api/connectors.py +++ b/src/api/connectors.py @@ -2,38 +2,40 @@ from starlette.requests import Request from starlette.responses import JSONResponse async def connector_sync(request: Request, connector_service, session_manager): - """Sync files from a connector connection""" + """Sync files from all active connections of a connector type""" + connector_type = request.path_params.get("connector_type", "google_drive") data = await request.json() - connection_id = data.get("connection_id") max_files = data.get("max_files") - if not connection_id: - return JSONResponse({"error": "connection_id is required"}, status_code=400) - try: - print(f"[DEBUG] Starting connector sync for connection_id={connection_id}, max_files={max_files}") + print(f"[DEBUG] Starting connector sync for connector_type={connector_type}, max_files={max_files}") - # Verify user owns this connection user = request.state.user print(f"[DEBUG] User: {user.user_id}") - connection_config = await connector_service.connection_manager.get_connection(connection_id) - print(f"[DEBUG] Got connection config: {connection_config is not None}") + # Get all active connections for this connector type and user + connections = await connector_service.connection_manager.list_connections( + user_id=user.user_id, + connector_type=connector_type + ) - if not connection_config: - return JSONResponse({"error": "Connection not found"}, status_code=404) + active_connections = [conn for conn in connections if conn.is_active] + if not active_connections: + return JSONResponse({"error": f"No active {connector_type} connections found"}, status_code=404) - if connection_config.user_id != user.user_id: - return JSONResponse({"error": "Access denied"}, status_code=403) - - print(f"[DEBUG] About to call sync_connector_files") - task_id = await connector_service.sync_connector_files(connection_id, user.user_id, max_files) - print(f"[DEBUG] Got task_id: {task_id}") + # Start sync tasks for all active connections + task_ids = [] + for connection in active_connections: + print(f"[DEBUG] About to call sync_connector_files for connection {connection.connection_id}") + task_id = await connector_service.sync_connector_files(connection.connection_id, user.user_id, max_files) + task_ids.append(task_id) + print(f"[DEBUG] Got task_id: {task_id}") return JSONResponse({ - "task_id": task_id, + "task_ids": task_ids, "status": "sync_started", - "message": f"Started syncing files from connection {connection_id}" + "message": f"Started syncing files from {len(active_connections)} {connector_type} connection(s)", + "connections_synced": len(active_connections) }, status_code=201 ) diff --git a/src/api/tasks.py b/src/api/tasks.py new file mode 100644 index 00000000..0d07837b --- /dev/null +++ b/src/api/tasks.py @@ -0,0 +1,30 @@ +from starlette.requests import Request +from starlette.responses import JSONResponse + +async def task_status(request: Request, task_service, session_manager): + """Get the status of a specific task""" + task_id = request.path_params.get("task_id") + user = request.state.user + + task_status_result = task_service.get_task_status(user.user_id, task_id) + if not task_status_result: + return JSONResponse({"error": "Task not found"}, status_code=404) + + return JSONResponse(task_status_result) + +async def all_tasks(request: Request, task_service, session_manager): + """Get all tasks for the authenticated user""" + user = request.state.user + tasks = task_service.get_all_tasks(user.user_id) + return JSONResponse({"tasks": tasks}) + +async def cancel_task(request: Request, task_service, session_manager): + """Cancel a task""" + task_id = request.path_params.get("task_id") + user = request.state.user + + success = task_service.cancel_task(user.user_id, task_id) + if not success: + return JSONResponse({"error": "Task not found or cannot be cancelled"}, status_code=400) + + return JSONResponse({"status": "cancelled", "task_id": task_id}) \ No newline at end of file diff --git a/src/main.py b/src/main.py index 10a8059a..2cb65679 100644 --- a/src/main.py +++ b/src/main.py @@ -145,6 +145,13 @@ def create_app(): session_manager=services['session_manager']) ), methods=["GET"]), + Route("/tasks/{task_id}/cancel", + require_auth(services['session_manager'])( + partial(tasks.cancel_task, + task_service=services['task_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + # Search endpoint Route("/search", require_auth(services['session_manager'])( @@ -197,14 +204,14 @@ def create_app(): ), methods=["POST"]), # Connector endpoints - Route("/connectors/sync", + Route("/connectors/{connector_type}/sync", require_auth(services['session_manager'])( partial(connectors.connector_sync, connector_service=services['connector_service'], session_manager=services['session_manager']) ), methods=["POST"]), - Route("/connectors/status/{connector_type}", + Route("/connectors/{connector_type}/status", require_auth(services['session_manager'])( partial(connectors.connector_status, connector_service=services['connector_service'], diff --git a/src/services/task_service.py b/src/services/task_service.py index 715bb51d..960d8ff2 100644 --- a/src/services/task_service.py +++ b/src/services/task_service.py @@ -52,6 +52,9 @@ class TaskService: self.background_tasks.add(background_task) background_task.add_done_callback(self.background_tasks.discard) + # Store reference to background task for cancellation + upload_task.background_task = background_task + return task_id async def background_upload_processor(self, user_id: str, task_id: str) -> None: @@ -128,6 +131,12 @@ class TaskService: upload_task.status = TaskStatus.COMPLETED upload_task.updated_at = time.time() + except asyncio.CancelledError: + print(f"[INFO] Background processor for task {task_id} was cancelled") + if user_id in self.task_store and task_id in self.task_store[user_id]: + # Task status and pending files already handled by cancel_task() + pass + raise # Re-raise to properly handle cancellation except Exception as e: print(f"[ERROR] Background custom processor failed for task {task_id}: {e}") import traceback @@ -190,6 +199,36 @@ class TaskService: tasks.sort(key=lambda x: x["created_at"], reverse=True) return tasks + def cancel_task(self, user_id: str, task_id: str) -> bool: + """Cancel a task if it exists and is not already completed""" + if (user_id not in self.task_store or + task_id not in self.task_store[user_id]): + return False + + upload_task = self.task_store[user_id][task_id] + + # Can only cancel pending or running tasks + if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]: + return False + + # Cancel the background task to stop scheduling new work + if hasattr(upload_task, 'background_task') and not upload_task.background_task.done(): + upload_task.background_task.cancel() + + # Mark task as failed (cancelled) + upload_task.status = TaskStatus.FAILED + upload_task.updated_at = time.time() + + # Mark all pending file tasks as failed + for file_task in upload_task.file_tasks.values(): + if file_task.status == TaskStatus.PENDING: + file_task.status = TaskStatus.FAILED + file_task.error = "Task cancelled by user" + file_task.updated_at = time.time() + upload_task.failed_files += 1 + + return True + def shutdown(self): """Cleanup process pool""" if hasattr(self, 'process_pool'):