tasks work post refactor
This commit is contained in:
parent
13e4b971f1
commit
4d8748ec75
8 changed files with 188 additions and 119 deletions
|
|
@ -96,9 +96,6 @@ class AppClients:
|
|||
# Initialize patched OpenAI client
|
||||
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
|
||||
|
||||
# Initialize Docling converter
|
||||
self.converter = DocumentConverter()
|
||||
|
||||
return self
|
||||
|
||||
# Global clients instance
|
||||
|
|
|
|||
|
|
@ -11,12 +11,13 @@ from .connection_manager import ConnectionManager
|
|||
class ConnectorService:
|
||||
"""Service to manage document connectors and process files"""
|
||||
|
||||
def __init__(self, opensearch_client, patched_async_client, process_pool, embed_model: str, index_name: str):
|
||||
def __init__(self, opensearch_client, patched_async_client, process_pool, embed_model: str, index_name: str, task_service=None):
|
||||
self.opensearch = opensearch_client
|
||||
self.openai_client = patched_async_client
|
||||
self.process_pool = process_pool
|
||||
self.embed_model = embed_model
|
||||
self.index_name = index_name
|
||||
self.task_service = task_service
|
||||
self.connection_manager = ConnectionManager()
|
||||
|
||||
async def initialize(self):
|
||||
|
|
@ -113,6 +114,9 @@ class ConnectorService:
|
|||
|
||||
async def sync_connector_files(self, connection_id: str, user_id: str, max_files: int = None) -> str:
|
||||
"""Sync files from a connector connection using existing task tracking system"""
|
||||
if not self.task_service:
|
||||
raise ValueError("TaskService not available - connector sync requires task service dependency")
|
||||
|
||||
print(f"[DEBUG] Starting sync for connection {connection_id}, max_files={max_files}")
|
||||
|
||||
connector = await self.get_connector(connection_id)
|
||||
|
|
@ -155,104 +159,14 @@ class ConnectorService:
|
|||
if not files_to_process:
|
||||
raise ValueError("No files found to sync")
|
||||
|
||||
# Create upload task using existing task system
|
||||
import uuid
|
||||
from app import UploadTask, FileTask, TaskStatus, task_store, background_upload_processor
|
||||
# Create custom processor for connector files
|
||||
from models.processors import ConnectorFileProcessor
|
||||
processor = ConnectorFileProcessor(self, connection_id, files_to_process)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
upload_task = UploadTask(
|
||||
task_id=task_id,
|
||||
total_files=len(files_to_process),
|
||||
file_tasks={f"connector_file_{file_info['id']}": FileTask(file_path=f"connector_file_{file_info['id']}") for file_info in files_to_process}
|
||||
)
|
||||
# Use file IDs as items (no more fake file paths!)
|
||||
file_ids = [file_info['id'] for file_info in files_to_process]
|
||||
|
||||
# Store task for user
|
||||
if user_id not in task_store:
|
||||
task_store[user_id] = {}
|
||||
task_store[user_id][task_id] = upload_task
|
||||
# Create custom task using TaskService
|
||||
task_id = await self.task_service.create_custom_task(user_id, file_ids, processor)
|
||||
|
||||
# Start background processing with connector-specific logic
|
||||
import asyncio
|
||||
from app import background_tasks
|
||||
background_task = asyncio.create_task(self._background_connector_sync(user_id, task_id, connection_id, files_to_process))
|
||||
background_tasks.add(background_task)
|
||||
background_task.add_done_callback(background_tasks.discard)
|
||||
|
||||
return task_id
|
||||
|
||||
async def _background_connector_sync(self, user_id: str, task_id: str, connection_id: str, files_to_process: List[Dict]):
|
||||
"""Background task to sync connector files"""
|
||||
from app import task_store, TaskStatus
|
||||
import datetime
|
||||
|
||||
try:
|
||||
upload_task = task_store[user_id][task_id]
|
||||
upload_task.status = TaskStatus.RUNNING
|
||||
upload_task.updated_at = datetime.datetime.now().timestamp()
|
||||
|
||||
connector = await self.get_connector(connection_id)
|
||||
if not connector:
|
||||
raise ValueError(f"Connection '{connection_id}' not found")
|
||||
|
||||
# Process files with limited concurrency
|
||||
semaphore = asyncio.Semaphore(4) # Limit concurrent file processing
|
||||
|
||||
async def process_connector_file(file_info):
|
||||
async with semaphore:
|
||||
file_key = f"connector_file_{file_info['id']}"
|
||||
file_task = upload_task.file_tasks[file_key]
|
||||
file_task.status = TaskStatus.RUNNING
|
||||
file_task.updated_at = datetime.datetime.now().timestamp()
|
||||
|
||||
try:
|
||||
# Get file content from connector
|
||||
document = await connector.get_file_content(file_info['id'])
|
||||
|
||||
# Process using existing pipeline
|
||||
result = await self.process_connector_document(document, user_id)
|
||||
|
||||
file_task.status = TaskStatus.COMPLETED
|
||||
file_task.result = result
|
||||
upload_task.successful_files += 1
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
error_msg = f"[ERROR] Failed to process connector file {file_info['id']}: {e}"
|
||||
print(error_msg, file=sys.stderr, flush=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
# Also store full traceback in task error
|
||||
full_error = f"{str(e)}\n{traceback.format_exc()}"
|
||||
file_task.status = TaskStatus.FAILED
|
||||
file_task.error = full_error
|
||||
upload_task.failed_files += 1
|
||||
finally:
|
||||
file_task.updated_at = datetime.datetime.now().timestamp()
|
||||
upload_task.processed_files += 1
|
||||
upload_task.updated_at = datetime.datetime.now().timestamp()
|
||||
|
||||
# Process all files concurrently
|
||||
tasks = [process_connector_file(file_info) for file_info in files_to_process]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Update connection last sync time
|
||||
await self.connection_manager.update_last_sync(connection_id)
|
||||
|
||||
upload_task.status = TaskStatus.COMPLETED
|
||||
upload_task.updated_at = datetime.datetime.now().timestamp()
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
error_msg = f"[ERROR] Background connector sync failed for task {task_id}: {e}"
|
||||
print(error_msg, file=sys.stderr, flush=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
if user_id in task_store and task_id in task_store[user_id]:
|
||||
task_store[user_id][task_id].status = TaskStatus.FAILED
|
||||
task_store[user_id][task_id].updated_at = datetime.datetime.now().timestamp()
|
||||
return task_id
|
||||
19
src/main.py
19
src/main.py
|
|
@ -1,10 +1,18 @@
|
|||
import asyncio
|
||||
import atexit
|
||||
import torch
|
||||
import multiprocessing
|
||||
from functools import partial
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
|
||||
# Set multiprocessing start method to 'spawn' for CUDA compatibility
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
# Create process pool FIRST, before any torch/CUDA imports
|
||||
from utils.process_pool import process_pool
|
||||
|
||||
import torch
|
||||
|
||||
# Configuration and setup
|
||||
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
|
||||
from utils.gpu_detection import detect_gpu_devices
|
||||
|
|
@ -65,19 +73,20 @@ def initialize_services():
|
|||
# Initialize services
|
||||
document_service = DocumentService()
|
||||
search_service = SearchService()
|
||||
task_service = TaskService(document_service)
|
||||
task_service = TaskService(document_service, process_pool)
|
||||
chat_service = ChatService()
|
||||
|
||||
# Set process pool for document service
|
||||
document_service.process_pool = task_service.process_pool
|
||||
document_service.process_pool = process_pool
|
||||
|
||||
# Initialize connector service
|
||||
connector_service = ConnectorService(
|
||||
opensearch_client=clients.opensearch,
|
||||
patched_async_client=clients.patched_async_client,
|
||||
process_pool=task_service.process_pool,
|
||||
process_pool=process_pool,
|
||||
embed_model="text-embedding-3-small",
|
||||
index_name=INDEX_NAME
|
||||
index_name=INDEX_NAME,
|
||||
task_service=task_service
|
||||
)
|
||||
|
||||
# Initialize auth service
|
||||
|
|
|
|||
78
src/models/processors.py
Normal file
78
src/models/processors.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from .tasks import UploadTask, FileTask
|
||||
|
||||
|
||||
class TaskProcessor(ABC):
|
||||
"""Abstract base class for task processors"""
|
||||
|
||||
@abstractmethod
|
||||
async def process_item(self, upload_task: UploadTask, item: Any, file_task: FileTask) -> None:
|
||||
"""
|
||||
Process a single item in the task.
|
||||
|
||||
Args:
|
||||
upload_task: The overall upload task
|
||||
item: The item to process (could be file path, file info, etc.)
|
||||
file_task: The specific file task to update
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DocumentFileProcessor(TaskProcessor):
|
||||
"""Default processor for regular file uploads"""
|
||||
|
||||
def __init__(self, document_service):
|
||||
self.document_service = document_service
|
||||
|
||||
async def process_item(self, upload_task: UploadTask, item: str, file_task: FileTask) -> None:
|
||||
"""Process a regular file path using DocumentService"""
|
||||
# This calls the existing logic
|
||||
await self.document_service.process_single_file_task(upload_task, item)
|
||||
|
||||
|
||||
class ConnectorFileProcessor(TaskProcessor):
|
||||
"""Processor for connector file uploads"""
|
||||
|
||||
def __init__(self, connector_service, connection_id: str, files_to_process: list):
|
||||
self.connector_service = connector_service
|
||||
self.connection_id = connection_id
|
||||
self.files_to_process = files_to_process
|
||||
# Create lookup map for file info
|
||||
self.file_info_map = {f['id']: f for f in files_to_process}
|
||||
|
||||
async def process_item(self, upload_task: UploadTask, item: str, file_task: FileTask) -> None:
|
||||
"""Process a connector file using ConnectorService"""
|
||||
from models.tasks import TaskStatus
|
||||
import time
|
||||
|
||||
file_id = item # item is the connector file ID
|
||||
file_info = self.file_info_map.get(file_id)
|
||||
|
||||
if not file_info:
|
||||
raise ValueError(f"File info not found for {file_id}")
|
||||
|
||||
# Get the connector
|
||||
connector = await self.connector_service.get_connector(self.connection_id)
|
||||
if not connector:
|
||||
raise ValueError(f"Connection '{self.connection_id}' not found")
|
||||
|
||||
# Get file content from connector
|
||||
document = await connector.get_file_content(file_info['id'])
|
||||
|
||||
# Get user_id from task store lookup
|
||||
user_id = None
|
||||
for uid, tasks in self.connector_service.task_service.task_store.items():
|
||||
if upload_task.task_id in tasks:
|
||||
user_id = uid
|
||||
break
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("Could not determine user_id for task")
|
||||
|
||||
# Process using existing pipeline
|
||||
result = await self.connector_service.process_connector_document(document, user_id)
|
||||
|
||||
file_task.status = TaskStatus.COMPLETED
|
||||
file_task.result = result
|
||||
upload_task.successful_files += 1
|
||||
|
|
@ -132,7 +132,7 @@ class DocumentService:
|
|||
file_task.updated_at = time.time()
|
||||
|
||||
try:
|
||||
# Check if file already exists in index
|
||||
# Handle regular file processing
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run CPU-intensive docling processing in separate process
|
||||
|
|
|
|||
|
|
@ -3,21 +3,21 @@ import uuid
|
|||
import time
|
||||
import random
|
||||
from typing import Dict
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from models.tasks import TaskStatus, UploadTask, FileTask
|
||||
from utils.gpu_detection import get_worker_count
|
||||
|
||||
from src.utils.gpu_detection import get_worker_count
|
||||
|
||||
|
||||
class TaskService:
|
||||
def __init__(self, document_service=None):
|
||||
def __init__(self, document_service=None, process_pool=None):
|
||||
self.document_service = document_service
|
||||
self.process_pool = process_pool
|
||||
self.task_store: Dict[str, Dict[str, UploadTask]] = {} # user_id -> {task_id -> UploadTask}
|
||||
self.background_tasks = set()
|
||||
|
||||
# Initialize process pool
|
||||
max_workers = get_worker_count()
|
||||
self.process_pool = ProcessPoolExecutor(max_workers=max_workers)
|
||||
print(f"Process pool initialized with {max_workers} workers")
|
||||
if self.process_pool is None:
|
||||
raise ValueError("TaskService requires a process_pool parameter")
|
||||
|
||||
async def exponential_backoff_delay(self, retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None:
|
||||
"""Apply exponential backoff with jitter"""
|
||||
|
|
@ -26,19 +26,29 @@ class TaskService:
|
|||
|
||||
async def create_upload_task(self, user_id: str, file_paths: list) -> str:
|
||||
"""Create a new upload task for bulk file processing"""
|
||||
# Use default DocumentFileProcessor
|
||||
from models.processors import DocumentFileProcessor
|
||||
processor = DocumentFileProcessor(self.document_service)
|
||||
return await self.create_custom_task(user_id, file_paths, processor)
|
||||
|
||||
async def create_custom_task(self, user_id: str, items: list, processor) -> str:
|
||||
"""Create a new task with custom processor for any type of items"""
|
||||
task_id = str(uuid.uuid4())
|
||||
upload_task = UploadTask(
|
||||
task_id=task_id,
|
||||
total_files=len(file_paths),
|
||||
file_tasks={path: FileTask(file_path=path) for path in file_paths}
|
||||
total_files=len(items),
|
||||
file_tasks={str(item): FileTask(file_path=str(item)) for item in items}
|
||||
)
|
||||
|
||||
# Attach the custom processor to the task
|
||||
upload_task.processor = processor
|
||||
|
||||
if user_id not in self.task_store:
|
||||
self.task_store[user_id] = {}
|
||||
self.task_store[user_id][task_id] = upload_task
|
||||
|
||||
# Start background processing
|
||||
background_task = asyncio.create_task(self.background_upload_processor(user_id, task_id))
|
||||
background_task = asyncio.create_task(self.background_custom_processor(user_id, task_id, items))
|
||||
self.background_tasks.add(background_task)
|
||||
background_task.add_done_callback(self.background_tasks.discard)
|
||||
|
||||
|
|
@ -74,6 +84,58 @@ class TaskService:
|
|||
self.task_store[user_id][task_id].status = TaskStatus.FAILED
|
||||
self.task_store[user_id][task_id].updated_at = time.time()
|
||||
|
||||
async def background_custom_processor(self, user_id: str, task_id: str, items: list) -> None:
|
||||
"""Background task to process items using custom processor"""
|
||||
try:
|
||||
upload_task = self.task_store[user_id][task_id]
|
||||
upload_task.status = TaskStatus.RUNNING
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
processor = upload_task.processor
|
||||
|
||||
# Process items with limited concurrency
|
||||
max_workers = get_worker_count()
|
||||
semaphore = asyncio.Semaphore(max_workers * 2)
|
||||
|
||||
async def process_with_semaphore(item, item_key: str):
|
||||
async with semaphore:
|
||||
file_task = upload_task.file_tasks[item_key]
|
||||
file_task.status = TaskStatus.RUNNING
|
||||
file_task.updated_at = time.time()
|
||||
|
||||
try:
|
||||
await processor.process_item(upload_task, item, file_task)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to process item {item}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
file_task.status = TaskStatus.FAILED
|
||||
file_task.error = str(e)
|
||||
upload_task.failed_files += 1
|
||||
finally:
|
||||
file_task.updated_at = time.time()
|
||||
upload_task.processed_files += 1
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
tasks = [
|
||||
process_with_semaphore(item, str(item))
|
||||
for item in items
|
||||
]
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Mark task as completed
|
||||
upload_task.status = TaskStatus.COMPLETED
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
if user_id in self.task_store and task_id in self.task_store[user_id]:
|
||||
self.task_store[user_id][task_id].status = TaskStatus.FAILED
|
||||
self.task_store[user_id][task_id].updated_at = time.time()
|
||||
|
||||
def get_task_status(self, user_id: str, task_id: str) -> dict:
|
||||
"""Get the status of a specific upload task"""
|
||||
if (not task_id or
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import hashlib
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from docling.document_converter import DocumentConverter
|
||||
from .gpu_detection import detect_gpu_devices
|
||||
|
||||
# Global converter cache for worker processes
|
||||
|
|
|
|||
10
src/utils/process_pool.py
Normal file
10
src/utils/process_pool.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from utils.gpu_detection import get_worker_count
|
||||
|
||||
# Create shared process pool at import time (before CUDA initialization)
|
||||
# This avoids the "Cannot re-initialize CUDA in forked subprocess" error
|
||||
MAX_WORKERS = get_worker_count()
|
||||
process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
print(f"Shared process pool initialized with {MAX_WORKERS} workers")
|
||||
Loading…
Add table
Reference in a new issue