diff --git a/src/main.py b/src/main.py index dc7366d5..ecf40a78 100644 --- a/src/main.py +++ b/src/main.py @@ -31,15 +31,16 @@ from auth_middleware import optional_auth, require_auth # Configuration and setup from config.settings import INDEX_BODY, INDEX_NAME, SESSION_SECRET, clients + # Existing services from connectors.service import ConnectorService from services.auth_service import AuthService from services.chat_service import ChatService +from services.document_service import DocumentService +from services.knowledge_filter_service import KnowledgeFilterService # Services from services.langflow_file_service import LangflowFileService -from services.document_service import DocumentService -from services.knowledge_filter_service import KnowledgeFilterService from services.monitor_service import MonitorService from services.search_service import SearchService from services.task_service import TaskService @@ -53,7 +54,7 @@ async def wait_for_opensearch(): """Wait for OpenSearch to be ready with retries""" max_retries = 30 retry_delay = 2 - + for attempt in range(max_retries): try: await clients.opensearch.info() @@ -74,11 +75,11 @@ async def configure_alerting_security(): alerting_settings = { "persistent": { "plugins.alerting.filter_by_backend_roles": "false", - "opendistro.alerting.filter_by_backend_roles": "false", + "opendistro.alerting.filter_by_backend_roles": "false", "opensearch.notifications.general.filter_by_backend_roles": "false" } } - + # Use admin client (clients.opensearch uses admin credentials) response = await clients.opensearch.cluster.put_settings(body=alerting_settings) print("Alerting security settings configured successfully") @@ -90,14 +91,14 @@ async def configure_alerting_security(): async def init_index(): """Initialize OpenSearch index and security roles""" await wait_for_opensearch() - + # Create documents index if not await clients.opensearch.indices.exists(index=INDEX_NAME): await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY) print(f"Created index '{INDEX_NAME}'") else: print(f"Index '{INDEX_NAME}' already exists, skipping creation.") - + # Create knowledge filters index knowledge_filter_index_name = "knowledge_filters" knowledge_filter_index_body = { @@ -116,13 +117,13 @@ async def init_index(): } } } - + if not await clients.opensearch.indices.exists(index=knowledge_filter_index_name): await clients.opensearch.indices.create(index=knowledge_filter_index_name, body=knowledge_filter_index_body) print(f"Created index '{knowledge_filter_index_name}'") else: print(f"Index '{knowledge_filter_index_name}' already exists, skipping creation.") - + # Configure alerting plugin security settings await configure_alerting_security() @@ -131,10 +132,10 @@ def generate_jwt_keys(): keys_dir = "keys" private_key_path = os.path.join(keys_dir, "private_key.pem") public_key_path = os.path.join(keys_dir, "public_key.pem") - + # Create keys directory if it doesn't exist os.makedirs(keys_dir, exist_ok=True) - + # Generate keys if they don't exist if not os.path.exists(private_key_path): try: @@ -142,12 +143,12 @@ def generate_jwt_keys(): subprocess.run([ "openssl", "genrsa", "-out", private_key_path, "2048" ], check=True, capture_output=True) - + # Generate public key subprocess.run([ "openssl", "rsa", "-in", private_key_path, "-pubout", "-out", public_key_path ], check=True, capture_output=True) - + print("Generated RSA keys for JWT signing") except subprocess.CalledProcessError as e: print(f"Failed to generate RSA keys: {e}") @@ -163,19 +164,19 @@ async def init_index_when_ready(): except Exception as e: print(f"OpenSearch index initialization failed: {e}") print("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready") - + async def initialize_services(): """Initialize all services and their dependencies""" # Generate JWT keys if they don't exist generate_jwt_keys() - + # Initialize clients (now async to generate Langflow API key) await clients.initialize() - + # Initialize session manager session_manager = SessionManager(SESSION_SECRET) - + # Initialize services document_service = DocumentService(session_manager=session_manager) search_service = SearchService(session_manager) @@ -183,10 +184,10 @@ async def initialize_services(): chat_service = ChatService() knowledge_filter_service = KnowledgeFilterService(session_manager) monitor_service = MonitorService(session_manager) - + # Set process pool for document service document_service.process_pool = process_pool - + # Initialize connector service connector_service = ConnectorService( patched_async_client=clients.patched_async_client, @@ -196,10 +197,10 @@ async def initialize_services(): task_service=task_service, session_manager=session_manager ) - + # Initialize auth service auth_service = AuthService(session_manager, connector_service) - + # Load persisted connector connections at startup so webhooks and syncs # can resolve existing subscriptions immediately after server boot # Skip in no-auth mode since connectors require OAuth @@ -213,7 +214,7 @@ async def initialize_services(): print(f"[WARNING] Failed to load persisted connections on startup: {e}") else: print("[CONNECTORS] Skipping connection loading in no-auth mode") - + # New: Langflow file service langflow_file_service = LangflowFileService() @@ -234,13 +235,13 @@ async def initialize_services(): async def create_app(): """Create and configure the Starlette application""" services = await initialize_services() - + # Create route handlers with service dependencies injected routes = [ # Upload endpoints - Route("/upload", + Route("/upload", require_auth(services['session_manager'])( - partial(upload.upload, + partial(upload.upload, document_service=services['document_service'], session_manager=services['session_manager']) ), methods=["POST"]), @@ -266,15 +267,15 @@ async def create_app(): langflow_file_service=services['langflow_file_service'], session_manager=services['session_manager']) ), methods=["DELETE"]), - - Route("/upload_context", + + Route("/upload_context", require_auth(services['session_manager'])( partial(upload.upload_context, document_service=services['document_service'], chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + Route("/upload_path", require_auth(services['session_manager'])( partial(upload.upload_path, @@ -294,227 +295,227 @@ async def create_app(): task_service=services['task_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/tasks/{task_id}", + + Route("/tasks/{task_id}", require_auth(services['session_manager'])( partial(tasks.task_status, task_service=services['task_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/tasks", + + Route("/tasks", require_auth(services['session_manager'])( partial(tasks.all_tasks, task_service=services['task_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/tasks/{task_id}/cancel", + + Route("/tasks/{task_id}/cancel", require_auth(services['session_manager'])( partial(tasks.cancel_task, task_service=services['task_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + # Search endpoint - Route("/search", + Route("/search", require_auth(services['session_manager'])( partial(search.search, search_service=services['search_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + # Knowledge Filter endpoints - Route("/knowledge-filter", + Route("/knowledge-filter", require_auth(services['session_manager'])( partial(knowledge_filter.create_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/knowledge-filter/search", + + Route("/knowledge-filter/search", require_auth(services['session_manager'])( partial(knowledge_filter.search_knowledge_filters, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/knowledge-filter/{filter_id}", + + Route("/knowledge-filter/{filter_id}", require_auth(services['session_manager'])( partial(knowledge_filter.get_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/knowledge-filter/{filter_id}", + + Route("/knowledge-filter/{filter_id}", require_auth(services['session_manager'])( partial(knowledge_filter.update_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["PUT"]), - - Route("/knowledge-filter/{filter_id}", + + Route("/knowledge-filter/{filter_id}", require_auth(services['session_manager'])( partial(knowledge_filter.delete_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["DELETE"]), - + # Knowledge Filter Subscription endpoints - Route("/knowledge-filter/{filter_id}/subscribe", + Route("/knowledge-filter/{filter_id}/subscribe", require_auth(services['session_manager'])( partial(knowledge_filter.subscribe_to_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], monitor_service=services['monitor_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/knowledge-filter/{filter_id}/subscriptions", + + Route("/knowledge-filter/{filter_id}/subscriptions", require_auth(services['session_manager'])( partial(knowledge_filter.list_knowledge_filter_subscriptions, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/knowledge-filter/{filter_id}/subscribe/{subscription_id}", + + Route("/knowledge-filter/{filter_id}/subscribe/{subscription_id}", require_auth(services['session_manager'])( partial(knowledge_filter.cancel_knowledge_filter_subscription, knowledge_filter_service=services['knowledge_filter_service'], monitor_service=services['monitor_service'], session_manager=services['session_manager']) ), methods=["DELETE"]), - + # Knowledge Filter Webhook endpoint (no auth required - called by OpenSearch) - Route("/knowledge-filter/{filter_id}/webhook/{subscription_id}", + Route("/knowledge-filter/{filter_id}/webhook/{subscription_id}", partial(knowledge_filter.knowledge_filter_webhook, knowledge_filter_service=services['knowledge_filter_service'], - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["POST"]), - + # Chat endpoints - Route("/chat", + Route("/chat", require_auth(services['session_manager'])( partial(chat.chat_endpoint, chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/langflow", + + Route("/langflow", require_auth(services['session_manager'])( partial(chat.langflow_endpoint, chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + # Chat history endpoints - Route("/chat/history", + Route("/chat/history", require_auth(services['session_manager'])( partial(chat.chat_history_endpoint, chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/langflow/history", + + Route("/langflow/history", require_auth(services['session_manager'])( partial(chat.langflow_history_endpoint, chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["GET"]), - + # Authentication endpoints - Route("/auth/init", + Route("/auth/init", optional_auth(services['session_manager'])( partial(auth.auth_init, auth_service=services['auth_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/auth/callback", + + Route("/auth/callback", partial(auth.auth_callback, auth_service=services['auth_service'], - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["POST"]), - - Route("/auth/me", + + Route("/auth/me", optional_auth(services['session_manager'])( partial(auth.auth_me, auth_service=services['auth_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/auth/logout", + + Route("/auth/logout", require_auth(services['session_manager'])( partial(auth.auth_logout, auth_service=services['auth_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + # Connector endpoints - Route("/connectors", + Route("/connectors", require_auth(services['session_manager'])( partial(connectors.list_connectors, connector_service=services['connector_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/connectors/{connector_type}/sync", + + Route("/connectors/{connector_type}/sync", require_auth(services['session_manager'])( partial(connectors.connector_sync, connector_service=services['connector_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/connectors/{connector_type}/status", + + Route("/connectors/{connector_type}/status", require_auth(services['session_manager'])( partial(connectors.connector_status, connector_service=services['connector_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/connectors/{connector_type}/webhook", + + Route("/connectors/{connector_type}/webhook", partial(connectors.connector_webhook, connector_service=services['connector_service'], - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["POST", "GET"]), - + # OIDC endpoints - Route("/.well-known/openid-configuration", + Route("/.well-known/openid-configuration", partial(oidc.oidc_discovery, - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["GET"]), - - Route("/auth/jwks", + + Route("/auth/jwks", partial(oidc.jwks_endpoint, - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["GET"]), - - Route("/auth/introspect", + + Route("/auth/introspect", partial(oidc.token_introspection, - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["POST"]), - + # Settings endpoint - Route("/settings", + Route("/settings", require_auth(services['session_manager'])( partial(settings.get_settings, session_manager=services['session_manager']) ), methods=["GET"]), ] - + app = Starlette(debug=True, routes=routes) app.state.services = services # Store services for cleanup - + # Add startup event handler - @app.on_event("startup") + @app.on_event("startup") async def startup_event(): # Start index initialization in background to avoid blocking OIDC endpoints asyncio.create_task(init_index_when_ready()) - + # Add shutdown event handler @app.on_event("shutdown") async def shutdown_event(): await cleanup_subscriptions_proper(services) - + return app async def startup(): @@ -533,15 +534,15 @@ def cleanup(): async def cleanup_subscriptions_proper(services): """Cancel all active webhook subscriptions""" print("[CLEANUP] Cancelling active webhook subscriptions...") - + try: connector_service = services['connector_service'] await connector_service.connection_manager.load_connections() - - # Get all active connections with webhook subscriptions + + # Get all active connections with webhook subscriptions all_connections = await connector_service.connection_manager.list_connections() active_connections = [c for c in all_connections if c.is_active and c.config.get('webhook_channel_id')] - + for connection in active_connections: try: print(f"[CLEANUP] Cancelling subscription for connection {connection.connection_id}") @@ -552,22 +553,22 @@ async def cleanup_subscriptions_proper(services): print(f"[CLEANUP] Cancelled subscription {subscription_id}") except Exception as e: print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}") - + print(f"[CLEANUP] Finished cancelling {len(active_connections)} subscriptions") - + except Exception as e: print(f"[ERROR] Failed to cleanup subscriptions: {e}") if __name__ == "__main__": import uvicorn - + # Register cleanup function atexit.register(cleanup) - + # Create app asynchronously app = asyncio.run(create_app()) - + # Run the server (startup tasks now handled by Starlette startup event) uvicorn.run( app,