multi worker uploads + docling gpu improvements
This commit is contained in:
parent
b159164593
commit
ea24c81eab
5 changed files with 602 additions and 35 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -11,3 +11,5 @@ wheels/
|
||||||
.env
|
.env
|
||||||
|
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
1001*.pdf
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,8 @@ services:
|
||||||
- OPENSEARCH_USERNAME=admin
|
- OPENSEARCH_USERNAME=admin
|
||||||
- OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD}
|
- OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD}
|
||||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||||
|
- NVIDIA_VISIBLE_DEVICES=all
|
||||||
ports:
|
ports:
|
||||||
- "3000:3000"
|
- "3000:3000"
|
||||||
volumes:
|
volumes:
|
||||||
|
|
@ -48,6 +50,7 @@ services:
|
||||||
- ./pyproject.toml:/app/pyproject.toml
|
- ./pyproject.toml:/app/pyproject.toml
|
||||||
- ./uv.lock:/app/uv.lock
|
- ./uv.lock:/app/uv.lock
|
||||||
- ./documents:/app/documents
|
- ./documents:/app/documents
|
||||||
|
gpus: all
|
||||||
|
|
||||||
langflow:
|
langflow:
|
||||||
volumes:
|
volumes:
|
||||||
|
|
|
||||||
|
|
@ -66,9 +66,24 @@ export default function AdminPage() {
|
||||||
|
|
||||||
const result = await response.json()
|
const result = await response.json()
|
||||||
|
|
||||||
if (response.ok) {
|
if (response.status === 201) {
|
||||||
const successful = result.results.filter((r: {status: string}) => r.status === "indexed").length
|
// New flow: Got task ID, start polling
|
||||||
const total = result.results.length
|
const taskId = result.task_id || result.id
|
||||||
|
const totalFiles = result.total_files || 0
|
||||||
|
|
||||||
|
if (!taskId) {
|
||||||
|
throw new Error("No task ID received from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
setUploadStatus(`🔄 Processing started for ${totalFiles} files... (Task ID: ${taskId})`)
|
||||||
|
|
||||||
|
// Start polling the task status
|
||||||
|
await pollPathTaskStatus(taskId, totalFiles)
|
||||||
|
|
||||||
|
} else if (response.ok) {
|
||||||
|
// Original flow: Direct response with results
|
||||||
|
const successful = result.results?.filter((r: {status: string}) => r.status === "indexed").length || 0
|
||||||
|
const total = result.results?.length || 0
|
||||||
setUploadStatus(`Path processed successfully! ${successful}/${total} files indexed.`)
|
setUploadStatus(`Path processed successfully! ${successful}/${total} files indexed.`)
|
||||||
setFolderPath("")
|
setFolderPath("")
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -81,6 +96,63 @@ export default function AdminPage() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const pollPathTaskStatus = async (taskId: string, totalFiles: number) => {
|
||||||
|
const maxAttempts = 120 // Poll for up to 10 minutes (120 * 5s intervals) for large batches
|
||||||
|
let attempts = 0
|
||||||
|
|
||||||
|
const poll = async (): Promise<void> => {
|
||||||
|
try {
|
||||||
|
attempts++
|
||||||
|
|
||||||
|
const response = await fetch(`/api/tasks/${taskId}`)
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to check task status: ${response.status}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const task = await response.json()
|
||||||
|
|
||||||
|
if (task.status === 'completed') {
|
||||||
|
setUploadStatus(`✅ Path processing completed! ${task.successful_files}/${task.total_files} files processed successfully.`)
|
||||||
|
setFolderPath("")
|
||||||
|
setPathUploadLoading(false)
|
||||||
|
|
||||||
|
} else if (task.status === 'failed' || task.status === 'error') {
|
||||||
|
setUploadStatus(`❌ Path processing failed: ${task.error || 'Unknown error occurred'}`)
|
||||||
|
setPathUploadLoading(false)
|
||||||
|
|
||||||
|
} else if (task.status === 'pending' || task.status === 'running') {
|
||||||
|
// Still in progress, update status and continue polling
|
||||||
|
const processed = task.processed_files || 0
|
||||||
|
const successful = task.successful_files || 0
|
||||||
|
const failed = task.failed_files || 0
|
||||||
|
|
||||||
|
setUploadStatus(`⏳ Processing files... ${processed}/${totalFiles} processed (${successful} successful, ${failed} failed)`)
|
||||||
|
|
||||||
|
// Continue polling if we haven't exceeded max attempts
|
||||||
|
if (attempts < maxAttempts) {
|
||||||
|
setTimeout(poll, 5000) // Poll every 5 seconds
|
||||||
|
} else {
|
||||||
|
setUploadStatus(`⚠️ Processing timeout after ${attempts} attempts. The task may still be running in the background.`)
|
||||||
|
setPathUploadLoading(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
setUploadStatus(`❓ Unknown task status: ${task.status}`)
|
||||||
|
setPathUploadLoading(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Task polling error:', error)
|
||||||
|
setUploadStatus(`❌ Failed to check processing status: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||||
|
setPathUploadLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start polling immediately
|
||||||
|
poll()
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-8">
|
<div className="space-y-8">
|
||||||
<div>
|
<div>
|
||||||
|
|
|
||||||
|
|
@ -64,10 +64,20 @@ export default function ChatPage() {
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleFileUpload = async (file: File) => {
|
const handleFileUpload = async (file: File) => {
|
||||||
|
console.log("handleFileUpload called with file:", file.name)
|
||||||
|
|
||||||
if (isUploading) return
|
if (isUploading) return
|
||||||
|
|
||||||
setIsUploading(true)
|
setIsUploading(true)
|
||||||
|
|
||||||
|
// Add initial upload message
|
||||||
|
const uploadStartMessage: Message = {
|
||||||
|
role: "assistant",
|
||||||
|
content: `🔄 Starting upload of **${file.name}**...`,
|
||||||
|
timestamp: new Date()
|
||||||
|
}
|
||||||
|
setMessages(prev => [...prev, uploadStartMessage])
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const formData = new FormData()
|
const formData = new FormData()
|
||||||
formData.append('file', file)
|
formData.append('file', file)
|
||||||
|
|
@ -84,27 +94,58 @@ export default function ChatPage() {
|
||||||
body: formData,
|
body: formData,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
console.log("Upload response status:", response.status)
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(`Upload failed: ${response.status}`)
|
const errorText = await response.text()
|
||||||
|
console.error("Upload failed with status:", response.status, "Response:", errorText)
|
||||||
|
throw new Error(`Upload failed: ${response.status} - ${errorText}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await response.json()
|
const result = await response.json()
|
||||||
|
console.log("Upload result:", result)
|
||||||
|
|
||||||
// Add upload confirmation as a system message in the UI
|
if (response.status === 201) {
|
||||||
const uploadMessage: Message = {
|
// New flow: Got task ID, start polling
|
||||||
role: "assistant",
|
const taskId = result.task_id || result.id
|
||||||
content: `📄 Document uploaded: **${result.filename}** (${result.pages} pages, ${result.content_length.toLocaleString()} characters)\n\n${result.confirmation}`,
|
|
||||||
timestamp: new Date()
|
|
||||||
}
|
|
||||||
|
|
||||||
setMessages(prev => [...prev, uploadMessage])
|
if (!taskId) {
|
||||||
|
console.error("No task ID in 201 response:", result)
|
||||||
|
throw new Error("No task ID received from server")
|
||||||
|
}
|
||||||
|
|
||||||
// Update the response ID for this endpoint
|
// Update message to show polling started
|
||||||
if (result.response_id) {
|
const pollingMessage: Message = {
|
||||||
setPreviousResponseIds(prev => ({
|
role: "assistant",
|
||||||
...prev,
|
content: `⏳ Upload initiated for **${file.name}**. Processing... (Task ID: ${taskId})`,
|
||||||
[endpoint]: result.response_id
|
timestamp: new Date()
|
||||||
}))
|
}
|
||||||
|
setMessages(prev => [...prev.slice(0, -1), pollingMessage])
|
||||||
|
|
||||||
|
// Start polling the task status
|
||||||
|
await pollTaskStatus(taskId, file.name)
|
||||||
|
|
||||||
|
} else if (response.ok) {
|
||||||
|
// Original flow: Direct response
|
||||||
|
|
||||||
|
const uploadMessage: Message = {
|
||||||
|
role: "assistant",
|
||||||
|
content: `📄 Document uploaded: **${result.filename}** (${result.pages} pages, ${result.content_length.toLocaleString()} characters)\n\n${result.confirmation}`,
|
||||||
|
timestamp: new Date()
|
||||||
|
}
|
||||||
|
|
||||||
|
setMessages(prev => [...prev.slice(0, -1), uploadMessage])
|
||||||
|
|
||||||
|
// Update the response ID for this endpoint
|
||||||
|
if (result.response_id) {
|
||||||
|
setPreviousResponseIds(prev => ({
|
||||||
|
...prev,
|
||||||
|
[endpoint]: result.response_id
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
throw new Error(`Upload failed: ${response.status}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
@ -114,12 +155,108 @@ export default function ChatPage() {
|
||||||
content: `❌ Upload failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
content: `❌ Upload failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
timestamp: new Date()
|
timestamp: new Date()
|
||||||
}
|
}
|
||||||
setMessages(prev => [...prev, errorMessage])
|
setMessages(prev => [...prev.slice(0, -1), errorMessage])
|
||||||
} finally {
|
} finally {
|
||||||
setIsUploading(false)
|
setIsUploading(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const pollTaskStatus = async (taskId: string, filename: string) => {
|
||||||
|
const maxAttempts = 60 // Poll for up to 5 minutes (60 * 5s intervals)
|
||||||
|
let attempts = 0
|
||||||
|
|
||||||
|
const poll = async (): Promise<void> => {
|
||||||
|
try {
|
||||||
|
attempts++
|
||||||
|
|
||||||
|
const response = await fetch(`/api/tasks/${taskId}`)
|
||||||
|
console.log("Task polling response status:", response.status)
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorText = await response.text()
|
||||||
|
console.error("Task polling failed:", response.status, errorText)
|
||||||
|
throw new Error(`Failed to check task status: ${response.status} - ${errorText}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const task = await response.json()
|
||||||
|
console.log("Task polling result:", task)
|
||||||
|
|
||||||
|
// Safety check to ensure task object exists
|
||||||
|
if (!task) {
|
||||||
|
throw new Error("No task data received from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the message based on task status
|
||||||
|
if (task.status === 'completed') {
|
||||||
|
const successMessage: Message = {
|
||||||
|
role: "assistant",
|
||||||
|
content: `✅ **${filename}** processed successfully!\n\n${task.result?.confirmation || 'Document has been added to the knowledge base.'}`,
|
||||||
|
timestamp: new Date()
|
||||||
|
}
|
||||||
|
setMessages(prev => [...prev.slice(0, -1), successMessage])
|
||||||
|
|
||||||
|
// Update response ID if available
|
||||||
|
if (task.result?.response_id) {
|
||||||
|
setPreviousResponseIds(prev => ({
|
||||||
|
...prev,
|
||||||
|
[endpoint]: task.result.response_id
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (task.status === 'failed' || task.status === 'error') {
|
||||||
|
const errorMessage: Message = {
|
||||||
|
role: "assistant",
|
||||||
|
content: `❌ Processing failed for **${filename}**: ${task.error || 'Unknown error occurred'}`,
|
||||||
|
timestamp: new Date()
|
||||||
|
}
|
||||||
|
setMessages(prev => [...prev.slice(0, -1), errorMessage])
|
||||||
|
|
||||||
|
} else if (task.status === 'pending' || task.status === 'running' || task.status === 'processing') {
|
||||||
|
// Still in progress, update message and continue polling
|
||||||
|
const progressMessage: Message = {
|
||||||
|
role: "assistant",
|
||||||
|
content: `⏳ Processing **${filename}**... (${task.status}) - Attempt ${attempts}/${maxAttempts}`,
|
||||||
|
timestamp: new Date()
|
||||||
|
}
|
||||||
|
setMessages(prev => [...prev.slice(0, -1), progressMessage])
|
||||||
|
|
||||||
|
// Continue polling if we haven't exceeded max attempts
|
||||||
|
if (attempts < maxAttempts) {
|
||||||
|
setTimeout(poll, 5000) // Poll every 5 seconds
|
||||||
|
} else {
|
||||||
|
const timeoutMessage: Message = {
|
||||||
|
role: "assistant",
|
||||||
|
content: `⚠️ Processing timeout for **${filename}**. The task may still be running in the background.`,
|
||||||
|
timestamp: new Date()
|
||||||
|
}
|
||||||
|
setMessages(prev => [...prev.slice(0, -1), timeoutMessage])
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Unknown status
|
||||||
|
const unknownMessage: Message = {
|
||||||
|
role: "assistant",
|
||||||
|
content: `❓ Unknown status for **${filename}**: ${task.status}`,
|
||||||
|
timestamp: new Date()
|
||||||
|
}
|
||||||
|
setMessages(prev => [...prev.slice(0, -1), unknownMessage])
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Task polling error:', error)
|
||||||
|
const errorMessage: Message = {
|
||||||
|
role: "assistant",
|
||||||
|
content: `❌ Failed to check processing status for **${filename}**: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
|
timestamp: new Date()
|
||||||
|
}
|
||||||
|
setMessages(prev => [...prev.slice(0, -1), errorMessage])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start polling immediately
|
||||||
|
poll()
|
||||||
|
}
|
||||||
|
|
||||||
const handleDragEnter = (e: React.DragEvent) => {
|
const handleDragEnter = (e: React.DragEvent) => {
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
e.stopPropagation()
|
e.stopPropagation()
|
||||||
|
|
|
||||||
383
src/app.py
383
src/app.py
|
|
@ -3,11 +3,16 @@
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
from agent import async_chat, async_langflow
|
from agent import async_chat, async_langflow
|
||||||
|
|
||||||
os.environ['USE_CPU_ONLY'] = 'true'
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -29,6 +34,10 @@ from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
load_dotenv("../")
|
load_dotenv("../")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
print("CUDA available:", torch.cuda.is_available())
|
||||||
|
print("CUDA version PyTorch was built with:", torch.version.cuda)
|
||||||
|
|
||||||
# Initialize Docling converter
|
# Initialize Docling converter
|
||||||
converter = DocumentConverter() # basic converter; tweak via PipelineOptions if you need OCR, etc.
|
converter = DocumentConverter() # basic converter; tweak via PipelineOptions if you need OCR, etc.
|
||||||
|
|
||||||
|
|
@ -43,7 +52,7 @@ langflow_key = os.getenv("LANGFLOW_SECRET_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
es = AsyncOpenSearch(
|
opensearch = AsyncOpenSearch(
|
||||||
hosts=[{"host": opensearch_host, "port": opensearch_port}],
|
hosts=[{"host": opensearch_host, "port": opensearch_port}],
|
||||||
connection_class=AIOHttpConnection,
|
connection_class=AIOHttpConnection,
|
||||||
scheme="https",
|
scheme="https",
|
||||||
|
|
@ -93,6 +102,183 @@ langflow_client = AsyncOpenAI(
|
||||||
)
|
)
|
||||||
patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) # Get the patched client back
|
patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) # Get the patched client back
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileTask:
|
||||||
|
file_path: str
|
||||||
|
status: TaskStatus = TaskStatus.PENDING
|
||||||
|
result: dict = None
|
||||||
|
error: str = None
|
||||||
|
retry_count: int = 0
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
updated_at: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UploadTask:
|
||||||
|
task_id: str
|
||||||
|
total_files: int
|
||||||
|
processed_files: int = 0
|
||||||
|
successful_files: int = 0
|
||||||
|
failed_files: int = 0
|
||||||
|
file_tasks: dict = field(default_factory=dict)
|
||||||
|
status: TaskStatus = TaskStatus.PENDING
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
updated_at: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
task_store = {}
|
||||||
|
background_tasks = set()
|
||||||
|
|
||||||
|
# GPU device detection
|
||||||
|
def detect_gpu_devices():
|
||||||
|
"""Detect if GPU devices are actually available"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||||
|
return True, torch.cuda.device_count()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||||||
|
if result.returncode == 0:
|
||||||
|
return True, "detected"
|
||||||
|
except (subprocess.SubprocessError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
# GPU and concurrency configuration
|
||||||
|
HAS_GPU_DEVICES, GPU_COUNT = detect_gpu_devices()
|
||||||
|
|
||||||
|
if HAS_GPU_DEVICES:
|
||||||
|
# GPU mode with actual GPU devices: Lower concurrency due to memory constraints
|
||||||
|
DEFAULT_WORKERS = min(4, multiprocessing.cpu_count() // 2)
|
||||||
|
print(f"GPU mode enabled with {GPU_COUNT} GPU(s) - using limited concurrency ({DEFAULT_WORKERS} workers)")
|
||||||
|
elif HAS_GPU_DEVICES:
|
||||||
|
# GPU mode requested but no devices found: Use full CPU concurrency
|
||||||
|
DEFAULT_WORKERS = multiprocessing.cpu_count()
|
||||||
|
print(f"GPU mode requested but no GPU devices found - falling back to full CPU concurrency ({DEFAULT_WORKERS} workers)")
|
||||||
|
else:
|
||||||
|
# CPU mode: Higher concurrency since no GPU memory constraints
|
||||||
|
DEFAULT_WORKERS = multiprocessing.cpu_count()
|
||||||
|
print(f"CPU-only mode enabled - using full concurrency ({DEFAULT_WORKERS} workers)")
|
||||||
|
|
||||||
|
MAX_WORKERS = int(os.getenv("MAX_WORKERS", DEFAULT_WORKERS))
|
||||||
|
process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
|
||||||
|
|
||||||
|
print(f"Process pool initialized with {MAX_WORKERS} workers")
|
||||||
|
|
||||||
|
# Global converter cache for worker processes
|
||||||
|
_worker_converter = None
|
||||||
|
|
||||||
|
def get_worker_converter():
|
||||||
|
"""Get or create a DocumentConverter instance for this worker process"""
|
||||||
|
global _worker_converter
|
||||||
|
if _worker_converter is None:
|
||||||
|
import os
|
||||||
|
from docling.document_converter import DocumentConverter
|
||||||
|
|
||||||
|
# Configure GPU settings for this worker
|
||||||
|
has_gpu_devices, _ = detect_gpu_devices()
|
||||||
|
if not has_gpu_devices:
|
||||||
|
# Force CPU-only mode in subprocess
|
||||||
|
os.environ['USE_CPU_ONLY'] = 'true'
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||||
|
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
||||||
|
os.environ['TRANSFORMERS_OFFLINE'] = '0'
|
||||||
|
os.environ['TORCH_USE_CUDA_DSA'] = '0'
|
||||||
|
|
||||||
|
# Try to disable CUDA in torch if available
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
torch.cuda.is_available = lambda: False
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# GPU mode - let libraries use GPU if available
|
||||||
|
os.environ.pop('USE_CPU_ONLY', None)
|
||||||
|
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # Still disable progress bars
|
||||||
|
|
||||||
|
print(f"🔧 Initializing DocumentConverter in worker process (PID: {os.getpid()})")
|
||||||
|
_worker_converter = DocumentConverter()
|
||||||
|
print(f"✅ DocumentConverter ready in worker process (PID: {os.getpid()})")
|
||||||
|
|
||||||
|
return _worker_converter
|
||||||
|
|
||||||
|
def detect_gpu_devices():
|
||||||
|
"""Detect if GPU devices are actually available"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||||
|
return True, torch.cuda.device_count()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||||||
|
if result.returncode == 0:
|
||||||
|
return True, "detected"
|
||||||
|
except (subprocess.SubprocessError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
def process_document_sync(file_path: str):
|
||||||
|
"""Synchronous document processing function for multiprocessing"""
|
||||||
|
import hashlib
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
# Get the cached converter for this worker
|
||||||
|
converter = get_worker_converter()
|
||||||
|
|
||||||
|
# Compute file hash
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = f.read(1 << 20)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
sha256.update(chunk)
|
||||||
|
file_hash = sha256.hexdigest()
|
||||||
|
|
||||||
|
# Convert with docling
|
||||||
|
result = converter.convert(file_path)
|
||||||
|
full_doc = result.document.export_to_dict()
|
||||||
|
|
||||||
|
# Extract relevant content (same logic as extract_relevant)
|
||||||
|
origin = full_doc.get("origin", {})
|
||||||
|
texts = full_doc.get("texts", [])
|
||||||
|
|
||||||
|
page_texts = defaultdict(list)
|
||||||
|
for txt in texts:
|
||||||
|
prov = txt.get("prov", [])
|
||||||
|
page_no = prov[0].get("page_no") if prov else None
|
||||||
|
if page_no is not None:
|
||||||
|
page_texts[page_no].append(txt.get("text", "").strip())
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for page in sorted(page_texts):
|
||||||
|
joined = "\n".join(page_texts[page])
|
||||||
|
chunks.append({
|
||||||
|
"page": page,
|
||||||
|
"text": joined
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": file_hash,
|
||||||
|
"filename": origin.get("filename"),
|
||||||
|
"mimetype": origin.get("mimetype"),
|
||||||
|
"chunks": chunks,
|
||||||
|
"file_path": file_path
|
||||||
|
}
|
||||||
|
|
||||||
async def wait_for_opensearch():
|
async def wait_for_opensearch():
|
||||||
"""Wait for OpenSearch to be ready with retries"""
|
"""Wait for OpenSearch to be ready with retries"""
|
||||||
max_retries = 30
|
max_retries = 30
|
||||||
|
|
@ -100,7 +286,7 @@ async def wait_for_opensearch():
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
await es.info()
|
await opensearch.info()
|
||||||
print("OpenSearch is ready!")
|
print("OpenSearch is ready!")
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -113,8 +299,8 @@ async def wait_for_opensearch():
|
||||||
async def init_index():
|
async def init_index():
|
||||||
await wait_for_opensearch()
|
await wait_for_opensearch()
|
||||||
|
|
||||||
if not await es.indices.exists(index=INDEX_NAME):
|
if not await opensearch.indices.exists(index=INDEX_NAME):
|
||||||
await es.indices.create(index=INDEX_NAME, body=index_body)
|
await opensearch.indices.create(index=INDEX_NAME, body=index_body)
|
||||||
print(f"Created index '{INDEX_NAME}'")
|
print(f"Created index '{INDEX_NAME}'")
|
||||||
else:
|
else:
|
||||||
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
||||||
|
|
@ -155,6 +341,26 @@ def extract_relevant(doc_dict: dict) -> dict:
|
||||||
"chunks": chunks
|
"chunks": chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def exponential_backoff_delay(retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None:
|
||||||
|
"""Apply exponential backoff with jitter"""
|
||||||
|
delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
async def process_file_with_retry(file_path: str, max_retries: int = 3) -> dict:
|
||||||
|
"""Process a file with retry logic - retries everything up to max_retries times"""
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await process_file_common(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
if attempt < max_retries:
|
||||||
|
await exponential_backoff_delay(attempt)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise last_error
|
||||||
|
|
||||||
async def process_file_common(file_path: str, file_hash: str = None):
|
async def process_file_common(file_path: str, file_hash: str = None):
|
||||||
"""
|
"""
|
||||||
Common processing logic for both upload and upload_path.
|
Common processing logic for both upload and upload_path.
|
||||||
|
|
@ -173,7 +379,7 @@ async def process_file_common(file_path: str, file_hash: str = None):
|
||||||
sha256.update(chunk)
|
sha256.update(chunk)
|
||||||
file_hash = sha256.hexdigest()
|
file_hash = sha256.hexdigest()
|
||||||
|
|
||||||
exists = await es.exists(index=INDEX_NAME, id=file_hash)
|
exists = await opensearch.exists(index=INDEX_NAME, id=file_hash)
|
||||||
if exists:
|
if exists:
|
||||||
return {"status": "unchanged", "id": file_hash}
|
return {"status": "unchanged", "id": file_hash}
|
||||||
|
|
||||||
|
|
@ -199,7 +405,7 @@ async def process_file_common(file_path: str, file_hash: str = None):
|
||||||
"chunk_embedding": vect
|
"chunk_embedding": vect
|
||||||
}
|
}
|
||||||
chunk_id = f"{file_hash}_{i}"
|
chunk_id = f"{file_hash}_{i}"
|
||||||
await es.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||||||
return {"status": "indexed", "id": file_hash}
|
return {"status": "indexed", "id": file_hash}
|
||||||
|
|
||||||
async def process_file_on_disk(path: str):
|
async def process_file_on_disk(path: str):
|
||||||
|
|
@ -210,6 +416,94 @@ async def process_file_on_disk(path: str):
|
||||||
result["path"] = path
|
result["path"] = path
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def process_single_file_task(upload_task: UploadTask, file_path: str) -> None:
|
||||||
|
"""Process a single file and update task tracking"""
|
||||||
|
file_task = upload_task.file_tasks[file_path]
|
||||||
|
file_task.status = TaskStatus.RUNNING
|
||||||
|
file_task.updated_at = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if file already exists in index
|
||||||
|
import asyncio
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# Run CPU-intensive docling processing in separate process
|
||||||
|
slim_doc = await loop.run_in_executor(process_pool, process_document_sync, file_path)
|
||||||
|
|
||||||
|
# Check if already indexed
|
||||||
|
exists = await opensearch.exists(index=INDEX_NAME, id=slim_doc["id"])
|
||||||
|
if exists:
|
||||||
|
result = {"status": "unchanged", "id": slim_doc["id"]}
|
||||||
|
else:
|
||||||
|
# Generate embeddings and index (I/O bound, keep in main process)
|
||||||
|
texts = [c["text"] for c in slim_doc["chunks"]]
|
||||||
|
resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
||||||
|
embeddings = [d.embedding for d in resp.data]
|
||||||
|
|
||||||
|
# Index each chunk
|
||||||
|
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||||||
|
chunk_doc = {
|
||||||
|
"document_id": slim_doc["id"],
|
||||||
|
"filename": slim_doc["filename"],
|
||||||
|
"mimetype": slim_doc["mimetype"],
|
||||||
|
"page": chunk["page"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"chunk_embedding": vect
|
||||||
|
}
|
||||||
|
chunk_id = f"{slim_doc['id']}_{i}"
|
||||||
|
await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||||||
|
|
||||||
|
result = {"status": "indexed", "id": slim_doc["id"]}
|
||||||
|
|
||||||
|
result["path"] = file_path
|
||||||
|
file_task.status = TaskStatus.COMPLETED
|
||||||
|
file_task.result = result
|
||||||
|
upload_task.successful_files += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to process file {file_path}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
file_task.status = TaskStatus.FAILED
|
||||||
|
file_task.error = str(e)
|
||||||
|
upload_task.failed_files += 1
|
||||||
|
finally:
|
||||||
|
file_task.updated_at = time.time()
|
||||||
|
upload_task.processed_files += 1
|
||||||
|
upload_task.updated_at = time.time()
|
||||||
|
|
||||||
|
if upload_task.processed_files >= upload_task.total_files:
|
||||||
|
upload_task.status = TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
async def background_upload_processor(task_id: str) -> None:
|
||||||
|
"""Background task to process all files in an upload job with concurrency control"""
|
||||||
|
try:
|
||||||
|
upload_task = task_store[task_id]
|
||||||
|
upload_task.status = TaskStatus.RUNNING
|
||||||
|
upload_task.updated_at = time.time()
|
||||||
|
|
||||||
|
# Process files with limited concurrency to avoid overwhelming the system
|
||||||
|
semaphore = asyncio.Semaphore(MAX_WORKERS * 2) # Allow 2x process pool size for async I/O
|
||||||
|
|
||||||
|
async def process_with_semaphore(file_path: str):
|
||||||
|
async with semaphore:
|
||||||
|
await process_single_file_task(upload_task, file_path)
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
process_with_semaphore(file_path)
|
||||||
|
for file_path in upload_task.file_tasks.keys()
|
||||||
|
]
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
if task_id in task_store:
|
||||||
|
task_store[task_id].status = TaskStatus.FAILED
|
||||||
|
task_store[task_id].updated_at = time.time()
|
||||||
|
|
||||||
async def upload(request: Request):
|
async def upload(request: Request):
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
upload_file = form["file"]
|
upload_file = form["file"]
|
||||||
|
|
@ -226,7 +520,7 @@ async def upload(request: Request):
|
||||||
tmp.flush()
|
tmp.flush()
|
||||||
|
|
||||||
file_hash = sha256.hexdigest()
|
file_hash = sha256.hexdigest()
|
||||||
exists = await es.exists(index=INDEX_NAME, id=file_hash)
|
exists = await opensearch.exists(index=INDEX_NAME, id=file_hash)
|
||||||
if exists:
|
if exists:
|
||||||
return JSONResponse({"status": "unchanged", "id": file_hash})
|
return JSONResponse({"status": "unchanged", "id": file_hash})
|
||||||
|
|
||||||
|
|
@ -243,12 +537,31 @@ async def upload_path(request: Request):
|
||||||
if not base_dir or not os.path.isdir(base_dir):
|
if not base_dir or not os.path.isdir(base_dir):
|
||||||
return JSONResponse({"error": "Invalid path"}, status_code=400)
|
return JSONResponse({"error": "Invalid path"}, status_code=400)
|
||||||
|
|
||||||
tasks = [process_file_on_disk(os.path.join(root, fn))
|
file_paths = [os.path.join(root, fn)
|
||||||
for root, _, files in os.walk(base_dir)
|
for root, _, files in os.walk(base_dir)
|
||||||
for fn in files]
|
for fn in files]
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks)
|
if not file_paths:
|
||||||
return JSONResponse({"results": results})
|
return JSONResponse({"error": "No files found in directory"}, status_code=400)
|
||||||
|
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
upload_task = UploadTask(
|
||||||
|
task_id=task_id,
|
||||||
|
total_files=len(file_paths),
|
||||||
|
file_tasks={path: FileTask(file_path=path) for path in file_paths}
|
||||||
|
)
|
||||||
|
|
||||||
|
task_store[task_id] = upload_task
|
||||||
|
|
||||||
|
background_task = asyncio.create_task(background_upload_processor(task_id))
|
||||||
|
background_tasks.add(background_task)
|
||||||
|
background_task.add_done_callback(background_tasks.discard)
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"task_id": task_id,
|
||||||
|
"total_files": len(file_paths),
|
||||||
|
"status": "accepted"
|
||||||
|
}, status_code=201)
|
||||||
|
|
||||||
async def upload_context(request: Request):
|
async def upload_context(request: Request):
|
||||||
"""Upload a file and add its content as context to the current conversation"""
|
"""Upload a file and add its content as context to the current conversation"""
|
||||||
|
|
@ -306,6 +619,38 @@ async def upload_context(request: Request):
|
||||||
|
|
||||||
return JSONResponse(response_data)
|
return JSONResponse(response_data)
|
||||||
|
|
||||||
|
async def task_status(request: Request):
|
||||||
|
"""Get the status of an upload task"""
|
||||||
|
task_id = request.path_params.get("task_id")
|
||||||
|
|
||||||
|
if not task_id or task_id not in task_store:
|
||||||
|
return JSONResponse({"error": "Task not found"}, status_code=404)
|
||||||
|
|
||||||
|
upload_task = task_store[task_id]
|
||||||
|
|
||||||
|
file_statuses = {}
|
||||||
|
for file_path, file_task in upload_task.file_tasks.items():
|
||||||
|
file_statuses[file_path] = {
|
||||||
|
"status": file_task.status.value,
|
||||||
|
"result": file_task.result,
|
||||||
|
"error": file_task.error,
|
||||||
|
"retry_count": file_task.retry_count,
|
||||||
|
"created_at": file_task.created_at,
|
||||||
|
"updated_at": file_task.updated_at
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSONResponse({
|
||||||
|
"task_id": upload_task.task_id,
|
||||||
|
"status": upload_task.status.value,
|
||||||
|
"total_files": upload_task.total_files,
|
||||||
|
"processed_files": upload_task.processed_files,
|
||||||
|
"successful_files": upload_task.successful_files,
|
||||||
|
"failed_files": upload_task.failed_files,
|
||||||
|
"created_at": upload_task.created_at,
|
||||||
|
"updated_at": upload_task.updated_at,
|
||||||
|
"files": file_statuses
|
||||||
|
})
|
||||||
|
|
||||||
async def search(request: Request):
|
async def search(request: Request):
|
||||||
|
|
||||||
payload = await request.json()
|
payload = await request.json()
|
||||||
|
|
@ -345,7 +690,7 @@ async def search_tool(query: str)-> dict[str, Any]:
|
||||||
"_source": ["filename", "mimetype", "page", "text"],
|
"_source": ["filename", "mimetype", "page", "text"],
|
||||||
"size": 10
|
"size": 10
|
||||||
}
|
}
|
||||||
results = await es.search(index=INDEX_NAME, body=search_body)
|
results = await opensearch.search(index=INDEX_NAME, body=search_body)
|
||||||
# Transform results to match expected format
|
# Transform results to match expected format
|
||||||
chunks = []
|
chunks = []
|
||||||
for hit in results["hits"]["hits"]:
|
for hit in results["hits"]["hits"]:
|
||||||
|
|
@ -425,6 +770,7 @@ app = Starlette(debug=True, routes=[
|
||||||
Route("/upload", upload, methods=["POST"]),
|
Route("/upload", upload, methods=["POST"]),
|
||||||
Route("/upload_context", upload_context, methods=["POST"]),
|
Route("/upload_context", upload_context, methods=["POST"]),
|
||||||
Route("/upload_path", upload_path, methods=["POST"]),
|
Route("/upload_path", upload_path, methods=["POST"]),
|
||||||
|
Route("/tasks/{task_id}", task_status, methods=["GET"]),
|
||||||
Route("/search", search, methods=["POST"]),
|
Route("/search", search, methods=["POST"]),
|
||||||
Route("/chat", chat_endpoint, methods=["POST"]),
|
Route("/chat", chat_endpoint, methods=["POST"]),
|
||||||
Route("/langflow", langflow_endpoint, methods=["POST"]),
|
Route("/langflow", langflow_endpoint, methods=["POST"]),
|
||||||
|
|
@ -432,10 +778,17 @@ app = Starlette(debug=True, routes=[
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import atexit
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
await init_index()
|
await init_index()
|
||||||
|
|
||||||
|
# Cleanup process pool on exit
|
||||||
|
def cleanup():
|
||||||
|
process_pool.shutdown(wait=True)
|
||||||
|
|
||||||
|
atexit.register(cleanup)
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"app:app",
|
"app:app",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue