1277 lines
46 KiB
Python
1277 lines
46 KiB
Python
# app.py
|
||
import datetime
|
||
import os
|
||
from collections import defaultdict
|
||
from typing import Any
|
||
import uuid
|
||
import time
|
||
import random
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from concurrent.futures import ProcessPoolExecutor
|
||
import multiprocessing
|
||
|
||
from agent import async_chat, async_langflow
|
||
|
||
# Import connector components
|
||
from connectors.service import ConnectorService
|
||
from connectors.google_drive import GoogleDriveConnector
|
||
from session_manager import SessionManager
|
||
from auth_middleware import require_auth, optional_auth
|
||
|
||
import hashlib
|
||
import tempfile
|
||
import asyncio
|
||
|
||
from starlette.applications import Starlette
|
||
from starlette.requests import Request
|
||
from starlette.responses import JSONResponse, StreamingResponse
|
||
from starlette.routing import Route
|
||
|
||
import aiofiles
|
||
from opensearchpy import AsyncOpenSearch
|
||
from opensearchpy._async.http_aiohttp import AIOHttpConnection
|
||
from docling.document_converter import DocumentConverter
|
||
from agentd.patch import patch_openai_with_mcp
|
||
from openai import AsyncOpenAI
|
||
from agentd.tool_decorator import tool
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
load_dotenv("../")
|
||
|
||
import torch
|
||
print("CUDA available:", torch.cuda.is_available())
|
||
print("CUDA version PyTorch was built with:", torch.version.cuda)
|
||
|
||
# Initialize Docling converter
|
||
converter = DocumentConverter() # basic converter; tweak via PipelineOptions if you need OCR, etc.
|
||
|
||
# Initialize Async OpenSearch (adjust hosts/auth as needed)
|
||
opensearch_host = os.getenv("OPENSEARCH_HOST", "localhost")
|
||
opensearch_port = int(os.getenv("OPENSEARCH_PORT", "9200"))
|
||
opensearch_username = os.getenv("OPENSEARCH_USERNAME", "admin")
|
||
opensearch_password = os.getenv("OPENSEARCH_PASSWORD")
|
||
langflow_url = os.getenv("LANGFLOW_URL", "http://localhost:7860")
|
||
flow_id = os.getenv("FLOW_ID")
|
||
langflow_key = os.getenv("LANGFLOW_SECRET_KEY")
|
||
|
||
|
||
|
||
opensearch = AsyncOpenSearch(
|
||
hosts=[{"host": opensearch_host, "port": opensearch_port}],
|
||
connection_class=AIOHttpConnection,
|
||
scheme="https",
|
||
use_ssl=True,
|
||
verify_certs=False,
|
||
ssl_assert_fingerprint=None,
|
||
http_auth=(opensearch_username, opensearch_password),
|
||
http_compress=True,
|
||
)
|
||
|
||
INDEX_NAME = "documents"
|
||
VECTOR_DIM = 1536 # e.g. text-embedding-3-small output size
|
||
EMBED_MODEL = "text-embedding-3-small"
|
||
index_body = {
|
||
"settings": {
|
||
"index": {"knn": True},
|
||
"number_of_shards": 1,
|
||
"number_of_replicas": 1
|
||
},
|
||
"mappings": {
|
||
"properties": {
|
||
"document_id": { "type": "keyword" },
|
||
"filename": { "type": "keyword" },
|
||
"mimetype": { "type": "keyword" },
|
||
"page": { "type": "integer" },
|
||
"text": { "type": "text" },
|
||
"chunk_embedding": {
|
||
"type": "knn_vector",
|
||
"dimension": VECTOR_DIM,
|
||
"method": {
|
||
"name": "disk_ann",
|
||
"engine": "jvector",
|
||
"space_type": "l2",
|
||
"parameters": {
|
||
"ef_construction": 100,
|
||
"m": 16
|
||
}
|
||
}
|
||
},
|
||
# Connector and source information
|
||
"source_url": { "type": "keyword" },
|
||
"connector_type": { "type": "keyword" },
|
||
# ACL fields
|
||
"owner": { "type": "keyword" },
|
||
"allowed_users": { "type": "keyword" },
|
||
"allowed_groups": { "type": "keyword" },
|
||
"user_permissions": { "type": "object" },
|
||
"group_permissions": { "type": "object" },
|
||
# Timestamps
|
||
"created_time": { "type": "date" },
|
||
"modified_time": { "type": "date" },
|
||
"indexed_time": { "type": "date" },
|
||
# Additional metadata
|
||
"metadata": { "type": "object" }
|
||
}
|
||
}
|
||
}
|
||
|
||
langflow_client = AsyncOpenAI(
|
||
base_url=f"{langflow_url}/api/v1",
|
||
api_key=langflow_key
|
||
)
|
||
patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) # Get the patched client back
|
||
|
||
# Initialize connector service
|
||
connector_service = ConnectorService(
|
||
opensearch_client=opensearch,
|
||
patched_async_client=patched_async_client,
|
||
process_pool=None, # Will be set after process_pool is initialized
|
||
embed_model=EMBED_MODEL,
|
||
index_name=INDEX_NAME
|
||
)
|
||
|
||
# Initialize session manager
|
||
session_secret = os.getenv("SESSION_SECRET", "your-secret-key-change-in-production")
|
||
session_manager = SessionManager(session_secret)
|
||
|
||
# Track used authorization codes to prevent duplicate usage
|
||
used_auth_codes = set()
|
||
|
||
class TaskStatus(Enum):
|
||
PENDING = "pending"
|
||
RUNNING = "running"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
|
||
@dataclass
|
||
class FileTask:
|
||
file_path: str
|
||
status: TaskStatus = TaskStatus.PENDING
|
||
result: dict = None
|
||
error: str = None
|
||
retry_count: int = 0
|
||
created_at: float = field(default_factory=time.time)
|
||
updated_at: float = field(default_factory=time.time)
|
||
|
||
@dataclass
|
||
class UploadTask:
|
||
task_id: str
|
||
total_files: int
|
||
processed_files: int = 0
|
||
successful_files: int = 0
|
||
failed_files: int = 0
|
||
file_tasks: dict = field(default_factory=dict)
|
||
status: TaskStatus = TaskStatus.PENDING
|
||
created_at: float = field(default_factory=time.time)
|
||
updated_at: float = field(default_factory=time.time)
|
||
|
||
task_store = {} # user_id -> {task_id -> UploadTask}
|
||
background_tasks = set()
|
||
|
||
# GPU device detection
|
||
def detect_gpu_devices():
|
||
"""Detect if GPU devices are actually available"""
|
||
try:
|
||
import torch
|
||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||
return True, torch.cuda.device_count()
|
||
except ImportError:
|
||
pass
|
||
|
||
try:
|
||
import subprocess
|
||
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||
if result.returncode == 0:
|
||
return True, "detected"
|
||
except (subprocess.SubprocessError, FileNotFoundError):
|
||
pass
|
||
|
||
return False, 0
|
||
|
||
# GPU and concurrency configuration
|
||
HAS_GPU_DEVICES, GPU_COUNT = detect_gpu_devices()
|
||
|
||
if HAS_GPU_DEVICES:
|
||
# GPU mode with actual GPU devices: Lower concurrency due to memory constraints
|
||
DEFAULT_WORKERS = min(4, multiprocessing.cpu_count() // 2)
|
||
print(f"GPU mode enabled with {GPU_COUNT} GPU(s) - using limited concurrency ({DEFAULT_WORKERS} workers)")
|
||
elif HAS_GPU_DEVICES:
|
||
# GPU mode requested but no devices found: Use full CPU concurrency
|
||
DEFAULT_WORKERS = multiprocessing.cpu_count()
|
||
print(f"GPU mode requested but no GPU devices found - falling back to full CPU concurrency ({DEFAULT_WORKERS} workers)")
|
||
else:
|
||
# CPU mode: Higher concurrency since no GPU memory constraints
|
||
DEFAULT_WORKERS = multiprocessing.cpu_count()
|
||
print(f"CPU-only mode enabled - using full concurrency ({DEFAULT_WORKERS} workers)")
|
||
|
||
MAX_WORKERS = int(os.getenv("MAX_WORKERS", DEFAULT_WORKERS))
|
||
process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
|
||
connector_service.process_pool = process_pool # Set the process pool for connector service
|
||
|
||
print(f"Process pool initialized with {MAX_WORKERS} workers")
|
||
|
||
# Global converter cache for worker processes
|
||
_worker_converter = None
|
||
|
||
def get_worker_converter():
|
||
"""Get or create a DocumentConverter instance for this worker process"""
|
||
global _worker_converter
|
||
if _worker_converter is None:
|
||
import os
|
||
from docling.document_converter import DocumentConverter
|
||
|
||
# Configure GPU settings for this worker
|
||
has_gpu_devices, _ = detect_gpu_devices()
|
||
if not has_gpu_devices:
|
||
# Force CPU-only mode in subprocess
|
||
os.environ['USE_CPU_ONLY'] = 'true'
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
||
os.environ['TRANSFORMERS_OFFLINE'] = '0'
|
||
os.environ['TORCH_USE_CUDA_DSA'] = '0'
|
||
|
||
# Try to disable CUDA in torch if available
|
||
try:
|
||
import torch
|
||
torch.cuda.is_available = lambda: False
|
||
except ImportError:
|
||
pass
|
||
else:
|
||
# GPU mode - let libraries use GPU if available
|
||
os.environ.pop('USE_CPU_ONLY', None)
|
||
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # Still disable progress bars
|
||
|
||
print(f"🔧 Initializing DocumentConverter in worker process (PID: {os.getpid()})")
|
||
_worker_converter = DocumentConverter()
|
||
print(f"✅ DocumentConverter ready in worker process (PID: {os.getpid()})")
|
||
|
||
return _worker_converter
|
||
|
||
def detect_gpu_devices():
|
||
"""Detect if GPU devices are actually available"""
|
||
try:
|
||
import torch
|
||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||
return True, torch.cuda.device_count()
|
||
except ImportError:
|
||
pass
|
||
|
||
try:
|
||
import subprocess
|
||
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
|
||
if result.returncode == 0:
|
||
return True, "detected"
|
||
except (subprocess.SubprocessError, FileNotFoundError):
|
||
pass
|
||
|
||
return False, 0
|
||
|
||
def process_document_sync(file_path: str):
|
||
"""Synchronous document processing function for multiprocessing"""
|
||
import hashlib
|
||
from collections import defaultdict
|
||
|
||
# Get the cached converter for this worker
|
||
converter = get_worker_converter()
|
||
|
||
# Compute file hash
|
||
sha256 = hashlib.sha256()
|
||
with open(file_path, "rb") as f:
|
||
while True:
|
||
chunk = f.read(1 << 20)
|
||
if not chunk:
|
||
break
|
||
sha256.update(chunk)
|
||
file_hash = sha256.hexdigest()
|
||
|
||
# Convert with docling
|
||
result = converter.convert(file_path)
|
||
full_doc = result.document.export_to_dict()
|
||
|
||
# Extract relevant content (same logic as extract_relevant)
|
||
origin = full_doc.get("origin", {})
|
||
texts = full_doc.get("texts", [])
|
||
|
||
page_texts = defaultdict(list)
|
||
for txt in texts:
|
||
prov = txt.get("prov", [])
|
||
page_no = prov[0].get("page_no") if prov else None
|
||
if page_no is not None:
|
||
page_texts[page_no].append(txt.get("text", "").strip())
|
||
|
||
chunks = []
|
||
for page in sorted(page_texts):
|
||
joined = "\n".join(page_texts[page])
|
||
chunks.append({
|
||
"page": page,
|
||
"text": joined
|
||
})
|
||
|
||
return {
|
||
"id": file_hash,
|
||
"filename": origin.get("filename"),
|
||
"mimetype": origin.get("mimetype"),
|
||
"chunks": chunks,
|
||
"file_path": file_path
|
||
}
|
||
|
||
async def wait_for_opensearch():
|
||
"""Wait for OpenSearch to be ready with retries"""
|
||
max_retries = 30
|
||
retry_delay = 2
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
await opensearch.info()
|
||
print("OpenSearch is ready!")
|
||
return
|
||
except Exception as e:
|
||
print(f"Attempt {attempt + 1}/{max_retries}: OpenSearch not ready yet ({e})")
|
||
if attempt < max_retries - 1:
|
||
await asyncio.sleep(retry_delay)
|
||
else:
|
||
raise Exception("OpenSearch failed to become ready")
|
||
|
||
async def init_index():
|
||
await wait_for_opensearch()
|
||
|
||
if not await opensearch.indices.exists(index=INDEX_NAME):
|
||
await opensearch.indices.create(index=INDEX_NAME, body=index_body)
|
||
print(f"Created index '{INDEX_NAME}'")
|
||
else:
|
||
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
||
|
||
from collections import defaultdict
|
||
|
||
|
||
def extract_relevant(doc_dict: dict) -> dict:
|
||
"""
|
||
Given the full export_to_dict() result:
|
||
- Grabs origin metadata (hash, filename, mimetype)
|
||
- Finds every text fragment in `texts`, groups them by page_no
|
||
- Flattens tables in `tables` into tab-separated text, grouping by row
|
||
- Concatenates each page’s fragments and each table into its own chunk
|
||
Returns a slimmed dict ready for indexing, with each chunk under "text".
|
||
"""
|
||
origin = doc_dict.get("origin", {})
|
||
chunks = []
|
||
|
||
# 1) process free-text fragments
|
||
page_texts = defaultdict(list)
|
||
for txt in doc_dict.get("texts", []):
|
||
prov = txt.get("prov", [])
|
||
page_no = prov[0].get("page_no") if prov else None
|
||
if page_no is not None:
|
||
page_texts[page_no].append(txt.get("text", "").strip())
|
||
|
||
for page in sorted(page_texts):
|
||
chunks.append({
|
||
"page": page,
|
||
"type": "text",
|
||
"text": "\n".join(page_texts[page])
|
||
})
|
||
|
||
# 2) process tables
|
||
for t_idx, table in enumerate(doc_dict.get("tables", [])):
|
||
prov = table.get("prov", [])
|
||
page_no = prov[0].get("page_no") if prov else None
|
||
|
||
# group cells by their row index
|
||
rows = defaultdict(list)
|
||
for cell in table.get("data").get("table_cells", []):
|
||
r = cell.get("start_row_offset_idx")
|
||
c = cell.get("start_col_offset_idx")
|
||
text = cell.get("text", "").strip()
|
||
rows[r].append((c, text))
|
||
|
||
# build a tab‑separated line for each row, in order
|
||
flat_rows = []
|
||
for r in sorted(rows):
|
||
cells = [txt for _, txt in sorted(rows[r], key=lambda x: x[0])]
|
||
flat_rows.append("\t".join(cells))
|
||
|
||
chunks.append({
|
||
"page": page_no,
|
||
"type": "table",
|
||
"table_index": t_idx,
|
||
"text": "\n".join(flat_rows)
|
||
})
|
||
|
||
return {
|
||
"id": origin.get("binary_hash"),
|
||
"filename": origin.get("filename"),
|
||
"mimetype": origin.get("mimetype"),
|
||
"chunks": chunks
|
||
}
|
||
|
||
async def exponential_backoff_delay(retry_count: int, base_delay: float = 1.0, max_delay: float = 60.0) -> None:
|
||
"""Apply exponential backoff with jitter"""
|
||
delay = min(base_delay * (2 ** retry_count) + random.uniform(0, 1), max_delay)
|
||
await asyncio.sleep(delay)
|
||
|
||
async def process_file_with_retry(file_path: str, max_retries: int = 3) -> dict:
|
||
"""Process a file with retry logic - retries everything up to max_retries times"""
|
||
last_error = None
|
||
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
return await process_file_common(file_path)
|
||
except Exception as e:
|
||
last_error = e
|
||
if attempt < max_retries:
|
||
await exponential_backoff_delay(attempt)
|
||
continue
|
||
else:
|
||
raise last_error
|
||
|
||
async def process_file_common(file_path: str, file_hash: str = None, owner_user_id: str = None):
|
||
"""
|
||
Common processing logic for both upload and upload_path.
|
||
1. Optionally compute SHA256 hash if not provided.
|
||
2. Convert with docling and extract relevant content.
|
||
3. Add embeddings.
|
||
4. Index into OpenSearch.
|
||
"""
|
||
if file_hash is None:
|
||
sha256 = hashlib.sha256()
|
||
async with aiofiles.open(file_path, "rb") as f:
|
||
while True:
|
||
chunk = await f.read(1 << 20)
|
||
if not chunk:
|
||
break
|
||
sha256.update(chunk)
|
||
file_hash = sha256.hexdigest()
|
||
|
||
exists = await opensearch.exists(index=INDEX_NAME, id=file_hash)
|
||
if exists:
|
||
return {"status": "unchanged", "id": file_hash}
|
||
|
||
# convert and extract
|
||
# TODO: Check if docling can handle in-memory bytes instead of file path
|
||
# This would eliminate the need for temp files in upload flow
|
||
result = converter.convert(file_path)
|
||
full_doc = result.document.export_to_dict()
|
||
slim_doc = extract_relevant(full_doc)
|
||
|
||
texts = [c["text"] for c in slim_doc["chunks"]]
|
||
resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
||
embeddings = [d.embedding for d in resp.data]
|
||
|
||
# Index each chunk as a separate document
|
||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||
chunk_doc = {
|
||
"document_id": file_hash,
|
||
"filename": slim_doc["filename"],
|
||
"mimetype": slim_doc["mimetype"],
|
||
"page": chunk["page"],
|
||
"text": chunk["text"],
|
||
"chunk_embedding": vect,
|
||
"owner": owner_user_id, # User who uploaded/owns this document
|
||
"indexed_time": datetime.datetime.now().isoformat()
|
||
}
|
||
chunk_id = f"{file_hash}_{i}"
|
||
await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||
return {"status": "indexed", "id": file_hash}
|
||
|
||
async def process_file_on_disk(path: str):
|
||
"""
|
||
Process a file already on disk.
|
||
"""
|
||
result = await process_file_common(path)
|
||
result["path"] = path
|
||
return result
|
||
|
||
async def process_single_file_task(upload_task: UploadTask, file_path: str) -> None:
|
||
"""Process a single file and update task tracking"""
|
||
file_task = upload_task.file_tasks[file_path]
|
||
file_task.status = TaskStatus.RUNNING
|
||
file_task.updated_at = time.time()
|
||
|
||
try:
|
||
# Check if file already exists in index
|
||
import asyncio
|
||
loop = asyncio.get_event_loop()
|
||
|
||
# Run CPU-intensive docling processing in separate process
|
||
slim_doc = await loop.run_in_executor(process_pool, process_document_sync, file_path)
|
||
|
||
# Check if already indexed
|
||
exists = await opensearch.exists(index=INDEX_NAME, id=slim_doc["id"])
|
||
if exists:
|
||
result = {"status": "unchanged", "id": slim_doc["id"]}
|
||
else:
|
||
# Generate embeddings and index (I/O bound, keep in main process)
|
||
texts = [c["text"] for c in slim_doc["chunks"]]
|
||
resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=texts)
|
||
embeddings = [d.embedding for d in resp.data]
|
||
|
||
# Index each chunk
|
||
for i, (chunk, vect) in enumerate(zip(slim_doc["chunks"], embeddings)):
|
||
chunk_doc = {
|
||
"document_id": slim_doc["id"],
|
||
"filename": slim_doc["filename"],
|
||
"mimetype": slim_doc["mimetype"],
|
||
"page": chunk["page"],
|
||
"text": chunk["text"],
|
||
"chunk_embedding": vect
|
||
}
|
||
chunk_id = f"{slim_doc['id']}_{i}"
|
||
await opensearch.index(index=INDEX_NAME, id=chunk_id, body=chunk_doc)
|
||
|
||
result = {"status": "indexed", "id": slim_doc["id"]}
|
||
|
||
result["path"] = file_path
|
||
file_task.status = TaskStatus.COMPLETED
|
||
file_task.result = result
|
||
upload_task.successful_files += 1
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Failed to process file {file_path}: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
file_task.status = TaskStatus.FAILED
|
||
file_task.error = str(e)
|
||
upload_task.failed_files += 1
|
||
finally:
|
||
file_task.updated_at = time.time()
|
||
upload_task.processed_files += 1
|
||
upload_task.updated_at = time.time()
|
||
|
||
if upload_task.processed_files >= upload_task.total_files:
|
||
upload_task.status = TaskStatus.COMPLETED
|
||
|
||
async def background_upload_processor(user_id: str, task_id: str) -> None:
|
||
"""Background task to process all files in an upload job with concurrency control"""
|
||
try:
|
||
upload_task = task_store[user_id][task_id]
|
||
upload_task.status = TaskStatus.RUNNING
|
||
upload_task.updated_at = time.time()
|
||
|
||
# Process files with limited concurrency to avoid overwhelming the system
|
||
semaphore = asyncio.Semaphore(MAX_WORKERS * 2) # Allow 2x process pool size for async I/O
|
||
|
||
async def process_with_semaphore(file_path: str):
|
||
async with semaphore:
|
||
await process_single_file_task(upload_task, file_path)
|
||
|
||
tasks = [
|
||
process_with_semaphore(file_path)
|
||
for file_path in upload_task.file_tasks.keys()
|
||
]
|
||
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
if user_id in task_store and task_id in task_store[user_id]:
|
||
task_store[user_id][task_id].status = TaskStatus.FAILED
|
||
task_store[user_id][task_id].updated_at = time.time()
|
||
|
||
@require_auth(session_manager)
|
||
async def upload(request: Request):
|
||
form = await request.form()
|
||
upload_file = form["file"]
|
||
|
||
sha256 = hashlib.sha256()
|
||
tmp = tempfile.NamedTemporaryFile(delete=False)
|
||
try:
|
||
while True:
|
||
chunk = await upload_file.read(1 << 20)
|
||
if not chunk:
|
||
break
|
||
sha256.update(chunk)
|
||
tmp.write(chunk)
|
||
tmp.flush()
|
||
|
||
file_hash = sha256.hexdigest()
|
||
exists = await opensearch.exists(index=INDEX_NAME, id=file_hash)
|
||
if exists:
|
||
return JSONResponse({"status": "unchanged", "id": file_hash})
|
||
|
||
user = request.state.user
|
||
result = await process_file_common(tmp.name, file_hash, owner_user_id=user.user_id)
|
||
return JSONResponse(result)
|
||
|
||
finally:
|
||
tmp.close()
|
||
os.remove(tmp.name)
|
||
|
||
@require_auth(session_manager)
|
||
async def upload_path(request: Request):
|
||
payload = await request.json()
|
||
base_dir = payload.get("path")
|
||
if not base_dir or not os.path.isdir(base_dir):
|
||
return JSONResponse({"error": "Invalid path"}, status_code=400)
|
||
|
||
file_paths = [os.path.join(root, fn)
|
||
for root, _, files in os.walk(base_dir)
|
||
for fn in files]
|
||
|
||
if not file_paths:
|
||
return JSONResponse({"error": "No files found in directory"}, status_code=400)
|
||
|
||
task_id = str(uuid.uuid4())
|
||
upload_task = UploadTask(
|
||
task_id=task_id,
|
||
total_files=len(file_paths),
|
||
file_tasks={path: FileTask(file_path=path) for path in file_paths}
|
||
)
|
||
|
||
user = request.state.user
|
||
if user.user_id not in task_store:
|
||
task_store[user.user_id] = {}
|
||
task_store[user.user_id][task_id] = upload_task
|
||
|
||
background_task = asyncio.create_task(background_upload_processor(user.user_id, task_id))
|
||
background_tasks.add(background_task)
|
||
background_task.add_done_callback(background_tasks.discard)
|
||
|
||
return JSONResponse({
|
||
"task_id": task_id,
|
||
"total_files": len(file_paths),
|
||
"status": "accepted"
|
||
}, status_code=201)
|
||
|
||
@require_auth(session_manager)
|
||
async def upload_context(request: Request):
|
||
"""Upload a file and add its content as context to the current conversation"""
|
||
import io
|
||
from docling_core.types.io import DocumentStream
|
||
|
||
form = await request.form()
|
||
upload_file = form["file"]
|
||
filename = upload_file.filename or "uploaded_document"
|
||
|
||
# Get optional parameters
|
||
previous_response_id = form.get("previous_response_id")
|
||
endpoint = form.get("endpoint", "langflow") # default to langflow
|
||
|
||
# Stream file content into BytesIO
|
||
content = io.BytesIO()
|
||
while True:
|
||
chunk = await upload_file.read(1 << 20) # 1MB chunks
|
||
if not chunk:
|
||
break
|
||
content.write(chunk)
|
||
content.seek(0) # Reset to beginning for reading
|
||
|
||
# Create DocumentStream and process with docling
|
||
doc_stream = DocumentStream(name=filename, stream=content)
|
||
result = converter.convert(doc_stream)
|
||
full_doc = result.document.export_to_dict()
|
||
slim_doc = extract_relevant(full_doc)
|
||
|
||
# Extract all text content
|
||
all_text = []
|
||
for chunk in slim_doc["chunks"]:
|
||
all_text.append(f"Page {chunk['page']}:\n{chunk['text']}")
|
||
|
||
full_content = "\n\n".join(all_text)
|
||
|
||
# Send document content as user message to get proper response_id
|
||
document_prompt = f"I'm uploading a document called '{filename}'. Here is its content:\n\n{full_content}\n\nPlease confirm you've received this document and are ready to answer questions about it."
|
||
|
||
if endpoint == "langflow":
|
||
from agent import async_langflow
|
||
response_text, response_id = await async_langflow(langflow_client, flow_id, document_prompt, previous_response_id=previous_response_id)
|
||
else: # chat
|
||
from agent import async_chat
|
||
response_text, response_id = await async_chat(patched_async_client, document_prompt, previous_response_id=previous_response_id)
|
||
|
||
response_data = {
|
||
"status": "context_added",
|
||
"filename": filename,
|
||
"pages": len(slim_doc["chunks"]),
|
||
"content_length": len(full_content),
|
||
"response_id": response_id,
|
||
"confirmation": response_text
|
||
}
|
||
|
||
return JSONResponse(response_data)
|
||
|
||
@require_auth(session_manager)
|
||
async def task_status(request: Request):
|
||
"""Get the status of an upload task"""
|
||
task_id = request.path_params.get("task_id")
|
||
|
||
user = request.state.user
|
||
|
||
if (not task_id or
|
||
user.user_id not in task_store or
|
||
task_id not in task_store[user.user_id]):
|
||
return JSONResponse({"error": "Task not found"}, status_code=404)
|
||
|
||
upload_task = task_store[user.user_id][task_id]
|
||
|
||
file_statuses = {}
|
||
for file_path, file_task in upload_task.file_tasks.items():
|
||
file_statuses[file_path] = {
|
||
"status": file_task.status.value,
|
||
"result": file_task.result,
|
||
"error": file_task.error,
|
||
"retry_count": file_task.retry_count,
|
||
"created_at": file_task.created_at,
|
||
"updated_at": file_task.updated_at
|
||
}
|
||
|
||
return JSONResponse({
|
||
"task_id": upload_task.task_id,
|
||
"status": upload_task.status.value,
|
||
"total_files": upload_task.total_files,
|
||
"processed_files": upload_task.processed_files,
|
||
"successful_files": upload_task.successful_files,
|
||
"failed_files": upload_task.failed_files,
|
||
"created_at": upload_task.created_at,
|
||
"updated_at": upload_task.updated_at,
|
||
"files": file_statuses
|
||
})
|
||
|
||
@require_auth(session_manager)
|
||
async def search(request: Request):
|
||
payload = await request.json()
|
||
query = payload.get("query")
|
||
if not query:
|
||
return JSONResponse({"error": "Query is required"}, status_code=400)
|
||
|
||
user = request.state.user
|
||
return JSONResponse(await search_tool(query, user_id=user.user_id))
|
||
|
||
|
||
@tool
|
||
async def search_tool(query: str, user_id: str = None)-> dict[str, Any]:
|
||
"""
|
||
Use this tool to search for documents relevant to the query.
|
||
|
||
Args:
|
||
query (str): query string to search the corpus
|
||
user_id (str): user ID for access control (optional)
|
||
|
||
Returns:
|
||
dict (str, Any): {"results": [chunks]} on success
|
||
"""
|
||
# Embed the query
|
||
resp = await patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
|
||
query_embedding = resp.data[0].embedding
|
||
|
||
# Base query structure
|
||
search_body = {
|
||
"query": {
|
||
"bool": {
|
||
"must": [
|
||
{
|
||
"knn": {
|
||
"chunk_embedding": {
|
||
"vector": query_embedding,
|
||
"k": 10
|
||
}
|
||
}
|
||
}
|
||
]
|
||
}
|
||
},
|
||
"_source": ["filename", "mimetype", "page", "text", "source_url", "owner", "allowed_users", "allowed_groups"],
|
||
"size": 10
|
||
}
|
||
|
||
# Require authentication - no anonymous access to search
|
||
if not user_id:
|
||
return {"results": [], "error": "Authentication required"}
|
||
|
||
# Authenticated user access control
|
||
# User can access documents if:
|
||
# 1. They own the document (owner field matches user_id)
|
||
# 2. They're in allowed_users list
|
||
# 3. Document has no ACL (public documents)
|
||
# TODO: Add group access control later
|
||
should_clauses = [
|
||
{"term": {"owner": user_id}},
|
||
{"term": {"allowed_users": user_id}},
|
||
{"bool": {"must_not": {"exists": {"field": "owner"}}}} # Public docs
|
||
]
|
||
|
||
search_body["query"]["bool"]["should"] = should_clauses
|
||
search_body["query"]["bool"]["minimum_should_match"] = 1
|
||
|
||
results = await opensearch.search(index=INDEX_NAME, body=search_body)
|
||
|
||
# Transform results
|
||
chunks = []
|
||
for hit in results["hits"]["hits"]:
|
||
chunks.append({
|
||
"filename": hit["_source"]["filename"],
|
||
"mimetype": hit["_source"]["mimetype"],
|
||
"page": hit["_source"]["page"],
|
||
"text": hit["_source"]["text"],
|
||
"score": hit["_score"],
|
||
"source_url": hit["_source"].get("source_url"),
|
||
"owner": hit["_source"].get("owner")
|
||
})
|
||
return {"results": chunks}
|
||
|
||
@require_auth(session_manager)
|
||
async def chat_endpoint(request):
|
||
data = await request.json()
|
||
prompt = data.get("prompt", "")
|
||
previous_response_id = data.get("previous_response_id")
|
||
stream = data.get("stream", False)
|
||
|
||
# Get authenticated user
|
||
user = request.state.user
|
||
user_id = user.user_id
|
||
|
||
if not prompt:
|
||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||
|
||
if stream:
|
||
from agent import async_chat_stream
|
||
return StreamingResponse(
|
||
async_chat_stream(patched_async_client, prompt, user_id, previous_response_id=previous_response_id),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"Access-Control-Allow-Origin": "*",
|
||
"Access-Control-Allow-Headers": "Cache-Control"
|
||
}
|
||
)
|
||
else:
|
||
response_text, response_id = await async_chat(patched_async_client, prompt, user_id, previous_response_id=previous_response_id)
|
||
response_data = {"response": response_text}
|
||
if response_id:
|
||
response_data["response_id"] = response_id
|
||
return JSONResponse(response_data)
|
||
|
||
@require_auth(session_manager)
|
||
async def langflow_endpoint(request):
|
||
data = await request.json()
|
||
prompt = data.get("prompt", "")
|
||
previous_response_id = data.get("previous_response_id")
|
||
stream = data.get("stream", False)
|
||
|
||
if not prompt:
|
||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||
|
||
if not langflow_url or not flow_id or not langflow_key:
|
||
return JSONResponse({"error": "LANGFLOW_URL, FLOW_ID, and LANGFLOW_KEY environment variables are required"}, status_code=500)
|
||
|
||
try:
|
||
if stream:
|
||
from agent import async_langflow_stream
|
||
return StreamingResponse(
|
||
async_langflow_stream(langflow_client, flow_id, prompt, previous_response_id=previous_response_id),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"Access-Control-Allow-Origin": "*",
|
||
"Access-Control-Allow-Headers": "Cache-Control"
|
||
}
|
||
)
|
||
else:
|
||
response_text, response_id = await async_langflow(langflow_client, flow_id, prompt, previous_response_id=previous_response_id)
|
||
response_data = {"response": response_text}
|
||
if response_id:
|
||
response_data["response_id"] = response_id
|
||
return JSONResponse(response_data)
|
||
|
||
except Exception as e:
|
||
return JSONResponse({"error": f"Langflow request failed: {str(e)}"}, status_code=500)
|
||
|
||
|
||
# Authentication endpoints
|
||
@optional_auth(session_manager) # Allow both authenticated and non-authenticated users
|
||
async def auth_init(request: Request):
|
||
"""Initialize OAuth flow for authentication or data source connection"""
|
||
try:
|
||
data = await request.json()
|
||
provider = data.get("provider") # "google", "microsoft", etc.
|
||
purpose = data.get("purpose", "data_source") # "app_auth" or "data_source"
|
||
connection_name = data.get("name", f"{provider}_{purpose}")
|
||
redirect_uri = data.get("redirect_uri") # Frontend provides this
|
||
|
||
# Get user from authentication if available
|
||
user = getattr(request.state, 'user', None)
|
||
user_id = user.user_id if user else None
|
||
|
||
if provider != "google":
|
||
return JSONResponse({"error": "Unsupported provider"}, status_code=400)
|
||
|
||
if not redirect_uri:
|
||
return JSONResponse({"error": "redirect_uri is required"}, status_code=400)
|
||
|
||
# Get OAuth client configuration from environment
|
||
google_client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||
if not google_client_id:
|
||
return JSONResponse({"error": "Google OAuth client ID not configured"}, status_code=500)
|
||
|
||
# Create connection configuration
|
||
token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json"
|
||
config = {
|
||
"client_id": google_client_id,
|
||
"token_file": token_file,
|
||
"provider": provider,
|
||
"purpose": purpose,
|
||
"redirect_uri": redirect_uri # Store redirect_uri for use in callback
|
||
}
|
||
|
||
# Create connection in manager
|
||
# For data sources, use provider name (e.g. "google_drive")
|
||
# For app auth, connector_type doesn't matter since it gets deleted
|
||
connector_type = f"{provider}_drive" if purpose == "data_source" else f"{provider}_auth"
|
||
connection_id = await connector_service.connection_manager.create_connection(
|
||
connector_type=connector_type,
|
||
name=connection_name,
|
||
config=config,
|
||
user_id=user_id
|
||
)
|
||
|
||
# Return OAuth configuration for client-side flow
|
||
# Include both identity and data access scopes
|
||
scopes = [
|
||
# Identity scopes (for app auth)
|
||
'openid',
|
||
'email',
|
||
'profile',
|
||
# Data access scopes (for connectors)
|
||
'https://www.googleapis.com/auth/drive.readonly',
|
||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||
]
|
||
|
||
oauth_config = {
|
||
"client_id": google_client_id,
|
||
"scopes": scopes,
|
||
"redirect_uri": redirect_uri, # Use the redirect_uri from frontend
|
||
"authorization_endpoint":
|
||
"https://accounts.google.com/o/oauth2/v2/auth",
|
||
"token_endpoint":
|
||
"https://oauth2.googleapis.com/token"
|
||
}
|
||
|
||
return JSONResponse({
|
||
"connection_id": connection_id,
|
||
"oauth_config": oauth_config
|
||
})
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return JSONResponse({"error": f"Failed to initialize OAuth: {str(e)}"}, status_code=500)
|
||
|
||
|
||
async def auth_callback(request: Request):
|
||
"""Handle OAuth callback - exchange authorization code for tokens"""
|
||
try:
|
||
data = await request.json()
|
||
connection_id = data.get("connection_id")
|
||
authorization_code = data.get("authorization_code")
|
||
state = data.get("state")
|
||
|
||
if not all([connection_id, authorization_code]):
|
||
return JSONResponse({"error": "Missing required parameters (connection_id, authorization_code)"}, status_code=400)
|
||
|
||
# Check if authorization code has already been used
|
||
if authorization_code in used_auth_codes:
|
||
return JSONResponse({"error": "Authorization code already used"}, status_code=400)
|
||
|
||
# Mark code as used to prevent duplicate requests
|
||
used_auth_codes.add(authorization_code)
|
||
|
||
try:
|
||
# Get connection config
|
||
connection_config = await connector_service.connection_manager.get_connection(connection_id)
|
||
if not connection_config:
|
||
return JSONResponse({"error": "Connection not found"}, status_code=404)
|
||
|
||
# Exchange authorization code for tokens
|
||
import httpx
|
||
|
||
# Use the redirect_uri that was stored during auth_init
|
||
redirect_uri = connection_config.config.get("redirect_uri")
|
||
if not redirect_uri:
|
||
return JSONResponse({"error": "Redirect URI not found in connection config"}, status_code=400)
|
||
|
||
token_url = "https://oauth2.googleapis.com/token"
|
||
|
||
token_payload = {
|
||
"code": authorization_code,
|
||
"client_id": connection_config.config["client_id"],
|
||
"client_secret": os.getenv("GOOGLE_OAUTH_CLIENT_SECRET"), # Need this for server-side
|
||
"redirect_uri": redirect_uri,
|
||
"grant_type": "authorization_code"
|
||
}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
token_response = await client.post(token_url, data=token_payload)
|
||
|
||
if token_response.status_code != 200:
|
||
raise Exception(f"Token exchange failed: {token_response.text}")
|
||
|
||
token_data = token_response.json()
|
||
|
||
# Store tokens in the token file
|
||
token_file_data = {
|
||
"token": token_data["access_token"],
|
||
"refresh_token": token_data.get("refresh_token"),
|
||
"scopes": [
|
||
"openid",
|
||
"email",
|
||
"profile",
|
||
"https://www.googleapis.com/auth/drive.readonly",
|
||
"https://www.googleapis.com/auth/drive.metadata.readonly"
|
||
]
|
||
}
|
||
|
||
# Add expiry if provided
|
||
if token_data.get("expires_in"):
|
||
from datetime import datetime, timedelta
|
||
expiry = datetime.now() + timedelta(seconds=int(token_data["expires_in"]))
|
||
token_file_data["expiry"] = expiry.isoformat()
|
||
|
||
# Save tokens to file
|
||
import json
|
||
token_file_path = connection_config.config["token_file"]
|
||
async with aiofiles.open(token_file_path, 'w') as f:
|
||
await f.write(json.dumps(token_file_data, indent=2))
|
||
|
||
# Route based on purpose
|
||
purpose = connection_config.config.get("purpose", "data_source")
|
||
|
||
if purpose == "app_auth":
|
||
# Handle app authentication - create user session
|
||
jwt_token = await session_manager.create_user_session(token_data["access_token"])
|
||
|
||
if jwt_token:
|
||
# Get the user info to create a persistent Google Drive connection
|
||
user_info = await session_manager.get_user_info_from_token(token_data["access_token"])
|
||
user_id = user_info["id"] if user_info else None
|
||
|
||
if user_id:
|
||
# Convert the temporary auth connection to a persistent Google Drive connection
|
||
# Update the connection to be a data source connection with the user_id
|
||
await connector_service.connection_manager.update_connection(
|
||
connection_id=connection_id,
|
||
connector_type="google_drive",
|
||
name=f"Google Drive ({user_info.get('email', 'Unknown')})",
|
||
user_id=user_id,
|
||
config={
|
||
**connection_config.config,
|
||
"purpose": "data_source", # Convert to data source
|
||
"user_email": user_info.get("email")
|
||
}
|
||
)
|
||
|
||
response = JSONResponse({
|
||
"status": "authenticated",
|
||
"purpose": "app_auth",
|
||
"redirect": "/", # Redirect to home page instead of dashboard
|
||
"google_drive_connection_id": connection_id # Return connection ID for frontend
|
||
})
|
||
else:
|
||
# Fallback: delete connection if we can't get user info
|
||
await connector_service.connection_manager.delete_connection(connection_id)
|
||
response = JSONResponse({
|
||
"status": "authenticated",
|
||
"purpose": "app_auth",
|
||
"redirect": "/"
|
||
})
|
||
|
||
# Set JWT as HTTP-only cookie for security
|
||
response.set_cookie(
|
||
key="auth_token",
|
||
value=jwt_token,
|
||
httponly=True,
|
||
secure=False, # False for development/testing
|
||
samesite="lax",
|
||
max_age=7 * 24 * 60 * 60 # 7 days
|
||
)
|
||
return response
|
||
else:
|
||
# Clean up connection if session creation failed
|
||
await connector_service.connection_manager.delete_connection(connection_id)
|
||
return JSONResponse({"error": "Failed to create user session"}, status_code=500)
|
||
else:
|
||
# Handle data source connection - keep the connection for syncing
|
||
return JSONResponse({
|
||
"status": "authenticated",
|
||
"connection_id": connection_id,
|
||
"purpose": "data_source",
|
||
"connector_type": connection_config.connector_type
|
||
})
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return JSONResponse({"error": f"OAuth callback failed: {str(e)}"}, status_code=500)
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return JSONResponse({"error": f"Callback failed: {str(e)}"}, status_code=500)
|
||
|
||
|
||
@optional_auth(session_manager)
|
||
async def auth_me(request: Request):
|
||
"""Get current user information"""
|
||
user = getattr(request.state, 'user', None)
|
||
|
||
if user:
|
||
return JSONResponse({
|
||
"authenticated": True,
|
||
"user": {
|
||
"user_id": user.user_id,
|
||
"email": user.email,
|
||
"name": user.name,
|
||
"picture": user.picture,
|
||
"provider": user.provider,
|
||
"last_login": user.last_login.isoformat() if user.last_login else None
|
||
}
|
||
})
|
||
else:
|
||
return JSONResponse({
|
||
"authenticated": False,
|
||
"user": None
|
||
})
|
||
|
||
@require_auth(session_manager)
|
||
async def auth_logout(request: Request):
|
||
"""Logout user by clearing auth cookie"""
|
||
response = JSONResponse({
|
||
"status": "logged_out",
|
||
"message": "Successfully logged out"
|
||
})
|
||
|
||
# Clear the auth cookie
|
||
response.delete_cookie(
|
||
key="auth_token",
|
||
httponly=True,
|
||
secure=False, # False for development/testing
|
||
samesite="lax"
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
@require_auth(session_manager)
|
||
async def connector_sync(request: Request):
|
||
"""Sync files from a connector connection"""
|
||
data = await request.json()
|
||
connection_id = data.get("connection_id")
|
||
max_files = data.get("max_files")
|
||
|
||
if not connection_id:
|
||
return JSONResponse({"error": "connection_id is required"}, status_code=400)
|
||
|
||
try:
|
||
print(f"[DEBUG] Starting connector sync for connection_id={connection_id}, max_files={max_files}")
|
||
|
||
# Verify user owns this connection
|
||
user = request.state.user
|
||
print(f"[DEBUG] User: {user.user_id}")
|
||
|
||
connection_config = await connector_service.connection_manager.get_connection(connection_id)
|
||
print(f"[DEBUG] Got connection config: {connection_config is not None}")
|
||
|
||
if not connection_config:
|
||
return JSONResponse({"error": "Connection not found"}, status_code=404)
|
||
|
||
if connection_config.user_id != user.user_id:
|
||
return JSONResponse({"error": "Access denied"}, status_code=403)
|
||
|
||
print(f"[DEBUG] About to call sync_connector_files")
|
||
task_id = await connector_service.sync_connector_files(connection_id, user.user_id, max_files)
|
||
print(f"[DEBUG] Got task_id: {task_id}")
|
||
|
||
return JSONResponse({
|
||
"task_id": task_id,
|
||
"status": "sync_started",
|
||
"message": f"Started syncing files from connection {connection_id}"
|
||
},
|
||
status_code=201
|
||
)
|
||
|
||
except Exception as e:
|
||
import sys
|
||
import traceback
|
||
|
||
error_msg = f"[ERROR] Connector sync failed: {str(e)}"
|
||
print(error_msg, file=sys.stderr, flush=True)
|
||
traceback.print_exc(file=sys.stderr)
|
||
sys.stderr.flush()
|
||
|
||
return JSONResponse({"error": f"Sync failed: {str(e)}"}, status_code=500)
|
||
|
||
|
||
@require_auth(session_manager)
|
||
async def connector_status(request: Request):
|
||
"""Get connector status for authenticated user"""
|
||
connector_type = request.path_params.get("connector_type", "google_drive")
|
||
user = request.state.user
|
||
|
||
# Get connections for this connector type and user
|
||
connections = await connector_service.connection_manager.list_connections(
|
||
user_id=user.user_id,
|
||
connector_type=connector_type
|
||
)
|
||
|
||
# Check if there are any active connections
|
||
active_connections = [conn for conn in connections if conn.is_active]
|
||
has_authenticated_connection = len(active_connections) > 0
|
||
|
||
return JSONResponse({
|
||
"connector_type": connector_type,
|
||
"authenticated": has_authenticated_connection, # For frontend compatibility
|
||
"status": "connected" if has_authenticated_connection else "not_connected",
|
||
"connections": [
|
||
{
|
||
"connection_id": conn.connection_id,
|
||
"name": conn.name,
|
||
"is_active": conn.is_active,
|
||
"created_at": conn.created_at.isoformat(),
|
||
"last_sync": conn.last_sync.isoformat() if conn.last_sync else None
|
||
}
|
||
for conn in connections
|
||
]
|
||
})
|
||
|
||
|
||
app = Starlette(debug=True, routes=[
|
||
Route("/upload", upload, methods=["POST"]),
|
||
Route("/upload_context", upload_context, methods=["POST"]),
|
||
Route("/upload_path", upload_path, methods=["POST"]),
|
||
Route("/tasks/{task_id}", task_status, methods=["GET"]),
|
||
Route("/search", search, methods=["POST"]),
|
||
Route("/chat", chat_endpoint, methods=["POST"]),
|
||
Route("/langflow", langflow_endpoint, methods=["POST"]),
|
||
# Authentication endpoints
|
||
Route("/auth/init", auth_init, methods=["POST"]),
|
||
Route("/auth/callback", auth_callback, methods=["POST"]),
|
||
Route("/auth/me", auth_me, methods=["GET"]),
|
||
Route("/auth/logout", auth_logout, methods=["POST"]),
|
||
Route("/connectors/sync", connector_sync, methods=["POST"]),
|
||
Route("/connectors/status/{connector_type}", connector_status, methods=["GET"]),
|
||
])
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
import atexit
|
||
|
||
async def main():
|
||
await init_index()
|
||
await connector_service.initialize()
|
||
|
||
# Cleanup process pool on exit
|
||
def cleanup():
|
||
process_pool.shutdown(wait=True)
|
||
|
||
atexit.register(cleanup)
|
||
|
||
asyncio.run(main())
|
||
uvicorn.run(
|
||
"app:app",
|
||
host="0.0.0.0",
|
||
port=8000,
|
||
reload=True,
|
||
)
|