From ea24c81eab2a1f0176dcb4c65bdaaa65d679d236 Mon Sep 17 00:00:00 2001 From: phact Date: Thu, 24 Jul 2025 01:51:45 -0400 Subject: [PATCH] multi worker uploads + docling gpu improvements --- .gitignore | 2 + docker-compose.yml | 3 + frontend/src/app/admin/page.tsx | 78 ++++++- frontend/src/app/chat/page.tsx | 171 ++++++++++++-- src/app.py | 383 ++++++++++++++++++++++++++++++-- 5 files changed, 602 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index 53ed2bd6..32f3d866 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ wheels/ .env .idea/ + +1001*.pdf diff --git a/docker-compose.yml b/docker-compose.yml index 3e2ef7ed..9f5f44d4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -40,6 +40,8 @@ services: - OPENSEARCH_USERNAME=admin - OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD} - OPENAI_API_KEY=${OPENAI_API_KEY} + - NVIDIA_DRIVER_CAPABILITIES=compute,utility + - NVIDIA_VISIBLE_DEVICES=all ports: - "3000:3000" volumes: @@ -48,6 +50,7 @@ services: - ./pyproject.toml:/app/pyproject.toml - ./uv.lock:/app/uv.lock - ./documents:/app/documents + gpus: all langflow: volumes: diff --git a/frontend/src/app/admin/page.tsx b/frontend/src/app/admin/page.tsx index c7ad77e8..6588f230 100644 --- a/frontend/src/app/admin/page.tsx +++ b/frontend/src/app/admin/page.tsx @@ -66,9 +66,24 @@ export default function AdminPage() { const result = await response.json() - if (response.ok) { - const successful = result.results.filter((r: {status: string}) => r.status === "indexed").length - const total = result.results.length + if (response.status === 201) { + // New flow: Got task ID, start polling + const taskId = result.task_id || result.id + const totalFiles = result.total_files || 0 + + if (!taskId) { + throw new Error("No task ID received from server") + } + + setUploadStatus(`🔄 Processing started for ${totalFiles} files... (Task ID: ${taskId})`) + + // Start polling the task status + await pollPathTaskStatus(taskId, totalFiles) + + } else if (response.ok) { + // Original flow: Direct response with results + const successful = result.results?.filter((r: {status: string}) => r.status === "indexed").length || 0 + const total = result.results?.length || 0 setUploadStatus(`Path processed successfully! ${successful}/${total} files indexed.`) setFolderPath("") } else { @@ -81,6 +96,63 @@ export default function AdminPage() { } } + const pollPathTaskStatus = async (taskId: string, totalFiles: number) => { + const maxAttempts = 120 // Poll for up to 10 minutes (120 * 5s intervals) for large batches + let attempts = 0 + + const poll = async (): Promise => { + try { + attempts++ + + const response = await fetch(`/api/tasks/${taskId}`) + + if (!response.ok) { + throw new Error(`Failed to check task status: ${response.status}`) + } + + const task = await response.json() + + if (task.status === 'completed') { + setUploadStatus(`✅ Path processing completed! ${task.successful_files}/${task.total_files} files processed successfully.`) + setFolderPath("") + setPathUploadLoading(false) + + } else if (task.status === 'failed' || task.status === 'error') { + setUploadStatus(`❌ Path processing failed: ${task.error || 'Unknown error occurred'}`) + setPathUploadLoading(false) + + } else if (task.status === 'pending' || task.status === 'running') { + // Still in progress, update status and continue polling + const processed = task.processed_files || 0 + const successful = task.successful_files || 0 + const failed = task.failed_files || 0 + + setUploadStatus(`⏳ Processing files... ${processed}/${totalFiles} processed (${successful} successful, ${failed} failed)`) + + // Continue polling if we haven't exceeded max attempts + if (attempts < maxAttempts) { + setTimeout(poll, 5000) // Poll every 5 seconds + } else { + setUploadStatus(`⚠️ Processing timeout after ${attempts} attempts. The task may still be running in the background.`) + setPathUploadLoading(false) + } + + } else { + setUploadStatus(`❓ Unknown task status: ${task.status}`) + setPathUploadLoading(false) + } + + } catch (error) { + console.error('Task polling error:', error) + setUploadStatus(`❌ Failed to check processing status: ${error instanceof Error ? error.message : 'Unknown error'}`) + setPathUploadLoading(false) + } + } + + // Start polling immediately + poll() + } + return (
diff --git a/frontend/src/app/chat/page.tsx b/frontend/src/app/chat/page.tsx index ed634339..342fcf25 100644 --- a/frontend/src/app/chat/page.tsx +++ b/frontend/src/app/chat/page.tsx @@ -64,10 +64,20 @@ export default function ChatPage() { } const handleFileUpload = async (file: File) => { + console.log("handleFileUpload called with file:", file.name) + if (isUploading) return setIsUploading(true) + // Add initial upload message + const uploadStartMessage: Message = { + role: "assistant", + content: `🔄 Starting upload of **${file.name}**...`, + timestamp: new Date() + } + setMessages(prev => [...prev, uploadStartMessage]) + try { const formData = new FormData() formData.append('file', file) @@ -84,27 +94,58 @@ export default function ChatPage() { body: formData, }) + console.log("Upload response status:", response.status) + if (!response.ok) { - throw new Error(`Upload failed: ${response.status}`) + const errorText = await response.text() + console.error("Upload failed with status:", response.status, "Response:", errorText) + throw new Error(`Upload failed: ${response.status} - ${errorText}`) } const result = await response.json() + console.log("Upload result:", result) - // Add upload confirmation as a system message in the UI - const uploadMessage: Message = { - role: "assistant", - content: `📄 Document uploaded: **${result.filename}** (${result.pages} pages, ${result.content_length.toLocaleString()} characters)\n\n${result.confirmation}`, - timestamp: new Date() - } - - setMessages(prev => [...prev, uploadMessage]) - - // Update the response ID for this endpoint - if (result.response_id) { - setPreviousResponseIds(prev => ({ - ...prev, - [endpoint]: result.response_id - })) + if (response.status === 201) { + // New flow: Got task ID, start polling + const taskId = result.task_id || result.id + + if (!taskId) { + console.error("No task ID in 201 response:", result) + throw new Error("No task ID received from server") + } + + // Update message to show polling started + const pollingMessage: Message = { + role: "assistant", + content: `⏳ Upload initiated for **${file.name}**. Processing... (Task ID: ${taskId})`, + timestamp: new Date() + } + setMessages(prev => [...prev.slice(0, -1), pollingMessage]) + + // Start polling the task status + await pollTaskStatus(taskId, file.name) + + } else if (response.ok) { + // Original flow: Direct response + + const uploadMessage: Message = { + role: "assistant", + content: `📄 Document uploaded: **${result.filename}** (${result.pages} pages, ${result.content_length.toLocaleString()} characters)\n\n${result.confirmation}`, + timestamp: new Date() + } + + setMessages(prev => [...prev.slice(0, -1), uploadMessage]) + + // Update the response ID for this endpoint + if (result.response_id) { + setPreviousResponseIds(prev => ({ + ...prev, + [endpoint]: result.response_id + })) + } + + } else { + throw new Error(`Upload failed: ${response.status}`) } } catch (error) { @@ -114,12 +155,108 @@ export default function ChatPage() { content: `❌ Upload failed: ${error instanceof Error ? error.message : 'Unknown error'}`, timestamp: new Date() } - setMessages(prev => [...prev, errorMessage]) + setMessages(prev => [...prev.slice(0, -1), errorMessage]) } finally { setIsUploading(false) } } + const pollTaskStatus = async (taskId: string, filename: string) => { + const maxAttempts = 60 // Poll for up to 5 minutes (60 * 5s intervals) + let attempts = 0 + + const poll = async (): Promise => { + try { + attempts++ + + const response = await fetch(`/api/tasks/${taskId}`) + console.log("Task polling response status:", response.status) + + if (!response.ok) { + const errorText = await response.text() + console.error("Task polling failed:", response.status, errorText) + throw new Error(`Failed to check task status: ${response.status} - ${errorText}`) + } + + const task = await response.json() + console.log("Task polling result:", task) + + // Safety check to ensure task object exists + if (!task) { + throw new Error("No task data received from server") + } + + // Update the message based on task status + if (task.status === 'completed') { + const successMessage: Message = { + role: "assistant", + content: `✅ **${filename}** processed successfully!\n\n${task.result?.confirmation || 'Document has been added to the knowledge base.'}`, + timestamp: new Date() + } + setMessages(prev => [...prev.slice(0, -1), successMessage]) + + // Update response ID if available + if (task.result?.response_id) { + setPreviousResponseIds(prev => ({ + ...prev, + [endpoint]: task.result.response_id + })) + } + + } else if (task.status === 'failed' || task.status === 'error') { + const errorMessage: Message = { + role: "assistant", + content: `❌ Processing failed for **${filename}**: ${task.error || 'Unknown error occurred'}`, + timestamp: new Date() + } + setMessages(prev => [...prev.slice(0, -1), errorMessage]) + + } else if (task.status === 'pending' || task.status === 'running' || task.status === 'processing') { + // Still in progress, update message and continue polling + const progressMessage: Message = { + role: "assistant", + content: `⏳ Processing **${filename}**... (${task.status}) - Attempt ${attempts}/${maxAttempts}`, + timestamp: new Date() + } + setMessages(prev => [...prev.slice(0, -1), progressMessage]) + + // Continue polling if we haven't exceeded max attempts + if (attempts < maxAttempts) { + setTimeout(poll, 5000) // Poll every 5 seconds + } else { + const timeoutMessage: Message = { + role: "assistant", + content: `⚠️ Processing timeout for **${filename}**. The task may still be running in the background.`, + timestamp: new Date() + } + setMessages(prev => [...prev.slice(0, -1), timeoutMessage]) + } + + } else { + // Unknown status + const unknownMessage: Message = { + role: "assistant", + content: `❓ Unknown status for **${filename}**: ${task.status}`, + timestamp: new Date() + } + setMessages(prev => [...prev.slice(0, -1), unknownMessage]) + } + + } catch (error) { + console.error('Task polling error:', error) + const errorMessage: Message = { + role: "assistant", + content: `❌ Failed to check processing status for **${filename}**: ${error instanceof Error ? error.message : 'Unknown error'}`, + timestamp: new Date() + } + setMessages(prev => [...prev.slice(0, -1), errorMessage]) + } + } + + // Start polling immediately + poll() + } + const handleDragEnter = (e: React.DragEvent) => { e.preventDefault() e.stopPropagation() diff --git a/src/app.py b/src/app.py index f9d45ceb..285de4e4 100644 --- a/src/app.py +++ b/src/app.py @@ -3,11 +3,16 @@ import os from collections import defaultdict from typing import Any +import uuid +import time +import random +from dataclasses import dataclass, field +from enum import Enum +from concurrent.futures import ProcessPoolExecutor +import multiprocessing from agent import async_chat, async_langflow -os.environ['USE_CPU_ONLY'] = 'true' - import hashlib import tempfile import asyncio @@ -29,6 +34,10 @@ from dotenv import load_dotenv load_dotenv() load_dotenv("../") +import torch +print("CUDA available:", torch.cuda.is_available()) +print("CUDA version PyTorch was built with:", torch.version.cuda) + # Initialize Docling converter converter = DocumentConverter() # basic converter; tweak via PipelineOptions if you need OCR, etc. @@ -43,7 +52,7 @@ langflow_key = os.getenv("LANGFLOW_SECRET_KEY") -es = AsyncOpenSearch( +opensearch = AsyncOpenSearch( hosts=[{"host": opensearch_host, "port": opensearch_port}], connection_class=AIOHttpConnection, scheme="https", @@ -93,6 +102,183 @@ langflow_client = AsyncOpenAI( ) patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) # Get the patched client back +class TaskStatus(Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + +@dataclass +class FileTask: + file_path: str + status: TaskStatus = TaskStatus.PENDING + result: dict = None + error: str = None + retry_count: int = 0 + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + +@dataclass +class UploadTask: + task_id: str + total_files: int + processed_files: int = 0 + successful_files: int = 0 + failed_files: int = 0 + file_tasks: dict = field(default_factory=dict) + status: TaskStatus = TaskStatus.PENDING + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + +task_store = {} +background_tasks = set() + +# GPU device detection +def detect_gpu_devices(): + """Detect if GPU devices are actually available""" + try: + import torch + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + return True, torch.cuda.device_count() + except ImportError: + pass + + try: + import subprocess + result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) + if result.returncode == 0: + return True, "detected" + except (subprocess.SubprocessError, FileNotFoundError): + pass + + return False, 0 + +# GPU and concurrency configuration +HAS_GPU_DEVICES, GPU_COUNT = detect_gpu_devices() + +if HAS_GPU_DEVICES: + # GPU mode with actual GPU devices: Lower concurrency due to memory constraints + DEFAULT_WORKERS = min(4, multiprocessing.cpu_count() // 2) + print(f"GPU mode enabled with {GPU_COUNT} GPU(s) - using limited concurrency ({DEFAULT_WORKERS} workers)") +elif HAS_GPU_DEVICES: + # GPU mode requested but no devices found: Use full CPU concurrency + DEFAULT_WORKERS = multiprocessing.cpu_count() + print(f"GPU mode requested but no GPU devices found - falling back to full CPU concurrency ({DEFAULT_WORKERS} workers)") +else: + # CPU mode: Higher concurrency since no GPU memory constraints + DEFAULT_WORKERS = multiprocessing.cpu_count() + print(f"CPU-only mode enabled - using full concurrency ({DEFAULT_WORKERS} workers)") + +MAX_WORKERS = int(os.getenv("MAX_WORKERS", DEFAULT_WORKERS)) +process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS) + +print(f"Process pool initialized with {MAX_WORKERS} workers") + +# Global converter cache for worker processes +_worker_converter = None + +def get_worker_converter(): + """Get or create a DocumentConverter instance for this worker process""" + global _worker_converter + if _worker_converter is None: + import os + from docling.document_converter import DocumentConverter + + # Configure GPU settings for this worker + has_gpu_devices, _ = detect_gpu_devices() + if not has_gpu_devices: + # Force CPU-only mode in subprocess + os.environ['USE_CPU_ONLY'] = 'true' + os.environ['CUDA_VISIBLE_DEVICES'] = '' + os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' + os.environ['TRANSFORMERS_OFFLINE'] = '0' + os.environ['TORCH_USE_CUDA_DSA'] = '0' + + # Try to disable CUDA in torch if available + try: + import torch + torch.cuda.is_available = lambda: False + except ImportError: + pass + else: + # GPU mode - let libraries use GPU if available + os.environ.pop('USE_CPU_ONLY', None) + os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # Still disable progress bars + + print(f"🔧 Initializing DocumentConverter in worker process (PID: {os.getpid()})") + _worker_converter = DocumentConverter() + print(f"✅ DocumentConverter ready in worker process (PID: {os.getpid()})") + + return _worker_converter + +def detect_gpu_devices(): + """Detect if GPU devices are actually available""" + try: + import torch + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + return True, torch.cuda.device_count() + except ImportError: + pass + + try: + import subprocess + result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) + if result.returncode == 0: + return True, "detected" + except (subprocess.SubprocessError, FileNotFoundError): + pass + + return False, 0 + +def process_document_sync(file_path: str): + """Synchronous document processing function for multiprocessing""" + import hashlib + from collections import defaultdict + + # Get the cached converter for this worker + converter = get_worker_converter() + + # Compute file hash + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + while True: + chunk = f.read(1 << 20) + if not chunk: + break + sha256.update(chunk) + file_hash = sha256.hexdigest() + + # Convert with docling + result = converter.convert(file_path) + full_doc = result.document.export_to_dict() + + # Extract relevant content (same logic as extract_relevant) + origin = full_doc.get("origin", {}) + texts = full_doc.get("texts", []) + + page_texts = defaultdict(list) + for txt in texts: + prov = txt.get("prov", []) + page_no = prov[0].get("page_no") if prov else None + if page_no is not None: + page_texts[page_no].append(txt.get("text", "").strip()) + + chunks = [] + for page in sorted(page_texts): + joined = "\n".join(page_texts[page]) + chunks.append({ + "page": page, + "text": joined + }) + + return { + "id": file_hash, + "filename": origin.get("filename"), + "mimetype": origin.get("mimetype"), + "chunks": chunks, + "file_path": file_path + } + async def wait_for_opensearch(): """Wait for OpenSearch to be ready with retries""" max_retries = 30 @@ -100,7 +286,7 @@ async def wait_for_opensearch(): for attempt in range(max_retries): try: - await es.info() + await opensearch.info() print("OpenSearch is ready!") return except Exception as e: @@ -113,8 +299,8 @@ async def wait_for_opensearch(): async def init_index(): await wait_for_opensearch() - if not await es.indices.exists(index=INDEX_NAME): - await es.indices.create(index=INDEX_NAME, body=index_body) + if not await opensearch.indices.exists(index=INDEX_NAME): + await opensearch.indices.create(index=INDEX_NAME, body=index_body) print(f"Created index '{INDEX_NAME}'") else: print(f"Index '{INDEX_NAME}' already exists, skipping creation.") @@ -155,6 +341,26 @@ def extract_relevant(doc_dict: dict) -> dict: "chunks": chunks } +async def exponential_backoff_delay(retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None: + """Apply exponential backoff with jitter""" + delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay) + await asyncio.sleep(delay) + +async def process_file_with_retry(file_path: str, max_retries: int = 3) -> dict: + """Process a file with retry logic - retries everything up to max_retries times""" + last_error = None + + for attempt in range(max_retries + 1): + try: + return await process_file_common(file_path) + except Exception as e: + last_error = e + if attempt < max_retries: + await exponential_backoff_delay(attempt) + continue + else: + raise last_error + async def process_file_common(file_path: str, file_hash: str = None): """ Common processing logic for both upload and upload_path. @@ -173,7 +379,7 @@ async def process_file_common(file_path: str, file_hash: str = None): sha256.update(chunk) file_hash = sha256.hexdigest() - exists = await es.exists(index=INDEX_NAME, id=file_hash) + exists = await opensearch.exists(index=INDEX_NAME, id=file_hash) if exists: return {"status": "unchanged", "id": file_hash} @@ -199,7 +405,7 @@ async def process_file_common(file_path: str, file_hash: str = None): "chunk_embedding": vect } chunk_id = f"{file_hash}_{i}" - await es.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc) + await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc) return {"status": "indexed", "id": file_hash} async def process_file_on_disk(path: str): @@ -210,6 +416,94 @@ async def process_file_on_disk(path: str): result["path"] = path return result +async def process_single_file_task(upload_task: UploadTask, file_path: str) -> None: + """Process a single file and update task tracking""" + file_task = upload_task.file_tasks[file_path] + file_task.status = TaskStatus.RUNNING + file_task.updated_at = time.time() + + try: + # Check if file already exists in index + import asyncio + loop = asyncio.get_event_loop() + + # Run CPU-intensive docling processing in separate process + slim_doc = await loop.run_in_executor(process_pool, process_document_sync, file_path) + + # Check if already indexed + exists = await opensearch.exists(index=INDEX_NAME, id=slim_doc["id"]) + if exists: + result = {"status": "unchanged", "id": slim_doc["id"]} + else: + # Generate embeddings and index (I/O bound, keep in main process) + texts = [c["text"] for c in slim_doc["chunks"]] + resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts) + embeddings = [d.embedding for d in resp.data] + + # Index each chunk + for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)): + chunk_doc = { + "document_id": slim_doc["id"], + "filename": slim_doc["filename"], + "mimetype": slim_doc["mimetype"], + "page": chunk["page"], + "text": chunk["text"], + "chunk_embedding": vect + } + chunk_id = f"{slim_doc['id']}_{i}" + await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc) + + result = {"status": "indexed", "id": slim_doc["id"]} + + result["path"] = file_path + file_task.status = TaskStatus.COMPLETED + file_task.result = result + upload_task.successful_files += 1 + + except Exception as e: + print(f"[ERROR] Failed to process file {file_path}: {e}") + import traceback + traceback.print_exc() + file_task.status = TaskStatus.FAILED + file_task.error = str(e) + upload_task.failed_files += 1 + finally: + file_task.updated_at = time.time() + upload_task.processed_files += 1 + upload_task.updated_at = time.time() + + if upload_task.processed_files >= upload_task.total_files: + upload_task.status = TaskStatus.COMPLETED + +async def background_upload_processor(task_id: str) -> None: + """Background task to process all files in an upload job with concurrency control""" + try: + upload_task = task_store[task_id] + upload_task.status = TaskStatus.RUNNING + upload_task.updated_at = time.time() + + # Process files with limited concurrency to avoid overwhelming the system + semaphore = asyncio.Semaphore(MAX_WORKERS * 2) # Allow 2x process pool size for async I/O + + async def process_with_semaphore(file_path: str): + async with semaphore: + await process_single_file_task(upload_task, file_path) + + tasks = [ + process_with_semaphore(file_path) + for file_path in upload_task.file_tasks.keys() + ] + + await asyncio.gather(*tasks, return_exceptions=True) + + except Exception as e: + print(f"[ERROR] Background upload processor failed for task {task_id}: {e}") + import traceback + traceback.print_exc() + if task_id in task_store: + task_store[task_id].status = TaskStatus.FAILED + task_store[task_id].updated_at = time.time() + async def upload(request: Request): form = await request.form() upload_file = form["file"] @@ -226,7 +520,7 @@ async def upload(request: Request): tmp.flush() file_hash = sha256.hexdigest() - exists = await es.exists(index=INDEX_NAME, id=file_hash) + exists = await opensearch.exists(index=INDEX_NAME, id=file_hash) if exists: return JSONResponse({"status": "unchanged", "id": file_hash}) @@ -243,12 +537,31 @@ async def upload_path(request: Request): if not base_dir or not os.path.isdir(base_dir): return JSONResponse({"error": "Invalid path"}, status_code=400) - tasks = [process_file_on_disk(os.path.join(root, fn)) - for root, _, files in os.walk(base_dir) - for fn in files] + file_paths = [os.path.join(root, fn) + for root, _, files in os.walk(base_dir) + for fn in files] + + if not file_paths: + return JSONResponse({"error": "No files found in directory"}, status_code=400) - results = await asyncio.gather(*tasks) - return JSONResponse({"results": results}) + task_id = str(uuid.uuid4()) + upload_task = UploadTask( + task_id=task_id, + total_files=len(file_paths), + file_tasks={path: FileTask(file_path=path) for path in file_paths} + ) + + task_store[task_id] = upload_task + + background_task = asyncio.create_task(background_upload_processor(task_id)) + background_tasks.add(background_task) + background_task.add_done_callback(background_tasks.discard) + + return JSONResponse({ + "task_id": task_id, + "total_files": len(file_paths), + "status": "accepted" + }, status_code=201) async def upload_context(request: Request): """Upload a file and add its content as context to the current conversation""" @@ -306,6 +619,38 @@ async def upload_context(request: Request): return JSONResponse(response_data) +async def task_status(request: Request): + """Get the status of an upload task""" + task_id = request.path_params.get("task_id") + + if not task_id or task_id not in task_store: + return JSONResponse({"error": "Task not found"}, status_code=404) + + upload_task = task_store[task_id] + + file_statuses = {} + for file_path, file_task in upload_task.file_tasks.items(): + file_statuses[file_path] = { + "status": file_task.status.value, + "result": file_task.result, + "error": file_task.error, + "retry_count": file_task.retry_count, + "created_at": file_task.created_at, + "updated_at": file_task.updated_at + } + + return JSONResponse({ + "task_id": upload_task.task_id, + "status": upload_task.status.value, + "total_files": upload_task.total_files, + "processed_files": upload_task.processed_files, + "successful_files": upload_task.successful_files, + "failed_files": upload_task.failed_files, + "created_at": upload_task.created_at, + "updated_at": upload_task.updated_at, + "files": file_statuses + }) + async def search(request: Request): payload = await request.json() @@ -345,7 +690,7 @@ async def search_tool(query: str)-> dict[str, Any]: "_source": ["filename", "mimetype", "page", "text"], "size": 10 } - results = await es.search(index=INDEX_NAME, body=search_body) + results = await opensearch.search(index=INDEX_NAME, body=search_body) # Transform results to match expected format chunks = [] for hit in results["hits"]["hits"]: @@ -425,6 +770,7 @@ app = Starlette(debug=True, routes=[ Route("/upload", upload, methods=["POST"]), Route("/upload_context", upload_context, methods=["POST"]), Route("/upload_path", upload_path, methods=["POST"]), + Route("/tasks/{task_id}", task_status, methods=["GET"]), Route("/search", search, methods=["POST"]), Route("/chat", chat_endpoint, methods=["POST"]), Route("/langflow", langflow_endpoint, methods=["POST"]), @@ -432,10 +778,17 @@ app = Starlette(debug=True, routes=[ if __name__ == "__main__": import uvicorn + import atexit async def main(): await init_index() + # Cleanup process pool on exit + def cleanup(): + process_pool.shutdown(wait=True) + + atexit.register(cleanup) + asyncio.run(main()) uvicorn.run( "app:app",