327 lines
No EOL
12 KiB
Python
327 lines
No EOL
12 KiB
Python
import asyncio
|
|
import atexit
|
|
import multiprocessing
|
|
from functools import partial
|
|
from starlette.applications import Starlette
|
|
from starlette.routing import Route
|
|
|
|
# Set multiprocessing start method to 'spawn' for CUDA compatibility
|
|
multiprocessing.set_start_method('spawn', force=True)
|
|
|
|
# Create process pool FIRST, before any torch/CUDA imports
|
|
from utils.process_pool import process_pool
|
|
|
|
import torch
|
|
|
|
# Configuration and setup
|
|
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
|
|
from utils.gpu_detection import detect_gpu_devices
|
|
|
|
# Services
|
|
from services.document_service import DocumentService
|
|
from services.search_service import SearchService
|
|
from services.task_service import TaskService
|
|
from services.auth_service import AuthService
|
|
from services.chat_service import ChatService
|
|
|
|
# Existing services
|
|
from connectors.service import ConnectorService
|
|
from session_manager import SessionManager
|
|
from auth_middleware import require_auth, optional_auth
|
|
|
|
# API endpoints
|
|
from api import upload, search, chat, auth, connectors, tasks, oidc
|
|
|
|
print("CUDA available:", torch.cuda.is_available())
|
|
print("CUDA version PyTorch was built with:", torch.version.cuda)
|
|
|
|
async def wait_for_opensearch():
|
|
"""Wait for OpenSearch to be ready with retries"""
|
|
max_retries = 30
|
|
retry_delay = 2
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
await clients.opensearch.info()
|
|
print("OpenSearch is ready!")
|
|
return
|
|
except Exception as e:
|
|
print(f"Attempt {attempt + 1}/{max_retries}: OpenSearch not ready yet ({e})")
|
|
if attempt < max_retries - 1:
|
|
await asyncio.sleep(retry_delay)
|
|
else:
|
|
raise Exception("OpenSearch failed to become ready")
|
|
|
|
async def init_index():
|
|
"""Initialize OpenSearch index and security roles"""
|
|
await wait_for_opensearch()
|
|
|
|
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
|
|
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
|
|
print(f"Created index '{INDEX_NAME}'")
|
|
else:
|
|
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
|
|
|
async def init_index_when_ready():
|
|
"""Initialize OpenSearch index when it becomes available"""
|
|
try:
|
|
await init_index()
|
|
print("OpenSearch index initialization completed successfully")
|
|
except Exception as e:
|
|
print(f"OpenSearch index initialization failed: {e}")
|
|
print("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready")
|
|
|
|
|
|
def initialize_services():
|
|
"""Initialize all services and their dependencies"""
|
|
# Initialize clients
|
|
clients.initialize()
|
|
|
|
# Initialize session manager
|
|
session_manager = SessionManager(SESSION_SECRET)
|
|
|
|
# Initialize services
|
|
document_service = DocumentService(session_manager=session_manager)
|
|
search_service = SearchService(session_manager)
|
|
task_service = TaskService(document_service, process_pool)
|
|
chat_service = ChatService()
|
|
|
|
# Set process pool for document service
|
|
document_service.process_pool = process_pool
|
|
|
|
# Initialize connector service
|
|
connector_service = ConnectorService(
|
|
patched_async_client=clients.patched_async_client,
|
|
process_pool=process_pool,
|
|
embed_model="text-embedding-3-small",
|
|
index_name=INDEX_NAME,
|
|
task_service=task_service,
|
|
session_manager=session_manager
|
|
)
|
|
|
|
# Initialize auth service
|
|
auth_service = AuthService(session_manager, connector_service)
|
|
|
|
return {
|
|
'document_service': document_service,
|
|
'search_service': search_service,
|
|
'task_service': task_service,
|
|
'chat_service': chat_service,
|
|
'auth_service': auth_service,
|
|
'connector_service': connector_service,
|
|
'session_manager': session_manager
|
|
}
|
|
|
|
def create_app():
|
|
"""Create and configure the Starlette application"""
|
|
services = initialize_services()
|
|
|
|
# Create route handlers with service dependencies injected
|
|
routes = [
|
|
# Upload endpoints
|
|
Route("/upload",
|
|
require_auth(services['session_manager'])(
|
|
partial(upload.upload,
|
|
document_service=services['document_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
Route("/upload_context",
|
|
require_auth(services['session_manager'])(
|
|
partial(upload.upload_context,
|
|
document_service=services['document_service'],
|
|
chat_service=services['chat_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
Route("/upload_path",
|
|
require_auth(services['session_manager'])(
|
|
partial(upload.upload_path,
|
|
task_service=services['task_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
Route("/tasks/{task_id}",
|
|
require_auth(services['session_manager'])(
|
|
partial(tasks.task_status,
|
|
task_service=services['task_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["GET"]),
|
|
|
|
Route("/tasks",
|
|
require_auth(services['session_manager'])(
|
|
partial(tasks.all_tasks,
|
|
task_service=services['task_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["GET"]),
|
|
|
|
Route("/tasks/{task_id}/cancel",
|
|
require_auth(services['session_manager'])(
|
|
partial(tasks.cancel_task,
|
|
task_service=services['task_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
# Search endpoint
|
|
Route("/search",
|
|
require_auth(services['session_manager'])(
|
|
partial(search.search,
|
|
search_service=services['search_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
# Chat endpoints
|
|
Route("/chat",
|
|
require_auth(services['session_manager'])(
|
|
partial(chat.chat_endpoint,
|
|
chat_service=services['chat_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
Route("/langflow",
|
|
require_auth(services['session_manager'])(
|
|
partial(chat.langflow_endpoint,
|
|
chat_service=services['chat_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
# Authentication endpoints
|
|
Route("/auth/init",
|
|
optional_auth(services['session_manager'])(
|
|
partial(auth.auth_init,
|
|
auth_service=services['auth_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
Route("/auth/callback",
|
|
partial(auth.auth_callback,
|
|
auth_service=services['auth_service'],
|
|
session_manager=services['session_manager']),
|
|
methods=["POST"]),
|
|
|
|
Route("/auth/me",
|
|
optional_auth(services['session_manager'])(
|
|
partial(auth.auth_me,
|
|
auth_service=services['auth_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["GET"]),
|
|
|
|
Route("/auth/logout",
|
|
require_auth(services['session_manager'])(
|
|
partial(auth.auth_logout,
|
|
auth_service=services['auth_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
# Connector endpoints
|
|
Route("/connectors/{connector_type}/sync",
|
|
require_auth(services['session_manager'])(
|
|
partial(connectors.connector_sync,
|
|
connector_service=services['connector_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["POST"]),
|
|
|
|
Route("/connectors/{connector_type}/status",
|
|
require_auth(services['session_manager'])(
|
|
partial(connectors.connector_status,
|
|
connector_service=services['connector_service'],
|
|
session_manager=services['session_manager'])
|
|
), methods=["GET"]),
|
|
|
|
Route("/connectors/{connector_type}/webhook",
|
|
partial(connectors.connector_webhook,
|
|
connector_service=services['connector_service'],
|
|
session_manager=services['session_manager']),
|
|
methods=["POST", "GET"]),
|
|
|
|
# OIDC endpoints
|
|
Route("/.well-known/openid-configuration",
|
|
partial(oidc.oidc_discovery,
|
|
session_manager=services['session_manager']),
|
|
methods=["GET"]),
|
|
|
|
Route("/auth/jwks",
|
|
partial(oidc.jwks_endpoint,
|
|
session_manager=services['session_manager']),
|
|
methods=["GET"]),
|
|
|
|
Route("/auth/introspect",
|
|
partial(oidc.token_introspection,
|
|
session_manager=services['session_manager']),
|
|
methods=["POST"]),
|
|
]
|
|
|
|
app = Starlette(debug=True, routes=routes)
|
|
app.state.services = services # Store services for cleanup
|
|
|
|
# Add startup event handler
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
# Start index initialization in background to avoid blocking OIDC endpoints
|
|
asyncio.create_task(init_index_when_ready())
|
|
|
|
# Add shutdown event handler
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
await cleanup_subscriptions_proper(services)
|
|
|
|
return app
|
|
|
|
async def startup():
|
|
"""Application startup tasks"""
|
|
await init_index()
|
|
# Get services from app state if needed for initialization
|
|
# services = app.state.services
|
|
# await services['connector_service'].initialize()
|
|
|
|
def cleanup():
|
|
"""Cleanup on application shutdown"""
|
|
# Cleanup process pools only (webhooks handled by Starlette shutdown)
|
|
print("[CLEANUP] Shutting down...")
|
|
pass
|
|
|
|
async def cleanup_subscriptions_proper(services):
|
|
"""Cancel all active webhook subscriptions"""
|
|
print("[CLEANUP] Cancelling active webhook subscriptions...")
|
|
|
|
try:
|
|
connector_service = services['connector_service']
|
|
await connector_service.connection_manager.load_connections()
|
|
|
|
# Get all active connections with webhook subscriptions
|
|
all_connections = await connector_service.connection_manager.list_connections()
|
|
active_connections = [c for c in all_connections if c.is_active and c.config.get('webhook_channel_id')]
|
|
|
|
for connection in active_connections:
|
|
try:
|
|
print(f"[CLEANUP] Cancelling subscription for connection {connection.connection_id}")
|
|
connector = await connector_service.get_connector(connection.connection_id)
|
|
if connector:
|
|
subscription_id = connection.config.get('webhook_channel_id')
|
|
resource_id = connection.config.get('resource_id') # If stored
|
|
await connector.cleanup_subscription(subscription_id, resource_id)
|
|
print(f"[CLEANUP] Cancelled subscription {subscription_id}")
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}")
|
|
|
|
print(f"[CLEANUP] Finished cancelling {len(active_connections)} subscriptions")
|
|
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to cleanup subscriptions: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# Register cleanup function
|
|
atexit.register(cleanup)
|
|
|
|
# Create app
|
|
app = create_app()
|
|
|
|
# Run the server (startup tasks now handled by Starlette startup event)
|
|
uvicorn.run(
|
|
app,
|
|
host="0.0.0.0",
|
|
port=8000,
|
|
reload=False, # Disable reload since we're running from main
|
|
) |