diff --git a/src/main.py b/src/main.py index d19ff68d..f59136aa 100644 --- a/src/main.py +++ b/src/main.py @@ -528,6 +528,357 @@ async def create_app(): ), methods=["GET"], ), + Route( + "/upload_bucket", + require_auth(services["session_manager"])( + partial( + upload.upload_bucket, + task_service=services["task_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + Route( + "/tasks/{task_id}", + require_auth(services["session_manager"])( + partial( + tasks.task_status, + task_service=services["task_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/tasks", + require_auth(services["session_manager"])( + partial( + tasks.all_tasks, + task_service=services["task_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/tasks/{task_id}/cancel", + require_auth(services["session_manager"])( + partial( + tasks.cancel_task, + task_service=services["task_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + # Search endpoint + Route( + "/search", + require_auth(services["session_manager"])( + partial( + search.search, + search_service=services["search_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + # Knowledge Filter endpoints + Route( + "/knowledge-filter", + require_auth(services["session_manager"])( + partial( + knowledge_filter.create_knowledge_filter, + knowledge_filter_service=services["knowledge_filter_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + Route( + "/knowledge-filter/search", + require_auth(services["session_manager"])( + partial( + knowledge_filter.search_knowledge_filters, + knowledge_filter_service=services["knowledge_filter_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + Route( + "/knowledge-filter/{filter_id}", + require_auth(services["session_manager"])( + partial( + knowledge_filter.get_knowledge_filter, + knowledge_filter_service=services["knowledge_filter_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/knowledge-filter/{filter_id}", + require_auth(services["session_manager"])( + partial( + knowledge_filter.update_knowledge_filter, + knowledge_filter_service=services["knowledge_filter_service"], + session_manager=services["session_manager"], + ) + ), + methods=["PUT"], + ), + Route( + "/knowledge-filter/{filter_id}", + require_auth(services["session_manager"])( + partial( + knowledge_filter.delete_knowledge_filter, + knowledge_filter_service=services["knowledge_filter_service"], + session_manager=services["session_manager"], + ) + ), + methods=["DELETE"], + ), + # Knowledge Filter Subscription endpoints + Route( + "/knowledge-filter/{filter_id}/subscribe", + require_auth(services["session_manager"])( + partial( + knowledge_filter.subscribe_to_knowledge_filter, + knowledge_filter_service=services["knowledge_filter_service"], + monitor_service=services["monitor_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + Route( + "/knowledge-filter/{filter_id}/subscriptions", + require_auth(services["session_manager"])( + partial( + knowledge_filter.list_knowledge_filter_subscriptions, + knowledge_filter_service=services["knowledge_filter_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/knowledge-filter/{filter_id}/subscribe/{subscription_id}", + require_auth(services["session_manager"])( + partial( + knowledge_filter.cancel_knowledge_filter_subscription, + knowledge_filter_service=services["knowledge_filter_service"], + monitor_service=services["monitor_service"], + session_manager=services["session_manager"], + ) + ), + methods=["DELETE"], + ), + # Knowledge Filter Webhook endpoint (no auth required - called by OpenSearch) + Route( + "/knowledge-filter/{filter_id}/webhook/{subscription_id}", + partial( + knowledge_filter.knowledge_filter_webhook, + knowledge_filter_service=services["knowledge_filter_service"], + session_manager=services["session_manager"], + ), + methods=["POST"], + ), + # Chat endpoints + Route( + "/chat", + require_auth(services["session_manager"])( + partial( + chat.chat_endpoint, + chat_service=services["chat_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + Route( + "/langflow", + require_auth(services["session_manager"])( + partial( + chat.langflow_endpoint, + chat_service=services["chat_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + # Chat history endpoints + Route( + "/chat/history", + require_auth(services["session_manager"])( + partial( + chat.chat_history_endpoint, + chat_service=services["chat_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/langflow/history", + require_auth(services["session_manager"])( + partial( + chat.langflow_history_endpoint, + chat_service=services["chat_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + # Authentication endpoints + Route( + "/auth/init", + optional_auth(services["session_manager"])( + partial( + auth.auth_init, + auth_service=services["auth_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + Route( + "/auth/callback", + partial( + auth.auth_callback, + auth_service=services["auth_service"], + session_manager=services["session_manager"], + ), + methods=["POST"], + ), + Route( + "/auth/me", + optional_auth(services["session_manager"])( + partial( + auth.auth_me, + auth_service=services["auth_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/auth/logout", + require_auth(services["session_manager"])( + partial( + auth.auth_logout, + auth_service=services["auth_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + # Connector endpoints + Route( + "/connectors", + require_auth(services["session_manager"])( + partial( + connectors.list_connectors, + connector_service=services["connector_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/connectors/{connector_type}/sync", + require_auth(services["session_manager"])( + partial( + connectors.connector_sync, + connector_service=services["connector_service"], + session_manager=services["session_manager"], + ) + ), + methods=["POST"], + ), + Route( + "/connectors/{connector_type}/status", + require_auth(services["session_manager"])( + partial( + connectors.connector_status, + connector_service=services["connector_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/connectors/{connector_type}/token", + require_auth(services["session_manager"])( + partial( + connectors.connector_token, + connector_service=services["connector_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/connectors/{connector_type}/webhook", + partial( + connectors.connector_webhook, + connector_service=services["connector_service"], + session_manager=services["session_manager"], + ), + methods=["POST", "GET"], + ), + # OIDC endpoints + Route( + "/.well-known/openid-configuration", + partial(oidc.oidc_discovery, session_manager=services["session_manager"]), + methods=["GET"], + ), + Route( + "/auth/jwks", + partial(oidc.jwks_endpoint, session_manager=services["session_manager"]), + methods=["GET"], + ), + Route( + "/auth/introspect", + partial( + oidc.token_introspection, session_manager=services["session_manager"] + ), + methods=["POST"], + ), + # Settings endpoint + Route( + "/settings", + require_auth(services["session_manager"])( + partial( + settings.get_settings, session_manager=services["session_manager"] + ) + ), + methods=["GET"], + ), + Route( + "/nudges", + require_auth(services["session_manager"])( + partial( + nudges.nudges_from_kb_endpoint, + chat_service=services["chat_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), + Route( + "/nudges/{chat_id}", + require_auth(services["session_manager"])( + partial( + nudges.nudges_from_chat_id_endpoint, + chat_service=services["chat_service"], + session_manager=services["session_manager"], + ) + ), + methods=["GET"], + ), Route( "/router/upload_ingest", require_auth(services["session_manager"])( @@ -540,4 +891,105 @@ async def create_app(): ), methods=["POST"], ), - ] \ No newline at end of file + ] + + app = Starlette(debug=True, routes=routes) + app.state.services = services # Store services for cleanup + app.state.background_tasks = set() + + # Add startup event handler + @app.on_event("startup") + async def startup_event(): + # Start index initialization in background to avoid blocking OIDC endpoints + t1 = asyncio.create_task(startup_tasks(services)) + app.state.background_tasks.add(t1) + t1.add_done_callback(app.state.background_tasks.discard) + + # Add shutdown event handler + @app.on_event("shutdown") + async def shutdown_event(): + await cleanup_subscriptions_proper(services) + + return app + + +async def startup(): + """Application startup tasks""" + await init_index() + # Get services from app state if needed for initialization + # services = app.state.services + # await services['connector_service'].initialize() + + +def cleanup(): + """Cleanup on application shutdown""" + # Cleanup process pools only (webhooks handled by Starlette shutdown) + logger.info("Application shutting down") + pass + + +async def cleanup_subscriptions_proper(services): + """Cancel all active webhook subscriptions""" + logger.info("Cancelling active webhook subscriptions") + + try: + connector_service = services["connector_service"] + await connector_service.connection_manager.load_connections() + + # Get all active connections with webhook subscriptions + all_connections = await connector_service.connection_manager.list_connections() + active_connections = [ + c + for c in all_connections + if c.is_active and c.config.get("webhook_channel_id") + ] + + for connection in active_connections: + try: + logger.info( + "Cancelling subscription for connection", + connection_id=connection.connection_id, + ) + connector = await connector_service.get_connector( + connection.connection_id + ) + if connector: + subscription_id = connection.config.get("webhook_channel_id") + await connector.cleanup_subscription(subscription_id) + logger.info( + "Cancelled subscription", subscription_id=subscription_id + ) + except Exception as e: + logger.error( + "Failed to cancel subscription", + connection_id=connection.connection_id, + error=str(e), + ) + + logger.info( + "Finished cancelling subscriptions", + subscription_count=len(active_connections), + ) + + except Exception as e: + logger.error("Failed to cleanup subscriptions", error=str(e)) + + +if __name__ == "__main__": + import uvicorn + + # TUI check already handled at top of file + # Register cleanup function + atexit.register(cleanup) + + # Create app asynchronously + app = asyncio.run(create_app()) + + # Run the server (startup tasks now handled by Starlette startup event) + uvicorn.run( + app, + workers=1, + host="0.0.0.0", + port=8000, + reload=False, # Disable reload since we're running from main + )