Refactor TaskService for improved readability and maintainability
This commit enhances the TaskService class by reorganizing import statements, updating type hints to use `dict` instead of `Dict`, and improving the formatting of method definitions for better clarity. Additionally, minor adjustments were made to comments and error handling, contributing to a more robust and well-documented codebase.
This commit is contained in:
parent
4be48270b7
commit
531ca7cd49
1 changed files with 79 additions and 74 deletions
|
|
@ -1,34 +1,43 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
import random
|
import random
|
||||||
from typing import Dict
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
from models.tasks import TaskStatus, UploadTask, FileTask
|
from models.tasks import FileTask, TaskStatus, UploadTask
|
||||||
|
from utils.gpu_detection import get_worker_count
|
||||||
from src.utils.gpu_detection import get_worker_count
|
|
||||||
|
|
||||||
|
|
||||||
class TaskService:
|
class TaskService:
|
||||||
def __init__(self, document_service=None, process_pool=None):
|
def __init__(self, document_service=None, process_pool=None):
|
||||||
self.document_service = document_service
|
self.document_service = document_service
|
||||||
self.process_pool = process_pool
|
self.process_pool = process_pool
|
||||||
self.task_store: Dict[str, Dict[str, UploadTask]] = {} # user_id -> {task_id -> UploadTask}
|
self.task_store: dict[str, dict[str, UploadTask]] = {} # user_id -> {task_id -> UploadTask}
|
||||||
self.background_tasks = set()
|
self.background_tasks = set()
|
||||||
|
|
||||||
if self.process_pool is None:
|
if self.process_pool is None:
|
||||||
raise ValueError("TaskService requires a process_pool parameter")
|
raise ValueError("TaskService requires a process_pool parameter")
|
||||||
|
|
||||||
async def exponential_backoff_delay(self, retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None:
|
async def exponential_backoff_delay(
|
||||||
|
self, retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0
|
||||||
|
) -> None:
|
||||||
"""Apply exponential backoff with jitter"""
|
"""Apply exponential backoff with jitter"""
|
||||||
delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay)
|
delay = min(base_delay * (2**retry_count) + random.uniform(0, 1), max_delay)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
async def create_upload_task(self, user_id: str, file_paths: list, jwt_token: str = None, owner_name: str = None, owner_email: str = None) -> str:
|
async def create_upload_task(
|
||||||
|
self, user_id: str, file_paths: list, jwt_token: str = None, owner_name: str = None, owner_email: str = None
|
||||||
|
) -> str:
|
||||||
"""Create a new upload task for bulk file processing"""
|
"""Create a new upload task for bulk file processing"""
|
||||||
# Use default DocumentFileProcessor with user context
|
# Use default DocumentFileProcessor with user context
|
||||||
from models.processors import DocumentFileProcessor
|
from models.processors import DocumentFileProcessor
|
||||||
processor = DocumentFileProcessor(self.document_service, owner_user_id=user_id, jwt_token=jwt_token, owner_name=owner_name, owner_email=owner_email)
|
|
||||||
|
processor = DocumentFileProcessor(
|
||||||
|
self.document_service,
|
||||||
|
owner_user_id=user_id,
|
||||||
|
jwt_token=jwt_token,
|
||||||
|
owner_name=owner_name,
|
||||||
|
owner_email=owner_email,
|
||||||
|
)
|
||||||
return await self.create_custom_task(user_id, file_paths, processor)
|
return await self.create_custom_task(user_id, file_paths, processor)
|
||||||
|
|
||||||
async def create_custom_task(self, user_id: str, items: list, processor) -> str:
|
async def create_custom_task(self, user_id: str, items: list, processor) -> str:
|
||||||
|
|
@ -37,7 +46,7 @@ class TaskService:
|
||||||
upload_task = UploadTask(
|
upload_task = UploadTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
total_files=len(items),
|
total_files=len(items),
|
||||||
file_tasks={str(item): FileTask(file_path=str(item)) for item in items}
|
file_tasks={str(item): FileTask(file_path=str(item)) for item in items},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attach the custom processor to the task
|
# Attach the custom processor to the task
|
||||||
|
|
@ -72,16 +81,14 @@ class TaskService:
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
await self.document_service.process_single_file_task(upload_task, file_path)
|
await self.document_service.process_single_file_task(upload_task, file_path)
|
||||||
|
|
||||||
tasks = [
|
tasks = [process_with_semaphore(file_path) for file_path in upload_task.file_tasks.keys()]
|
||||||
process_with_semaphore(file_path)
|
|
||||||
for file_path in upload_task.file_tasks.keys()
|
|
||||||
]
|
|
||||||
|
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
|
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if user_id in self.task_store and task_id in self.task_store[user_id]:
|
if user_id in self.task_store and task_id in self.task_store[user_id]:
|
||||||
self.task_store[user_id][task_id].status = TaskStatus.FAILED
|
self.task_store[user_id][task_id].status = TaskStatus.FAILED
|
||||||
|
|
@ -111,6 +118,7 @@ class TaskService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Failed to process item {item}: {e}")
|
print(f"[ERROR] Failed to process item {item}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
file_task.status = TaskStatus.FAILED
|
file_task.status = TaskStatus.FAILED
|
||||||
file_task.error = str(e)
|
file_task.error = str(e)
|
||||||
|
|
@ -120,10 +128,7 @@ class TaskService:
|
||||||
upload_task.processed_files += 1
|
upload_task.processed_files += 1
|
||||||
upload_task.updated_at = time.time()
|
upload_task.updated_at = time.time()
|
||||||
|
|
||||||
tasks = [
|
tasks = [process_with_semaphore(item, str(item)) for item in items]
|
||||||
process_with_semaphore(item, str(item))
|
|
||||||
for item in items
|
|
||||||
]
|
|
||||||
|
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
@ -140,6 +145,7 @@ class TaskService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}")
|
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if user_id in self.task_store and task_id in self.task_store[user_id]:
|
if user_id in self.task_store and task_id in self.task_store[user_id]:
|
||||||
self.task_store[user_id][task_id].status = TaskStatus.FAILED
|
self.task_store[user_id][task_id].status = TaskStatus.FAILED
|
||||||
|
|
@ -147,9 +153,7 @@ class TaskService:
|
||||||
|
|
||||||
def get_task_status(self, user_id: str, task_id: str) -> dict:
|
def get_task_status(self, user_id: str, task_id: str) -> dict:
|
||||||
"""Get the status of a specific upload task"""
|
"""Get the status of a specific upload task"""
|
||||||
if (not task_id or
|
if not task_id or user_id not in self.task_store or task_id not in self.task_store[user_id]:
|
||||||
user_id not in self.task_store or
|
|
||||||
task_id not in self.task_store[user_id]):
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
upload_task = self.task_store[user_id][task_id]
|
upload_task = self.task_store[user_id][task_id]
|
||||||
|
|
@ -162,7 +166,7 @@ class TaskService:
|
||||||
"error": file_task.error,
|
"error": file_task.error,
|
||||||
"retry_count": file_task.retry_count,
|
"retry_count": file_task.retry_count,
|
||||||
"created_at": file_task.created_at,
|
"created_at": file_task.created_at,
|
||||||
"updated_at": file_task.updated_at
|
"updated_at": file_task.updated_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -174,7 +178,7 @@ class TaskService:
|
||||||
"failed_files": upload_task.failed_files,
|
"failed_files": upload_task.failed_files,
|
||||||
"created_at": upload_task.created_at,
|
"created_at": upload_task.created_at,
|
||||||
"updated_at": upload_task.updated_at,
|
"updated_at": upload_task.updated_at,
|
||||||
"files": file_statuses
|
"files": file_statuses,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_all_tasks(self, user_id: str) -> list:
|
def get_all_tasks(self, user_id: str) -> list:
|
||||||
|
|
@ -184,16 +188,18 @@ class TaskService:
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
for task_id, upload_task in self.task_store[user_id].items():
|
for task_id, upload_task in self.task_store[user_id].items():
|
||||||
tasks.append({
|
tasks.append(
|
||||||
"task_id": upload_task.task_id,
|
{
|
||||||
"status": upload_task.status.value,
|
"task_id": upload_task.task_id,
|
||||||
"total_files": upload_task.total_files,
|
"status": upload_task.status.value,
|
||||||
"processed_files": upload_task.processed_files,
|
"total_files": upload_task.total_files,
|
||||||
"successful_files": upload_task.successful_files,
|
"processed_files": upload_task.processed_files,
|
||||||
"failed_files": upload_task.failed_files,
|
"successful_files": upload_task.successful_files,
|
||||||
"created_at": upload_task.created_at,
|
"failed_files": upload_task.failed_files,
|
||||||
"updated_at": upload_task.updated_at
|
"created_at": upload_task.created_at,
|
||||||
})
|
"updated_at": upload_task.updated_at,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Sort by creation time, most recent first
|
# Sort by creation time, most recent first
|
||||||
tasks.sort(key=lambda x: x["created_at"], reverse=True)
|
tasks.sort(key=lambda x: x["created_at"], reverse=True)
|
||||||
|
|
@ -201,8 +207,7 @@ class TaskService:
|
||||||
|
|
||||||
def cancel_task(self, user_id: str, task_id: str) -> bool:
|
def cancel_task(self, user_id: str, task_id: str) -> bool:
|
||||||
"""Cancel a task if it exists and is not already completed"""
|
"""Cancel a task if it exists and is not already completed"""
|
||||||
if (user_id not in self.task_store or
|
if user_id not in self.task_store or task_id not in self.task_store[user_id]:
|
||||||
task_id not in self.task_store[user_id]):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
upload_task = self.task_store[user_id][task_id]
|
upload_task = self.task_store[user_id][task_id]
|
||||||
|
|
@ -212,7 +217,7 @@ class TaskService:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Cancel the background task to stop scheduling new work
|
# Cancel the background task to stop scheduling new work
|
||||||
if hasattr(upload_task, 'background_task') and not upload_task.background_task.done():
|
if hasattr(upload_task, "background_task") and not upload_task.background_task.done():
|
||||||
upload_task.background_task.cancel()
|
upload_task.background_task.cancel()
|
||||||
|
|
||||||
# Mark task as failed (cancelled)
|
# Mark task as failed (cancelled)
|
||||||
|
|
@ -231,5 +236,5 @@ class TaskService:
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Cleanup process pool"""
|
"""Cleanup process pool"""
|
||||||
if hasattr(self, 'process_pool'):
|
if hasattr(self, "process_pool"):
|
||||||
self.process_pool.shutdown(wait=True)
|
self.process_pool.shutdown(wait=True)
|
||||||
Loading…
Add table
Reference in a new issue