From 4be48270b774f970ee4c679d0cc0921d69e96e37 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 3 Sep 2025 10:36:54 -0300 Subject: [PATCH] Refactor main.py for improved organization and clarity This commit reorganizes the import statements in main.py, enhancing the structure and readability of the code. It also includes minor formatting adjustments for consistency. The changes contribute to a cleaner codebase, facilitating easier maintenance and future development. --- src/main.py | 213 ++++++++++++++++++++++++++-------------------------- 1 file changed, 107 insertions(+), 106 deletions(-) diff --git a/src/main.py b/src/main.py index dc7366d5..ecf40a78 100644 --- a/src/main.py +++ b/src/main.py @@ -31,15 +31,16 @@ from auth_middleware import optional_auth, require_auth # Configuration and setup from config.settings import INDEX_BODY, INDEX_NAME, SESSION_SECRET, clients + # Existing services from connectors.service import ConnectorService from services.auth_service import AuthService from services.chat_service import ChatService +from services.document_service import DocumentService +from services.knowledge_filter_service import KnowledgeFilterService # Services from services.langflow_file_service import LangflowFileService -from services.document_service import DocumentService -from services.knowledge_filter_service import KnowledgeFilterService from services.monitor_service import MonitorService from services.search_service import SearchService from services.task_service import TaskService @@ -53,7 +54,7 @@ async def wait_for_opensearch(): """Wait for OpenSearch to be ready with retries""" max_retries = 30 retry_delay = 2 - + for attempt in range(max_retries): try: await clients.opensearch.info() @@ -74,11 +75,11 @@ async def configure_alerting_security(): alerting_settings = { "persistent": { "plugins.alerting.filter_by_backend_roles": "false", - "opendistro.alerting.filter_by_backend_roles": "false", + "opendistro.alerting.filter_by_backend_roles": "false", "opensearch.notifications.general.filter_by_backend_roles": "false" } } - + # Use admin client (clients.opensearch uses admin credentials) response = await clients.opensearch.cluster.put_settings(body=alerting_settings) print("Alerting security settings configured successfully") @@ -90,14 +91,14 @@ async def configure_alerting_security(): async def init_index(): """Initialize OpenSearch index and security roles""" await wait_for_opensearch() - + # Create documents index if not await clients.opensearch.indices.exists(index=INDEX_NAME): await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY) print(f"Created index '{INDEX_NAME}'") else: print(f"Index '{INDEX_NAME}' already exists, skipping creation.") - + # Create knowledge filters index knowledge_filter_index_name = "knowledge_filters" knowledge_filter_index_body = { @@ -116,13 +117,13 @@ async def init_index(): } } } - + if not await clients.opensearch.indices.exists(index=knowledge_filter_index_name): await clients.opensearch.indices.create(index=knowledge_filter_index_name, body=knowledge_filter_index_body) print(f"Created index '{knowledge_filter_index_name}'") else: print(f"Index '{knowledge_filter_index_name}' already exists, skipping creation.") - + # Configure alerting plugin security settings await configure_alerting_security() @@ -131,10 +132,10 @@ def generate_jwt_keys(): keys_dir = "keys" private_key_path = os.path.join(keys_dir, "private_key.pem") public_key_path = os.path.join(keys_dir, "public_key.pem") - + # Create keys directory if it doesn't exist os.makedirs(keys_dir, exist_ok=True) - + # Generate keys if they don't exist if not os.path.exists(private_key_path): try: @@ -142,12 +143,12 @@ def generate_jwt_keys(): subprocess.run([ "openssl", "genrsa", "-out", private_key_path, "2048" ], check=True, capture_output=True) - + # Generate public key subprocess.run([ "openssl", "rsa", "-in", private_key_path, "-pubout", "-out", public_key_path ], check=True, capture_output=True) - + print("Generated RSA keys for JWT signing") except subprocess.CalledProcessError as e: print(f"Failed to generate RSA keys: {e}") @@ -163,19 +164,19 @@ async def init_index_when_ready(): except Exception as e: print(f"OpenSearch index initialization failed: {e}") print("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready") - + async def initialize_services(): """Initialize all services and their dependencies""" # Generate JWT keys if they don't exist generate_jwt_keys() - + # Initialize clients (now async to generate Langflow API key) await clients.initialize() - + # Initialize session manager session_manager = SessionManager(SESSION_SECRET) - + # Initialize services document_service = DocumentService(session_manager=session_manager) search_service = SearchService(session_manager) @@ -183,10 +184,10 @@ async def initialize_services(): chat_service = ChatService() knowledge_filter_service = KnowledgeFilterService(session_manager) monitor_service = MonitorService(session_manager) - + # Set process pool for document service document_service.process_pool = process_pool - + # Initialize connector service connector_service = ConnectorService( patched_async_client=clients.patched_async_client, @@ -196,10 +197,10 @@ async def initialize_services(): task_service=task_service, session_manager=session_manager ) - + # Initialize auth service auth_service = AuthService(session_manager, connector_service) - + # Load persisted connector connections at startup so webhooks and syncs # can resolve existing subscriptions immediately after server boot # Skip in no-auth mode since connectors require OAuth @@ -213,7 +214,7 @@ async def initialize_services(): print(f"[WARNING] Failed to load persisted connections on startup: {e}") else: print("[CONNECTORS] Skipping connection loading in no-auth mode") - + # New: Langflow file service langflow_file_service = LangflowFileService() @@ -234,13 +235,13 @@ async def initialize_services(): async def create_app(): """Create and configure the Starlette application""" services = await initialize_services() - + # Create route handlers with service dependencies injected routes = [ # Upload endpoints - Route("/upload", + Route("/upload", require_auth(services['session_manager'])( - partial(upload.upload, + partial(upload.upload, document_service=services['document_service'], session_manager=services['session_manager']) ), methods=["POST"]), @@ -266,15 +267,15 @@ async def create_app(): langflow_file_service=services['langflow_file_service'], session_manager=services['session_manager']) ), methods=["DELETE"]), - - Route("/upload_context", + + Route("/upload_context", require_auth(services['session_manager'])( partial(upload.upload_context, document_service=services['document_service'], chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + Route("/upload_path", require_auth(services['session_manager'])( partial(upload.upload_path, @@ -294,227 +295,227 @@ async def create_app(): task_service=services['task_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/tasks/{task_id}", + + Route("/tasks/{task_id}", require_auth(services['session_manager'])( partial(tasks.task_status, task_service=services['task_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/tasks", + + Route("/tasks", require_auth(services['session_manager'])( partial(tasks.all_tasks, task_service=services['task_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/tasks/{task_id}/cancel", + + Route("/tasks/{task_id}/cancel", require_auth(services['session_manager'])( partial(tasks.cancel_task, task_service=services['task_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + # Search endpoint - Route("/search", + Route("/search", require_auth(services['session_manager'])( partial(search.search, search_service=services['search_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + # Knowledge Filter endpoints - Route("/knowledge-filter", + Route("/knowledge-filter", require_auth(services['session_manager'])( partial(knowledge_filter.create_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/knowledge-filter/search", + + Route("/knowledge-filter/search", require_auth(services['session_manager'])( partial(knowledge_filter.search_knowledge_filters, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/knowledge-filter/{filter_id}", + + Route("/knowledge-filter/{filter_id}", require_auth(services['session_manager'])( partial(knowledge_filter.get_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/knowledge-filter/{filter_id}", + + Route("/knowledge-filter/{filter_id}", require_auth(services['session_manager'])( partial(knowledge_filter.update_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["PUT"]), - - Route("/knowledge-filter/{filter_id}", + + Route("/knowledge-filter/{filter_id}", require_auth(services['session_manager'])( partial(knowledge_filter.delete_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["DELETE"]), - + # Knowledge Filter Subscription endpoints - Route("/knowledge-filter/{filter_id}/subscribe", + Route("/knowledge-filter/{filter_id}/subscribe", require_auth(services['session_manager'])( partial(knowledge_filter.subscribe_to_knowledge_filter, knowledge_filter_service=services['knowledge_filter_service'], monitor_service=services['monitor_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/knowledge-filter/{filter_id}/subscriptions", + + Route("/knowledge-filter/{filter_id}/subscriptions", require_auth(services['session_manager'])( partial(knowledge_filter.list_knowledge_filter_subscriptions, knowledge_filter_service=services['knowledge_filter_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/knowledge-filter/{filter_id}/subscribe/{subscription_id}", + + Route("/knowledge-filter/{filter_id}/subscribe/{subscription_id}", require_auth(services['session_manager'])( partial(knowledge_filter.cancel_knowledge_filter_subscription, knowledge_filter_service=services['knowledge_filter_service'], monitor_service=services['monitor_service'], session_manager=services['session_manager']) ), methods=["DELETE"]), - + # Knowledge Filter Webhook endpoint (no auth required - called by OpenSearch) - Route("/knowledge-filter/{filter_id}/webhook/{subscription_id}", + Route("/knowledge-filter/{filter_id}/webhook/{subscription_id}", partial(knowledge_filter.knowledge_filter_webhook, knowledge_filter_service=services['knowledge_filter_service'], - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["POST"]), - + # Chat endpoints - Route("/chat", + Route("/chat", require_auth(services['session_manager'])( partial(chat.chat_endpoint, chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/langflow", + + Route("/langflow", require_auth(services['session_manager'])( partial(chat.langflow_endpoint, chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + # Chat history endpoints - Route("/chat/history", + Route("/chat/history", require_auth(services['session_manager'])( partial(chat.chat_history_endpoint, chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/langflow/history", + + Route("/langflow/history", require_auth(services['session_manager'])( partial(chat.langflow_history_endpoint, chat_service=services['chat_service'], session_manager=services['session_manager']) ), methods=["GET"]), - + # Authentication endpoints - Route("/auth/init", + Route("/auth/init", optional_auth(services['session_manager'])( partial(auth.auth_init, auth_service=services['auth_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/auth/callback", + + Route("/auth/callback", partial(auth.auth_callback, auth_service=services['auth_service'], - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["POST"]), - - Route("/auth/me", + + Route("/auth/me", optional_auth(services['session_manager'])( partial(auth.auth_me, auth_service=services['auth_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/auth/logout", + + Route("/auth/logout", require_auth(services['session_manager'])( partial(auth.auth_logout, auth_service=services['auth_service'], session_manager=services['session_manager']) ), methods=["POST"]), - + # Connector endpoints - Route("/connectors", + Route("/connectors", require_auth(services['session_manager'])( partial(connectors.list_connectors, connector_service=services['connector_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/connectors/{connector_type}/sync", + + Route("/connectors/{connector_type}/sync", require_auth(services['session_manager'])( partial(connectors.connector_sync, connector_service=services['connector_service'], session_manager=services['session_manager']) ), methods=["POST"]), - - Route("/connectors/{connector_type}/status", + + Route("/connectors/{connector_type}/status", require_auth(services['session_manager'])( partial(connectors.connector_status, connector_service=services['connector_service'], session_manager=services['session_manager']) ), methods=["GET"]), - - Route("/connectors/{connector_type}/webhook", + + Route("/connectors/{connector_type}/webhook", partial(connectors.connector_webhook, connector_service=services['connector_service'], - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["POST", "GET"]), - + # OIDC endpoints - Route("/.well-known/openid-configuration", + Route("/.well-known/openid-configuration", partial(oidc.oidc_discovery, - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["GET"]), - - Route("/auth/jwks", + + Route("/auth/jwks", partial(oidc.jwks_endpoint, - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["GET"]), - - Route("/auth/introspect", + + Route("/auth/introspect", partial(oidc.token_introspection, - session_manager=services['session_manager']), + session_manager=services['session_manager']), methods=["POST"]), - + # Settings endpoint - Route("/settings", + Route("/settings", require_auth(services['session_manager'])( partial(settings.get_settings, session_manager=services['session_manager']) ), methods=["GET"]), ] - + app = Starlette(debug=True, routes=routes) app.state.services = services # Store services for cleanup - + # Add startup event handler - @app.on_event("startup") + @app.on_event("startup") async def startup_event(): # Start index initialization in background to avoid blocking OIDC endpoints asyncio.create_task(init_index_when_ready()) - + # Add shutdown event handler @app.on_event("shutdown") async def shutdown_event(): await cleanup_subscriptions_proper(services) - + return app async def startup(): @@ -533,15 +534,15 @@ def cleanup(): async def cleanup_subscriptions_proper(services): """Cancel all active webhook subscriptions""" print("[CLEANUP] Cancelling active webhook subscriptions...") - + try: connector_service = services['connector_service'] await connector_service.connection_manager.load_connections() - - # Get all active connections with webhook subscriptions + + # Get all active connections with webhook subscriptions all_connections = await connector_service.connection_manager.list_connections() active_connections = [c for c in all_connections if c.is_active and c.config.get('webhook_channel_id')] - + for connection in active_connections: try: print(f"[CLEANUP] Cancelling subscription for connection {connection.connection_id}") @@ -552,22 +553,22 @@ async def cleanup_subscriptions_proper(services): print(f"[CLEANUP] Cancelled subscription {subscription_id}") except Exception as e: print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}") - + print(f"[CLEANUP] Finished cancelling {len(active_connections)} subscriptions") - + except Exception as e: print(f"[ERROR] Failed to cleanup subscriptions: {e}") if __name__ == "__main__": import uvicorn - + # Register cleanup function atexit.register(cleanup) - + # Create app asynchronously app = asyncio.run(create_app()) - + # Run the server (startup tasks now handled by Starlette startup event) uvicorn.run( app,