openrag/src/main.py
2025-08-11 16:45:54 -04:00

327 lines
No EOL
12 KiB
Python

import asyncio
import atexit
import multiprocessing
from functools import partial
from starlette.applications import Starlette
from starlette.routing import Route
# Set multiprocessing start method to 'spawn' for CUDA compatibility
multiprocessing.set_start_method('spawn', force=True)
# Create process pool FIRST, before any torch/CUDA imports
from utils.process_pool import process_pool
import torch
# Configuration and setup
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
from utils.gpu_detection import detect_gpu_devices
# Services
from services.document_service import DocumentService
from services.search_service import SearchService
from services.task_service import TaskService
from services.auth_service import AuthService
from services.chat_service import ChatService
# Existing services
from connectors.service import ConnectorService
from session_manager import SessionManager
from auth_middleware import require_auth, optional_auth
# API endpoints
from api import upload, search, chat, auth, connectors, tasks, oidc
print("CUDA available:", torch.cuda.is_available())
print("CUDA version PyTorch was built with:", torch.version.cuda)
async def wait_for_opensearch():
"""Wait for OpenSearch to be ready with retries"""
max_retries = 30
retry_delay = 2
for attempt in range(max_retries):
try:
await clients.opensearch.info()
print("OpenSearch is ready!")
return
except Exception as e:
print(f"Attempt {attempt + 1}/{max_retries}: OpenSearch not ready yet ({e})")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
else:
raise Exception("OpenSearch failed to become ready")
async def init_index():
"""Initialize OpenSearch index and security roles"""
await wait_for_opensearch()
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
print(f"Created index '{INDEX_NAME}'")
else:
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
async def init_index_when_ready():
"""Initialize OpenSearch index when it becomes available"""
try:
await init_index()
print("OpenSearch index initialization completed successfully")
except Exception as e:
print(f"OpenSearch index initialization failed: {e}")
print("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready")
def initialize_services():
"""Initialize all services and their dependencies"""
# Initialize clients
clients.initialize()
# Initialize session manager
session_manager = SessionManager(SESSION_SECRET)
# Initialize services
document_service = DocumentService(session_manager=session_manager)
search_service = SearchService(session_manager)
task_service = TaskService(document_service, process_pool)
chat_service = ChatService()
# Set process pool for document service
document_service.process_pool = process_pool
# Initialize connector service
connector_service = ConnectorService(
patched_async_client=clients.patched_async_client,
process_pool=process_pool,
embed_model="text-embedding-3-small",
index_name=INDEX_NAME,
task_service=task_service,
session_manager=session_manager
)
# Initialize auth service
auth_service = AuthService(session_manager, connector_service)
return {
'document_service': document_service,
'search_service': search_service,
'task_service': task_service,
'chat_service': chat_service,
'auth_service': auth_service,
'connector_service': connector_service,
'session_manager': session_manager
}
def create_app():
"""Create and configure the Starlette application"""
services = initialize_services()
# Create route handlers with service dependencies injected
routes = [
# Upload endpoints
Route("/upload",
require_auth(services['session_manager'])(
partial(upload.upload,
document_service=services['document_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/upload_context",
require_auth(services['session_manager'])(
partial(upload.upload_context,
document_service=services['document_service'],
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/upload_path",
require_auth(services['session_manager'])(
partial(upload.upload_path,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/tasks/{task_id}",
require_auth(services['session_manager'])(
partial(tasks.task_status,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/tasks",
require_auth(services['session_manager'])(
partial(tasks.all_tasks,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/tasks/{task_id}/cancel",
require_auth(services['session_manager'])(
partial(tasks.cancel_task,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Search endpoint
Route("/search",
require_auth(services['session_manager'])(
partial(search.search,
search_service=services['search_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Chat endpoints
Route("/chat",
require_auth(services['session_manager'])(
partial(chat.chat_endpoint,
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/langflow",
require_auth(services['session_manager'])(
partial(chat.langflow_endpoint,
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Authentication endpoints
Route("/auth/init",
optional_auth(services['session_manager'])(
partial(auth.auth_init,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/auth/callback",
partial(auth.auth_callback,
auth_service=services['auth_service'],
session_manager=services['session_manager']),
methods=["POST"]),
Route("/auth/me",
optional_auth(services['session_manager'])(
partial(auth.auth_me,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/auth/logout",
require_auth(services['session_manager'])(
partial(auth.auth_logout,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Connector endpoints
Route("/connectors/{connector_type}/sync",
require_auth(services['session_manager'])(
partial(connectors.connector_sync,
connector_service=services['connector_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/connectors/{connector_type}/status",
require_auth(services['session_manager'])(
partial(connectors.connector_status,
connector_service=services['connector_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/connectors/{connector_type}/webhook",
partial(connectors.connector_webhook,
connector_service=services['connector_service'],
session_manager=services['session_manager']),
methods=["POST", "GET"]),
# OIDC endpoints
Route("/.well-known/openid-configuration",
partial(oidc.oidc_discovery,
session_manager=services['session_manager']),
methods=["GET"]),
Route("/auth/jwks",
partial(oidc.jwks_endpoint,
session_manager=services['session_manager']),
methods=["GET"]),
Route("/auth/introspect",
partial(oidc.token_introspection,
session_manager=services['session_manager']),
methods=["POST"]),
]
app = Starlette(debug=True, routes=routes)
app.state.services = services # Store services for cleanup
# Add startup event handler
@app.on_event("startup")
async def startup_event():
# Start index initialization in background to avoid blocking OIDC endpoints
asyncio.create_task(init_index_when_ready())
# Add shutdown event handler
@app.on_event("shutdown")
async def shutdown_event():
await cleanup_subscriptions_proper(services)
return app
async def startup():
"""Application startup tasks"""
await init_index()
# Get services from app state if needed for initialization
# services = app.state.services
# await services['connector_service'].initialize()
def cleanup():
"""Cleanup on application shutdown"""
# Cleanup process pools only (webhooks handled by Starlette shutdown)
print("[CLEANUP] Shutting down...")
pass
async def cleanup_subscriptions_proper(services):
"""Cancel all active webhook subscriptions"""
print("[CLEANUP] Cancelling active webhook subscriptions...")
try:
connector_service = services['connector_service']
await connector_service.connection_manager.load_connections()
# Get all active connections with webhook subscriptions
all_connections = await connector_service.connection_manager.list_connections()
active_connections = [c for c in all_connections if c.is_active and c.config.get('webhook_channel_id')]
for connection in active_connections:
try:
print(f"[CLEANUP] Cancelling subscription for connection {connection.connection_id}")
connector = await connector_service.get_connector(connection.connection_id)
if connector:
subscription_id = connection.config.get('webhook_channel_id')
resource_id = connection.config.get('resource_id') # If stored
await connector.cleanup_subscription(subscription_id, resource_id)
print(f"[CLEANUP] Cancelled subscription {subscription_id}")
except Exception as e:
print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}")
print(f"[CLEANUP] Finished cancelling {len(active_connections)} subscriptions")
except Exception as e:
print(f"[ERROR] Failed to cleanup subscriptions: {e}")
if __name__ == "__main__":
import uvicorn
# Register cleanup function
atexit.register(cleanup)
# Create app
app = create_app()
# Run the server (startup tasks now handled by Starlette startup event)
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
reload=False, # Disable reload since we're running from main
)