Refactor main.py for improved organization and clarity

This commit reorganizes the import statements in main.py, enhancing the structure and readability of the code. It also includes minor formatting adjustments for consistency. The changes contribute to a cleaner codebase, facilitating easier maintenance and future development.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2025-09-03 10:36:54 -03:00
parent a10b35f631
commit 4be48270b7

View file

@ -31,15 +31,16 @@ from auth_middleware import optional_auth, require_auth
# Configuration and setup
from config.settings import INDEX_BODY, INDEX_NAME, SESSION_SECRET, clients
# Existing services
from connectors.service import ConnectorService
from services.auth_service import AuthService
from services.chat_service import ChatService
from services.document_service import DocumentService
from services.knowledge_filter_service import KnowledgeFilterService
# Services
from services.langflow_file_service import LangflowFileService
from services.document_service import DocumentService
from services.knowledge_filter_service import KnowledgeFilterService
from services.monitor_service import MonitorService
from services.search_service import SearchService
from services.task_service import TaskService
@ -53,7 +54,7 @@ async def wait_for_opensearch():
"""Wait for OpenSearch to be ready with retries"""
max_retries = 30
retry_delay = 2
for attempt in range(max_retries):
try:
await clients.opensearch.info()
@ -74,11 +75,11 @@ async def configure_alerting_security():
alerting_settings = {
"persistent": {
"plugins.alerting.filter_by_backend_roles": "false",
"opendistro.alerting.filter_by_backend_roles": "false",
"opendistro.alerting.filter_by_backend_roles": "false",
"opensearch.notifications.general.filter_by_backend_roles": "false"
}
}
# Use admin client (clients.opensearch uses admin credentials)
response = await clients.opensearch.cluster.put_settings(body=alerting_settings)
print("Alerting security settings configured successfully")
@ -90,14 +91,14 @@ async def configure_alerting_security():
async def init_index():
"""Initialize OpenSearch index and security roles"""
await wait_for_opensearch()
# Create documents index
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
print(f"Created index '{INDEX_NAME}'")
else:
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
# Create knowledge filters index
knowledge_filter_index_name = "knowledge_filters"
knowledge_filter_index_body = {
@ -116,13 +117,13 @@ async def init_index():
}
}
}
if not await clients.opensearch.indices.exists(index=knowledge_filter_index_name):
await clients.opensearch.indices.create(index=knowledge_filter_index_name, body=knowledge_filter_index_body)
print(f"Created index '{knowledge_filter_index_name}'")
else:
print(f"Index '{knowledge_filter_index_name}' already exists, skipping creation.")
# Configure alerting plugin security settings
await configure_alerting_security()
@ -131,10 +132,10 @@ def generate_jwt_keys():
keys_dir = "keys"
private_key_path = os.path.join(keys_dir, "private_key.pem")
public_key_path = os.path.join(keys_dir, "public_key.pem")
# Create keys directory if it doesn't exist
os.makedirs(keys_dir, exist_ok=True)
# Generate keys if they don't exist
if not os.path.exists(private_key_path):
try:
@ -142,12 +143,12 @@ def generate_jwt_keys():
subprocess.run([
"openssl", "genrsa", "-out", private_key_path, "2048"
], check=True, capture_output=True)
# Generate public key
subprocess.run([
"openssl", "rsa", "-in", private_key_path, "-pubout", "-out", public_key_path
], check=True, capture_output=True)
print("Generated RSA keys for JWT signing")
except subprocess.CalledProcessError as e:
print(f"Failed to generate RSA keys: {e}")
@ -163,19 +164,19 @@ async def init_index_when_ready():
except Exception as e:
print(f"OpenSearch index initialization failed: {e}")
print("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready")
async def initialize_services():
"""Initialize all services and their dependencies"""
# Generate JWT keys if they don't exist
generate_jwt_keys()
# Initialize clients (now async to generate Langflow API key)
await clients.initialize()
# Initialize session manager
session_manager = SessionManager(SESSION_SECRET)
# Initialize services
document_service = DocumentService(session_manager=session_manager)
search_service = SearchService(session_manager)
@ -183,10 +184,10 @@ async def initialize_services():
chat_service = ChatService()
knowledge_filter_service = KnowledgeFilterService(session_manager)
monitor_service = MonitorService(session_manager)
# Set process pool for document service
document_service.process_pool = process_pool
# Initialize connector service
connector_service = ConnectorService(
patched_async_client=clients.patched_async_client,
@ -196,10 +197,10 @@ async def initialize_services():
task_service=task_service,
session_manager=session_manager
)
# Initialize auth service
auth_service = AuthService(session_manager, connector_service)
# Load persisted connector connections at startup so webhooks and syncs
# can resolve existing subscriptions immediately after server boot
# Skip in no-auth mode since connectors require OAuth
@ -213,7 +214,7 @@ async def initialize_services():
print(f"[WARNING] Failed to load persisted connections on startup: {e}")
else:
print("[CONNECTORS] Skipping connection loading in no-auth mode")
# New: Langflow file service
langflow_file_service = LangflowFileService()
@ -234,13 +235,13 @@ async def initialize_services():
async def create_app():
"""Create and configure the Starlette application"""
services = await initialize_services()
# Create route handlers with service dependencies injected
routes = [
# Upload endpoints
Route("/upload",
Route("/upload",
require_auth(services['session_manager'])(
partial(upload.upload,
partial(upload.upload,
document_service=services['document_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
@ -266,15 +267,15 @@ async def create_app():
langflow_file_service=services['langflow_file_service'],
session_manager=services['session_manager'])
), methods=["DELETE"]),
Route("/upload_context",
Route("/upload_context",
require_auth(services['session_manager'])(
partial(upload.upload_context,
document_service=services['document_service'],
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/upload_path",
require_auth(services['session_manager'])(
partial(upload.upload_path,
@ -294,227 +295,227 @@ async def create_app():
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/tasks/{task_id}",
Route("/tasks/{task_id}",
require_auth(services['session_manager'])(
partial(tasks.task_status,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/tasks",
Route("/tasks",
require_auth(services['session_manager'])(
partial(tasks.all_tasks,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/tasks/{task_id}/cancel",
Route("/tasks/{task_id}/cancel",
require_auth(services['session_manager'])(
partial(tasks.cancel_task,
task_service=services['task_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Search endpoint
Route("/search",
Route("/search",
require_auth(services['session_manager'])(
partial(search.search,
search_service=services['search_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Knowledge Filter endpoints
Route("/knowledge-filter",
Route("/knowledge-filter",
require_auth(services['session_manager'])(
partial(knowledge_filter.create_knowledge_filter,
knowledge_filter_service=services['knowledge_filter_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/knowledge-filter/search",
Route("/knowledge-filter/search",
require_auth(services['session_manager'])(
partial(knowledge_filter.search_knowledge_filters,
knowledge_filter_service=services['knowledge_filter_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/knowledge-filter/{filter_id}",
Route("/knowledge-filter/{filter_id}",
require_auth(services['session_manager'])(
partial(knowledge_filter.get_knowledge_filter,
knowledge_filter_service=services['knowledge_filter_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/knowledge-filter/{filter_id}",
Route("/knowledge-filter/{filter_id}",
require_auth(services['session_manager'])(
partial(knowledge_filter.update_knowledge_filter,
knowledge_filter_service=services['knowledge_filter_service'],
session_manager=services['session_manager'])
), methods=["PUT"]),
Route("/knowledge-filter/{filter_id}",
Route("/knowledge-filter/{filter_id}",
require_auth(services['session_manager'])(
partial(knowledge_filter.delete_knowledge_filter,
knowledge_filter_service=services['knowledge_filter_service'],
session_manager=services['session_manager'])
), methods=["DELETE"]),
# Knowledge Filter Subscription endpoints
Route("/knowledge-filter/{filter_id}/subscribe",
Route("/knowledge-filter/{filter_id}/subscribe",
require_auth(services['session_manager'])(
partial(knowledge_filter.subscribe_to_knowledge_filter,
knowledge_filter_service=services['knowledge_filter_service'],
monitor_service=services['monitor_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/knowledge-filter/{filter_id}/subscriptions",
Route("/knowledge-filter/{filter_id}/subscriptions",
require_auth(services['session_manager'])(
partial(knowledge_filter.list_knowledge_filter_subscriptions,
knowledge_filter_service=services['knowledge_filter_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/knowledge-filter/{filter_id}/subscribe/{subscription_id}",
Route("/knowledge-filter/{filter_id}/subscribe/{subscription_id}",
require_auth(services['session_manager'])(
partial(knowledge_filter.cancel_knowledge_filter_subscription,
knowledge_filter_service=services['knowledge_filter_service'],
monitor_service=services['monitor_service'],
session_manager=services['session_manager'])
), methods=["DELETE"]),
# Knowledge Filter Webhook endpoint (no auth required - called by OpenSearch)
Route("/knowledge-filter/{filter_id}/webhook/{subscription_id}",
Route("/knowledge-filter/{filter_id}/webhook/{subscription_id}",
partial(knowledge_filter.knowledge_filter_webhook,
knowledge_filter_service=services['knowledge_filter_service'],
session_manager=services['session_manager']),
session_manager=services['session_manager']),
methods=["POST"]),
# Chat endpoints
Route("/chat",
Route("/chat",
require_auth(services['session_manager'])(
partial(chat.chat_endpoint,
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/langflow",
Route("/langflow",
require_auth(services['session_manager'])(
partial(chat.langflow_endpoint,
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Chat history endpoints
Route("/chat/history",
Route("/chat/history",
require_auth(services['session_manager'])(
partial(chat.chat_history_endpoint,
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/langflow/history",
Route("/langflow/history",
require_auth(services['session_manager'])(
partial(chat.langflow_history_endpoint,
chat_service=services['chat_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
# Authentication endpoints
Route("/auth/init",
Route("/auth/init",
optional_auth(services['session_manager'])(
partial(auth.auth_init,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/auth/callback",
Route("/auth/callback",
partial(auth.auth_callback,
auth_service=services['auth_service'],
session_manager=services['session_manager']),
session_manager=services['session_manager']),
methods=["POST"]),
Route("/auth/me",
Route("/auth/me",
optional_auth(services['session_manager'])(
partial(auth.auth_me,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/auth/logout",
Route("/auth/logout",
require_auth(services['session_manager'])(
partial(auth.auth_logout,
auth_service=services['auth_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
# Connector endpoints
Route("/connectors",
Route("/connectors",
require_auth(services['session_manager'])(
partial(connectors.list_connectors,
connector_service=services['connector_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/connectors/{connector_type}/sync",
Route("/connectors/{connector_type}/sync",
require_auth(services['session_manager'])(
partial(connectors.connector_sync,
connector_service=services['connector_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/connectors/{connector_type}/status",
Route("/connectors/{connector_type}/status",
require_auth(services['session_manager'])(
partial(connectors.connector_status,
connector_service=services['connector_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/connectors/{connector_type}/webhook",
Route("/connectors/{connector_type}/webhook",
partial(connectors.connector_webhook,
connector_service=services['connector_service'],
session_manager=services['session_manager']),
session_manager=services['session_manager']),
methods=["POST", "GET"]),
# OIDC endpoints
Route("/.well-known/openid-configuration",
Route("/.well-known/openid-configuration",
partial(oidc.oidc_discovery,
session_manager=services['session_manager']),
session_manager=services['session_manager']),
methods=["GET"]),
Route("/auth/jwks",
Route("/auth/jwks",
partial(oidc.jwks_endpoint,
session_manager=services['session_manager']),
session_manager=services['session_manager']),
methods=["GET"]),
Route("/auth/introspect",
Route("/auth/introspect",
partial(oidc.token_introspection,
session_manager=services['session_manager']),
session_manager=services['session_manager']),
methods=["POST"]),
# Settings endpoint
Route("/settings",
Route("/settings",
require_auth(services['session_manager'])(
partial(settings.get_settings,
session_manager=services['session_manager'])
), methods=["GET"]),
]
app = Starlette(debug=True, routes=routes)
app.state.services = services # Store services for cleanup
# Add startup event handler
@app.on_event("startup")
@app.on_event("startup")
async def startup_event():
# Start index initialization in background to avoid blocking OIDC endpoints
asyncio.create_task(init_index_when_ready())
# Add shutdown event handler
@app.on_event("shutdown")
async def shutdown_event():
await cleanup_subscriptions_proper(services)
return app
async def startup():
@ -533,15 +534,15 @@ def cleanup():
async def cleanup_subscriptions_proper(services):
"""Cancel all active webhook subscriptions"""
print("[CLEANUP] Cancelling active webhook subscriptions...")
try:
connector_service = services['connector_service']
await connector_service.connection_manager.load_connections()
# Get all active connections with webhook subscriptions
# Get all active connections with webhook subscriptions
all_connections = await connector_service.connection_manager.list_connections()
active_connections = [c for c in all_connections if c.is_active and c.config.get('webhook_channel_id')]
for connection in active_connections:
try:
print(f"[CLEANUP] Cancelling subscription for connection {connection.connection_id}")
@ -552,22 +553,22 @@ async def cleanup_subscriptions_proper(services):
print(f"[CLEANUP] Cancelled subscription {subscription_id}")
except Exception as e:
print(f"[ERROR] Failed to cancel subscription for {connection.connection_id}: {e}")
print(f"[CLEANUP] Finished cancelling {len(active_connections)} subscriptions")
except Exception as e:
print(f"[ERROR] Failed to cleanup subscriptions: {e}")
if __name__ == "__main__":
import uvicorn
# Register cleanup function
atexit.register(cleanup)
# Create app asynchronously
app = asyncio.run(create_app())
# Run the server (startup tasks now handled by Starlette startup event)
uvicorn.run(
app,