multi worker uploads + docling gpu improvements
This commit is contained in:
parent
b159164593
commit
ea24c81eab
5 changed files with 602 additions and 35 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -11,3 +11,5 @@ wheels/
|
|||
.env
|
||||
|
||||
.idea/
|
||||
|
||||
1001*.pdf
|
||||
|
|
|
|||
|
|
@ -40,6 +40,8 @@ services:
|
|||
- OPENSEARCH_USERNAME=admin
|
||||
- OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
- NVIDIA_VISIBLE_DEVICES=all
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
|
|
@ -48,6 +50,7 @@ services:
|
|||
- ./pyproject.toml:/app/pyproject.toml
|
||||
- ./uv.lock:/app/uv.lock
|
||||
- ./documents:/app/documents
|
||||
gpus: all
|
||||
|
||||
langflow:
|
||||
volumes:
|
||||
|
|
|
|||
|
|
@ -66,9 +66,24 @@ export default function AdminPage() {
|
|||
|
||||
const result = await response.json()
|
||||
|
||||
if (response.ok) {
|
||||
const successful = result.results.filter((r: {status: string}) => r.status === "indexed").length
|
||||
const total = result.results.length
|
||||
if (response.status === 201) {
|
||||
// New flow: Got task ID, start polling
|
||||
const taskId = result.task_id || result.id
|
||||
const totalFiles = result.total_files || 0
|
||||
|
||||
if (!taskId) {
|
||||
throw new Error("No task ID received from server")
|
||||
}
|
||||
|
||||
setUploadStatus(`🔄 Processing started for ${totalFiles} files... (Task ID: ${taskId})`)
|
||||
|
||||
// Start polling the task status
|
||||
await pollPathTaskStatus(taskId, totalFiles)
|
||||
|
||||
} else if (response.ok) {
|
||||
// Original flow: Direct response with results
|
||||
const successful = result.results?.filter((r: {status: string}) => r.status === "indexed").length || 0
|
||||
const total = result.results?.length || 0
|
||||
setUploadStatus(`Path processed successfully! ${successful}/${total} files indexed.`)
|
||||
setFolderPath("")
|
||||
} else {
|
||||
|
|
@ -81,6 +96,63 @@ export default function AdminPage() {
|
|||
}
|
||||
}
|
||||
|
||||
const pollPathTaskStatus = async (taskId: string, totalFiles: number) => {
|
||||
const maxAttempts = 120 // Poll for up to 10 minutes (120 * 5s intervals) for large batches
|
||||
let attempts = 0
|
||||
|
||||
const poll = async (): Promise<void> => {
|
||||
try {
|
||||
attempts++
|
||||
|
||||
const response = await fetch(`/api/tasks/${taskId}`)
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to check task status: ${response.status}`)
|
||||
}
|
||||
|
||||
const task = await response.json()
|
||||
|
||||
if (task.status === 'completed') {
|
||||
setUploadStatus(`✅ Path processing completed! ${task.successful_files}/${task.total_files} files processed successfully.`)
|
||||
setFolderPath("")
|
||||
setPathUploadLoading(false)
|
||||
|
||||
} else if (task.status === 'failed' || task.status === 'error') {
|
||||
setUploadStatus(`❌ Path processing failed: ${task.error || 'Unknown error occurred'}`)
|
||||
setPathUploadLoading(false)
|
||||
|
||||
} else if (task.status === 'pending' || task.status === 'running') {
|
||||
// Still in progress, update status and continue polling
|
||||
const processed = task.processed_files || 0
|
||||
const successful = task.successful_files || 0
|
||||
const failed = task.failed_files || 0
|
||||
|
||||
setUploadStatus(`⏳ Processing files... ${processed}/${totalFiles} processed (${successful} successful, ${failed} failed)`)
|
||||
|
||||
// Continue polling if we haven't exceeded max attempts
|
||||
if (attempts < maxAttempts) {
|
||||
setTimeout(poll, 5000) // Poll every 5 seconds
|
||||
} else {
|
||||
setUploadStatus(`⚠️ Processing timeout after ${attempts} attempts. The task may still be running in the background.`)
|
||||
setPathUploadLoading(false)
|
||||
}
|
||||
|
||||
} else {
|
||||
setUploadStatus(`❓ Unknown task status: ${task.status}`)
|
||||
setPathUploadLoading(false)
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Task polling error:', error)
|
||||
setUploadStatus(`❌ Failed to check processing status: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||
setPathUploadLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Start polling immediately
|
||||
poll()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-8">
|
||||
<div>
|
||||
|
|
|
|||
|
|
@ -64,10 +64,20 @@ export default function ChatPage() {
|
|||
}
|
||||
|
||||
const handleFileUpload = async (file: File) => {
|
||||
console.log("handleFileUpload called with file:", file.name)
|
||||
|
||||
if (isUploading) return
|
||||
|
||||
setIsUploading(true)
|
||||
|
||||
// Add initial upload message
|
||||
const uploadStartMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `🔄 Starting upload of **${file.name}**...`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev, uploadStartMessage])
|
||||
|
||||
try {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file)
|
||||
|
|
@ -84,27 +94,58 @@ export default function ChatPage() {
|
|||
body: formData,
|
||||
})
|
||||
|
||||
console.log("Upload response status:", response.status)
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Upload failed: ${response.status}`)
|
||||
const errorText = await response.text()
|
||||
console.error("Upload failed with status:", response.status, "Response:", errorText)
|
||||
throw new Error(`Upload failed: ${response.status} - ${errorText}`)
|
||||
}
|
||||
|
||||
const result = await response.json()
|
||||
console.log("Upload result:", result)
|
||||
|
||||
// Add upload confirmation as a system message in the UI
|
||||
const uploadMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `📄 Document uploaded: **${result.filename}** (${result.pages} pages, ${result.content_length.toLocaleString()} characters)\n\n${result.confirmation}`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
|
||||
setMessages(prev => [...prev, uploadMessage])
|
||||
|
||||
// Update the response ID for this endpoint
|
||||
if (result.response_id) {
|
||||
setPreviousResponseIds(prev => ({
|
||||
...prev,
|
||||
[endpoint]: result.response_id
|
||||
}))
|
||||
if (response.status === 201) {
|
||||
// New flow: Got task ID, start polling
|
||||
const taskId = result.task_id || result.id
|
||||
|
||||
if (!taskId) {
|
||||
console.error("No task ID in 201 response:", result)
|
||||
throw new Error("No task ID received from server")
|
||||
}
|
||||
|
||||
// Update message to show polling started
|
||||
const pollingMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `⏳ Upload initiated for **${file.name}**. Processing... (Task ID: ${taskId})`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev.slice(0, -1), pollingMessage])
|
||||
|
||||
// Start polling the task status
|
||||
await pollTaskStatus(taskId, file.name)
|
||||
|
||||
} else if (response.ok) {
|
||||
// Original flow: Direct response
|
||||
|
||||
const uploadMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `📄 Document uploaded: **${result.filename}** (${result.pages} pages, ${result.content_length.toLocaleString()} characters)\n\n${result.confirmation}`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
|
||||
setMessages(prev => [...prev.slice(0, -1), uploadMessage])
|
||||
|
||||
// Update the response ID for this endpoint
|
||||
if (result.response_id) {
|
||||
setPreviousResponseIds(prev => ({
|
||||
...prev,
|
||||
[endpoint]: result.response_id
|
||||
}))
|
||||
}
|
||||
|
||||
} else {
|
||||
throw new Error(`Upload failed: ${response.status}`)
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
|
|
@ -114,12 +155,108 @@ export default function ChatPage() {
|
|||
content: `❌ Upload failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev, errorMessage])
|
||||
setMessages(prev => [...prev.slice(0, -1), errorMessage])
|
||||
} finally {
|
||||
setIsUploading(false)
|
||||
}
|
||||
}
|
||||
|
||||
const pollTaskStatus = async (taskId: string, filename: string) => {
|
||||
const maxAttempts = 60 // Poll for up to 5 minutes (60 * 5s intervals)
|
||||
let attempts = 0
|
||||
|
||||
const poll = async (): Promise<void> => {
|
||||
try {
|
||||
attempts++
|
||||
|
||||
const response = await fetch(`/api/tasks/${taskId}`)
|
||||
console.log("Task polling response status:", response.status)
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
console.error("Task polling failed:", response.status, errorText)
|
||||
throw new Error(`Failed to check task status: ${response.status} - ${errorText}`)
|
||||
}
|
||||
|
||||
const task = await response.json()
|
||||
console.log("Task polling result:", task)
|
||||
|
||||
// Safety check to ensure task object exists
|
||||
if (!task) {
|
||||
throw new Error("No task data received from server")
|
||||
}
|
||||
|
||||
// Update the message based on task status
|
||||
if (task.status === 'completed') {
|
||||
const successMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `✅ **${filename}** processed successfully!\n\n${task.result?.confirmation || 'Document has been added to the knowledge base.'}`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev.slice(0, -1), successMessage])
|
||||
|
||||
// Update response ID if available
|
||||
if (task.result?.response_id) {
|
||||
setPreviousResponseIds(prev => ({
|
||||
...prev,
|
||||
[endpoint]: task.result.response_id
|
||||
}))
|
||||
}
|
||||
|
||||
} else if (task.status === 'failed' || task.status === 'error') {
|
||||
const errorMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `❌ Processing failed for **${filename}**: ${task.error || 'Unknown error occurred'}`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev.slice(0, -1), errorMessage])
|
||||
|
||||
} else if (task.status === 'pending' || task.status === 'running' || task.status === 'processing') {
|
||||
// Still in progress, update message and continue polling
|
||||
const progressMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `⏳ Processing **${filename}**... (${task.status}) - Attempt ${attempts}/${maxAttempts}`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev.slice(0, -1), progressMessage])
|
||||
|
||||
// Continue polling if we haven't exceeded max attempts
|
||||
if (attempts < maxAttempts) {
|
||||
setTimeout(poll, 5000) // Poll every 5 seconds
|
||||
} else {
|
||||
const timeoutMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `⚠️ Processing timeout for **${filename}**. The task may still be running in the background.`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev.slice(0, -1), timeoutMessage])
|
||||
}
|
||||
|
||||
} else {
|
||||
// Unknown status
|
||||
const unknownMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `❓ Unknown status for **${filename}**: ${task.status}`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev.slice(0, -1), unknownMessage])
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Task polling error:', error)
|
||||
const errorMessage: Message = {
|
||||
role: "assistant",
|
||||
content: `❌ Failed to check processing status for **${filename}**: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
timestamp: new Date()
|
||||
}
|
||||
setMessages(prev => [...prev.slice(0, -1), errorMessage])
|
||||
}
|
||||
}
|
||||
|
||||
// Start polling immediately
|
||||
poll()
|
||||
}
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
|
|
|
|||
383
src/app.py
383
src/app.py
|
|
@ -3,11 +3,16 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
import uuid
|
||||
import time
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import multiprocessing
|
||||
|
||||
from agent import async_chat, async_langflow
|
||||
|
||||
os.environ['USE_CPU_ONLY'] = 'true'
|
||||
|
||||
import hashlib
|
||||
import tempfile
|
||||
import asyncio
|
||||
|
|
@ -29,6 +34,10 @@ from dotenv import load_dotenv
|
|||
load_dotenv()
|
||||
load_dotenv("../")
|
||||
|
||||
import torch
|
||||
print("CUDA available:", torch.cuda.is_available())
|
||||
print("CUDA version PyTorch was built with:", torch.version.cuda)
|
||||
|
||||
# Initialize Docling converter
|
||||
converter = DocumentConverter() # basic converter; tweak via PipelineOptions if you need OCR, etc.
|
||||
|
||||
|
|
@ -43,7 +52,7 @@ langflow_key = os.getenv("LANGFLOW_SECRET_KEY")
|
|||
|
||||
|
||||
|
||||
es = AsyncOpenSearch(
|
||||
opensearch = AsyncOpenSearch(
|
||||
hosts=[{"host": opensearch_host, "port": opensearch_port}],
|
||||
connection_class=AIOHttpConnection,
|
||||
scheme="https",
|
||||
|
|
@ -93,6 +102,183 @@ langflow_client = AsyncOpenAI(
|
|||
)
|
||||
patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) # Get the patched client back
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
@dataclass
|
||||
class FileTask:
|
||||
file_path: str
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
result: dict = None
|
||||
error: str = None
|
||||
retry_count: int = 0
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
@dataclass
|
||||
class UploadTask:
|
||||
task_id: str
|
||||
total_files: int
|
||||
processed_files: int = 0
|
||||
successful_files: int = 0
|
||||
failed_files: int = 0
|
||||
file_tasks: dict = field(default_factory=dict)
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
task_store = {}
|
||||
background_tasks = set()
|
||||
|
||||
# GPU device detection
|
||||
def detect_gpu_devices():
|
||||
"""Detect if GPU devices are actually available"""
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
return True, torch.cuda.device_count()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return True, "detected"
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
return False, 0
|
||||
|
||||
# GPU and concurrency configuration
|
||||
HAS_GPU_DEVICES, GPU_COUNT = detect_gpu_devices()
|
||||
|
||||
if HAS_GPU_DEVICES:
|
||||
# GPU mode with actual GPU devices: Lower concurrency due to memory constraints
|
||||
DEFAULT_WORKERS = min(4, multiprocessing.cpu_count() // 2)
|
||||
print(f"GPU mode enabled with {GPU_COUNT} GPU(s) - using limited concurrency ({DEFAULT_WORKERS} workers)")
|
||||
elif HAS_GPU_DEVICES:
|
||||
# GPU mode requested but no devices found: Use full CPU concurrency
|
||||
DEFAULT_WORKERS = multiprocessing.cpu_count()
|
||||
print(f"GPU mode requested but no GPU devices found - falling back to full CPU concurrency ({DEFAULT_WORKERS} workers)")
|
||||
else:
|
||||
# CPU mode: Higher concurrency since no GPU memory constraints
|
||||
DEFAULT_WORKERS = multiprocessing.cpu_count()
|
||||
print(f"CPU-only mode enabled - using full concurrency ({DEFAULT_WORKERS} workers)")
|
||||
|
||||
MAX_WORKERS = int(os.getenv("MAX_WORKERS", DEFAULT_WORKERS))
|
||||
process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
print(f"Process pool initialized with {MAX_WORKERS} workers")
|
||||
|
||||
# Global converter cache for worker processes
|
||||
_worker_converter = None
|
||||
|
||||
def get_worker_converter():
|
||||
"""Get or create a DocumentConverter instance for this worker process"""
|
||||
global _worker_converter
|
||||
if _worker_converter is None:
|
||||
import os
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
# Configure GPU settings for this worker
|
||||
has_gpu_devices, _ = detect_gpu_devices()
|
||||
if not has_gpu_devices:
|
||||
# Force CPU-only mode in subprocess
|
||||
os.environ['USE_CPU_ONLY'] = 'true'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
||||
os.environ['TRANSFORMERS_OFFLINE'] = '0'
|
||||
os.environ['TORCH_USE_CUDA_DSA'] = '0'
|
||||
|
||||
# Try to disable CUDA in torch if available
|
||||
try:
|
||||
import torch
|
||||
torch.cuda.is_available = lambda: False
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
# GPU mode - let libraries use GPU if available
|
||||
os.environ.pop('USE_CPU_ONLY', None)
|
||||
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # Still disable progress bars
|
||||
|
||||
print(f"🔧 Initializing DocumentConverter in worker process (PID: {os.getpid()})")
|
||||
_worker_converter = DocumentConverter()
|
||||
print(f"✅ DocumentConverter ready in worker process (PID: {os.getpid()})")
|
||||
|
||||
return _worker_converter
|
||||
|
||||
def detect_gpu_devices():
|
||||
"""Detect if GPU devices are actually available"""
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
return True, torch.cuda.device_count()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return True, "detected"
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
return False, 0
|
||||
|
||||
def process_document_sync(file_path: str):
|
||||
"""Synchronous document processing function for multiprocessing"""
|
||||
import hashlib
|
||||
from collections import defaultdict
|
||||
|
||||
# Get the cached converter for this worker
|
||||
converter = get_worker_converter()
|
||||
|
||||
# Compute file hash
|
||||
sha256 = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(1 << 20)
|
||||
if not chunk:
|
||||
break
|
||||
sha256.update(chunk)
|
||||
file_hash = sha256.hexdigest()
|
||||
|
||||
# Convert with docling
|
||||
result = converter.convert(file_path)
|
||||
full_doc = result.document.export_to_dict()
|
||||
|
||||
# Extract relevant content (same logic as extract_relevant)
|
||||
origin = full_doc.get("origin", {})
|
||||
texts = full_doc.get("texts", [])
|
||||
|
||||
page_texts = defaultdict(list)
|
||||
for txt in texts:
|
||||
prov = txt.get("prov", [])
|
||||
page_no = prov[0].get("page_no") if prov else None
|
||||
if page_no is not None:
|
||||
page_texts[page_no].append(txt.get("text", "").strip())
|
||||
|
||||
chunks = []
|
||||
for page in sorted(page_texts):
|
||||
joined = "\n".join(page_texts[page])
|
||||
chunks.append({
|
||||
"page": page,
|
||||
"text": joined
|
||||
})
|
||||
|
||||
return {
|
||||
"id": file_hash,
|
||||
"filename": origin.get("filename"),
|
||||
"mimetype": origin.get("mimetype"),
|
||||
"chunks": chunks,
|
||||
"file_path": file_path
|
||||
}
|
||||
|
||||
async def wait_for_opensearch():
|
||||
"""Wait for OpenSearch to be ready with retries"""
|
||||
max_retries = 30
|
||||
|
|
@ -100,7 +286,7 @@ async def wait_for_opensearch():
|
|||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await es.info()
|
||||
await opensearch.info()
|
||||
print("OpenSearch is ready!")
|
||||
return
|
||||
except Exception as e:
|
||||
|
|
@ -113,8 +299,8 @@ async def wait_for_opensearch():
|
|||
async def init_index():
|
||||
await wait_for_opensearch()
|
||||
|
||||
if not await es.indices.exists(index=INDEX_NAME):
|
||||
await es.indices.create(index=INDEX_NAME, body=index_body)
|
||||
if not await opensearch.indices.exists(index=INDEX_NAME):
|
||||
await opensearch.indices.create(index=INDEX_NAME, body=index_body)
|
||||
print(f"Created index '{INDEX_NAME}'")
|
||||
else:
|
||||
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
||||
|
|
@ -155,6 +341,26 @@ def extract_relevant(doc_dict: dict) -> dict:
|
|||
"chunks": chunks
|
||||
}
|
||||
|
||||
async def exponential_backoff_delay(retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None:
|
||||
"""Apply exponential backoff with jitter"""
|
||||
delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def process_file_with_retry(file_path: str, max_retries: int = 3) -> dict:
|
||||
"""Process a file with retry logic - retries everything up to max_retries times"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await process_file_common(file_path)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries:
|
||||
await exponential_backoff_delay(attempt)
|
||||
continue
|
||||
else:
|
||||
raise last_error
|
||||
|
||||
async def process_file_common(file_path: str, file_hash: str = None):
|
||||
"""
|
||||
Common processing logic for both upload and upload_path.
|
||||
|
|
@ -173,7 +379,7 @@ async def process_file_common(file_path: str, file_hash: str = None):
|
|||
sha256.update(chunk)
|
||||
file_hash = sha256.hexdigest()
|
||||
|
||||
exists = await es.exists(index=INDEX_NAME, id=file_hash)
|
||||
exists = await opensearch.exists(index=INDEX_NAME, id=file_hash)
|
||||
if exists:
|
||||
return {"status": "unchanged", "id": file_hash}
|
||||
|
||||
|
|
@ -199,7 +405,7 @@ async def process_file_common(file_path: str, file_hash: str = None):
|
|||
"chunk_embedding": vect
|
||||
}
|
||||
chunk_id = f"{file_hash}_{i}"
|
||||
await es.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||||
await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||||
return {"status": "indexed", "id": file_hash}
|
||||
|
||||
async def process_file_on_disk(path: str):
|
||||
|
|
@ -210,6 +416,94 @@ async def process_file_on_disk(path: str):
|
|||
result["path"] = path
|
||||
return result
|
||||
|
||||
async def process_single_file_task(upload_task: UploadTask, file_path: str) -> None:
|
||||
"""Process a single file and update task tracking"""
|
||||
file_task = upload_task.file_tasks[file_path]
|
||||
file_task.status = TaskStatus.RUNNING
|
||||
file_task.updated_at = time.time()
|
||||
|
||||
try:
|
||||
# Check if file already exists in index
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run CPU-intensive docling processing in separate process
|
||||
slim_doc = await loop.run_in_executor(process_pool, process_document_sync, file_path)
|
||||
|
||||
# Check if already indexed
|
||||
exists = await opensearch.exists(index=INDEX_NAME, id=slim_doc["id"])
|
||||
if exists:
|
||||
result = {"status": "unchanged", "id": slim_doc["id"]}
|
||||
else:
|
||||
# Generate embeddings and index (I/O bound, keep in main process)
|
||||
texts = [c["text"] for c in slim_doc["chunks"]]
|
||||
resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
||||
embeddings = [d.embedding for d in resp.data]
|
||||
|
||||
# Index each chunk
|
||||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||
chunk_doc = {
|
||||
"document_id": slim_doc["id"],
|
||||
"filename": slim_doc["filename"],
|
||||
"mimetype": slim_doc["mimetype"],
|
||||
"page": chunk["page"],
|
||||
"text": chunk["text"],
|
||||
"chunk_embedding": vect
|
||||
}
|
||||
chunk_id = f"{slim_doc['id']}_{i}"
|
||||
await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||||
|
||||
result = {"status": "indexed", "id": slim_doc["id"]}
|
||||
|
||||
result["path"] = file_path
|
||||
file_task.status = TaskStatus.COMPLETED
|
||||
file_task.result = result
|
||||
upload_task.successful_files += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to process file {file_path}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
file_task.status = TaskStatus.FAILED
|
||||
file_task.error = str(e)
|
||||
upload_task.failed_files += 1
|
||||
finally:
|
||||
file_task.updated_at = time.time()
|
||||
upload_task.processed_files += 1
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
if upload_task.processed_files >= upload_task.total_files:
|
||||
upload_task.status = TaskStatus.COMPLETED
|
||||
|
||||
async def background_upload_processor(task_id: str) -> None:
|
||||
"""Background task to process all files in an upload job with concurrency control"""
|
||||
try:
|
||||
upload_task = task_store[task_id]
|
||||
upload_task.status = TaskStatus.RUNNING
|
||||
upload_task.updated_at = time.time()
|
||||
|
||||
# Process files with limited concurrency to avoid overwhelming the system
|
||||
semaphore = asyncio.Semaphore(MAX_WORKERS * 2) # Allow 2x process pool size for async I/O
|
||||
|
||||
async def process_with_semaphore(file_path: str):
|
||||
async with semaphore:
|
||||
await process_single_file_task(upload_task, file_path)
|
||||
|
||||
tasks = [
|
||||
process_with_semaphore(file_path)
|
||||
for file_path in upload_task.file_tasks.keys()
|
||||
]
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
if task_id in task_store:
|
||||
task_store[task_id].status = TaskStatus.FAILED
|
||||
task_store[task_id].updated_at = time.time()
|
||||
|
||||
async def upload(request: Request):
|
||||
form = await request.form()
|
||||
upload_file = form["file"]
|
||||
|
|
@ -226,7 +520,7 @@ async def upload(request: Request):
|
|||
tmp.flush()
|
||||
|
||||
file_hash = sha256.hexdigest()
|
||||
exists = await es.exists(index=INDEX_NAME, id=file_hash)
|
||||
exists = await opensearch.exists(index=INDEX_NAME, id=file_hash)
|
||||
if exists:
|
||||
return JSONResponse({"status": "unchanged", "id": file_hash})
|
||||
|
||||
|
|
@ -243,12 +537,31 @@ async def upload_path(request: Request):
|
|||
if not base_dir or not os.path.isdir(base_dir):
|
||||
return JSONResponse({"error": "Invalid path"}, status_code=400)
|
||||
|
||||
tasks = [process_file_on_disk(os.path.join(root, fn))
|
||||
for root, _, files in os.walk(base_dir)
|
||||
for fn in files]
|
||||
file_paths = [os.path.join(root, fn)
|
||||
for root, _, files in os.walk(base_dir)
|
||||
for fn in files]
|
||||
|
||||
if not file_paths:
|
||||
return JSONResponse({"error": "No files found in directory"}, status_code=400)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
return JSONResponse({"results": results})
|
||||
task_id = str(uuid.uuid4())
|
||||
upload_task = UploadTask(
|
||||
task_id=task_id,
|
||||
total_files=len(file_paths),
|
||||
file_tasks={path: FileTask(file_path=path) for path in file_paths}
|
||||
)
|
||||
|
||||
task_store[task_id] = upload_task
|
||||
|
||||
background_task = asyncio.create_task(background_upload_processor(task_id))
|
||||
background_tasks.add(background_task)
|
||||
background_task.add_done_callback(background_tasks.discard)
|
||||
|
||||
return JSONResponse({
|
||||
"task_id": task_id,
|
||||
"total_files": len(file_paths),
|
||||
"status": "accepted"
|
||||
}, status_code=201)
|
||||
|
||||
async def upload_context(request: Request):
|
||||
"""Upload a file and add its content as context to the current conversation"""
|
||||
|
|
@ -306,6 +619,38 @@ async def upload_context(request: Request):
|
|||
|
||||
return JSONResponse(response_data)
|
||||
|
||||
async def task_status(request: Request):
|
||||
"""Get the status of an upload task"""
|
||||
task_id = request.path_params.get("task_id")
|
||||
|
||||
if not task_id or task_id not in task_store:
|
||||
return JSONResponse({"error": "Task not found"}, status_code=404)
|
||||
|
||||
upload_task = task_store[task_id]
|
||||
|
||||
file_statuses = {}
|
||||
for file_path, file_task in upload_task.file_tasks.items():
|
||||
file_statuses[file_path] = {
|
||||
"status": file_task.status.value,
|
||||
"result": file_task.result,
|
||||
"error": file_task.error,
|
||||
"retry_count": file_task.retry_count,
|
||||
"created_at": file_task.created_at,
|
||||
"updated_at": file_task.updated_at
|
||||
}
|
||||
|
||||
return JSONResponse({
|
||||
"task_id": upload_task.task_id,
|
||||
"status": upload_task.status.value,
|
||||
"total_files": upload_task.total_files,
|
||||
"processed_files": upload_task.processed_files,
|
||||
"successful_files": upload_task.successful_files,
|
||||
"failed_files": upload_task.failed_files,
|
||||
"created_at": upload_task.created_at,
|
||||
"updated_at": upload_task.updated_at,
|
||||
"files": file_statuses
|
||||
})
|
||||
|
||||
async def search(request: Request):
|
||||
|
||||
payload = await request.json()
|
||||
|
|
@ -345,7 +690,7 @@ async def search_tool(query: str)-> dict[str, Any]:
|
|||
"_source": ["filename", "mimetype", "page", "text"],
|
||||
"size": 10
|
||||
}
|
||||
results = await es.search(index=INDEX_NAME, body=search_body)
|
||||
results = await opensearch.search(index=INDEX_NAME, body=search_body)
|
||||
# Transform results to match expected format
|
||||
chunks = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
|
|
@ -425,6 +770,7 @@ app = Starlette(debug=True, routes=[
|
|||
Route("/upload", upload, methods=["POST"]),
|
||||
Route("/upload_context", upload_context, methods=["POST"]),
|
||||
Route("/upload_path", upload_path, methods=["POST"]),
|
||||
Route("/tasks/{task_id}", task_status, methods=["GET"]),
|
||||
Route("/search", search, methods=["POST"]),
|
||||
Route("/chat", chat_endpoint, methods=["POST"]),
|
||||
Route("/langflow", langflow_endpoint, methods=["POST"]),
|
||||
|
|
@ -432,10 +778,17 @@ app = Starlette(debug=True, routes=[
|
|||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
import atexit
|
||||
|
||||
async def main():
|
||||
await init_index()
|
||||
|
||||
# Cleanup process pool on exit
|
||||
def cleanup():
|
||||
process_pool.shutdown(wait=True)
|
||||
|
||||
atexit.register(cleanup)
|
||||
|
||||
asyncio.run(main())
|
||||
uvicorn.run(
|
||||
"app:app",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue