Merge branch 'feat-googledrive-enhancements' of github.com:langflow-ai/openrag into feat-googledrive-enhancements

This commit is contained in:
Mike Fortman 2025-09-05 16:13:00 -05:00
commit 86d8c9f5b2
41 changed files with 1730 additions and 967 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

Binary file not shown.

View file

@ -95,7 +95,9 @@ async def async_response_stream(
chunk_count = 0 chunk_count = 0
async for chunk in response: async for chunk in response:
chunk_count += 1 chunk_count += 1
logger.debug("Stream chunk received", chunk_count=chunk_count, chunk=str(chunk)) logger.debug(
"Stream chunk received", chunk_count=chunk_count, chunk=str(chunk)
)
# Yield the raw event as JSON for the UI to process # Yield the raw event as JSON for the UI to process
import json import json
@ -241,7 +243,10 @@ async def async_langflow_stream(
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
log_prefix="langflow", log_prefix="langflow",
): ):
logger.debug("Yielding chunk from langflow stream", chunk_preview=chunk[:100].decode('utf-8', errors='replace')) logger.debug(
"Yielding chunk from langflow stream",
chunk_preview=chunk[:100].decode("utf-8", errors="replace"),
)
yield chunk yield chunk
logger.debug("Langflow stream completed") logger.debug("Langflow stream completed")
except Exception as e: except Exception as e:
@ -260,18 +265,24 @@ async def async_chat(
model: str = "gpt-4.1-mini", model: str = "gpt-4.1-mini",
previous_response_id: str = None, previous_response_id: str = None,
): ):
logger.debug("async_chat called", user_id=user_id, previous_response_id=previous_response_id) logger.debug(
"async_chat called", user_id=user_id, previous_response_id=previous_response_id
)
# Get the specific conversation thread (or create new one) # Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id) conversation_state = get_conversation_thread(user_id, previous_response_id)
logger.debug("Got conversation state", message_count=len(conversation_state['messages'])) logger.debug(
"Got conversation state", message_count=len(conversation_state["messages"])
)
# Add user message to conversation with timestamp # Add user message to conversation with timestamp
from datetime import datetime from datetime import datetime
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()} user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
conversation_state["messages"].append(user_message) conversation_state["messages"].append(user_message)
logger.debug("Added user message", message_count=len(conversation_state['messages'])) logger.debug(
"Added user message", message_count=len(conversation_state["messages"])
)
response_text, response_id = await async_response( response_text, response_id = await async_response(
async_client, async_client,
@ -280,7 +291,9 @@ async def async_chat(
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
log_prefix="agent", log_prefix="agent",
) )
logger.debug("Got response", response_preview=response_text[:50], response_id=response_id) logger.debug(
"Got response", response_preview=response_text[:50], response_id=response_id
)
# Add assistant response to conversation with response_id and timestamp # Add assistant response to conversation with response_id and timestamp
assistant_message = { assistant_message = {
@ -290,17 +303,26 @@ async def async_chat(
"timestamp": datetime.now(), "timestamp": datetime.now(),
} }
conversation_state["messages"].append(assistant_message) conversation_state["messages"].append(assistant_message)
logger.debug("Added assistant message", message_count=len(conversation_state['messages'])) logger.debug(
"Added assistant message", message_count=len(conversation_state["messages"])
)
# Store the conversation thread with its response_id # Store the conversation thread with its response_id
if response_id: if response_id:
conversation_state["last_activity"] = datetime.now() conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state) store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored conversation thread", user_id=user_id, response_id=response_id) logger.debug(
"Stored conversation thread", user_id=user_id, response_id=response_id
)
# Debug: Check what's in user_conversations now # Debug: Check what's in user_conversations now
conversations = get_user_conversations(user_id) conversations = get_user_conversations(user_id)
logger.debug("User conversations updated", user_id=user_id, conversation_count=len(conversations), conversation_ids=list(conversations.keys())) logger.debug(
"User conversations updated",
user_id=user_id,
conversation_count=len(conversations),
conversation_ids=list(conversations.keys()),
)
else: else:
logger.warning("No response_id received, conversation not stored") logger.warning("No response_id received, conversation not stored")
@ -363,7 +385,9 @@ async def async_chat_stream(
if response_id: if response_id:
conversation_state["last_activity"] = datetime.now() conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state) store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored conversation thread", user_id=user_id, response_id=response_id) logger.debug(
"Stored conversation thread", user_id=user_id, response_id=response_id
)
# Async langflow function with conversation storage (non-streaming) # Async langflow function with conversation storage (non-streaming)
@ -375,18 +399,28 @@ async def async_langflow_chat(
extra_headers: dict = None, extra_headers: dict = None,
previous_response_id: str = None, previous_response_id: str = None,
): ):
logger.debug("async_langflow_chat called", user_id=user_id, previous_response_id=previous_response_id) logger.debug(
"async_langflow_chat called",
user_id=user_id,
previous_response_id=previous_response_id,
)
# Get the specific conversation thread (or create new one) # Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id) conversation_state = get_conversation_thread(user_id, previous_response_id)
logger.debug("Got langflow conversation state", message_count=len(conversation_state['messages'])) logger.debug(
"Got langflow conversation state",
message_count=len(conversation_state["messages"]),
)
# Add user message to conversation with timestamp # Add user message to conversation with timestamp
from datetime import datetime from datetime import datetime
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()} user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
conversation_state["messages"].append(user_message) conversation_state["messages"].append(user_message)
logger.debug("Added user message to langflow", message_count=len(conversation_state['messages'])) logger.debug(
"Added user message to langflow",
message_count=len(conversation_state["messages"]),
)
response_text, response_id = await async_response( response_text, response_id = await async_response(
langflow_client, langflow_client,
@ -396,7 +430,11 @@ async def async_langflow_chat(
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
log_prefix="langflow", log_prefix="langflow",
) )
logger.debug("Got langflow response", response_preview=response_text[:50], response_id=response_id) logger.debug(
"Got langflow response",
response_preview=response_text[:50],
response_id=response_id,
)
# Add assistant response to conversation with response_id and timestamp # Add assistant response to conversation with response_id and timestamp
assistant_message = { assistant_message = {
@ -406,17 +444,29 @@ async def async_langflow_chat(
"timestamp": datetime.now(), "timestamp": datetime.now(),
} }
conversation_state["messages"].append(assistant_message) conversation_state["messages"].append(assistant_message)
logger.debug("Added assistant message to langflow", message_count=len(conversation_state['messages'])) logger.debug(
"Added assistant message to langflow",
message_count=len(conversation_state["messages"]),
)
# Store the conversation thread with its response_id # Store the conversation thread with its response_id
if response_id: if response_id:
conversation_state["last_activity"] = datetime.now() conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state) store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id) logger.debug(
"Stored langflow conversation thread",
user_id=user_id,
response_id=response_id,
)
# Debug: Check what's in user_conversations now # Debug: Check what's in user_conversations now
conversations = get_user_conversations(user_id) conversations = get_user_conversations(user_id)
logger.debug("User conversations updated", user_id=user_id, conversation_count=len(conversations), conversation_ids=list(conversations.keys())) logger.debug(
"User conversations updated",
user_id=user_id,
conversation_count=len(conversations),
conversation_ids=list(conversations.keys()),
)
else: else:
logger.warning("No response_id received from langflow, conversation not stored") logger.warning("No response_id received from langflow, conversation not stored")
@ -432,7 +482,11 @@ async def async_langflow_chat_stream(
extra_headers: dict = None, extra_headers: dict = None,
previous_response_id: str = None, previous_response_id: str = None,
): ):
logger.debug("async_langflow_chat_stream called", user_id=user_id, previous_response_id=previous_response_id) logger.debug(
"async_langflow_chat_stream called",
user_id=user_id,
previous_response_id=previous_response_id,
)
# Get the specific conversation thread (or create new one) # Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id) conversation_state = get_conversation_thread(user_id, previous_response_id)
@ -483,4 +537,8 @@ async def async_langflow_chat_stream(
if response_id: if response_id:
conversation_state["last_activity"] = datetime.now() conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state) store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id) logger.debug(
"Stored langflow conversation thread",
user_id=user_id,
response_id=response_id,
)

View file

@ -25,6 +25,11 @@ async def connector_sync(request: Request, connector_service, session_manager):
selected_files = data.get("selected_files") selected_files = data.get("selected_files")
try: try:
logger.debug(
"Starting connector sync",
connector_type=connector_type,
max_files=max_files,
)
user = request.state.user user = request.state.user
jwt_token = request.state.jwt_token jwt_token = request.state.jwt_token
@ -44,6 +49,10 @@ async def connector_sync(request: Request, connector_service, session_manager):
# Start sync tasks for all active connections # Start sync tasks for all active connections
task_ids = [] task_ids = []
for connection in active_connections: for connection in active_connections:
logger.debug(
"About to call sync_connector_files for connection",
connection_id=connection.connection_id,
)
if selected_files: if selected_files:
task_id = await connector_service.sync_specific_files( task_id = await connector_service.sync_specific_files(
connection.connection_id, connection.connection_id,
@ -58,8 +67,6 @@ async def connector_sync(request: Request, connector_service, session_manager):
max_files, max_files,
jwt_token=jwt_token, jwt_token=jwt_token,
) )
task_ids.append(task_id)
return JSONResponse( return JSONResponse(
{ {
"task_ids": task_ids, "task_ids": task_ids,
@ -170,7 +177,9 @@ async def connector_webhook(request: Request, connector_service, session_manager
channel_id = None channel_id = None
if not channel_id: if not channel_id:
logger.warning("No channel ID found in webhook", connector_type=connector_type) logger.warning(
"No channel ID found in webhook", connector_type=connector_type
)
return JSONResponse({"status": "ignored", "reason": "no_channel_id"}) return JSONResponse({"status": "ignored", "reason": "no_channel_id"})
# Find the specific connection for this webhook # Find the specific connection for this webhook
@ -180,7 +189,9 @@ async def connector_webhook(request: Request, connector_service, session_manager
) )
) )
if not connection or not connection.is_active: if not connection or not connection.is_active:
logger.info("Unknown webhook channel, will auto-expire", channel_id=channel_id) logger.info(
"Unknown webhook channel, will auto-expire", channel_id=channel_id
)
return JSONResponse( return JSONResponse(
{"status": "ignored_unknown_channel", "channel_id": channel_id} {"status": "ignored_unknown_channel", "channel_id": channel_id}
) )
@ -190,7 +201,10 @@ async def connector_webhook(request: Request, connector_service, session_manager
# Get the connector instance # Get the connector instance
connector = await connector_service._get_connector(connection.connection_id) connector = await connector_service._get_connector(connection.connection_id)
if not connector: if not connector:
logger.error("Could not get connector for connection", connection_id=connection.connection_id) logger.error(
"Could not get connector for connection",
connection_id=connection.connection_id,
)
return JSONResponse( return JSONResponse(
{"status": "error", "reason": "connector_not_found"} {"status": "error", "reason": "connector_not_found"}
) )
@ -199,7 +213,11 @@ async def connector_webhook(request: Request, connector_service, session_manager
affected_files = await connector.handle_webhook(payload) affected_files = await connector.handle_webhook(payload)
if affected_files: if affected_files:
logger.info("Webhook connection files affected", connection_id=connection.connection_id, affected_count=len(affected_files)) logger.info(
"Webhook connection files affected",
connection_id=connection.connection_id,
affected_count=len(affected_files),
)
# Generate JWT token for the user (needed for OpenSearch authentication) # Generate JWT token for the user (needed for OpenSearch authentication)
user = session_manager.get_user(connection.user_id) user = session_manager.get_user(connection.user_id)
@ -223,7 +241,10 @@ async def connector_webhook(request: Request, connector_service, session_manager
} }
else: else:
# No specific files identified - just log the webhook # No specific files identified - just log the webhook
logger.info("Webhook general change detected, no specific files", connection_id=connection.connection_id) logger.info(
"Webhook general change detected, no specific files",
connection_id=connection.connection_id,
)
result = { result = {
"connection_id": connection.connection_id, "connection_id": connection.connection_id,
@ -241,7 +262,15 @@ async def connector_webhook(request: Request, connector_service, session_manager
) )
except Exception as e: except Exception as e:
logger.error("Failed to process webhook for connection", connection_id=connection.connection_id, error=str(e)) logger.error(
"Failed to process webhook for connection",
connection_id=connection.connection_id,
error=str(e),
)
import traceback
traceback.print_exc()
return JSONResponse( return JSONResponse(
{ {
"status": "error", "status": "error",

View file

@ -395,15 +395,19 @@ async def knowledge_filter_webhook(
# Get the webhook payload # Get the webhook payload
payload = await request.json() payload = await request.json()
logger.info("Knowledge filter webhook received", logger.info(
filter_id=filter_id, "Knowledge filter webhook received",
subscription_id=subscription_id, filter_id=filter_id,
payload_size=len(str(payload))) subscription_id=subscription_id,
payload_size=len(str(payload)),
)
# Extract findings from the payload # Extract findings from the payload
findings = payload.get("findings", []) findings = payload.get("findings", [])
if not findings: if not findings:
logger.info("No findings in webhook payload", subscription_id=subscription_id) logger.info(
"No findings in webhook payload", subscription_id=subscription_id
)
return JSONResponse({"status": "no_findings"}) return JSONResponse({"status": "no_findings"})
# Process the findings - these are the documents that matched the knowledge filter # Process the findings - these are the documents that matched the knowledge filter
@ -420,14 +424,18 @@ async def knowledge_filter_webhook(
) )
# Log the matched documents # Log the matched documents
logger.info("Knowledge filter matched documents", logger.info(
filter_id=filter_id, "Knowledge filter matched documents",
matched_count=len(matched_documents)) filter_id=filter_id,
matched_count=len(matched_documents),
)
for doc in matched_documents: for doc in matched_documents:
logger.debug("Matched document", logger.debug(
document_id=doc['document_id'], "Matched document",
index=doc['index'], document_id=doc["document_id"],
score=doc.get('score')) index=doc["index"],
score=doc.get("score"),
)
# Here you could add additional processing: # Here you could add additional processing:
# - Send notifications to external webhooks # - Send notifications to external webhooks
@ -446,10 +454,12 @@ async def knowledge_filter_webhook(
) )
except Exception as e: except Exception as e:
logger.error("Failed to process knowledge filter webhook", logger.error(
filter_id=filter_id, "Failed to process knowledge filter webhook",
subscription_id=subscription_id, filter_id=filter_id,
error=str(e)) subscription_id=subscription_id,
error=str(e),
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View file

@ -23,14 +23,16 @@ async def search(request: Request, search_service, session_manager):
# Extract JWT token from auth middleware # Extract JWT token from auth middleware
jwt_token = request.state.jwt_token jwt_token = request.state.jwt_token
logger.debug("Search API request", logger.debug(
user=str(user), "Search API request",
user_id=user.user_id if user else None, user=str(user),
has_jwt_token=jwt_token is not None, user_id=user.user_id if user else None,
query=query, has_jwt_token=jwt_token is not None,
filters=filters, query=query,
limit=limit, filters=filters,
score_threshold=score_threshold) limit=limit,
score_threshold=score_threshold,
)
result = await search_service.search( result = await search_service.search(
query, query,

View file

@ -14,7 +14,7 @@ async def upload(request: Request, document_service, session_manager):
jwt_token = request.state.jwt_token jwt_token = request.state.jwt_token
from config.settings import is_no_auth_mode from config.settings import is_no_auth_mode
# In no-auth mode, pass None for owner fields so documents have no owner # In no-auth mode, pass None for owner fields so documents have no owner
# This allows all users to see them when switching to auth mode # This allows all users to see them when switching to auth mode
if is_no_auth_mode(): if is_no_auth_mode():
@ -25,7 +25,7 @@ async def upload(request: Request, document_service, session_manager):
owner_user_id = user.user_id owner_user_id = user.user_id
owner_name = user.name owner_name = user.name
owner_email = user.email owner_email = user.email
result = await document_service.process_upload_file( result = await document_service.process_upload_file(
upload_file, upload_file,
owner_user_id=owner_user_id, owner_user_id=owner_user_id,
@ -61,9 +61,9 @@ async def upload_path(request: Request, task_service, session_manager):
user = request.state.user user = request.state.user
jwt_token = request.state.jwt_token jwt_token = request.state.jwt_token
from config.settings import is_no_auth_mode from config.settings import is_no_auth_mode
# In no-auth mode, pass None for owner fields so documents have no owner # In no-auth mode, pass None for owner fields so documents have no owner
if is_no_auth_mode(): if is_no_auth_mode():
owner_user_id = None owner_user_id = None
@ -73,7 +73,7 @@ async def upload_path(request: Request, task_service, session_manager):
owner_user_id = user.user_id owner_user_id = user.user_id
owner_name = user.name owner_name = user.name
owner_email = user.email owner_email = user.email
task_id = await task_service.create_upload_task( task_id = await task_service.create_upload_task(
owner_user_id, owner_user_id,
file_paths, file_paths,

View file

@ -3,6 +3,9 @@ from starlette.responses import JSONResponse
from typing import Optional from typing import Optional
from session_manager import User from session_manager import User
from config.settings import is_no_auth_mode from config.settings import is_no_auth_mode
from utils.logging_config import get_logger
logger = get_logger(__name__)
def get_current_user(request: Request, session_manager) -> Optional[User]: def get_current_user(request: Request, session_manager) -> Optional[User]:
@ -25,22 +28,15 @@ def require_auth(session_manager):
async def wrapper(request: Request): async def wrapper(request: Request):
# In no-auth mode, bypass authentication entirely # In no-auth mode, bypass authentication entirely
if is_no_auth_mode(): if is_no_auth_mode():
print(f"[DEBUG] No-auth mode: Creating anonymous user") logger.debug("No-auth mode: Creating anonymous user")
# Create an anonymous user object so endpoints don't break # Create an anonymous user object so endpoints don't break
from session_manager import User from session_manager import User
from datetime import datetime from datetime import datetime
request.state.user = User( from session_manager import AnonymousUser
user_id="anonymous", request.state.user = AnonymousUser()
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
request.state.jwt_token = None # No JWT in no-auth mode request.state.jwt_token = None # No JWT in no-auth mode
print(f"[DEBUG] Set user_id=anonymous, jwt_token=None") logger.debug("Set user_id=anonymous, jwt_token=None")
return await handler(request) return await handler(request)
user = get_current_user(request, session_manager) user = get_current_user(request, session_manager)
@ -72,15 +68,8 @@ def optional_auth(session_manager):
from session_manager import User from session_manager import User
from datetime import datetime from datetime import datetime
request.state.user = User( from session_manager import AnonymousUser
user_id="anonymous", request.state.user = AnonymousUser()
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
request.state.jwt_token = None # No JWT in no-auth mode request.state.jwt_token = None # No JWT in no-auth mode
else: else:
user = get_current_user(request, session_manager) user = get_current_user(request, session_manager)

View file

@ -36,7 +36,12 @@ GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
def is_no_auth_mode(): def is_no_auth_mode():
"""Check if we're running in no-auth mode (OAuth credentials missing)""" """Check if we're running in no-auth mode (OAuth credentials missing)"""
result = not (GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET) result = not (GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET)
logger.debug("Checking auth mode", no_auth_mode=result, has_client_id=GOOGLE_OAUTH_CLIENT_ID is not None, has_client_secret=GOOGLE_OAUTH_CLIENT_SECRET is not None) logger.debug(
"Checking auth mode",
no_auth_mode=result,
has_client_id=GOOGLE_OAUTH_CLIENT_ID is not None,
has_client_secret=GOOGLE_OAUTH_CLIENT_SECRET is not None,
)
return result return result
@ -99,7 +104,9 @@ async def generate_langflow_api_key():
return LANGFLOW_KEY return LANGFLOW_KEY
if not LANGFLOW_SUPERUSER or not LANGFLOW_SUPERUSER_PASSWORD: if not LANGFLOW_SUPERUSER or not LANGFLOW_SUPERUSER_PASSWORD:
logger.warning("LANGFLOW_SUPERUSER and LANGFLOW_SUPERUSER_PASSWORD not set, skipping API key generation") logger.warning(
"LANGFLOW_SUPERUSER and LANGFLOW_SUPERUSER_PASSWORD not set, skipping API key generation"
)
return None return None
try: try:
@ -141,11 +148,19 @@ async def generate_langflow_api_key():
raise KeyError("api_key") raise KeyError("api_key")
LANGFLOW_KEY = api_key LANGFLOW_KEY = api_key
logger.info("Successfully generated Langflow API key", api_key_preview=api_key[:8]) logger.info(
"Successfully generated Langflow API key",
api_key_preview=api_key[:8],
)
return api_key return api_key
except (requests.exceptions.RequestException, KeyError) as e: except (requests.exceptions.RequestException, KeyError) as e:
last_error = e last_error = e
logger.warning("Attempt to generate Langflow API key failed", attempt=attempt, max_attempts=max_attempts, error=str(e)) logger.warning(
"Attempt to generate Langflow API key failed",
attempt=attempt,
max_attempts=max_attempts,
error=str(e),
)
if attempt < max_attempts: if attempt < max_attempts:
time.sleep(delay_seconds) time.sleep(delay_seconds)
else: else:
@ -195,7 +210,9 @@ class AppClients:
logger.warning("Failed to initialize Langflow client", error=str(e)) logger.warning("Failed to initialize Langflow client", error=str(e))
self.langflow_client = None self.langflow_client = None
if self.langflow_client is None: if self.langflow_client is None:
logger.warning("No Langflow client initialized yet, will attempt later on first use") logger.warning(
"No Langflow client initialized yet, will attempt later on first use"
)
# Initialize patched OpenAI client # Initialize patched OpenAI client
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
@ -218,7 +235,9 @@ class AppClients:
) )
logger.info("Langflow client initialized on-demand") logger.info("Langflow client initialized on-demand")
except Exception as e: except Exception as e:
logger.error("Failed to initialize Langflow client on-demand", error=str(e)) logger.error(
"Failed to initialize Langflow client on-demand", error=str(e)
)
self.langflow_client = None self.langflow_client = None
return self.langflow_client return self.langflow_client

View file

@ -321,13 +321,18 @@ class ConnectionManager:
if connection_config.config.get( if connection_config.config.get(
"webhook_channel_id" "webhook_channel_id"
) or connection_config.config.get("subscription_id"): ) or connection_config.config.get("subscription_id"):
logger.info("Webhook subscription already exists", connection_id=connection_id) logger.info(
"Webhook subscription already exists", connection_id=connection_id
)
return return
# Check if webhook URL is configured # Check if webhook URL is configured
webhook_url = connection_config.config.get("webhook_url") webhook_url = connection_config.config.get("webhook_url")
if not webhook_url: if not webhook_url:
logger.info("No webhook URL configured, skipping subscription setup", connection_id=connection_id) logger.info(
"No webhook URL configured, skipping subscription setup",
connection_id=connection_id,
)
return return
try: try:
@ -345,10 +350,18 @@ class ConnectionManager:
# Save updated connection config # Save updated connection config
await self.save_connections() await self.save_connections()
logger.info("Successfully set up webhook subscription", connection_id=connection_id, subscription_id=subscription_id) logger.info(
"Successfully set up webhook subscription",
connection_id=connection_id,
subscription_id=subscription_id,
)
except Exception as e: except Exception as e:
logger.error("Failed to setup webhook subscription", connection_id=connection_id, error=str(e)) logger.error(
"Failed to setup webhook subscription",
connection_id=connection_id,
error=str(e),
)
# Don't fail the entire connection setup if webhook fails # Don't fail the entire connection setup if webhook fails
async def _setup_webhook_for_new_connection( async def _setup_webhook_for_new_connection(
@ -356,12 +369,18 @@ class ConnectionManager:
): ):
"""Setup webhook subscription for a newly authenticated connection""" """Setup webhook subscription for a newly authenticated connection"""
try: try:
logger.info("Setting up subscription for newly authenticated connection", connection_id=connection_id) logger.info(
"Setting up subscription for newly authenticated connection",
connection_id=connection_id,
)
# Create and authenticate connector # Create and authenticate connector
connector = self._create_connector(connection_config) connector = self._create_connector(connection_config)
if not await connector.authenticate(): if not await connector.authenticate():
logger.error("Failed to authenticate connector for webhook setup", connection_id=connection_id) logger.error(
"Failed to authenticate connector for webhook setup",
connection_id=connection_id,
)
return return
# Setup subscription # Setup subscription
@ -376,8 +395,16 @@ class ConnectionManager:
# Save updated connection config # Save updated connection config
await self.save_connections() await self.save_connections()
logger.info("Successfully set up webhook subscription", connection_id=connection_id, subscription_id=subscription_id) logger.info(
"Successfully set up webhook subscription",
connection_id=connection_id,
subscription_id=subscription_id,
)
except Exception as e: except Exception as e:
logger.error("Failed to setup webhook subscription for new connection", connection_id=connection_id, error=str(e)) logger.error(
"Failed to setup webhook subscription for new connection",
connection_id=connection_id,
error=str(e),
)
# Don't fail the connection setup if webhook fails # Don't fail the connection setup if webhook fails

View file

@ -4,6 +4,11 @@ from typing import Dict, Any, List, Optional
from .base import BaseConnector, ConnectorDocument from .base import BaseConnector, ConnectorDocument
from utils.logging_config import get_logger from utils.logging_config import get_logger
logger = get_logger(__name__)
from .google_drive import GoogleDriveConnector
from .sharepoint import SharePointConnector
from .onedrive import OneDriveConnector
from .connection_manager import ConnectionManager from .connection_manager import ConnectionManager
logger = get_logger(__name__) logger = get_logger(__name__)
@ -62,7 +67,7 @@ class ConnectorService:
doc_service = DocumentService(session_manager=self.session_manager) doc_service = DocumentService(session_manager=self.session_manager)
print(f"[DEBUG] Processing connector document with ID: {document.id}") logger.debug("Processing connector document", document_id=document.id)
# Process using the existing pipeline but with connector document metadata # Process using the existing pipeline but with connector document metadata
result = await doc_service.process_file_common( result = await doc_service.process_file_common(
@ -77,7 +82,7 @@ class ConnectorService:
connector_type=connector_type, connector_type=connector_type,
) )
print(f"[DEBUG] Document processing result: {result}") logger.debug("Document processing result", result=result)
# If successfully indexed or already exists, update the indexed documents with connector metadata # If successfully indexed or already exists, update the indexed documents with connector metadata
if result["status"] in ["indexed", "unchanged"]: if result["status"] in ["indexed", "unchanged"]:
@ -104,7 +109,7 @@ class ConnectorService:
jwt_token: str = None, jwt_token: str = None,
): ):
"""Update indexed chunks with connector-specific metadata""" """Update indexed chunks with connector-specific metadata"""
print(f"[DEBUG] Looking for chunks with document_id: {document.id}") logger.debug("Looking for chunks", document_id=document.id)
# Find all chunks for this document # Find all chunks for this document
query = {"query": {"term": {"document_id": document.id}}} query = {"query": {"term": {"document_id": document.id}}}
@ -117,26 +122,34 @@ class ConnectorService:
try: try:
response = await opensearch_client.search(index=self.index_name, body=query) response = await opensearch_client.search(index=self.index_name, body=query)
except Exception as e: except Exception as e:
print( logger.error(
f"[ERROR] OpenSearch search failed for connector metadata update: {e}" "OpenSearch search failed for connector metadata update",
error=str(e),
query=query,
) )
print(f"[ERROR] Search query: {query}")
raise raise
print(f"[DEBUG] Search query: {query}") logger.debug(
print( "Search query executed",
f"[DEBUG] Found {len(response['hits']['hits'])} chunks matching document_id: {document.id}" query=query,
chunks_found=len(response["hits"]["hits"]),
document_id=document.id,
) )
# Update each chunk with connector metadata # Update each chunk with connector metadata
print( logger.debug(
f"[DEBUG] Updating {len(response['hits']['hits'])} chunks with connector_type: {connector_type}" "Updating chunks with connector_type",
chunk_count=len(response["hits"]["hits"]),
connector_type=connector_type,
) )
for hit in response["hits"]["hits"]: for hit in response["hits"]["hits"]:
chunk_id = hit["_id"] chunk_id = hit["_id"]
current_connector_type = hit["_source"].get("connector_type", "unknown") current_connector_type = hit["_source"].get("connector_type", "unknown")
print( logger.debug(
f"[DEBUG] Chunk {chunk_id}: current connector_type = {current_connector_type}, updating to {connector_type}" "Updating chunk connector metadata",
chunk_id=chunk_id,
current_connector_type=current_connector_type,
new_connector_type=connector_type,
) )
update_body = { update_body = {
@ -164,10 +177,14 @@ class ConnectorService:
await opensearch_client.update( await opensearch_client.update(
index=self.index_name, id=chunk_id, body=update_body index=self.index_name, id=chunk_id, body=update_body
) )
print(f"[DEBUG] Updated chunk {chunk_id} with connector metadata") logger.debug("Updated chunk with connector metadata", chunk_id=chunk_id)
except Exception as e: except Exception as e:
print(f"[ERROR] OpenSearch update failed for chunk {chunk_id}: {e}") logger.error(
print(f"[ERROR] Update body: {update_body}") "OpenSearch update failed for chunk",
chunk_id=chunk_id,
error=str(e),
update_body=update_body,
)
raise raise
def _get_file_extension(self, mimetype: str) -> str: def _get_file_extension(self, mimetype: str) -> str:
@ -226,11 +243,11 @@ class ConnectorService:
while True: while True:
# List files from connector with limit # List files from connector with limit
logger.info( logger.debug(
"Calling list_files", page_size=page_size, page_token=page_token "Calling list_files", page_size=page_size, page_token=page_token
) )
file_list = await connector.list_files(page_token, max_files=page_size) file_list = await connector.list_files(page_token, limit=page_size)
logger.info( logger.debug(
"Got files from connector", file_count=len(file_list.get("files", [])) "Got files from connector", file_count=len(file_list.get("files", []))
) )
files = file_list["files"] files = file_list["files"]

View file

@ -3,11 +3,13 @@ import sys
# Check for TUI flag FIRST, before any heavy imports # Check for TUI flag FIRST, before any heavy imports
if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == "--tui": if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == "--tui":
from tui.main import run_tui from tui.main import run_tui
run_tui() run_tui()
sys.exit(0) sys.exit(0)
# Configure structured logging early # Configure structured logging early
from utils.logging_config import configure_from_env, get_logger from utils.logging_config import configure_from_env, get_logger
configure_from_env() configure_from_env()
logger = get_logger(__name__) logger = get_logger(__name__)
@ -25,6 +27,8 @@ import torch
# Configuration and setup # Configuration and setup
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
from config.settings import is_no_auth_mode
from utils.gpu_detection import detect_gpu_devices
# Services # Services
from services.document_service import DocumentService from services.document_service import DocumentService
@ -56,8 +60,11 @@ from api import (
# Set multiprocessing start method to 'spawn' for CUDA compatibility # Set multiprocessing start method to 'spawn' for CUDA compatibility
multiprocessing.set_start_method("spawn", force=True) multiprocessing.set_start_method("spawn", force=True)
logger.info("CUDA available", cuda_available=torch.cuda.is_available()) logger.info(
logger.info("CUDA version PyTorch was built with", cuda_version=torch.version.cuda) "CUDA device information",
cuda_available=torch.cuda.is_available(),
cuda_version=torch.version.cuda,
)
async def wait_for_opensearch(): async def wait_for_opensearch():
@ -71,7 +78,12 @@ async def wait_for_opensearch():
logger.info("OpenSearch is ready") logger.info("OpenSearch is ready")
return return
except Exception as e: except Exception as e:
logger.warning("OpenSearch not ready yet", attempt=attempt + 1, max_retries=max_retries, error=str(e)) logger.warning(
"OpenSearch not ready yet",
attempt=attempt + 1,
max_retries=max_retries,
error=str(e),
)
if attempt < max_retries - 1: if attempt < max_retries - 1:
await asyncio.sleep(retry_delay) await asyncio.sleep(retry_delay)
else: else:
@ -93,7 +105,9 @@ async def configure_alerting_security():
# Use admin client (clients.opensearch uses admin credentials) # Use admin client (clients.opensearch uses admin credentials)
response = await clients.opensearch.cluster.put_settings(body=alerting_settings) response = await clients.opensearch.cluster.put_settings(body=alerting_settings)
logger.info("Alerting security settings configured successfully", response=response) logger.info(
"Alerting security settings configured successfully", response=response
)
except Exception as e: except Exception as e:
logger.warning("Failed to configure alerting security settings", error=str(e)) logger.warning("Failed to configure alerting security settings", error=str(e))
# Don't fail startup if alerting config fails # Don't fail startup if alerting config fails
@ -133,9 +147,14 @@ async def init_index():
await clients.opensearch.indices.create( await clients.opensearch.indices.create(
index=knowledge_filter_index_name, body=knowledge_filter_index_body index=knowledge_filter_index_name, body=knowledge_filter_index_body
) )
logger.info("Created knowledge filters index", index_name=knowledge_filter_index_name) logger.info(
"Created knowledge filters index", index_name=knowledge_filter_index_name
)
else: else:
logger.info("Knowledge filters index already exists, skipping creation", index_name=knowledge_filter_index_name) logger.info(
"Knowledge filters index already exists, skipping creation",
index_name=knowledge_filter_index_name,
)
# Configure alerting plugin security settings # Configure alerting plugin security settings
await configure_alerting_security() await configure_alerting_security()
@ -190,9 +209,59 @@ async def init_index_when_ready():
logger.info("OpenSearch index initialization completed successfully") logger.info("OpenSearch index initialization completed successfully")
except Exception as e: except Exception as e:
logger.error("OpenSearch index initialization failed", error=str(e)) logger.error("OpenSearch index initialization failed", error=str(e))
logger.warning("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready") logger.warning(
"OIDC endpoints will still work, but document operations may fail until OpenSearch is ready"
)
async def ingest_default_documents_when_ready(services):
"""Scan the local documents folder and ingest files like a non-auth upload."""
try:
logger.info("Ingesting default documents when ready")
base_dir = os.path.abspath(os.path.join(os.getcwd(), "documents"))
if not os.path.isdir(base_dir):
logger.info("Default documents directory not found; skipping ingestion", base_dir=base_dir)
return
# Collect files recursively
file_paths = [
os.path.join(root, fn)
for root, _, files in os.walk(base_dir)
for fn in files
]
if not file_paths:
logger.info("No default documents found; nothing to ingest", base_dir=base_dir)
return
# Build a processor that DOES NOT set 'owner' on documents (owner_user_id=None)
from models.processors import DocumentFileProcessor
processor = DocumentFileProcessor(
services["document_service"],
owner_user_id=None,
jwt_token=None,
owner_name=None,
owner_email=None,
)
task_id = await services["task_service"].create_custom_task(
"anonymous", file_paths, processor
)
logger.info(
"Started default documents ingestion task",
task_id=task_id,
file_count=len(file_paths),
)
except Exception as e:
logger.error("Default documents ingestion failed", error=str(e))
async def startup_tasks(services):
"""Startup tasks"""
logger.info("Starting startup tasks")
await init_index()
await ingest_default_documents_when_ready(services)
async def initialize_services(): async def initialize_services():
"""Initialize all services and their dependencies""" """Initialize all services and their dependencies"""
# Generate JWT keys if they don't exist # Generate JWT keys if they don't exist
@ -237,9 +306,14 @@ async def initialize_services():
try: try:
await connector_service.initialize() await connector_service.initialize()
loaded_count = len(connector_service.connection_manager.connections) loaded_count = len(connector_service.connection_manager.connections)
logger.info("Loaded persisted connector connections on startup", loaded_count=loaded_count) logger.info(
"Loaded persisted connector connections on startup",
loaded_count=loaded_count,
)
except Exception as e: except Exception as e:
logger.warning("Failed to load persisted connections on startup", error=str(e)) logger.warning(
"Failed to load persisted connections on startup", error=str(e)
)
else: else:
logger.info("[CONNECTORS] Skipping connection loading in no-auth mode") logger.info("[CONNECTORS] Skipping connection loading in no-auth mode")
@ -639,12 +713,15 @@ async def create_app():
app = Starlette(debug=True, routes=routes) app = Starlette(debug=True, routes=routes)
app.state.services = services # Store services for cleanup app.state.services = services # Store services for cleanup
app.state.background_tasks = set()
# Add startup event handler # Add startup event handler
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
# Start index initialization in background to avoid blocking OIDC endpoints # Start index initialization in background to avoid blocking OIDC endpoints
asyncio.create_task(init_index_when_ready()) t1 = asyncio.create_task(startup_tasks(services))
app.state.background_tasks.add(t1)
t1.add_done_callback(app.state.background_tasks.discard)
# Add shutdown event handler # Add shutdown event handler
@app.on_event("shutdown") @app.on_event("shutdown")
@ -687,18 +764,30 @@ async def cleanup_subscriptions_proper(services):
for connection in active_connections: for connection in active_connections:
try: try:
logger.info("Cancelling subscription for connection", connection_id=connection.connection_id) logger.info(
"Cancelling subscription for connection",
connection_id=connection.connection_id,
)
connector = await connector_service.get_connector( connector = await connector_service.get_connector(
connection.connection_id connection.connection_id
) )
if connector: if connector:
subscription_id = connection.config.get("webhook_channel_id") subscription_id = connection.config.get("webhook_channel_id")
await connector.cleanup_subscription(subscription_id) await connector.cleanup_subscription(subscription_id)
logger.info("Cancelled subscription", subscription_id=subscription_id) logger.info(
"Cancelled subscription", subscription_id=subscription_id
)
except Exception as e: except Exception as e:
logger.error("Failed to cancel subscription", connection_id=connection.connection_id, error=str(e)) logger.error(
"Failed to cancel subscription",
connection_id=connection.connection_id,
error=str(e),
)
logger.info("Finished cancelling subscriptions", subscription_count=len(active_connections)) logger.info(
"Finished cancelling subscriptions",
subscription_count=len(active_connections),
)
except Exception as e: except Exception as e:
logger.error("Failed to cleanup subscriptions", error=str(e)) logger.error("Failed to cleanup subscriptions", error=str(e))

View file

@ -1,6 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict from typing import Any, Dict
from .tasks import UploadTask, FileTask from .tasks import UploadTask, FileTask
from utils.logging_config import get_logger
logger = get_logger(__name__)
class TaskProcessor(ABC): class TaskProcessor(ABC):
@ -211,7 +214,7 @@ class S3FileProcessor(TaskProcessor):
"connector_type": "s3", # S3 uploads "connector_type": "s3", # S3 uploads
"indexed_time": datetime.datetime.now().isoformat(), "indexed_time": datetime.datetime.now().isoformat(),
} }
# Only set owner fields if owner_user_id is provided (for no-auth mode support) # Only set owner fields if owner_user_id is provided (for no-auth mode support)
if self.owner_user_id is not None: if self.owner_user_id is not None:
chunk_doc["owner"] = self.owner_user_id chunk_doc["owner"] = self.owner_user_id
@ -225,10 +228,12 @@ class S3FileProcessor(TaskProcessor):
index=INDEX_NAME, id=chunk_id, body=chunk_doc index=INDEX_NAME, id=chunk_id, body=chunk_doc
) )
except Exception as e: except Exception as e:
print( logger.error(
f"[ERROR] OpenSearch indexing failed for S3 chunk {chunk_id}: {e}" "OpenSearch indexing failed for S3 chunk",
chunk_id=chunk_id,
error=str(e),
chunk_doc=chunk_doc,
) )
print(f"[ERROR] Chunk document: {chunk_doc}")
raise raise
result = {"status": "indexed", "id": slim_doc["id"]} result = {"status": "indexed", "id": slim_doc["id"]}

View file

@ -111,7 +111,10 @@ class ChatService:
# Pass the complete filter expression as a single header to Langflow (only if we have something to send) # Pass the complete filter expression as a single header to Langflow (only if we have something to send)
if filter_expression: if filter_expression:
logger.info("Sending OpenRAG query filter to Langflow", filter_expression=filter_expression) logger.info(
"Sending OpenRAG query filter to Langflow",
filter_expression=filter_expression,
)
extra_headers["X-LANGFLOW-GLOBAL-VAR-OPENRAG-QUERY-FILTER"] = json.dumps( extra_headers["X-LANGFLOW-GLOBAL-VAR-OPENRAG-QUERY-FILTER"] = json.dumps(
filter_expression filter_expression
) )
@ -201,7 +204,11 @@ class ChatService:
return {"error": "User ID is required", "conversations": []} return {"error": "User ID is required", "conversations": []}
conversations_dict = get_user_conversations(user_id) conversations_dict = get_user_conversations(user_id)
logger.debug("Getting chat history for user", user_id=user_id, conversation_count=len(conversations_dict)) logger.debug(
"Getting chat history for user",
user_id=user_id,
conversation_count=len(conversations_dict),
)
# Convert conversations dict to list format with metadata # Convert conversations dict to list format with metadata
conversations = [] conversations = []

View file

@ -196,7 +196,11 @@ class DocumentService:
index=INDEX_NAME, id=chunk_id, body=chunk_doc index=INDEX_NAME, id=chunk_id, body=chunk_doc
) )
except Exception as e: except Exception as e:
logger.error("OpenSearch indexing failed for chunk", chunk_id=chunk_id, error=str(e)) logger.error(
"OpenSearch indexing failed for chunk",
chunk_id=chunk_id,
error=str(e),
)
logger.error("Chunk document details", chunk_doc=chunk_doc) logger.error("Chunk document details", chunk_doc=chunk_doc)
raise raise
return {"status": "indexed", "id": file_hash} return {"status": "indexed", "id": file_hash}
@ -232,7 +236,9 @@ class DocumentService:
try: try:
exists = await opensearch_client.exists(index=INDEX_NAME, id=file_hash) exists = await opensearch_client.exists(index=INDEX_NAME, id=file_hash)
except Exception as e: except Exception as e:
logger.error("OpenSearch exists check failed", file_hash=file_hash, error=str(e)) logger.error(
"OpenSearch exists check failed", file_hash=file_hash, error=str(e)
)
raise raise
if exists: if exists:
return {"status": "unchanged", "id": file_hash} return {"status": "unchanged", "id": file_hash}
@ -372,7 +378,11 @@ class DocumentService:
index=INDEX_NAME, id=chunk_id, body=chunk_doc index=INDEX_NAME, id=chunk_id, body=chunk_doc
) )
except Exception as e: except Exception as e:
logger.error("OpenSearch indexing failed for batch chunk", chunk_id=chunk_id, error=str(e)) logger.error(
"OpenSearch indexing failed for batch chunk",
chunk_id=chunk_id,
error=str(e),
)
logger.error("Chunk document details", chunk_doc=chunk_doc) logger.error("Chunk document details", chunk_doc=chunk_doc)
raise raise
@ -388,9 +398,13 @@ class DocumentService:
from concurrent.futures import BrokenExecutor from concurrent.futures import BrokenExecutor
if isinstance(e, BrokenExecutor): if isinstance(e, BrokenExecutor):
logger.error("Process pool broken while processing file", file_path=file_path) logger.error(
"Process pool broken while processing file", file_path=file_path
)
logger.info("Worker process likely crashed") logger.info("Worker process likely crashed")
logger.info("You should see detailed crash logs above from the worker process") logger.info(
"You should see detailed crash logs above from the worker process"
)
# Mark pool as broken for potential recreation # Mark pool as broken for potential recreation
self._process_pool_broken = True self._process_pool_broken = True
@ -399,11 +413,15 @@ class DocumentService:
if self._recreate_process_pool(): if self._recreate_process_pool():
logger.info("Process pool successfully recreated") logger.info("Process pool successfully recreated")
else: else:
logger.warning("Failed to recreate process pool - future operations may fail") logger.warning(
"Failed to recreate process pool - future operations may fail"
)
file_task.error = f"Worker process crashed: {str(e)}" file_task.error = f"Worker process crashed: {str(e)}"
else: else:
logger.error("Failed to process file", file_path=file_path, error=str(e)) logger.error(
"Failed to process file", file_path=file_path, error=str(e)
)
file_task.error = str(e) file_task.error = str(e)
logger.error("Full traceback available") logger.error("Full traceback available")

View file

@ -195,7 +195,9 @@ class MonitorService:
return monitors return monitors
except Exception as e: except Exception as e:
logger.error("Error listing monitors for user", user_id=user_id, error=str(e)) logger.error(
"Error listing monitors for user", user_id=user_id, error=str(e)
)
return [] return []
async def list_monitors_for_filter( async def list_monitors_for_filter(
@ -236,7 +238,9 @@ class MonitorService:
return monitors return monitors
except Exception as e: except Exception as e:
logger.error("Error listing monitors for filter", filter_id=filter_id, error=str(e)) logger.error(
"Error listing monitors for filter", filter_id=filter_id, error=str(e)
)
return [] return []
async def _get_or_create_webhook_destination( async def _get_or_create_webhook_destination(

View file

@ -138,7 +138,11 @@ class SearchService:
search_body["min_score"] = score_threshold search_body["min_score"] = score_threshold
# Authentication required - DLS will handle document filtering automatically # Authentication required - DLS will handle document filtering automatically
logger.debug("search_service authentication info", user_id=user_id, has_jwt_token=jwt_token is not None) logger.debug(
"search_service authentication info",
user_id=user_id,
has_jwt_token=jwt_token is not None,
)
if not user_id: if not user_id:
logger.debug("search_service: user_id is None/empty, returning auth error") logger.debug("search_service: user_id is None/empty, returning auth error")
return {"results": [], "error": "Authentication required"} return {"results": [], "error": "Authentication required"}
@ -151,7 +155,9 @@ class SearchService:
try: try:
results = await opensearch_client.search(index=INDEX_NAME, body=search_body) results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
except Exception as e: except Exception as e:
logger.error("OpenSearch query failed", error=str(e), search_body=search_body) logger.error(
"OpenSearch query failed", error=str(e), search_body=search_body
)
# Re-raise the exception so the API returns the error to frontend # Re-raise the exception so the API returns the error to frontend
raise raise

View file

@ -2,11 +2,14 @@ import asyncio
import uuid import uuid
import time import time
import random import random
from typing import Dict from typing import Dict, Optional
from models.tasks import TaskStatus, UploadTask, FileTask from models.tasks import TaskStatus, UploadTask, FileTask
from session_manager import AnonymousUser
from src.utils.gpu_detection import get_worker_count from src.utils.gpu_detection import get_worker_count
from utils.logging_config import get_logger
logger = get_logger(__name__)
class TaskService: class TaskService:
@ -104,7 +107,9 @@ class TaskService:
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e: except Exception as e:
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}") logger.error(
"Background upload processor failed", task_id=task_id, error=str(e)
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@ -136,7 +141,9 @@ class TaskService:
try: try:
await processor.process_item(upload_task, item, file_task) await processor.process_item(upload_task, item, file_task)
except Exception as e: except Exception as e:
print(f"[ERROR] Failed to process item {item}: {e}") logger.error(
"Failed to process item", item=str(item), error=str(e)
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@ -157,13 +164,15 @@ class TaskService:
upload_task.updated_at = time.time() upload_task.updated_at = time.time()
except asyncio.CancelledError: except asyncio.CancelledError:
print(f"[INFO] Background processor for task {task_id} was cancelled") logger.info("Background processor cancelled", task_id=task_id)
if user_id in self.task_store and task_id in self.task_store[user_id]: if user_id in self.task_store and task_id in self.task_store[user_id]:
# Task status and pending files already handled by cancel_task() # Task status and pending files already handled by cancel_task()
pass pass
raise # Re-raise to properly handle cancellation raise # Re-raise to properly handle cancellation
except Exception as e: except Exception as e:
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}") logger.error(
"Background custom processor failed", task_id=task_id, error=str(e)
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@ -171,16 +180,29 @@ class TaskService:
self.task_store[user_id][task_id].status = TaskStatus.FAILED self.task_store[user_id][task_id].status = TaskStatus.FAILED
self.task_store[user_id][task_id].updated_at = time.time() self.task_store[user_id][task_id].updated_at = time.time()
def get_task_status(self, user_id: str, task_id: str) -> dict: def get_task_status(self, user_id: str, task_id: str) -> Optional[dict]:
"""Get the status of a specific upload task""" """Get the status of a specific upload task
if (
not task_id Includes fallback to shared tasks stored under the "anonymous" user key
or user_id not in self.task_store so default system tasks are visible to all users.
or task_id not in self.task_store[user_id] """
): if not task_id:
return None return None
upload_task = self.task_store[user_id][task_id] # Prefer the caller's user_id; otherwise check shared/anonymous tasks
candidate_user_ids = [user_id, AnonymousUser().user_id]
upload_task = None
for candidate_user_id in candidate_user_ids:
if (
candidate_user_id in self.task_store
and task_id in self.task_store[candidate_user_id]
):
upload_task = self.task_store[candidate_user_id][task_id]
break
if upload_task is None:
return None
file_statuses = {} file_statuses = {}
for file_path, file_task in upload_task.file_tasks.items(): for file_path, file_task in upload_task.file_tasks.items():
@ -206,14 +228,21 @@ class TaskService:
} }
def get_all_tasks(self, user_id: str) -> list: def get_all_tasks(self, user_id: str) -> list:
"""Get all tasks for a user""" """Get all tasks for a user
if user_id not in self.task_store:
return []
tasks = [] Returns the union of the user's own tasks and shared default tasks stored
for task_id, upload_task in self.task_store[user_id].items(): under the "anonymous" user key. User-owned tasks take precedence
tasks.append( if a task_id overlaps.
{ """
tasks_by_id = {}
def add_tasks_from_store(store_user_id):
if store_user_id not in self.task_store:
return
for task_id, upload_task in self.task_store[store_user_id].items():
if task_id in tasks_by_id:
continue
tasks_by_id[task_id] = {
"task_id": upload_task.task_id, "task_id": upload_task.task_id,
"status": upload_task.status.value, "status": upload_task.status.value,
"total_files": upload_task.total_files, "total_files": upload_task.total_files,
@ -223,18 +252,36 @@ class TaskService:
"created_at": upload_task.created_at, "created_at": upload_task.created_at,
"updated_at": upload_task.updated_at, "updated_at": upload_task.updated_at,
} }
)
# Sort by creation time, most recent first # First, add user-owned tasks; then shared anonymous;
add_tasks_from_store(user_id)
add_tasks_from_store(AnonymousUser().user_id)
tasks = list(tasks_by_id.values())
tasks.sort(key=lambda x: x["created_at"], reverse=True) tasks.sort(key=lambda x: x["created_at"], reverse=True)
return tasks return tasks
def cancel_task(self, user_id: str, task_id: str) -> bool: def cancel_task(self, user_id: str, task_id: str) -> bool:
"""Cancel a task if it exists and is not already completed""" """Cancel a task if it exists and is not already completed.
if user_id not in self.task_store or task_id not in self.task_store[user_id]:
Supports cancellation of shared default tasks stored under the anonymous user.
"""
# Check candidate user IDs first, then anonymous to find which user ID the task is mapped to
candidate_user_ids = [user_id, AnonymousUser().user_id]
store_user_id = None
for candidate_user_id in candidate_user_ids:
if (
candidate_user_id in self.task_store
and task_id in self.task_store[candidate_user_id]
):
store_user_id = candidate_user_id
break
if store_user_id is None:
return False return False
upload_task = self.task_store[user_id][task_id] upload_task = self.task_store[store_user_id][task_id]
# Can only cancel pending or running tasks # Can only cancel pending or running tasks
if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]: if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:

View file

@ -6,8 +6,13 @@ from typing import Dict, Optional, Any
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
import os import os
from utils.logging_config import get_logger
logger = get_logger(__name__)
from utils.logging_config import get_logger
logger = get_logger(__name__)
@dataclass @dataclass
class User: class User:
"""User information from OAuth provider""" """User information from OAuth provider"""
@ -26,6 +31,19 @@ class User:
if self.last_login is None: if self.last_login is None:
self.last_login = datetime.now() self.last_login = datetime.now()
class AnonymousUser(User):
"""Anonymous user"""
def __init__(self):
super().__init__(
user_id="anonymous",
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
)
class SessionManager: class SessionManager:
"""Manages user sessions and JWT tokens""" """Manages user sessions and JWT tokens"""
@ -80,13 +98,15 @@ class SessionManager:
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
else: else:
print( logger.error(
f"Failed to get user info: {response.status_code} {response.text}" "Failed to get user info",
status_code=response.status_code,
response_text=response.text,
) )
return None return None
except Exception as e: except Exception as e:
print(f"Error getting user info: {e}") logger.error("Error getting user info", error=str(e))
return None return None
async def create_user_session( async def create_user_session(
@ -173,19 +193,24 @@ class SessionManager:
"""Get or create OpenSearch client for user with their JWT""" """Get or create OpenSearch client for user with their JWT"""
from config.settings import is_no_auth_mode from config.settings import is_no_auth_mode
print( logger.debug(
f"[DEBUG] get_user_opensearch_client: user_id={user_id}, jwt_token={'None' if jwt_token is None else 'present'}, no_auth_mode={is_no_auth_mode()}" "get_user_opensearch_client",
user_id=user_id,
jwt_token_present=(jwt_token is not None),
no_auth_mode=is_no_auth_mode(),
) )
# In no-auth mode, create anonymous JWT for OpenSearch DLS # In no-auth mode, create anonymous JWT for OpenSearch DLS
if is_no_auth_mode() and jwt_token is None: if jwt_token is None and (is_no_auth_mode() or user_id in (None, AnonymousUser().user_id)):
if not hasattr(self, "_anonymous_jwt"): if not hasattr(self, "_anonymous_jwt"):
# Create anonymous JWT token for OpenSearch OIDC # Create anonymous JWT token for OpenSearch OIDC
print(f"[DEBUG] Creating anonymous JWT...") logger.debug("Creating anonymous JWT")
self._anonymous_jwt = self._create_anonymous_jwt() self._anonymous_jwt = self._create_anonymous_jwt()
print(f"[DEBUG] Anonymous JWT created: {self._anonymous_jwt[:50]}...") logger.debug(
"Anonymous JWT created", jwt_prefix=self._anonymous_jwt[:50]
)
jwt_token = self._anonymous_jwt jwt_token = self._anonymous_jwt
print(f"[DEBUG] Using anonymous JWT for OpenSearch") logger.debug("Using anonymous JWT for OpenSearch")
# Check if we have a cached client for this user # Check if we have a cached client for this user
if user_id not in self.user_opensearch_clients: if user_id not in self.user_opensearch_clients:
@ -199,14 +224,5 @@ class SessionManager:
def _create_anonymous_jwt(self) -> str: def _create_anonymous_jwt(self) -> str:
"""Create JWT token for anonymous user in no-auth mode""" """Create JWT token for anonymous user in no-auth mode"""
anonymous_user = User( anonymous_user = AnonymousUser()
user_id="anonymous",
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
return self.create_jwt_token(anonymous_user) return self.create_jwt_token(anonymous_user)

View file

@ -1 +1 @@
"""OpenRAG Terminal User Interface package.""" """OpenRAG Terminal User Interface package."""

View file

@ -3,6 +3,9 @@
import sys import sys
from pathlib import Path from pathlib import Path
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from utils.logging_config import get_logger
logger = get_logger(__name__)
from .screens.welcome import WelcomeScreen from .screens.welcome import WelcomeScreen
from .screens.config import ConfigScreen from .screens.config import ConfigScreen
@ -17,10 +20,10 @@ from .widgets.diagnostics_notification import notify_with_diagnostics
class OpenRAGTUI(App): class OpenRAGTUI(App):
"""OpenRAG Terminal User Interface application.""" """OpenRAG Terminal User Interface application."""
TITLE = "OpenRAG TUI" TITLE = "OpenRAG TUI"
SUB_TITLE = "Container Management & Configuration" SUB_TITLE = "Container Management & Configuration"
CSS = """ CSS = """
Screen { Screen {
background: $background; background: $background;
@ -172,13 +175,13 @@ class OpenRAGTUI(App):
padding: 1; padding: 1;
} }
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.platform_detector = PlatformDetector() self.platform_detector = PlatformDetector()
self.container_manager = ContainerManager() self.container_manager = ContainerManager()
self.env_manager = EnvManager() self.env_manager = EnvManager()
def on_mount(self) -> None: def on_mount(self) -> None:
"""Initialize the application.""" """Initialize the application."""
# Check for runtime availability and show appropriate screen # Check for runtime availability and show appropriate screen
@ -187,31 +190,33 @@ class OpenRAGTUI(App):
self, self,
"No container runtime found. Please install Docker or Podman.", "No container runtime found. Please install Docker or Podman.",
severity="warning", severity="warning",
timeout=10 timeout=10,
) )
# Load existing config if available # Load existing config if available
config_exists = self.env_manager.load_existing_env() config_exists = self.env_manager.load_existing_env()
# Start with welcome screen # Start with welcome screen
self.push_screen(WelcomeScreen()) self.push_screen(WelcomeScreen())
async def action_quit(self) -> None: async def action_quit(self) -> None:
"""Quit the application.""" """Quit the application."""
self.exit() self.exit()
def check_runtime_requirements(self) -> tuple[bool, str]: def check_runtime_requirements(self) -> tuple[bool, str]:
"""Check if runtime requirements are met.""" """Check if runtime requirements are met."""
if not self.container_manager.is_available(): if not self.container_manager.is_available():
return False, self.platform_detector.get_installation_instructions() return False, self.platform_detector.get_installation_instructions()
# Check Podman macOS memory if applicable # Check Podman macOS memory if applicable
runtime_info = self.container_manager.get_runtime_info() runtime_info = self.container_manager.get_runtime_info()
if runtime_info.runtime_type.value == "podman": if runtime_info.runtime_type.value == "podman":
is_sufficient, _, message = self.platform_detector.check_podman_macos_memory() is_sufficient, _, message = (
self.platform_detector.check_podman_macos_memory()
)
if not is_sufficient: if not is_sufficient:
return False, f"Podman VM memory insufficient:\n{message}" return False, f"Podman VM memory insufficient:\n{message}"
return True, "Runtime requirements satisfied" return True, "Runtime requirements satisfied"
@ -221,10 +226,10 @@ def run_tui():
app = OpenRAGTUI() app = OpenRAGTUI()
app.run() app.run()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nOpenRAG TUI interrupted by user") logger.info("OpenRAG TUI interrupted by user")
sys.exit(0) sys.exit(0)
except Exception as e: except Exception as e:
print(f"Error running OpenRAG TUI: {e}") logger.error("Error running OpenRAG TUI", error=str(e))
sys.exit(1) sys.exit(1)

View file

@ -1 +1 @@
"""TUI managers package.""" """TUI managers package."""

View file

@ -8,6 +8,9 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, AsyncIterator from typing import Dict, List, Optional, AsyncIterator
from utils.logging_config import get_logger
logger = get_logger(__name__)
from ..utils.platform import PlatformDetector, RuntimeInfo, RuntimeType from ..utils.platform import PlatformDetector, RuntimeInfo, RuntimeType
from utils.gpu_detection import detect_gpu_devices from utils.gpu_detection import detect_gpu_devices
@ -15,6 +18,7 @@ from utils.gpu_detection import detect_gpu_devices
class ServiceStatus(Enum): class ServiceStatus(Enum):
"""Container service status.""" """Container service status."""
UNKNOWN = "unknown" UNKNOWN = "unknown"
RUNNING = "running" RUNNING = "running"
STOPPED = "stopped" STOPPED = "stopped"
@ -27,6 +31,7 @@ class ServiceStatus(Enum):
@dataclass @dataclass
class ServiceInfo: class ServiceInfo:
"""Container service information.""" """Container service information."""
name: str name: str
status: ServiceStatus status: ServiceStatus
health: Optional[str] = None health: Optional[str] = None
@ -34,7 +39,7 @@ class ServiceInfo:
image: Optional[str] = None image: Optional[str] = None
image_digest: Optional[str] = None image_digest: Optional[str] = None
created: Optional[str] = None created: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.ports is None: if self.ports is None:
self.ports = [] self.ports = []
@ -42,7 +47,7 @@ class ServiceInfo:
class ContainerManager: class ContainerManager:
"""Manages Docker/Podman container lifecycle for OpenRAG.""" """Manages Docker/Podman container lifecycle for OpenRAG."""
def __init__(self, compose_file: Optional[Path] = None): def __init__(self, compose_file: Optional[Path] = None):
self.platform_detector = PlatformDetector() self.platform_detector = PlatformDetector()
self.runtime_info = self.platform_detector.detect_runtime() self.runtime_info = self.platform_detector.detect_runtime()
@ -56,138 +61,142 @@ class ContainerManager:
self.use_cpu_compose = not has_gpu self.use_cpu_compose = not has_gpu
except Exception: except Exception:
self.use_cpu_compose = True self.use_cpu_compose = True
# Expected services based on compose files # Expected services based on compose files
self.expected_services = [ self.expected_services = [
"openrag-backend", "openrag-backend",
"openrag-frontend", "openrag-frontend",
"opensearch", "opensearch",
"dashboards", "dashboards",
"langflow" "langflow",
] ]
# Map container names to service names # Map container names to service names
self.container_name_map = { self.container_name_map = {
"openrag-backend": "openrag-backend", "openrag-backend": "openrag-backend",
"openrag-frontend": "openrag-frontend", "openrag-frontend": "openrag-frontend",
"os": "opensearch", "os": "opensearch",
"osdash": "dashboards", "osdash": "dashboards",
"langflow": "langflow" "langflow": "langflow",
} }
def is_available(self) -> bool: def is_available(self) -> bool:
"""Check if container runtime is available.""" """Check if container runtime is available."""
return self.runtime_info.runtime_type != RuntimeType.NONE return self.runtime_info.runtime_type != RuntimeType.NONE
def get_runtime_info(self) -> RuntimeInfo: def get_runtime_info(self) -> RuntimeInfo:
"""Get container runtime information.""" """Get container runtime information."""
return self.runtime_info return self.runtime_info
def get_installation_help(self) -> str: def get_installation_help(self) -> str:
"""Get installation instructions if runtime is not available.""" """Get installation instructions if runtime is not available."""
return self.platform_detector.get_installation_instructions() return self.platform_detector.get_installation_instructions()
async def _run_compose_command(self, args: List[str], cpu_mode: Optional[bool] = None) -> tuple[bool, str, str]: async def _run_compose_command(
self, args: List[str], cpu_mode: Optional[bool] = None
) -> tuple[bool, str, str]:
"""Run a compose command and return (success, stdout, stderr).""" """Run a compose command and return (success, stdout, stderr)."""
if not self.is_available(): if not self.is_available():
return False, "", "No container runtime available" return False, "", "No container runtime available"
if cpu_mode is None: if cpu_mode is None:
cpu_mode = self.use_cpu_compose cpu_mode = self.use_cpu_compose
compose_file = self.cpu_compose_file if cpu_mode else self.compose_file compose_file = self.cpu_compose_file if cpu_mode else self.compose_file
cmd = self.runtime_info.compose_command + ["-f", str(compose_file)] + args cmd = self.runtime_info.compose_command + ["-f", str(compose_file)] + args
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=Path.cwd() cwd=Path.cwd(),
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
stdout_text = stdout.decode() if stdout else "" stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else "" stderr_text = stderr.decode() if stderr else ""
success = process.returncode == 0 success = process.returncode == 0
return success, stdout_text, stderr_text return success, stdout_text, stderr_text
except Exception as e: except Exception as e:
return False, "", f"Command execution failed: {e}" return False, "", f"Command execution failed: {e}"
async def _run_compose_command_streaming(self, args: List[str], cpu_mode: Optional[bool] = None) -> AsyncIterator[str]: async def _run_compose_command_streaming(
self, args: List[str], cpu_mode: Optional[bool] = None
) -> AsyncIterator[str]:
"""Run a compose command and yield output lines in real-time.""" """Run a compose command and yield output lines in real-time."""
if not self.is_available(): if not self.is_available():
yield "No container runtime available" yield "No container runtime available"
return return
if cpu_mode is None: if cpu_mode is None:
cpu_mode = self.use_cpu_compose cpu_mode = self.use_cpu_compose
compose_file = self.cpu_compose_file if cpu_mode else self.compose_file compose_file = self.cpu_compose_file if cpu_mode else self.compose_file
cmd = self.runtime_info.compose_command + ["-f", str(compose_file)] + args cmd = self.runtime_info.compose_command + ["-f", str(compose_file)] + args
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT, # Combine stderr with stdout for unified output stderr=asyncio.subprocess.STDOUT, # Combine stderr with stdout for unified output
cwd=Path.cwd() cwd=Path.cwd(),
) )
# Simple approach: read line by line and yield each one # Simple approach: read line by line and yield each one
while True: while True:
line = await process.stdout.readline() line = await process.stdout.readline()
if not line: if not line:
break break
line_text = line.decode().rstrip() line_text = line.decode().rstrip()
if line_text: if line_text:
yield line_text yield line_text
# Wait for process to complete # Wait for process to complete
await process.wait() await process.wait()
except Exception as e: except Exception as e:
yield f"Command execution failed: {e}" yield f"Command execution failed: {e}"
async def _run_runtime_command(self, args: List[str]) -> tuple[bool, str, str]: async def _run_runtime_command(self, args: List[str]) -> tuple[bool, str, str]:
"""Run a runtime command (docker/podman) and return (success, stdout, stderr).""" """Run a runtime command (docker/podman) and return (success, stdout, stderr)."""
if not self.is_available(): if not self.is_available():
return False, "", "No container runtime available" return False, "", "No container runtime available"
cmd = self.runtime_info.runtime_command + args cmd = self.runtime_info.runtime_command + args
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
stdout_text = stdout.decode() if stdout else "" stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else "" stderr_text = stderr.decode() if stderr else ""
success = process.returncode == 0 success = process.returncode == 0
return success, stdout_text, stderr_text return success, stdout_text, stderr_text
except Exception as e: except Exception as e:
return False, "", f"Command execution failed: {e}" return False, "", f"Command execution failed: {e}"
def _process_service_json(self, service: Dict, services: Dict[str, ServiceInfo]) -> None: def _process_service_json(
self, service: Dict, services: Dict[str, ServiceInfo]
) -> None:
"""Process a service JSON object and add it to the services dict.""" """Process a service JSON object and add it to the services dict."""
# Debug print to see the actual service data # Debug print to see the actual service data
print(f"DEBUG: Processing service data: {json.dumps(service, indent=2)}") logger.debug("Processing service data", service_data=service)
container_name = service.get("Name", "") container_name = service.get("Name", "")
# Map container name to service name # Map container name to service name
service_name = self.container_name_map.get(container_name) service_name = self.container_name_map.get(container_name)
if not service_name: if not service_name:
return return
state = service.get("State", "").lower() state = service.get("State", "").lower()
# Map compose states to our status enum # Map compose states to our status enum
if "running" in state: if "running" in state:
status = ServiceStatus.RUNNING status = ServiceStatus.RUNNING
@ -197,17 +206,19 @@ class ContainerManager:
status = ServiceStatus.STARTING status = ServiceStatus.STARTING
else: else:
status = ServiceStatus.UNKNOWN status = ServiceStatus.UNKNOWN
# Extract health - use Status if Health is empty # Extract health - use Status if Health is empty
health = service.get("Health", "") or service.get("Status", "N/A") health = service.get("Health", "") or service.get("Status", "N/A")
# Extract ports # Extract ports
ports_str = service.get("Ports", "") ports_str = service.get("Ports", "")
ports = [p.strip() for p in ports_str.split(",") if p.strip()] if ports_str else [] ports = (
[p.strip() for p in ports_str.split(",") if p.strip()] if ports_str else []
)
# Extract image # Extract image
image = service.get("Image", "N/A") image = service.get("Image", "N/A")
services[service_name] = ServiceInfo( services[service_name] = ServiceInfo(
name=service_name, name=service_name,
status=status, status=status,
@ -215,23 +226,25 @@ class ContainerManager:
ports=ports, ports=ports,
image=image, image=image,
) )
async def get_service_status(self, force_refresh: bool = False) -> Dict[str, ServiceInfo]: async def get_service_status(
self, force_refresh: bool = False
) -> Dict[str, ServiceInfo]:
"""Get current status of all services.""" """Get current status of all services."""
current_time = time.time() current_time = time.time()
# Use cache if recent and not forcing refresh # Use cache if recent and not forcing refresh
if not force_refresh and current_time - self.last_status_update < 5: if not force_refresh and current_time - self.last_status_update < 5:
return self.services_cache return self.services_cache
services = {} services = {}
# Different approach for Podman vs Docker # Different approach for Podman vs Docker
if self.runtime_info.runtime_type == RuntimeType.PODMAN: if self.runtime_info.runtime_type == RuntimeType.PODMAN:
# For Podman, use direct podman ps command instead of compose # For Podman, use direct podman ps command instead of compose
cmd = ["ps", "--all", "--format", "json"] cmd = ["ps", "--all", "--format", "json"]
success, stdout, stderr = await self._run_runtime_command(cmd) success, stdout, stderr = await self._run_runtime_command(cmd)
if success and stdout.strip(): if success and stdout.strip():
try: try:
containers = json.loads(stdout.strip()) containers = json.loads(stdout.strip())
@ -240,12 +253,12 @@ class ContainerManager:
names = container.get("Names", []) names = container.get("Names", [])
if not names: if not names:
continue continue
container_name = names[0] container_name = names[0]
service_name = self.container_name_map.get(container_name) service_name = self.container_name_map.get(container_name)
if not service_name: if not service_name:
continue continue
# Get container state # Get container state
state = container.get("State", "").lower() state = container.get("State", "").lower()
if "running" in state: if "running" in state:
@ -256,7 +269,7 @@ class ContainerManager:
status = ServiceStatus.STARTING status = ServiceStatus.STARTING
else: else:
status = ServiceStatus.UNKNOWN status = ServiceStatus.UNKNOWN
# Get other container info # Get other container info
image = container.get("Image", "N/A") image = container.get("Image", "N/A")
ports = [] ports = []
@ -268,7 +281,7 @@ class ContainerManager:
container_port = port.get("container_port") container_port = port.get("container_port")
if host_port and container_port: if host_port and container_port:
ports.append(f"{host_port}:{container_port}") ports.append(f"{host_port}:{container_port}")
services[service_name] = ServiceInfo( services[service_name] = ServiceInfo(
name=service_name, name=service_name,
status=status, status=status,
@ -280,55 +293,63 @@ class ContainerManager:
pass pass
else: else:
# For Docker, use compose ps command # For Docker, use compose ps command
success, stdout, stderr = await self._run_compose_command(["ps", "--format", "json"]) success, stdout, stderr = await self._run_compose_command(
["ps", "--format", "json"]
)
if success and stdout.strip(): if success and stdout.strip():
try: try:
# Handle both single JSON object (Podman) and multiple JSON objects (Docker) # Handle both single JSON object (Podman) and multiple JSON objects (Docker)
if stdout.strip().startswith('[') and stdout.strip().endswith(']'): if stdout.strip().startswith("[") and stdout.strip().endswith("]"):
# JSON array format # JSON array format
service_list = json.loads(stdout.strip()) service_list = json.loads(stdout.strip())
for service in service_list: for service in service_list:
self._process_service_json(service, services) self._process_service_json(service, services)
else: else:
# Line-by-line JSON format # Line-by-line JSON format
for line in stdout.strip().split('\n'): for line in stdout.strip().split("\n"):
if line.strip() and line.startswith('{'): if line.strip() and line.startswith("{"):
service = json.loads(line) service = json.loads(line)
self._process_service_json(service, services) self._process_service_json(service, services)
except json.JSONDecodeError: except json.JSONDecodeError:
# Fallback to parsing text output # Fallback to parsing text output
lines = stdout.strip().split('\n') lines = stdout.strip().split("\n")
if len(lines) > 1: # Make sure we have at least a header and one line if (
len(lines) > 1
): # Make sure we have at least a header and one line
for line in lines[1:]: # Skip header for line in lines[1:]: # Skip header
if line.strip(): if line.strip():
parts = line.split() parts = line.split()
if len(parts) >= 3: if len(parts) >= 3:
name = parts[0] name = parts[0]
# Only include our expected services # Only include our expected services
if name not in self.expected_services: if name not in self.expected_services:
continue continue
state = parts[2].lower() state = parts[2].lower()
if "up" in state: if "up" in state:
status = ServiceStatus.RUNNING status = ServiceStatus.RUNNING
elif "exit" in state: elif "exit" in state:
status = ServiceStatus.STOPPED status = ServiceStatus.STOPPED
else: else:
status = ServiceStatus.UNKNOWN status = ServiceStatus.UNKNOWN
services[name] = ServiceInfo(name=name, status=status) services[name] = ServiceInfo(
name=name, status=status
)
# Add expected services that weren't found # Add expected services that weren't found
for expected in self.expected_services: for expected in self.expected_services:
if expected not in services: if expected not in services:
services[expected] = ServiceInfo(name=expected, status=ServiceStatus.MISSING) services[expected] = ServiceInfo(
name=expected, status=ServiceStatus.MISSING
)
self.services_cache = services self.services_cache = services
self.last_status_update = current_time self.last_status_update = current_time
return services return services
async def get_images_digests(self, images: List[str]) -> Dict[str, str]: async def get_images_digests(self, images: List[str]) -> Dict[str, str]:
@ -337,9 +358,9 @@ class ContainerManager:
for image in images: for image in images:
if not image or image in digests: if not image or image in digests:
continue continue
success, stdout, _ = await self._run_runtime_command([ success, stdout, _ = await self._run_runtime_command(
"image", "inspect", image, "--format", "{{.Id}}" ["image", "inspect", image, "--format", "{{.Id}}"]
]) )
if success and stdout.strip(): if success and stdout.strip():
digests[image] = stdout.strip().splitlines()[0] digests[image] = stdout.strip().splitlines()[0]
return digests return digests
@ -353,13 +374,15 @@ class ContainerManager:
continue continue
for line in compose.read_text().splitlines(): for line in compose.read_text().splitlines():
line = line.strip() line = line.strip()
if not line or line.startswith('#'): if not line or line.startswith("#"):
continue continue
if line.startswith('image:'): if line.startswith("image:"):
# image: repo/name:tag # image: repo/name:tag
val = line.split(':', 1)[1].strip() val = line.split(":", 1)[1].strip()
# Remove quotes if present # Remove quotes if present
if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")): if (val.startswith('"') and val.endswith('"')) or (
val.startswith("'") and val.endswith("'")
):
val = val[1:-1] val = val[1:-1]
images.add(val) images.add(val)
except Exception: except Exception:
@ -374,53 +397,61 @@ class ContainerManager:
expected = self._parse_compose_images() expected = self._parse_compose_images()
results: list[tuple[str, str]] = [] results: list[tuple[str, str]] = []
for image in expected: for image in expected:
digest = '-' digest = "-"
success, stdout, _ = await self._run_runtime_command([ success, stdout, _ = await self._run_runtime_command(
'image', 'inspect', image, '--format', '{{.Id}}' ["image", "inspect", image, "--format", "{{.Id}}"]
]) )
if success and stdout.strip(): if success and stdout.strip():
digest = stdout.strip().splitlines()[0] digest = stdout.strip().splitlines()[0]
results.append((image, digest)) results.append((image, digest))
results.sort(key=lambda x: x[0]) results.sort(key=lambda x: x[0])
return results return results
async def start_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]: async def start_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Start all services and yield progress updates.""" """Start all services and yield progress updates."""
yield False, "Starting OpenRAG services..." yield False, "Starting OpenRAG services..."
success, stdout, stderr = await self._run_compose_command(["up", "-d"], cpu_mode) success, stdout, stderr = await self._run_compose_command(
["up", "-d"], cpu_mode
)
if success: if success:
yield True, "Services started successfully" yield True, "Services started successfully"
else: else:
yield False, f"Failed to start services: {stderr}" yield False, f"Failed to start services: {stderr}"
async def stop_services(self) -> AsyncIterator[tuple[bool, str]]: async def stop_services(self) -> AsyncIterator[tuple[bool, str]]:
"""Stop all services and yield progress updates.""" """Stop all services and yield progress updates."""
yield False, "Stopping OpenRAG services..." yield False, "Stopping OpenRAG services..."
success, stdout, stderr = await self._run_compose_command(["down"]) success, stdout, stderr = await self._run_compose_command(["down"])
if success: if success:
yield True, "Services stopped successfully" yield True, "Services stopped successfully"
else: else:
yield False, f"Failed to stop services: {stderr}" yield False, f"Failed to stop services: {stderr}"
async def restart_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]: async def restart_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Restart all services and yield progress updates.""" """Restart all services and yield progress updates."""
yield False, "Restarting OpenRAG services..." yield False, "Restarting OpenRAG services..."
success, stdout, stderr = await self._run_compose_command(["restart"], cpu_mode) success, stdout, stderr = await self._run_compose_command(["restart"], cpu_mode)
if success: if success:
yield True, "Services restarted successfully" yield True, "Services restarted successfully"
else: else:
yield False, f"Failed to restart services: {stderr}" yield False, f"Failed to restart services: {stderr}"
async def upgrade_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]: async def upgrade_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Upgrade services (pull latest images and restart) and yield progress updates.""" """Upgrade services (pull latest images and restart) and yield progress updates."""
yield False, "Pulling latest images..." yield False, "Pulling latest images..."
# Pull latest images with streaming output # Pull latest images with streaming output
pull_success = True pull_success = True
async for line in self._run_compose_command_streaming(["pull"], cpu_mode): async for line in self._run_compose_command_streaming(["pull"], cpu_mode):
@ -428,75 +459,89 @@ class ContainerManager:
# Check for error patterns in the output # Check for error patterns in the output
if "error" in line.lower() or "failed" in line.lower(): if "error" in line.lower() or "failed" in line.lower():
pull_success = False pull_success = False
if not pull_success: if not pull_success:
yield False, "Failed to pull some images, but continuing with restart..." yield False, "Failed to pull some images, but continuing with restart..."
yield False, "Images updated, restarting services..." yield False, "Images updated, restarting services..."
# Restart with new images using streaming output # Restart with new images using streaming output
restart_success = True restart_success = True
async for line in self._run_compose_command_streaming(["up", "-d", "--force-recreate"], cpu_mode): async for line in self._run_compose_command_streaming(
["up", "-d", "--force-recreate"], cpu_mode
):
yield False, line yield False, line
# Check for error patterns in the output # Check for error patterns in the output
if "error" in line.lower() or "failed" in line.lower(): if "error" in line.lower() or "failed" in line.lower():
restart_success = False restart_success = False
if restart_success: if restart_success:
yield True, "Services upgraded and restarted successfully" yield True, "Services upgraded and restarted successfully"
else: else:
yield False, "Some errors occurred during service restart" yield False, "Some errors occurred during service restart"
async def reset_services(self) -> AsyncIterator[tuple[bool, str]]: async def reset_services(self) -> AsyncIterator[tuple[bool, str]]:
"""Reset all services (stop, remove containers/volumes, clear data) and yield progress updates.""" """Reset all services (stop, remove containers/volumes, clear data) and yield progress updates."""
yield False, "Stopping all services..." yield False, "Stopping all services..."
# Stop and remove everything # Stop and remove everything
success, stdout, stderr = await self._run_compose_command([ success, stdout, stderr = await self._run_compose_command(
"down", ["down", "--volumes", "--remove-orphans", "--rmi", "local"]
"--volumes", )
"--remove-orphans",
"--rmi", "local"
])
if not success: if not success:
yield False, f"Failed to stop services: {stderr}" yield False, f"Failed to stop services: {stderr}"
return return
yield False, "Cleaning up container data..." yield False, "Cleaning up container data..."
# Additional cleanup - remove any remaining containers/volumes # Additional cleanup - remove any remaining containers/volumes
# This is more thorough than just compose down # This is more thorough than just compose down
await self._run_runtime_command(["system", "prune", "-f"]) await self._run_runtime_command(["system", "prune", "-f"])
yield True, "System reset completed - all containers, volumes, and local images removed" yield (
True,
async def get_service_logs(self, service_name: str, lines: int = 100) -> tuple[bool, str]: "System reset completed - all containers, volumes, and local images removed",
)
async def get_service_logs(
self, service_name: str, lines: int = 100
) -> tuple[bool, str]:
"""Get logs for a specific service.""" """Get logs for a specific service."""
success, stdout, stderr = await self._run_compose_command(["logs", "--tail", str(lines), service_name]) success, stdout, stderr = await self._run_compose_command(
["logs", "--tail", str(lines), service_name]
)
if success: if success:
return True, stdout return True, stdout
else: else:
return False, f"Failed to get logs: {stderr}" return False, f"Failed to get logs: {stderr}"
async def follow_service_logs(self, service_name: str) -> AsyncIterator[str]: async def follow_service_logs(self, service_name: str) -> AsyncIterator[str]:
"""Follow logs for a specific service.""" """Follow logs for a specific service."""
if not self.is_available(): if not self.is_available():
yield "No container runtime available" yield "No container runtime available"
return return
compose_file = self.cpu_compose_file if self.use_cpu_compose else self.compose_file compose_file = (
cmd = self.runtime_info.compose_command + ["-f", str(compose_file), "logs", "-f", service_name] self.cpu_compose_file if self.use_cpu_compose else self.compose_file
)
cmd = self.runtime_info.compose_command + [
"-f",
str(compose_file),
"logs",
"-f",
service_name,
]
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT, stderr=asyncio.subprocess.STDOUT,
cwd=Path.cwd() cwd=Path.cwd(),
) )
if process.stdout: if process.stdout:
while True: while True:
line = await process.stdout.readline() line = await process.stdout.readline()
@ -506,20 +551,22 @@ class ContainerManager:
break break
else: else:
yield "Error: Unable to read process output" yield "Error: Unable to read process output"
except Exception as e: except Exception as e:
yield f"Error following logs: {e}" yield f"Error following logs: {e}"
async def get_system_stats(self) -> Dict[str, Dict[str, str]]: async def get_system_stats(self) -> Dict[str, Dict[str, str]]:
"""Get system resource usage statistics.""" """Get system resource usage statistics."""
stats = {} stats = {}
# Get container stats # Get container stats
success, stdout, stderr = await self._run_runtime_command(["stats", "--no-stream", "--format", "json"]) success, stdout, stderr = await self._run_runtime_command(
["stats", "--no-stream", "--format", "json"]
)
if success and stdout.strip(): if success and stdout.strip():
try: try:
for line in stdout.strip().split('\n'): for line in stdout.strip().split("\n"):
if line.strip(): if line.strip():
data = json.loads(line) data = json.loads(line)
name = data.get("Name", data.get("Container", "")) name = data.get("Name", data.get("Container", ""))
@ -533,14 +580,14 @@ class ContainerManager:
} }
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
return stats return stats
async def debug_podman_services(self) -> str: async def debug_podman_services(self) -> str:
"""Run a direct Podman command to check services status for debugging.""" """Run a direct Podman command to check services status for debugging."""
if self.runtime_info.runtime_type != RuntimeType.PODMAN: if self.runtime_info.runtime_type != RuntimeType.PODMAN:
return "Not using Podman" return "Not using Podman"
# Try direct podman command # Try direct podman command
cmd = ["podman", "ps", "--all", "--format", "json"] cmd = ["podman", "ps", "--all", "--format", "json"]
try: try:
@ -548,18 +595,18 @@ class ContainerManager:
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=Path.cwd() cwd=Path.cwd(),
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
stdout_text = stdout.decode() if stdout else "" stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else "" stderr_text = stderr.decode() if stderr else ""
result = f"Command: {' '.join(cmd)}\n" result = f"Command: {' '.join(cmd)}\n"
result += f"Return code: {process.returncode}\n" result += f"Return code: {process.returncode}\n"
result += f"Stdout: {stdout_text}\n" result += f"Stdout: {stdout_text}\n"
result += f"Stderr: {stderr_text}\n" result += f"Stderr: {stderr_text}\n"
# Try to parse the output # Try to parse the output
if stdout_text.strip(): if stdout_text.strip():
try: try:
@ -571,16 +618,18 @@ class ContainerManager:
result += f" - {name}: {state}\n" result += f" - {name}: {state}\n"
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
result += f"\nFailed to parse JSON: {e}\n" result += f"\nFailed to parse JSON: {e}\n"
return result return result
except Exception as e: except Exception as e:
return f"Error executing command: {e}" return f"Error executing command: {e}"
def check_podman_macos_memory(self) -> tuple[bool, str]: def check_podman_macos_memory(self) -> tuple[bool, str]:
"""Check if Podman VM has sufficient memory on macOS.""" """Check if Podman VM has sufficient memory on macOS."""
if self.runtime_info.runtime_type != RuntimeType.PODMAN: if self.runtime_info.runtime_type != RuntimeType.PODMAN:
return True, "Not using Podman" return True, "Not using Podman"
is_sufficient, memory_mb, message = self.platform_detector.check_podman_macos_memory() is_sufficient, memory_mb, message = (
self.platform_detector.check_podman_macos_memory()
)
return is_sufficient, message return is_sufficient, message

View file

@ -7,6 +7,9 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, List from typing import Dict, Optional, List
from dataclasses import dataclass, field from dataclasses import dataclass, field
from utils.logging_config import get_logger
logger = get_logger(__name__)
from ..utils.validation import ( from ..utils.validation import (
validate_openai_api_key, validate_openai_api_key,
@ -14,13 +17,14 @@ from ..utils.validation import (
validate_non_empty, validate_non_empty,
validate_url, validate_url,
validate_documents_paths, validate_documents_paths,
sanitize_env_value sanitize_env_value,
) )
@dataclass @dataclass
class EnvConfig: class EnvConfig:
"""Environment configuration data.""" """Environment configuration data."""
# Core settings # Core settings
openai_api_key: str = "" openai_api_key: str = ""
opensearch_password: str = "" opensearch_password: str = ""
@ -28,155 +32,186 @@ class EnvConfig:
langflow_superuser: str = "admin" langflow_superuser: str = "admin"
langflow_superuser_password: str = "" langflow_superuser_password: str = ""
flow_id: str = "1098eea1-6649-4e1d-aed1-b77249fb8dd0" flow_id: str = "1098eea1-6649-4e1d-aed1-b77249fb8dd0"
# OAuth settings # OAuth settings
google_oauth_client_id: str = "" google_oauth_client_id: str = ""
google_oauth_client_secret: str = "" google_oauth_client_secret: str = ""
microsoft_graph_oauth_client_id: str = "" microsoft_graph_oauth_client_id: str = ""
microsoft_graph_oauth_client_secret: str = "" microsoft_graph_oauth_client_secret: str = ""
# Optional settings # Optional settings
webhook_base_url: str = "" webhook_base_url: str = ""
aws_access_key_id: str = "" aws_access_key_id: str = ""
aws_secret_access_key: str = "" aws_secret_access_key: str = ""
langflow_public_url: str = "" langflow_public_url: str = ""
# Langflow auth settings # Langflow auth settings
langflow_auto_login: str = "False" langflow_auto_login: str = "False"
langflow_new_user_is_active: str = "False" langflow_new_user_is_active: str = "False"
langflow_enable_superuser_cli: str = "False" langflow_enable_superuser_cli: str = "False"
# Document paths (comma-separated) # Document paths (comma-separated)
openrag_documents_paths: str = "./documents" openrag_documents_paths: str = "./documents"
# Validation errors # Validation errors
validation_errors: Dict[str, str] = field(default_factory=dict) validation_errors: Dict[str, str] = field(default_factory=dict)
class EnvManager: class EnvManager:
"""Manages environment configuration for OpenRAG.""" """Manages environment configuration for OpenRAG."""
def __init__(self, env_file: Optional[Path] = None): def __init__(self, env_file: Optional[Path] = None):
self.env_file = env_file or Path(".env") self.env_file = env_file or Path(".env")
self.config = EnvConfig() self.config = EnvConfig()
def generate_secure_password(self) -> str: def generate_secure_password(self) -> str:
"""Generate a secure password for OpenSearch.""" """Generate a secure password for OpenSearch."""
# Generate a 16-character password with letters, digits, and symbols # Generate a 16-character password with letters, digits, and symbols
alphabet = string.ascii_letters + string.digits + "!@#$%^&*" alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
return ''.join(secrets.choice(alphabet) for _ in range(16)) return "".join(secrets.choice(alphabet) for _ in range(16))
def generate_langflow_secret_key(self) -> str: def generate_langflow_secret_key(self) -> str:
"""Generate a secure secret key for Langflow.""" """Generate a secure secret key for Langflow."""
return secrets.token_urlsafe(32) return secrets.token_urlsafe(32)
def load_existing_env(self) -> bool: def load_existing_env(self) -> bool:
"""Load existing .env file if it exists.""" """Load existing .env file if it exists."""
if not self.env_file.exists(): if not self.env_file.exists():
return False return False
try: try:
with open(self.env_file, 'r') as f: with open(self.env_file, "r") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line or line.startswith('#'): if not line or line.startswith("#"):
continue continue
if '=' in line: if "=" in line:
key, value = line.split('=', 1) key, value = line.split("=", 1)
key = key.strip() key = key.strip()
value = sanitize_env_value(value) value = sanitize_env_value(value)
# Map env vars to config attributes # Map env vars to config attributes
attr_map = { attr_map = {
'OPENAI_API_KEY': 'openai_api_key', "OPENAI_API_KEY": "openai_api_key",
'OPENSEARCH_PASSWORD': 'opensearch_password', "OPENSEARCH_PASSWORD": "opensearch_password",
'LANGFLOW_SECRET_KEY': 'langflow_secret_key', "LANGFLOW_SECRET_KEY": "langflow_secret_key",
'LANGFLOW_SUPERUSER': 'langflow_superuser', "LANGFLOW_SUPERUSER": "langflow_superuser",
'LANGFLOW_SUPERUSER_PASSWORD': 'langflow_superuser_password', "LANGFLOW_SUPERUSER_PASSWORD": "langflow_superuser_password",
'FLOW_ID': 'flow_id', "FLOW_ID": "flow_id",
'GOOGLE_OAUTH_CLIENT_ID': 'google_oauth_client_id', "GOOGLE_OAUTH_CLIENT_ID": "google_oauth_client_id",
'GOOGLE_OAUTH_CLIENT_SECRET': 'google_oauth_client_secret', "GOOGLE_OAUTH_CLIENT_SECRET": "google_oauth_client_secret",
'MICROSOFT_GRAPH_OAUTH_CLIENT_ID': 'microsoft_graph_oauth_client_id', "MICROSOFT_GRAPH_OAUTH_CLIENT_ID": "microsoft_graph_oauth_client_id",
'MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET': 'microsoft_graph_oauth_client_secret', "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET": "microsoft_graph_oauth_client_secret",
'WEBHOOK_BASE_URL': 'webhook_base_url', "WEBHOOK_BASE_URL": "webhook_base_url",
'AWS_ACCESS_KEY_ID': 'aws_access_key_id', "AWS_ACCESS_KEY_ID": "aws_access_key_id",
'AWS_SECRET_ACCESS_KEY': 'aws_secret_access_key', "AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
'LANGFLOW_PUBLIC_URL': 'langflow_public_url', "LANGFLOW_PUBLIC_URL": "langflow_public_url",
'OPENRAG_DOCUMENTS_PATHS': 'openrag_documents_paths', "OPENRAG_DOCUMENTS_PATHS": "openrag_documents_paths",
'LANGFLOW_AUTO_LOGIN': 'langflow_auto_login', "LANGFLOW_AUTO_LOGIN": "langflow_auto_login",
'LANGFLOW_NEW_USER_IS_ACTIVE': 'langflow_new_user_is_active', "LANGFLOW_NEW_USER_IS_ACTIVE": "langflow_new_user_is_active",
'LANGFLOW_ENABLE_SUPERUSER_CLI': 'langflow_enable_superuser_cli', "LANGFLOW_ENABLE_SUPERUSER_CLI": "langflow_enable_superuser_cli",
} }
if key in attr_map: if key in attr_map:
setattr(self.config, attr_map[key], value) setattr(self.config, attr_map[key], value)
return True return True
except Exception as e: except Exception as e:
print(f"Error loading .env file: {e}") logger.error("Error loading .env file", error=str(e))
return False return False
def setup_secure_defaults(self) -> None: def setup_secure_defaults(self) -> None:
"""Set up secure default values for passwords and keys.""" """Set up secure default values for passwords and keys."""
if not self.config.opensearch_password: if not self.config.opensearch_password:
self.config.opensearch_password = self.generate_secure_password() self.config.opensearch_password = self.generate_secure_password()
if not self.config.langflow_secret_key: if not self.config.langflow_secret_key:
self.config.langflow_secret_key = self.generate_langflow_secret_key() self.config.langflow_secret_key = self.generate_langflow_secret_key()
if not self.config.langflow_superuser_password: if not self.config.langflow_superuser_password:
self.config.langflow_superuser_password = self.generate_secure_password() self.config.langflow_superuser_password = self.generate_secure_password()
def validate_config(self, mode: str = "full") -> bool: def validate_config(self, mode: str = "full") -> bool:
""" """
Validate the current configuration. Validate the current configuration.
Args: Args:
mode: "no_auth" for minimal validation, "full" for complete validation mode: "no_auth" for minimal validation, "full" for complete validation
""" """
self.config.validation_errors.clear() self.config.validation_errors.clear()
# Always validate OpenAI API key # Always validate OpenAI API key
if not validate_openai_api_key(self.config.openai_api_key): if not validate_openai_api_key(self.config.openai_api_key):
self.config.validation_errors['openai_api_key'] = "Invalid OpenAI API key format (should start with sk-)" self.config.validation_errors["openai_api_key"] = (
"Invalid OpenAI API key format (should start with sk-)"
)
# Validate documents paths only if provided (optional) # Validate documents paths only if provided (optional)
if self.config.openrag_documents_paths: if self.config.openrag_documents_paths:
is_valid, error_msg, _ = validate_documents_paths(self.config.openrag_documents_paths) is_valid, error_msg, _ = validate_documents_paths(
self.config.openrag_documents_paths
)
if not is_valid: if not is_valid:
self.config.validation_errors['openrag_documents_paths'] = error_msg self.config.validation_errors["openrag_documents_paths"] = error_msg
# Validate required fields # Validate required fields
if not validate_non_empty(self.config.opensearch_password): if not validate_non_empty(self.config.opensearch_password):
self.config.validation_errors['opensearch_password'] = "OpenSearch password is required" self.config.validation_errors["opensearch_password"] = (
"OpenSearch password is required"
)
# Langflow secret key is auto-generated; no user input required # Langflow secret key is auto-generated; no user input required
if not validate_non_empty(self.config.langflow_superuser_password): if not validate_non_empty(self.config.langflow_superuser_password):
self.config.validation_errors['langflow_superuser_password'] = "Langflow superuser password is required" self.config.validation_errors["langflow_superuser_password"] = (
"Langflow superuser password is required"
)
if mode == "full": if mode == "full":
# Validate OAuth settings if provided # Validate OAuth settings if provided
if self.config.google_oauth_client_id and not validate_google_oauth_client_id(self.config.google_oauth_client_id): if (
self.config.validation_errors['google_oauth_client_id'] = "Invalid Google OAuth client ID format" self.config.google_oauth_client_id
and not validate_google_oauth_client_id(
if self.config.google_oauth_client_id and not validate_non_empty(self.config.google_oauth_client_secret): self.config.google_oauth_client_id
self.config.validation_errors['google_oauth_client_secret'] = "Google OAuth client secret required when client ID is provided" )
):
if self.config.microsoft_graph_oauth_client_id and not validate_non_empty(self.config.microsoft_graph_oauth_client_secret): self.config.validation_errors["google_oauth_client_id"] = (
self.config.validation_errors['microsoft_graph_oauth_client_secret'] = "Microsoft Graph client secret required when client ID is provided" "Invalid Google OAuth client ID format"
)
if self.config.google_oauth_client_id and not validate_non_empty(
self.config.google_oauth_client_secret
):
self.config.validation_errors["google_oauth_client_secret"] = (
"Google OAuth client secret required when client ID is provided"
)
if self.config.microsoft_graph_oauth_client_id and not validate_non_empty(
self.config.microsoft_graph_oauth_client_secret
):
self.config.validation_errors["microsoft_graph_oauth_client_secret"] = (
"Microsoft Graph client secret required when client ID is provided"
)
# Validate optional URLs if provided # Validate optional URLs if provided
if self.config.webhook_base_url and not validate_url(self.config.webhook_base_url): if self.config.webhook_base_url and not validate_url(
self.config.validation_errors['webhook_base_url'] = "Invalid webhook URL format" self.config.webhook_base_url
):
if self.config.langflow_public_url and not validate_url(self.config.langflow_public_url): self.config.validation_errors["webhook_base_url"] = (
self.config.validation_errors['langflow_public_url'] = "Invalid Langflow public URL format" "Invalid webhook URL format"
)
if self.config.langflow_public_url and not validate_url(
self.config.langflow_public_url
):
self.config.validation_errors["langflow_public_url"] = (
"Invalid Langflow public URL format"
)
return len(self.config.validation_errors) == 0 return len(self.config.validation_errors) == 0
def save_env_file(self) -> bool: def save_env_file(self) -> bool:
"""Save current configuration to .env file.""" """Save current configuration to .env file."""
try: try:
@ -184,45 +219,67 @@ class EnvManager:
self.setup_secure_defaults() self.setup_secure_defaults()
# Create timestamped backup if file exists # Create timestamped backup if file exists
if self.env_file.exists(): if self.env_file.exists():
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = self.env_file.with_suffix(f'.env.backup.{timestamp}') backup_file = self.env_file.with_suffix(f".env.backup.{timestamp}")
self.env_file.rename(backup_file) self.env_file.rename(backup_file)
with open(self.env_file, 'w') as f: with open(self.env_file, "w") as f:
f.write("# OpenRAG Environment Configuration\n") f.write("# OpenRAG Environment Configuration\n")
f.write("# Generated by OpenRAG TUI\n\n") f.write("# Generated by OpenRAG TUI\n\n")
# Core settings # Core settings
f.write("# Core settings\n") f.write("# Core settings\n")
f.write(f"LANGFLOW_SECRET_KEY={self.config.langflow_secret_key}\n") f.write(f"LANGFLOW_SECRET_KEY={self.config.langflow_secret_key}\n")
f.write(f"LANGFLOW_SUPERUSER={self.config.langflow_superuser}\n") f.write(f"LANGFLOW_SUPERUSER={self.config.langflow_superuser}\n")
f.write(f"LANGFLOW_SUPERUSER_PASSWORD={self.config.langflow_superuser_password}\n") f.write(
f"LANGFLOW_SUPERUSER_PASSWORD={self.config.langflow_superuser_password}\n"
)
f.write(f"FLOW_ID={self.config.flow_id}\n") f.write(f"FLOW_ID={self.config.flow_id}\n")
f.write(f"OPENSEARCH_PASSWORD={self.config.opensearch_password}\n") f.write(f"OPENSEARCH_PASSWORD={self.config.opensearch_password}\n")
f.write(f"OPENAI_API_KEY={self.config.openai_api_key}\n") f.write(f"OPENAI_API_KEY={self.config.openai_api_key}\n")
f.write(f"OPENRAG_DOCUMENTS_PATHS={self.config.openrag_documents_paths}\n") f.write(
f"OPENRAG_DOCUMENTS_PATHS={self.config.openrag_documents_paths}\n"
)
f.write("\n") f.write("\n")
# Langflow auth settings # Langflow auth settings
f.write("# Langflow auth settings\n") f.write("# Langflow auth settings\n")
f.write(f"LANGFLOW_AUTO_LOGIN={self.config.langflow_auto_login}\n") f.write(f"LANGFLOW_AUTO_LOGIN={self.config.langflow_auto_login}\n")
f.write(f"LANGFLOW_NEW_USER_IS_ACTIVE={self.config.langflow_new_user_is_active}\n") f.write(
f.write(f"LANGFLOW_ENABLE_SUPERUSER_CLI={self.config.langflow_enable_superuser_cli}\n") f"LANGFLOW_NEW_USER_IS_ACTIVE={self.config.langflow_new_user_is_active}\n"
)
f.write(
f"LANGFLOW_ENABLE_SUPERUSER_CLI={self.config.langflow_enable_superuser_cli}\n"
)
f.write("\n") f.write("\n")
# OAuth settings # OAuth settings
if self.config.google_oauth_client_id or self.config.google_oauth_client_secret: if (
self.config.google_oauth_client_id
or self.config.google_oauth_client_secret
):
f.write("# Google OAuth settings\n") f.write("# Google OAuth settings\n")
f.write(f"GOOGLE_OAUTH_CLIENT_ID={self.config.google_oauth_client_id}\n") f.write(
f.write(f"GOOGLE_OAUTH_CLIENT_SECRET={self.config.google_oauth_client_secret}\n") f"GOOGLE_OAUTH_CLIENT_ID={self.config.google_oauth_client_id}\n"
)
f.write(
f"GOOGLE_OAUTH_CLIENT_SECRET={self.config.google_oauth_client_secret}\n"
)
f.write("\n") f.write("\n")
if self.config.microsoft_graph_oauth_client_id or self.config.microsoft_graph_oauth_client_secret: if (
self.config.microsoft_graph_oauth_client_id
or self.config.microsoft_graph_oauth_client_secret
):
f.write("# Microsoft Graph OAuth settings\n") f.write("# Microsoft Graph OAuth settings\n")
f.write(f"MICROSOFT_GRAPH_OAUTH_CLIENT_ID={self.config.microsoft_graph_oauth_client_id}\n") f.write(
f.write(f"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET={self.config.microsoft_graph_oauth_client_secret}\n") f"MICROSOFT_GRAPH_OAUTH_CLIENT_ID={self.config.microsoft_graph_oauth_client_id}\n"
)
f.write(
f"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET={self.config.microsoft_graph_oauth_client_secret}\n"
)
f.write("\n") f.write("\n")
# Optional settings # Optional settings
optional_vars = [ optional_vars = [
("WEBHOOK_BASE_URL", self.config.webhook_base_url), ("WEBHOOK_BASE_URL", self.config.webhook_base_url),
@ -230,7 +287,7 @@ class EnvManager:
("AWS_SECRET_ACCESS_KEY", self.config.aws_secret_access_key), ("AWS_SECRET_ACCESS_KEY", self.config.aws_secret_access_key),
("LANGFLOW_PUBLIC_URL", self.config.langflow_public_url), ("LANGFLOW_PUBLIC_URL", self.config.langflow_public_url),
] ]
optional_written = False optional_written = False
for var_name, var_value in optional_vars: for var_name, var_value in optional_vars:
if var_value: if var_value:
@ -238,52 +295,89 @@ class EnvManager:
f.write("# Optional settings\n") f.write("# Optional settings\n")
optional_written = True optional_written = True
f.write(f"{var_name}={var_value}\n") f.write(f"{var_name}={var_value}\n")
if optional_written: if optional_written:
f.write("\n") f.write("\n")
return True return True
except Exception as e: except Exception as e:
print(f"Error saving .env file: {e}") logger.error("Error saving .env file", error=str(e))
return False return False
def get_no_auth_setup_fields(self) -> List[tuple[str, str, str, bool]]: def get_no_auth_setup_fields(self) -> List[tuple[str, str, str, bool]]:
"""Get fields required for no-auth setup mode. Returns (field_name, display_name, placeholder, can_generate).""" """Get fields required for no-auth setup mode. Returns (field_name, display_name, placeholder, can_generate)."""
return [ return [
("openai_api_key", "OpenAI API Key", "sk-...", False), ("openai_api_key", "OpenAI API Key", "sk-...", False),
("opensearch_password", "OpenSearch Password", "Will be auto-generated if empty", True), (
("langflow_superuser_password", "Langflow Superuser Password", "Will be auto-generated if empty", True), "opensearch_password",
("openrag_documents_paths", "Documents Paths", "./documents,/path/to/more/docs", False), "OpenSearch Password",
"Will be auto-generated if empty",
True,
),
(
"langflow_superuser_password",
"Langflow Superuser Password",
"Will be auto-generated if empty",
True,
),
(
"openrag_documents_paths",
"Documents Paths",
"./documents,/path/to/more/docs",
False,
),
] ]
def get_full_setup_fields(self) -> List[tuple[str, str, str, bool]]: def get_full_setup_fields(self) -> List[tuple[str, str, str, bool]]:
"""Get all fields for full setup mode.""" """Get all fields for full setup mode."""
base_fields = self.get_no_auth_setup_fields() base_fields = self.get_no_auth_setup_fields()
oauth_fields = [ oauth_fields = [
("google_oauth_client_id", "Google OAuth Client ID", "xxx.apps.googleusercontent.com", False), (
"google_oauth_client_id",
"Google OAuth Client ID",
"xxx.apps.googleusercontent.com",
False,
),
("google_oauth_client_secret", "Google OAuth Client Secret", "", False), ("google_oauth_client_secret", "Google OAuth Client Secret", "", False),
("microsoft_graph_oauth_client_id", "Microsoft Graph Client ID", "", False), ("microsoft_graph_oauth_client_id", "Microsoft Graph Client ID", "", False),
("microsoft_graph_oauth_client_secret", "Microsoft Graph Client Secret", "", False), (
"microsoft_graph_oauth_client_secret",
"Microsoft Graph Client Secret",
"",
False,
),
] ]
optional_fields = [ optional_fields = [
("webhook_base_url", "Webhook Base URL (optional)", "https://your-domain.com", False), (
"webhook_base_url",
"Webhook Base URL (optional)",
"https://your-domain.com",
False,
),
("aws_access_key_id", "AWS Access Key ID (optional)", "", False), ("aws_access_key_id", "AWS Access Key ID (optional)", "", False),
("aws_secret_access_key", "AWS Secret Access Key (optional)", "", False), ("aws_secret_access_key", "AWS Secret Access Key (optional)", "", False),
("langflow_public_url", "Langflow Public URL (optional)", "http://localhost:7860", False), (
"langflow_public_url",
"Langflow Public URL (optional)",
"http://localhost:7860",
False,
),
] ]
return base_fields + oauth_fields + optional_fields return base_fields + oauth_fields + optional_fields
def generate_compose_volume_mounts(self) -> List[str]: def generate_compose_volume_mounts(self) -> List[str]:
"""Generate Docker Compose volume mount strings from documents paths.""" """Generate Docker Compose volume mount strings from documents paths."""
is_valid, _, validated_paths = validate_documents_paths(self.config.openrag_documents_paths) is_valid, _, validated_paths = validate_documents_paths(
self.config.openrag_documents_paths
)
if not is_valid: if not is_valid:
return ["./documents:/app/documents:Z"] # fallback return ["./documents:/app/documents:Z"] # fallback
volume_mounts = [] volume_mounts = []
for i, path in enumerate(validated_paths): for i, path in enumerate(validated_paths):
if i == 0: if i == 0:
@ -291,6 +385,6 @@ class EnvManager:
volume_mounts.append(f"{path}:/app/documents:Z") volume_mounts.append(f"{path}:/app/documents:Z")
else: else:
# Additional paths map to numbered directories # Additional paths map to numbered directories
volume_mounts.append(f"{path}:/app/documents{i+1}:Z") volume_mounts.append(f"{path}:/app/documents{i + 1}:Z")
return volume_mounts return volume_mounts

View file

@ -1 +1 @@
"""TUI screens package.""" """TUI screens package."""

View file

@ -3,7 +3,16 @@
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.containers import Container, Vertical, Horizontal, ScrollableContainer from textual.containers import Container, Vertical, Horizontal, ScrollableContainer
from textual.screen import Screen from textual.screen import Screen
from textual.widgets import Header, Footer, Static, Button, Input, Label, TabbedContent, TabPane from textual.widgets import (
Header,
Footer,
Static,
Button,
Input,
Label,
TabbedContent,
TabPane,
)
from textual.validation import ValidationResult, Validator from textual.validation import ValidationResult, Validator
from rich.text import Text from rich.text import Text
from pathlib import Path from pathlib import Path
@ -15,11 +24,11 @@ from pathlib import Path
class OpenAIKeyValidator(Validator): class OpenAIKeyValidator(Validator):
"""Validator for OpenAI API keys.""" """Validator for OpenAI API keys."""
def validate(self, value: str) -> ValidationResult: def validate(self, value: str) -> ValidationResult:
if not value: if not value:
return self.success() return self.success()
if validate_openai_api_key(value): if validate_openai_api_key(value):
return self.success() return self.success()
else: else:
@ -28,12 +37,12 @@ class OpenAIKeyValidator(Validator):
class DocumentsPathValidator(Validator): class DocumentsPathValidator(Validator):
"""Validator for documents paths.""" """Validator for documents paths."""
def validate(self, value: str) -> ValidationResult: def validate(self, value: str) -> ValidationResult:
# Optional: allow empty value # Optional: allow empty value
if not value: if not value:
return self.success() return self.success()
is_valid, error_msg, _ = validate_documents_paths(value) is_valid, error_msg, _ = validate_documents_paths(value)
if is_valid: if is_valid:
return self.success() return self.success()
@ -43,22 +52,22 @@ class DocumentsPathValidator(Validator):
class ConfigScreen(Screen): class ConfigScreen(Screen):
"""Configuration screen for environment setup.""" """Configuration screen for environment setup."""
BINDINGS = [ BINDINGS = [
("escape", "back", "Back"), ("escape", "back", "Back"),
("ctrl+s", "save", "Save"), ("ctrl+s", "save", "Save"),
("ctrl+g", "generate", "Generate Passwords"), ("ctrl+g", "generate", "Generate Passwords"),
] ]
def __init__(self, mode: str = "full"): def __init__(self, mode: str = "full"):
super().__init__() super().__init__()
self.mode = mode # "no_auth" or "full" self.mode = mode # "no_auth" or "full"
self.env_manager = EnvManager() self.env_manager = EnvManager()
self.inputs = {} self.inputs = {}
# Load existing config if available # Load existing config if available
self.env_manager.load_existing_env() self.env_manager.load_existing_env()
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Create the configuration screen layout.""" """Create the configuration screen layout."""
# Removed top header bar and header text # Removed top header bar and header text
@ -70,33 +79,37 @@ class ConfigScreen(Screen):
Button("Generate Passwords", variant="default", id="generate-btn"), Button("Generate Passwords", variant="default", id="generate-btn"),
Button("Save Configuration", variant="success", id="save-btn"), Button("Save Configuration", variant="success", id="save-btn"),
Button("Back", variant="default", id="back-btn"), Button("Back", variant="default", id="back-btn"),
classes="button-row" classes="button-row",
) )
yield Footer() yield Footer()
def _create_header_text(self) -> Text: def _create_header_text(self) -> Text:
"""Create the configuration header text.""" """Create the configuration header text."""
header_text = Text() header_text = Text()
if self.mode == "no_auth": if self.mode == "no_auth":
header_text.append("Quick Setup - No Authentication\n", style="bold green") header_text.append("Quick Setup - No Authentication\n", style="bold green")
header_text.append("Configure OpenRAG for local document processing only.\n\n", style="dim") header_text.append(
"Configure OpenRAG for local document processing only.\n\n", style="dim"
)
else: else:
header_text.append("Full Setup - OAuth Integration\n", style="bold cyan") header_text.append("Full Setup - OAuth Integration\n", style="bold cyan")
header_text.append("Configure OpenRAG with cloud service integrations.\n\n", style="dim") header_text.append(
"Configure OpenRAG with cloud service integrations.\n\n", style="dim"
)
header_text.append("Required fields are marked with *\n", style="yellow") header_text.append("Required fields are marked with *\n", style="yellow")
header_text.append("Use Ctrl+G to generate admin passwords\n", style="dim") header_text.append("Use Ctrl+G to generate admin passwords\n", style="dim")
return header_text return header_text
def _create_all_fields(self) -> ComposeResult: def _create_all_fields(self) -> ComposeResult:
"""Create all configuration fields in a single scrollable layout.""" """Create all configuration fields in a single scrollable layout."""
# Admin Credentials Section # Admin Credentials Section
yield Static("Admin Credentials", classes="tab-header") yield Static("Admin Credentials", classes="tab-header")
yield Static(" ") yield Static(" ")
# OpenSearch Admin Password # OpenSearch Admin Password
yield Label("OpenSearch Admin Password *") yield Label("OpenSearch Admin Password *")
current_value = getattr(self.env_manager.config, "opensearch_password", "") current_value = getattr(self.env_manager.config, "opensearch_password", "")
@ -104,64 +117,73 @@ class ConfigScreen(Screen):
placeholder="Auto-generated secure password", placeholder="Auto-generated secure password",
value=current_value, value=current_value,
password=True, password=True,
id="input-opensearch_password" id="input-opensearch_password",
) )
yield input_widget yield input_widget
self.inputs["opensearch_password"] = input_widget self.inputs["opensearch_password"] = input_widget
yield Static(" ") yield Static(" ")
# Langflow Admin Username # Langflow Admin Username
yield Label("Langflow Admin Username *") yield Label("Langflow Admin Username *")
current_value = getattr(self.env_manager.config, "langflow_superuser", "") current_value = getattr(self.env_manager.config, "langflow_superuser", "")
input_widget = Input( input_widget = Input(
placeholder="admin", placeholder="admin", value=current_value, id="input-langflow_superuser"
value=current_value,
id="input-langflow_superuser"
) )
yield input_widget yield input_widget
self.inputs["langflow_superuser"] = input_widget self.inputs["langflow_superuser"] = input_widget
yield Static(" ") yield Static(" ")
# Langflow Admin Password # Langflow Admin Password
yield Label("Langflow Admin Password *") yield Label("Langflow Admin Password *")
current_value = getattr(self.env_manager.config, "langflow_superuser_password", "") current_value = getattr(
self.env_manager.config, "langflow_superuser_password", ""
)
input_widget = Input( input_widget = Input(
placeholder="Auto-generated secure password", placeholder="Auto-generated secure password",
value=current_value, value=current_value,
password=True, password=True,
id="input-langflow_superuser_password" id="input-langflow_superuser_password",
) )
yield input_widget yield input_widget
self.inputs["langflow_superuser_password"] = input_widget self.inputs["langflow_superuser_password"] = input_widget
yield Static(" ") yield Static(" ")
yield Static(" ") yield Static(" ")
# API Keys Section # API Keys Section
yield Static("API Keys", classes="tab-header") yield Static("API Keys", classes="tab-header")
yield Static(" ") yield Static(" ")
# OpenAI API Key # OpenAI API Key
yield Label("OpenAI API Key *") yield Label("OpenAI API Key *")
# Where to create OpenAI keys (helper above the box) # Where to create OpenAI keys (helper above the box)
yield Static(Text("Get a key: https://platform.openai.com/api-keys", style="dim"), classes="helper-text") yield Static(
Text("Get a key: https://platform.openai.com/api-keys", style="dim"),
classes="helper-text",
)
current_value = getattr(self.env_manager.config, "openai_api_key", "") current_value = getattr(self.env_manager.config, "openai_api_key", "")
input_widget = Input( input_widget = Input(
placeholder="sk-...", placeholder="sk-...",
value=current_value, value=current_value,
password=True, password=True,
validators=[OpenAIKeyValidator()], validators=[OpenAIKeyValidator()],
id="input-openai_api_key" id="input-openai_api_key",
) )
yield input_widget yield input_widget
self.inputs["openai_api_key"] = input_widget self.inputs["openai_api_key"] = input_widget
yield Static(" ") yield Static(" ")
# Add OAuth fields only in full mode # Add OAuth fields only in full mode
if self.mode == "full": if self.mode == "full":
# Google OAuth Client ID # Google OAuth Client ID
yield Label("Google OAuth Client ID") yield Label("Google OAuth Client ID")
# Where to create Google OAuth credentials (helper above the box) # Where to create Google OAuth credentials (helper above the box)
yield Static(Text("Create credentials: https://console.cloud.google.com/apis/credentials", style="dim"), classes="helper-text") yield Static(
Text(
"Create credentials: https://console.cloud.google.com/apis/credentials",
style="dim",
),
classes="helper-text",
)
# Callback URL guidance for Google OAuth # Callback URL guidance for Google OAuth
yield Static( yield Static(
Text( Text(
@ -169,37 +191,47 @@ class ConfigScreen(Screen):
" - Local: http://localhost:3000/auth/callback\n" " - Local: http://localhost:3000/auth/callback\n"
" - Prod: https://your-domain.com/auth/callback\n" " - Prod: https://your-domain.com/auth/callback\n"
"If you use separate apps for login and connectors, add this URL to BOTH.", "If you use separate apps for login and connectors, add this URL to BOTH.",
style="dim" style="dim",
), ),
classes="helper-text" classes="helper-text",
)
current_value = getattr(
self.env_manager.config, "google_oauth_client_id", ""
) )
current_value = getattr(self.env_manager.config, "google_oauth_client_id", "")
input_widget = Input( input_widget = Input(
placeholder="xxx.apps.googleusercontent.com", placeholder="xxx.apps.googleusercontent.com",
value=current_value, value=current_value,
id="input-google_oauth_client_id" id="input-google_oauth_client_id",
) )
yield input_widget yield input_widget
self.inputs["google_oauth_client_id"] = input_widget self.inputs["google_oauth_client_id"] = input_widget
yield Static(" ") yield Static(" ")
# Google OAuth Client Secret # Google OAuth Client Secret
yield Label("Google OAuth Client Secret") yield Label("Google OAuth Client Secret")
current_value = getattr(self.env_manager.config, "google_oauth_client_secret", "") current_value = getattr(
self.env_manager.config, "google_oauth_client_secret", ""
)
input_widget = Input( input_widget = Input(
placeholder="", placeholder="",
value=current_value, value=current_value,
password=True, password=True,
id="input-google_oauth_client_secret" id="input-google_oauth_client_secret",
) )
yield input_widget yield input_widget
self.inputs["google_oauth_client_secret"] = input_widget self.inputs["google_oauth_client_secret"] = input_widget
yield Static(" ") yield Static(" ")
# Microsoft Graph Client ID # Microsoft Graph Client ID
yield Label("Microsoft Graph Client ID") yield Label("Microsoft Graph Client ID")
# Where to create Microsoft app registrations (helper above the box) # Where to create Microsoft app registrations (helper above the box)
yield Static(Text("Create app: https://portal.azure.com/#view/Microsoft_AAD_RegisteredApps/ApplicationsListBlade", style="dim"), classes="helper-text") yield Static(
Text(
"Create app: https://portal.azure.com/#view/Microsoft_AAD_RegisteredApps/ApplicationsListBlade",
style="dim",
),
classes="helper-text",
)
# Callback URL guidance for Microsoft OAuth # Callback URL guidance for Microsoft OAuth
yield Static( yield Static(
Text( Text(
@ -207,66 +239,76 @@ class ConfigScreen(Screen):
" - Local: http://localhost:3000/auth/callback\n" " - Local: http://localhost:3000/auth/callback\n"
" - Prod: https://your-domain.com/auth/callback\n" " - Prod: https://your-domain.com/auth/callback\n"
"If you use separate apps for login and connectors, add this URI to BOTH.", "If you use separate apps for login and connectors, add this URI to BOTH.",
style="dim" style="dim",
), ),
classes="helper-text" classes="helper-text",
)
current_value = getattr(
self.env_manager.config, "microsoft_graph_oauth_client_id", ""
) )
current_value = getattr(self.env_manager.config, "microsoft_graph_oauth_client_id", "")
input_widget = Input( input_widget = Input(
placeholder="", placeholder="",
value=current_value, value=current_value,
id="input-microsoft_graph_oauth_client_id" id="input-microsoft_graph_oauth_client_id",
) )
yield input_widget yield input_widget
self.inputs["microsoft_graph_oauth_client_id"] = input_widget self.inputs["microsoft_graph_oauth_client_id"] = input_widget
yield Static(" ") yield Static(" ")
# Microsoft Graph Client Secret # Microsoft Graph Client Secret
yield Label("Microsoft Graph Client Secret") yield Label("Microsoft Graph Client Secret")
current_value = getattr(self.env_manager.config, "microsoft_graph_oauth_client_secret", "") current_value = getattr(
self.env_manager.config, "microsoft_graph_oauth_client_secret", ""
)
input_widget = Input( input_widget = Input(
placeholder="", placeholder="",
value=current_value, value=current_value,
password=True, password=True,
id="input-microsoft_graph_oauth_client_secret" id="input-microsoft_graph_oauth_client_secret",
) )
yield input_widget yield input_widget
self.inputs["microsoft_graph_oauth_client_secret"] = input_widget self.inputs["microsoft_graph_oauth_client_secret"] = input_widget
yield Static(" ") yield Static(" ")
# AWS Access Key ID # AWS Access Key ID
yield Label("AWS Access Key ID") yield Label("AWS Access Key ID")
# Where to create AWS keys (helper above the box) # Where to create AWS keys (helper above the box)
yield Static(Text("Create keys: https://console.aws.amazon.com/iam/home#/security_credentials", style="dim"), classes="helper-text") yield Static(
Text(
"Create keys: https://console.aws.amazon.com/iam/home#/security_credentials",
style="dim",
),
classes="helper-text",
)
current_value = getattr(self.env_manager.config, "aws_access_key_id", "") current_value = getattr(self.env_manager.config, "aws_access_key_id", "")
input_widget = Input( input_widget = Input(
placeholder="", placeholder="", value=current_value, id="input-aws_access_key_id"
value=current_value,
id="input-aws_access_key_id"
) )
yield input_widget yield input_widget
self.inputs["aws_access_key_id"] = input_widget self.inputs["aws_access_key_id"] = input_widget
yield Static(" ") yield Static(" ")
# AWS Secret Access Key # AWS Secret Access Key
yield Label("AWS Secret Access Key") yield Label("AWS Secret Access Key")
current_value = getattr(self.env_manager.config, "aws_secret_access_key", "") current_value = getattr(
self.env_manager.config, "aws_secret_access_key", ""
)
input_widget = Input( input_widget = Input(
placeholder="", placeholder="",
value=current_value, value=current_value,
password=True, password=True,
id="input-aws_secret_access_key" id="input-aws_secret_access_key",
) )
yield input_widget yield input_widget
self.inputs["aws_secret_access_key"] = input_widget self.inputs["aws_secret_access_key"] = input_widget
yield Static(" ") yield Static(" ")
yield Static(" ") yield Static(" ")
# Other Settings Section # Other Settings Section
yield Static("Others", classes="tab-header") yield Static("Others", classes="tab-header")
yield Static(" ") yield Static(" ")
# Documents Paths (optional) + picker action button on next line # Documents Paths (optional) + picker action button on next line
yield Label("Documents Paths") yield Label("Documents Paths")
current_value = getattr(self.env_manager.config, "openrag_documents_paths", "") current_value = getattr(self.env_manager.config, "openrag_documents_paths", "")
@ -274,57 +316,63 @@ class ConfigScreen(Screen):
placeholder="./documents,/path/to/more/docs", placeholder="./documents,/path/to/more/docs",
value=current_value, value=current_value,
validators=[DocumentsPathValidator()], validators=[DocumentsPathValidator()],
id="input-openrag_documents_paths" id="input-openrag_documents_paths",
) )
yield input_widget yield input_widget
# Actions row with pick button # Actions row with pick button
yield Horizontal(Button("Pick…", id="pick-docs-btn"), id="docs-path-actions", classes="controls-row") yield Horizontal(
Button("Pick…", id="pick-docs-btn"),
id="docs-path-actions",
classes="controls-row",
)
self.inputs["openrag_documents_paths"] = input_widget self.inputs["openrag_documents_paths"] = input_widget
yield Static(" ") yield Static(" ")
# Langflow Auth Settings # Langflow Auth Settings
yield Static("Langflow Auth Settings", classes="tab-header") yield Static("Langflow Auth Settings", classes="tab-header")
yield Static(" ") yield Static(" ")
# Langflow Auto Login # Langflow Auto Login
yield Label("Langflow Auto Login") yield Label("Langflow Auto Login")
current_value = getattr(self.env_manager.config, "langflow_auto_login", "False") current_value = getattr(self.env_manager.config, "langflow_auto_login", "False")
input_widget = Input( input_widget = Input(
placeholder="False", placeholder="False", value=current_value, id="input-langflow_auto_login"
value=current_value,
id="input-langflow_auto_login"
) )
yield input_widget yield input_widget
self.inputs["langflow_auto_login"] = input_widget self.inputs["langflow_auto_login"] = input_widget
yield Static(" ") yield Static(" ")
# Langflow New User Is Active # Langflow New User Is Active
yield Label("Langflow New User Is Active") yield Label("Langflow New User Is Active")
current_value = getattr(self.env_manager.config, "langflow_new_user_is_active", "False") current_value = getattr(
self.env_manager.config, "langflow_new_user_is_active", "False"
)
input_widget = Input( input_widget = Input(
placeholder="False", placeholder="False",
value=current_value, value=current_value,
id="input-langflow_new_user_is_active" id="input-langflow_new_user_is_active",
) )
yield input_widget yield input_widget
self.inputs["langflow_new_user_is_active"] = input_widget self.inputs["langflow_new_user_is_active"] = input_widget
yield Static(" ") yield Static(" ")
# Langflow Enable Superuser CLI # Langflow Enable Superuser CLI
yield Label("Langflow Enable Superuser CLI") yield Label("Langflow Enable Superuser CLI")
current_value = getattr(self.env_manager.config, "langflow_enable_superuser_cli", "False") current_value = getattr(
self.env_manager.config, "langflow_enable_superuser_cli", "False"
)
input_widget = Input( input_widget = Input(
placeholder="False", placeholder="False",
value=current_value, value=current_value,
id="input-langflow_enable_superuser_cli" id="input-langflow_enable_superuser_cli",
) )
yield input_widget yield input_widget
self.inputs["langflow_enable_superuser_cli"] = input_widget self.inputs["langflow_enable_superuser_cli"] = input_widget
yield Static(" ") yield Static(" ")
yield Static(" ") yield Static(" ")
# Langflow Secret Key removed from UI; generated automatically on save # Langflow Secret Key removed from UI; generated automatically on save
# Add optional fields only in full mode # Add optional fields only in full mode
if self.mode == "full": if self.mode == "full":
# Webhook Base URL # Webhook Base URL
@ -333,36 +381,43 @@ class ConfigScreen(Screen):
input_widget = Input( input_widget = Input(
placeholder="https://your-domain.com", placeholder="https://your-domain.com",
value=current_value, value=current_value,
id="input-webhook_base_url" id="input-webhook_base_url",
) )
yield input_widget yield input_widget
self.inputs["webhook_base_url"] = input_widget self.inputs["webhook_base_url"] = input_widget
yield Static(" ") yield Static(" ")
# Langflow Public URL # Langflow Public URL
yield Label("Langflow Public URL") yield Label("Langflow Public URL")
current_value = getattr(self.env_manager.config, "langflow_public_url", "") current_value = getattr(self.env_manager.config, "langflow_public_url", "")
input_widget = Input( input_widget = Input(
placeholder="http://localhost:7860", placeholder="http://localhost:7860",
value=current_value, value=current_value,
id="input-langflow_public_url" id="input-langflow_public_url",
) )
yield input_widget yield input_widget
self.inputs["langflow_public_url"] = input_widget self.inputs["langflow_public_url"] = input_widget
yield Static(" ") yield Static(" ")
def _create_field(self, field_name: str, display_name: str, placeholder: str, can_generate: bool, required: bool = False) -> ComposeResult: def _create_field(
self,
field_name: str,
display_name: str,
placeholder: str,
can_generate: bool,
required: bool = False,
) -> ComposeResult:
"""Create a single form field.""" """Create a single form field."""
# Create label # Create label
label_text = f"{display_name}" label_text = f"{display_name}"
if required: if required:
label_text += " *" label_text += " *"
yield Label(label_text) yield Label(label_text)
# Get current value # Get current value
current_value = getattr(self.env_manager.config, field_name, "") current_value = getattr(self.env_manager.config, field_name, "")
# Create input with appropriate validator # Create input with appropriate validator
if field_name == "openai_api_key": if field_name == "openai_api_key":
input_widget = Input( input_widget = Input(
@ -370,35 +425,33 @@ class ConfigScreen(Screen):
value=current_value, value=current_value,
password=True, password=True,
validators=[OpenAIKeyValidator()], validators=[OpenAIKeyValidator()],
id=f"input-{field_name}" id=f"input-{field_name}",
) )
elif field_name == "openrag_documents_paths": elif field_name == "openrag_documents_paths":
input_widget = Input( input_widget = Input(
placeholder=placeholder, placeholder=placeholder,
value=current_value, value=current_value,
validators=[DocumentsPathValidator()], validators=[DocumentsPathValidator()],
id=f"input-{field_name}" id=f"input-{field_name}",
) )
elif "password" in field_name or "secret" in field_name: elif "password" in field_name or "secret" in field_name:
input_widget = Input( input_widget = Input(
placeholder=placeholder, placeholder=placeholder,
value=current_value, value=current_value,
password=True, password=True,
id=f"input-{field_name}" id=f"input-{field_name}",
) )
else: else:
input_widget = Input( input_widget = Input(
placeholder=placeholder, placeholder=placeholder, value=current_value, id=f"input-{field_name}"
value=current_value,
id=f"input-{field_name}"
) )
yield input_widget yield input_widget
self.inputs[field_name] = input_widget self.inputs[field_name] = input_widget
# Add spacing # Add spacing
yield Static(" ") yield Static(" ")
def on_mount(self) -> None: def on_mount(self) -> None:
"""Initialize the screen when mounted.""" """Initialize the screen when mounted."""
# Focus the first input field # Focus the first input field
@ -409,7 +462,7 @@ class ConfigScreen(Screen):
inputs[0].focus() inputs[0].focus()
except Exception: except Exception:
pass pass
def on_button_pressed(self, event: Button.Pressed) -> None: def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses.""" """Handle button presses."""
if event.button.id == "generate-btn": if event.button.id == "generate-btn":
@ -420,43 +473,47 @@ class ConfigScreen(Screen):
self.action_back() self.action_back()
elif event.button.id == "pick-docs-btn": elif event.button.id == "pick-docs-btn":
self.action_pick_documents_path() self.action_pick_documents_path()
def action_generate(self) -> None: def action_generate(self) -> None:
"""Generate secure passwords for admin accounts.""" """Generate secure passwords for admin accounts."""
self.env_manager.setup_secure_defaults() self.env_manager.setup_secure_defaults()
# Update input fields with generated values # Update input fields with generated values
for field_name, input_widget in self.inputs.items(): for field_name, input_widget in self.inputs.items():
if field_name in ["opensearch_password", "langflow_superuser_password"]: if field_name in ["opensearch_password", "langflow_superuser_password"]:
new_value = getattr(self.env_manager.config, field_name) new_value = getattr(self.env_manager.config, field_name)
input_widget.value = new_value input_widget.value = new_value
self.notify("Generated secure passwords", severity="information") self.notify("Generated secure passwords", severity="information")
def action_save(self) -> None: def action_save(self) -> None:
"""Save the configuration.""" """Save the configuration."""
# Update config from input fields # Update config from input fields
for field_name, input_widget in self.inputs.items(): for field_name, input_widget in self.inputs.items():
setattr(self.env_manager.config, field_name, input_widget.value) setattr(self.env_manager.config, field_name, input_widget.value)
# Validate the configuration # Validate the configuration
if not self.env_manager.validate_config(self.mode): if not self.env_manager.validate_config(self.mode):
error_messages = [] error_messages = []
for field, error in self.env_manager.config.validation_errors.items(): for field, error in self.env_manager.config.validation_errors.items():
error_messages.append(f"{field}: {error}") error_messages.append(f"{field}: {error}")
self.notify(f"Validation failed:\n" + "\n".join(error_messages[:3]), severity="error") self.notify(
f"Validation failed:\n" + "\n".join(error_messages[:3]),
severity="error",
)
return return
# Save to file # Save to file
if self.env_manager.save_env_file(): if self.env_manager.save_env_file():
self.notify("Configuration saved successfully!", severity="information") self.notify("Configuration saved successfully!", severity="information")
# Switch to monitor screen # Switch to monitor screen
from .monitor import MonitorScreen from .monitor import MonitorScreen
self.app.push_screen(MonitorScreen()) self.app.push_screen(MonitorScreen())
else: else:
self.notify("Failed to save configuration", severity="error") self.notify("Failed to save configuration", severity="error")
def action_back(self) -> None: def action_back(self) -> None:
"""Go back to welcome screen.""" """Go back to welcome screen."""
self.app.pop_screen() self.app.pop_screen()
@ -465,6 +522,7 @@ class ConfigScreen(Screen):
"""Open textual-fspicker to select a path and append it to the input.""" """Open textual-fspicker to select a path and append it to the input."""
try: try:
import importlib import importlib
fsp = importlib.import_module("textual_fspicker") fsp = importlib.import_module("textual_fspicker")
except Exception: except Exception:
self.notify("textual-fspicker not available", severity="warning") self.notify("textual-fspicker not available", severity="warning")
@ -479,9 +537,13 @@ class ConfigScreen(Screen):
start = Path(first).expanduser() start = Path(first).expanduser()
# Prefer SelectDirectory for directories; fallback to FileOpen # Prefer SelectDirectory for directories; fallback to FileOpen
PickerClass = getattr(fsp, "SelectDirectory", None) or getattr(fsp, "FileOpen", None) PickerClass = getattr(fsp, "SelectDirectory", None) or getattr(
fsp, "FileOpen", None
)
if PickerClass is None: if PickerClass is None:
self.notify("No compatible picker found in textual-fspicker", severity="warning") self.notify(
"No compatible picker found in textual-fspicker", severity="warning"
)
return return
try: try:
picker = PickerClass(location=start) picker = PickerClass(location=start)
@ -523,7 +585,7 @@ class ConfigScreen(Screen):
pass pass
except Exception: except Exception:
pass pass
def on_input_changed(self, event: Input.Changed) -> None: def on_input_changed(self, event: Input.Changed) -> None:
"""Handle input changes for real-time validation feedback.""" """Handle input changes for real-time validation feedback."""
# This will trigger validation display in real-time # This will trigger validation display in real-time

View file

@ -18,7 +18,7 @@ from ..managers.container_manager import ContainerManager
class DiagnosticsScreen(Screen): class DiagnosticsScreen(Screen):
"""Diagnostics screen for debugging OpenRAG.""" """Diagnostics screen for debugging OpenRAG."""
CSS = """ CSS = """
#diagnostics-log { #diagnostics-log {
border: solid $accent; border: solid $accent;
@ -40,20 +40,20 @@ class DiagnosticsScreen(Screen):
text-align: center; text-align: center;
} }
""" """
BINDINGS = [ BINDINGS = [
("escape", "back", "Back"), ("escape", "back", "Back"),
("r", "refresh", "Refresh"), ("r", "refresh", "Refresh"),
("ctrl+c", "copy", "Copy to Clipboard"), ("ctrl+c", "copy", "Copy to Clipboard"),
("ctrl+s", "save", "Save to File"), ("ctrl+s", "save", "Save to File"),
] ]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.container_manager = ContainerManager() self.container_manager = ContainerManager()
self._logger = logging.getLogger("openrag.diagnostics") self._logger = logging.getLogger("openrag.diagnostics")
self._status_timer = None self._status_timer = None
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Create the diagnostics screen layout.""" """Create the diagnostics screen layout."""
yield Header() yield Header()
@ -66,24 +66,24 @@ class DiagnosticsScreen(Screen):
yield Button("Copy to Clipboard", variant="default", id="copy-btn") yield Button("Copy to Clipboard", variant="default", id="copy-btn")
yield Button("Save to File", variant="default", id="save-btn") yield Button("Save to File", variant="default", id="save-btn")
yield Button("Back", variant="default", id="back-btn") yield Button("Back", variant="default", id="back-btn")
# Status indicator for copy/save operations # Status indicator for copy/save operations
yield Static("", id="copy-status", classes="copy-indicator") yield Static("", id="copy-status", classes="copy-indicator")
with ScrollableContainer(id="diagnostics-scroll"): with ScrollableContainer(id="diagnostics-scroll"):
yield Log(id="diagnostics-log", highlight=True) yield Log(id="diagnostics-log", highlight=True)
yield Footer() yield Footer()
def on_mount(self) -> None: def on_mount(self) -> None:
"""Initialize the screen.""" """Initialize the screen."""
self.run_diagnostics() self.run_diagnostics()
# Focus the first button (refresh-btn) # Focus the first button (refresh-btn)
try: try:
self.query_one("#refresh-btn").focus() self.query_one("#refresh-btn").focus()
except Exception: except Exception:
pass pass
def on_button_pressed(self, event: Button.Pressed) -> None: def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses.""" """Handle button presses."""
if event.button.id == "refresh-btn": if event.button.id == "refresh-btn":
@ -98,25 +98,26 @@ class DiagnosticsScreen(Screen):
self.save_to_file() self.save_to_file()
elif event.button.id == "back-btn": elif event.button.id == "back-btn":
self.action_back() self.action_back()
def action_refresh(self) -> None: def action_refresh(self) -> None:
"""Refresh diagnostics.""" """Refresh diagnostics."""
self.run_diagnostics() self.run_diagnostics()
def action_copy(self) -> None: def action_copy(self) -> None:
"""Copy log content to clipboard (keyboard shortcut).""" """Copy log content to clipboard (keyboard shortcut)."""
self.copy_to_clipboard() self.copy_to_clipboard()
def copy_to_clipboard(self) -> None: def copy_to_clipboard(self) -> None:
"""Copy log content to clipboard.""" """Copy log content to clipboard."""
try: try:
log = self.query_one("#diagnostics-log", Log) log = self.query_one("#diagnostics-log", Log)
content = "\n".join(str(line) for line in log.lines) content = "\n".join(str(line) for line in log.lines)
status = self.query_one("#copy-status", Static) status = self.query_one("#copy-status", Static)
# Try to use pyperclip if available # Try to use pyperclip if available
try: try:
import pyperclip import pyperclip
pyperclip.copy(content) pyperclip.copy(content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard") status.update("✓ Content copied to clipboard")
@ -124,23 +125,19 @@ class DiagnosticsScreen(Screen):
return return
except ImportError: except ImportError:
pass pass
# Fallback to platform-specific clipboard commands # Fallback to platform-specific clipboard commands
import subprocess import subprocess
import platform import platform
system = platform.system() system = platform.system()
if system == "Darwin": # macOS if system == "Darwin": # macOS
process = subprocess.Popen( process = subprocess.Popen(["pbcopy"], stdin=subprocess.PIPE, text=True)
["pbcopy"], stdin=subprocess.PIPE, text=True
)
process.communicate(input=content) process.communicate(input=content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard") status.update("✓ Content copied to clipboard")
elif system == "Windows": elif system == "Windows":
process = subprocess.Popen( process = subprocess.Popen(["clip"], stdin=subprocess.PIPE, text=True)
["clip"], stdin=subprocess.PIPE, text=True
)
process.communicate(input=content) process.communicate(input=content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard") status.update("✓ Content copied to clipboard")
@ -150,7 +147,7 @@ class DiagnosticsScreen(Screen):
process = subprocess.Popen( process = subprocess.Popen(
["xclip", "-selection", "clipboard"], ["xclip", "-selection", "clipboard"],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
text=True text=True,
) )
process.communicate(input=content) process.communicate(input=content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
@ -160,65 +157,78 @@ class DiagnosticsScreen(Screen):
process = subprocess.Popen( process = subprocess.Popen(
["xsel", "--clipboard", "--input"], ["xsel", "--clipboard", "--input"],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
text=True text=True,
) )
process.communicate(input=content) process.communicate(input=content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard") status.update("✓ Content copied to clipboard")
except FileNotFoundError: except FileNotFoundError:
self.notify("Clipboard utilities not found. Install xclip or xsel.", severity="error") self.notify(
status.update("❌ Clipboard utilities not found. Install xclip or xsel.") "Clipboard utilities not found. Install xclip or xsel.",
severity="error",
)
status.update(
"❌ Clipboard utilities not found. Install xclip or xsel."
)
else: else:
self.notify("Clipboard not supported on this platform", severity="error") self.notify(
"Clipboard not supported on this platform", severity="error"
)
status.update("❌ Clipboard not supported on this platform") status.update("❌ Clipboard not supported on this platform")
self._hide_status_after_delay(status) self._hide_status_after_delay(status)
except Exception as e: except Exception as e:
self.notify(f"Failed to copy to clipboard: {e}", severity="error") self.notify(f"Failed to copy to clipboard: {e}", severity="error")
status = self.query_one("#copy-status", Static) status = self.query_one("#copy-status", Static)
status.update(f"❌ Failed to copy: {e}") status.update(f"❌ Failed to copy: {e}")
self._hide_status_after_delay(status) self._hide_status_after_delay(status)
def _hide_status_after_delay(self, status_widget: Static, delay: float = 3.0) -> None: def _hide_status_after_delay(
self, status_widget: Static, delay: float = 3.0
) -> None:
"""Hide the status message after a delay.""" """Hide the status message after a delay."""
# Cancel any existing timer # Cancel any existing timer
if self._status_timer: if self._status_timer:
self._status_timer.cancel() self._status_timer.cancel()
# Create and run the timer task # Create and run the timer task
self._status_timer = asyncio.create_task(self._clear_status_after_delay(status_widget, delay)) self._status_timer = asyncio.create_task(
self._clear_status_after_delay(status_widget, delay)
async def _clear_status_after_delay(self, status_widget: Static, delay: float) -> None: )
async def _clear_status_after_delay(
self, status_widget: Static, delay: float
) -> None:
"""Clear the status message after a delay.""" """Clear the status message after a delay."""
await asyncio.sleep(delay) await asyncio.sleep(delay)
status_widget.update("") status_widget.update("")
def action_save(self) -> None: def action_save(self) -> None:
"""Save log content to file (keyboard shortcut).""" """Save log content to file (keyboard shortcut)."""
self.save_to_file() self.save_to_file()
def save_to_file(self) -> None: def save_to_file(self) -> None:
"""Save log content to a file.""" """Save log content to a file."""
try: try:
log = self.query_one("#diagnostics-log", Log) log = self.query_one("#diagnostics-log", Log)
content = "\n".join(str(line) for line in log.lines) content = "\n".join(str(line) for line in log.lines)
status = self.query_one("#copy-status", Static) status = self.query_one("#copy-status", Static)
# Create logs directory if it doesn't exist # Create logs directory if it doesn't exist
logs_dir = Path("logs") logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True) logs_dir.mkdir(exist_ok=True)
# Create a timestamped filename # Create a timestamped filename
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = logs_dir / f"openrag_diagnostics_{timestamp}.txt" filename = logs_dir / f"openrag_diagnostics_{timestamp}.txt"
# Save to file # Save to file
with open(filename, "w") as f: with open(filename, "w") as f:
f.write(content) f.write(content)
self.notify(f"Saved to {filename}", severity="information") self.notify(f"Saved to {filename}", severity="information")
status.update(f"✓ Saved to {filename}") status.update(f"✓ Saved to {filename}")
# Log the save operation # Log the save operation
self._logger.info(f"Diagnostics saved to {filename}") self._logger.info(f"Diagnostics saved to {filename}")
self._hide_status_after_delay(status) self._hide_status_after_delay(status)
@ -226,55 +236,57 @@ class DiagnosticsScreen(Screen):
error_msg = f"Failed to save file: {e}" error_msg = f"Failed to save file: {e}"
self.notify(error_msg, severity="error") self.notify(error_msg, severity="error")
self._logger.error(error_msg) self._logger.error(error_msg)
status = self.query_one("#copy-status", Static) status = self.query_one("#copy-status", Static)
status.update(f"{error_msg}") status.update(f"{error_msg}")
self._hide_status_after_delay(status) self._hide_status_after_delay(status)
def action_back(self) -> None: def action_back(self) -> None:
"""Go back to previous screen.""" """Go back to previous screen."""
self.app.pop_screen() self.app.pop_screen()
def _get_system_info(self) -> Text: def _get_system_info(self) -> Text:
"""Get system information text.""" """Get system information text."""
info_text = Text() info_text = Text()
runtime_info = self.container_manager.get_runtime_info() runtime_info = self.container_manager.get_runtime_info()
info_text.append("Container Runtime Information\n", style="bold") info_text.append("Container Runtime Information\n", style="bold")
info_text.append("=" * 30 + "\n") info_text.append("=" * 30 + "\n")
info_text.append(f"Type: {runtime_info.runtime_type.value}\n") info_text.append(f"Type: {runtime_info.runtime_type.value}\n")
info_text.append(f"Compose Command: {' '.join(runtime_info.compose_command)}\n") info_text.append(f"Compose Command: {' '.join(runtime_info.compose_command)}\n")
info_text.append(f"Runtime Command: {' '.join(runtime_info.runtime_command)}\n") info_text.append(f"Runtime Command: {' '.join(runtime_info.runtime_command)}\n")
if runtime_info.version: if runtime_info.version:
info_text.append(f"Version: {runtime_info.version}\n") info_text.append(f"Version: {runtime_info.version}\n")
return info_text return info_text
def run_diagnostics(self) -> None: def run_diagnostics(self) -> None:
"""Run all diagnostics.""" """Run all diagnostics."""
log = self.query_one("#diagnostics-log", Log) log = self.query_one("#diagnostics-log", Log)
log.clear() log.clear()
# System information # System information
system_info = self._get_system_info() system_info = self._get_system_info()
log.write(str(system_info)) log.write(str(system_info))
log.write("") log.write("")
# Run async diagnostics # Run async diagnostics
asyncio.create_task(self._run_async_diagnostics()) asyncio.create_task(self._run_async_diagnostics())
async def _run_async_diagnostics(self) -> None: async def _run_async_diagnostics(self) -> None:
"""Run asynchronous diagnostics.""" """Run asynchronous diagnostics."""
log = self.query_one("#diagnostics-log", Log) log = self.query_one("#diagnostics-log", Log)
# Check services # Check services
log.write("[bold green]Service Status[/bold green]") log.write("[bold green]Service Status[/bold green]")
services = await self.container_manager.get_service_status(force_refresh=True) services = await self.container_manager.get_service_status(force_refresh=True)
for name, info in services.items(): for name, info in services.items():
status_color = "green" if info.status == "running" else "red" status_color = "green" if info.status == "running" else "red"
log.write(f"[bold]{name}[/bold]: [{status_color}]{info.status.value}[/{status_color}]") log.write(
f"[bold]{name}[/bold]: [{status_color}]{info.status.value}[/{status_color}]"
)
if info.health: if info.health:
log.write(f" Health: {info.health}") log.write(f" Health: {info.health}")
if info.ports: if info.ports:
@ -282,40 +294,38 @@ class DiagnosticsScreen(Screen):
if info.image: if info.image:
log.write(f" Image: {info.image}") log.write(f" Image: {info.image}")
log.write("") log.write("")
# Check for Podman-specific issues # Check for Podman-specific issues
if self.container_manager.runtime_info.runtime_type.name == "PODMAN": if self.container_manager.runtime_info.runtime_type.name == "PODMAN":
await self.check_podman() await self.check_podman()
async def check_podman(self) -> None: async def check_podman(self) -> None:
"""Run Podman-specific diagnostics.""" """Run Podman-specific diagnostics."""
log = self.query_one("#diagnostics-log", Log) log = self.query_one("#diagnostics-log", Log)
log.write("[bold green]Podman Diagnostics[/bold green]") log.write("[bold green]Podman Diagnostics[/bold green]")
# Check if using Podman # Check if using Podman
if self.container_manager.runtime_info.runtime_type.name != "PODMAN": if self.container_manager.runtime_info.runtime_type.name != "PODMAN":
log.write("[yellow]Not using Podman[/yellow]") log.write("[yellow]Not using Podman[/yellow]")
return return
# Check Podman version # Check Podman version
cmd = ["podman", "--version"] cmd = ["podman", "--version"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
log.write(f"Podman version: {stdout.decode().strip()}") log.write(f"Podman version: {stdout.decode().strip()}")
else: else:
log.write(f"[red]Failed to get Podman version: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to get Podman version: {stderr.decode().strip()}[/red]"
)
# Check Podman containers # Check Podman containers
cmd = ["podman", "ps", "--all"] cmd = ["podman", "ps", "--all"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
@ -323,15 +333,17 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"): for line in stdout.decode().strip().split("\n"):
log.write(f" {line}") log.write(f" {line}")
else: else:
log.write(f"[red]Failed to list Podman containers: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to list Podman containers: {stderr.decode().strip()}[/red]"
)
# Check Podman compose # Check Podman compose
cmd = ["podman", "compose", "ps"] cmd = ["podman", "compose", "ps"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=self.container_manager.compose_file.parent cwd=self.container_manager.compose_file.parent,
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
@ -339,39 +351,39 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"): for line in stdout.decode().strip().split("\n"):
log.write(f" {line}") log.write(f" {line}")
else: else:
log.write(f"[red]Failed to list Podman compose services: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to list Podman compose services: {stderr.decode().strip()}[/red]"
)
log.write("") log.write("")
async def check_docker(self) -> None: async def check_docker(self) -> None:
"""Run Docker-specific diagnostics.""" """Run Docker-specific diagnostics."""
log = self.query_one("#diagnostics-log", Log) log = self.query_one("#diagnostics-log", Log)
log.write("[bold green]Docker Diagnostics[/bold green]") log.write("[bold green]Docker Diagnostics[/bold green]")
# Check if using Docker # Check if using Docker
if "DOCKER" not in self.container_manager.runtime_info.runtime_type.name: if "DOCKER" not in self.container_manager.runtime_info.runtime_type.name:
log.write("[yellow]Not using Docker[/yellow]") log.write("[yellow]Not using Docker[/yellow]")
return return
# Check Docker version # Check Docker version
cmd = ["docker", "--version"] cmd = ["docker", "--version"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
log.write(f"Docker version: {stdout.decode().strip()}") log.write(f"Docker version: {stdout.decode().strip()}")
else: else:
log.write(f"[red]Failed to get Docker version: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to get Docker version: {stderr.decode().strip()}[/red]"
)
# Check Docker containers # Check Docker containers
cmd = ["docker", "ps", "--all"] cmd = ["docker", "ps", "--all"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
@ -379,15 +391,17 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"): for line in stdout.decode().strip().split("\n"):
log.write(f" {line}") log.write(f" {line}")
else: else:
log.write(f"[red]Failed to list Docker containers: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to list Docker containers: {stderr.decode().strip()}[/red]"
)
# Check Docker compose # Check Docker compose
cmd = ["docker", "compose", "ps"] cmd = ["docker", "compose", "ps"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=self.container_manager.compose_file.parent cwd=self.container_manager.compose_file.parent,
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
@ -395,8 +409,11 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"): for line in stdout.decode().strip().split("\n"):
log.write(f" {line}") log.write(f" {line}")
else: else:
log.write(f"[red]Failed to list Docker compose services: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to list Docker compose services: {stderr.decode().strip()}[/red]"
)
log.write("") log.write("")
# Made with Bob # Made with Bob

View file

@ -13,7 +13,7 @@ from ..managers.container_manager import ContainerManager
class LogsScreen(Screen): class LogsScreen(Screen):
"""Logs viewing and monitoring screen.""" """Logs viewing and monitoring screen."""
BINDINGS = [ BINDINGS = [
("escape", "back", "Back"), ("escape", "back", "Back"),
("f", "follow", "Follow Logs"), ("f", "follow", "Follow Logs"),
@ -27,44 +27,50 @@ class LogsScreen(Screen):
("ctrl+u", "scroll_page_up", "Page Up"), ("ctrl+u", "scroll_page_up", "Page Up"),
("ctrl+f", "scroll_page_down", "Page Down"), ("ctrl+f", "scroll_page_down", "Page Down"),
] ]
def __init__(self, initial_service: str = "openrag-backend"): def __init__(self, initial_service: str = "openrag-backend"):
super().__init__() super().__init__()
self.container_manager = ContainerManager() self.container_manager = ContainerManager()
# Validate the initial service against available options # Validate the initial service against available options
valid_services = ["openrag-backend", "openrag-frontend", "opensearch", "langflow", "dashboards"] valid_services = [
"openrag-backend",
"openrag-frontend",
"opensearch",
"langflow",
"dashboards",
]
if initial_service not in valid_services: if initial_service not in valid_services:
initial_service = "openrag-backend" # fallback initial_service = "openrag-backend" # fallback
self.current_service = initial_service self.current_service = initial_service
self.logs_area = None self.logs_area = None
self.following = False self.following = False
self.follow_task = None self.follow_task = None
self.auto_scroll = True self.auto_scroll = True
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Create the logs screen layout.""" """Create the logs screen layout."""
yield Container( yield Container(
Vertical( Vertical(
Static(f"Service Logs: {self.current_service}", id="logs-title"), Static(f"Service Logs: {self.current_service}", id="logs-title"),
self._create_logs_area(), self._create_logs_area(),
id="logs-content" id="logs-content",
), ),
id="main-container" id="main-container",
) )
yield Footer() yield Footer()
def _create_logs_area(self) -> TextArea: def _create_logs_area(self) -> TextArea:
"""Create the logs text area.""" """Create the logs text area."""
self.logs_area = TextArea( self.logs_area = TextArea(
text="Loading logs...", text="Loading logs...",
read_only=True, read_only=True,
show_line_numbers=False, show_line_numbers=False,
id="logs-area" id="logs-area",
) )
return self.logs_area return self.logs_area
async def on_mount(self) -> None: async def on_mount(self) -> None:
"""Initialize the screen when mounted.""" """Initialize the screen when mounted."""
# Set the correct service in the select widget after a brief delay # Set the correct service in the select widget after a brief delay
@ -72,34 +78,40 @@ class LogsScreen(Screen):
select = self.query_one("#service-select") select = self.query_one("#service-select")
# Set a default first, then set the desired value # Set a default first, then set the desired value
select.value = "openrag-backend" select.value = "openrag-backend"
if self.current_service in ["openrag-backend", "openrag-frontend", "opensearch", "langflow", "dashboards"]: if self.current_service in [
"openrag-backend",
"openrag-frontend",
"opensearch",
"langflow",
"dashboards",
]:
select.value = self.current_service select.value = self.current_service
except Exception as e: except Exception as e:
# If setting the service fails, just use the default # If setting the service fails, just use the default
pass pass
await self._load_logs() await self._load_logs()
# Focus the logs area since there are no buttons # Focus the logs area since there are no buttons
try: try:
self.logs_area.focus() self.logs_area.focus()
except Exception: except Exception:
pass pass
def on_unmount(self) -> None: def on_unmount(self) -> None:
"""Clean up when unmounting.""" """Clean up when unmounting."""
self._stop_following() self._stop_following()
async def _load_logs(self, lines: int = 200) -> None: async def _load_logs(self, lines: int = 200) -> None:
"""Load recent logs for the current service.""" """Load recent logs for the current service."""
if not self.container_manager.is_available(): if not self.container_manager.is_available():
self.logs_area.text = "No container runtime available" self.logs_area.text = "No container runtime available"
return return
success, logs = await self.container_manager.get_service_logs(self.current_service, lines) success, logs = await self.container_manager.get_service_logs(
self.current_service, lines
)
if success: if success:
self.logs_area.text = logs self.logs_area.text = logs
# Scroll to bottom if auto scroll is enabled # Scroll to bottom if auto scroll is enabled
@ -107,67 +119,71 @@ class LogsScreen(Screen):
self.logs_area.scroll_end() self.logs_area.scroll_end()
else: else:
self.logs_area.text = f"Failed to load logs: {logs}" self.logs_area.text = f"Failed to load logs: {logs}"
def _stop_following(self) -> None: def _stop_following(self) -> None:
"""Stop following logs.""" """Stop following logs."""
self.following = False self.following = False
if self.follow_task and not self.follow_task.is_finished: if self.follow_task and not self.follow_task.is_finished:
self.follow_task.cancel() self.follow_task.cancel()
# No button to update since we removed it # No button to update since we removed it
async def _follow_logs(self) -> None: async def _follow_logs(self) -> None:
"""Follow logs in real-time.""" """Follow logs in real-time."""
if not self.container_manager.is_available(): if not self.container_manager.is_available():
return return
try: try:
async for log_line in self.container_manager.follow_service_logs(self.current_service): async for log_line in self.container_manager.follow_service_logs(
self.current_service
):
if not self.following: if not self.following:
break break
# Append new line to logs area # Append new line to logs area
current_text = self.logs_area.text current_text = self.logs_area.text
new_text = current_text + "\n" + log_line new_text = current_text + "\n" + log_line
# Keep only last 1000 lines to prevent memory issues # Keep only last 1000 lines to prevent memory issues
lines = new_text.split('\n') lines = new_text.split("\n")
if len(lines) > 1000: if len(lines) > 1000:
lines = lines[-1000:] lines = lines[-1000:]
new_text = '\n'.join(lines) new_text = "\n".join(lines)
self.logs_area.text = new_text self.logs_area.text = new_text
# Scroll to bottom if auto scroll is enabled # Scroll to bottom if auto scroll is enabled
if self.auto_scroll: if self.auto_scroll:
self.logs_area.scroll_end() self.logs_area.scroll_end()
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
if self.following: # Only show error if we're still supposed to be following if (
self.following
): # Only show error if we're still supposed to be following
self.notify(f"Error following logs: {e}", severity="error") self.notify(f"Error following logs: {e}", severity="error")
finally: finally:
self.following = False self.following = False
def action_refresh(self) -> None: def action_refresh(self) -> None:
"""Refresh logs.""" """Refresh logs."""
self._stop_following() self._stop_following()
self.run_worker(self._load_logs()) self.run_worker(self._load_logs())
def action_follow(self) -> None: def action_follow(self) -> None:
"""Toggle log following.""" """Toggle log following."""
if self.following: if self.following:
self._stop_following() self._stop_following()
else: else:
self.following = True self.following = True
# Start following # Start following
self.follow_task = self.run_worker(self._follow_logs(), exclusive=False) self.follow_task = self.run_worker(self._follow_logs(), exclusive=False)
def action_clear(self) -> None: def action_clear(self) -> None:
"""Clear the logs area.""" """Clear the logs area."""
self.logs_area.text = "" self.logs_area.text = ""
def action_toggle_auto_scroll(self) -> None: def action_toggle_auto_scroll(self) -> None:
"""Toggle auto scroll on/off.""" """Toggle auto scroll on/off."""
self.auto_scroll = not self.auto_scroll self.auto_scroll = not self.auto_scroll
@ -201,13 +217,13 @@ class LogsScreen(Screen):
def on_key(self, event) -> None: def on_key(self, event) -> None:
"""Handle key presses that might be intercepted by TextArea.""" """Handle key presses that might be intercepted by TextArea."""
key = event.key key = event.key
# Handle keys that TextArea might intercept # Handle keys that TextArea might intercept
if key == "ctrl+u": if key == "ctrl+u":
self.action_scroll_page_up() self.action_scroll_page_up()
event.prevent_default() event.prevent_default()
elif key == "ctrl+f": elif key == "ctrl+f":
self.action_scroll_page_down() self.action_scroll_page_down()
event.prevent_default() event.prevent_default()
elif key.upper() == "G": elif key.upper() == "G":
self.action_scroll_bottom() self.action_scroll_bottom()
@ -216,4 +232,4 @@ class LogsScreen(Screen):
def action_back(self) -> None: def action_back(self) -> None:
"""Go back to previous screen.""" """Go back to previous screen."""
self._stop_following() self._stop_following()
self.app.pop_screen() self.app.pop_screen()

View file

@ -23,7 +23,7 @@ from ..widgets.diagnostics_notification import notify_with_diagnostics
class MonitorScreen(Screen): class MonitorScreen(Screen):
"""Service monitoring and control screen.""" """Service monitoring and control screen."""
BINDINGS = [ BINDINGS = [
("escape", "back", "Back"), ("escape", "back", "Back"),
("r", "refresh", "Refresh"), ("r", "refresh", "Refresh"),
@ -35,7 +35,7 @@ class MonitorScreen(Screen):
("j", "cursor_down", "Move Down"), ("j", "cursor_down", "Move Down"),
("k", "cursor_up", "Move Up"), ("k", "cursor_up", "Move Up"),
] ]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.container_manager = ContainerManager() self.container_manager = ContainerManager()
@ -47,14 +47,14 @@ class MonitorScreen(Screen):
self._follow_task = None self._follow_task = None
self._follow_service = None self._follow_service = None
self._logs_buffer = [] self._logs_buffer = []
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Create the monitoring screen layout.""" """Create the monitoring screen layout."""
# Just show the services content directly (no header, no tabs) # Just show the services content directly (no header, no tabs)
yield from self._create_services_tab() yield from self._create_services_tab()
yield Footer() yield Footer()
def _create_services_tab(self) -> ComposeResult: def _create_services_tab(self) -> ComposeResult:
"""Create the services monitoring tab.""" """Create the services monitoring tab."""
# Current mode indicator + toggle # Current mode indicator + toggle
@ -75,69 +75,73 @@ class MonitorScreen(Screen):
yield Horizontal(id="services-controls", classes="button-row") yield Horizontal(id="services-controls", classes="button-row")
# Create services table with image + digest info # Create services table with image + digest info
self.services_table = DataTable(id="services-table") self.services_table = DataTable(id="services-table")
self.services_table.add_columns("Service", "Status", "Health", "Ports", "Image", "Digest") self.services_table.add_columns(
"Service", "Status", "Health", "Ports", "Image", "Digest"
)
yield self.services_table yield self.services_table
def _get_runtime_status(self) -> Text: def _get_runtime_status(self) -> Text:
"""Get container runtime status text.""" """Get container runtime status text."""
status_text = Text() status_text = Text()
if not self.container_manager.is_available(): if not self.container_manager.is_available():
status_text.append("WARNING: No container runtime available\n", style="bold red") status_text.append(
status_text.append("Please install Docker or Podman to continue.\n", style="dim") "WARNING: No container runtime available\n", style="bold red"
)
status_text.append(
"Please install Docker or Podman to continue.\n", style="dim"
)
return status_text return status_text
runtime_info = self.container_manager.get_runtime_info() runtime_info = self.container_manager.get_runtime_info()
if runtime_info.runtime_type == RuntimeType.DOCKER: if runtime_info.runtime_type == RuntimeType.DOCKER:
status_text.append("Docker Runtime\n", style="bold blue") status_text.append("Docker Runtime\n", style="bold blue")
elif runtime_info.runtime_type == RuntimeType.PODMAN: elif runtime_info.runtime_type == RuntimeType.PODMAN:
status_text.append("Podman Runtime\n", style="bold purple") status_text.append("Podman Runtime\n", style="bold purple")
else: else:
status_text.append("Container Runtime\n", style="bold green") status_text.append("Container Runtime\n", style="bold green")
if runtime_info.version: if runtime_info.version:
status_text.append(f"Version: {runtime_info.version}\n", style="dim") status_text.append(f"Version: {runtime_info.version}\n", style="dim")
# Check Podman macOS memory if applicable # Check Podman macOS memory if applicable
if runtime_info.runtime_type == RuntimeType.PODMAN: if runtime_info.runtime_type == RuntimeType.PODMAN:
is_sufficient, message = self.container_manager.check_podman_macos_memory() is_sufficient, message = self.container_manager.check_podman_macos_memory()
if not is_sufficient: if not is_sufficient:
status_text.append(f"WARNING: {message}\n", style="bold yellow") status_text.append(f"WARNING: {message}\n", style="bold yellow")
return status_text return status_text
async def on_mount(self) -> None: async def on_mount(self) -> None:
"""Initialize the screen when mounted.""" """Initialize the screen when mounted."""
await self._refresh_services() await self._refresh_services()
# Set up auto-refresh every 5 seconds # Set up auto-refresh every 5 seconds
self.refresh_timer = self.set_interval(5.0, self._auto_refresh) self.refresh_timer = self.set_interval(5.0, self._auto_refresh)
# Focus the services table # Focus the services table
try: try:
self.services_table.focus() self.services_table.focus()
except Exception: except Exception:
pass pass
def on_unmount(self) -> None: def on_unmount(self) -> None:
"""Clean up when unmounting.""" """Clean up when unmounting."""
if self.refresh_timer: if self.refresh_timer:
self.refresh_timer.stop() self.refresh_timer.stop()
# Stop following logs if running # Stop following logs if running
self._stop_follow() self._stop_follow()
async def on_screen_resume(self) -> None: async def on_screen_resume(self) -> None:
"""Called when the screen is resumed (e.g., after a modal is closed).""" """Called when the screen is resumed (e.g., after a modal is closed)."""
# Refresh services when returning from a modal # Refresh services when returning from a modal
await self._refresh_services() await self._refresh_services()
async def _refresh_services(self) -> None: async def _refresh_services(self) -> None:
"""Refresh the services table.""" """Refresh the services table."""
if not self.container_manager.is_available(): if not self.container_manager.is_available():
return return
services = await self.container_manager.get_service_status(force_refresh=True) services = await self.container_manager.get_service_status(force_refresh=True)
# Collect images actually reported by running/stopped containers so names match runtime # Collect images actually reported by running/stopped containers so names match runtime
images_set = set() images_set = set()
@ -147,7 +151,9 @@ class MonitorScreen(Screen):
images_set.add(img) images_set.add(img)
# Ensure compose-declared images are also shown (e.g., langflow when stopped) # Ensure compose-declared images are also shown (e.g., langflow when stopped)
try: try:
for img in self.container_manager._parse_compose_images(): # best-effort, no YAML dep for img in (
self.container_manager._parse_compose_images()
): # best-effort, no YAML dep
if img: if img:
images_set.add(img) images_set.add(img)
except Exception: except Exception:
@ -155,23 +161,23 @@ class MonitorScreen(Screen):
images = list(images_set) images = list(images_set)
# Lookup digests/IDs for these image names # Lookup digests/IDs for these image names
digest_map = await self.container_manager.get_images_digests(images) digest_map = await self.container_manager.get_images_digests(images)
# Clear existing rows # Clear existing rows
self.services_table.clear() self.services_table.clear()
if self.images_table: if self.images_table:
self.images_table.clear() self.images_table.clear()
# Add service rows # Add service rows
for service_name, service_info in services.items(): for service_name, service_info in services.items():
status_style = self._get_status_style(service_info.status) status_style = self._get_status_style(service_info.status)
self.services_table.add_row( self.services_table.add_row(
service_info.name, service_info.name,
Text(service_info.status.value, style=status_style), Text(service_info.status.value, style=status_style),
service_info.health or "N/A", service_info.health or "N/A",
", ".join(service_info.ports) if service_info.ports else "N/A", ", ".join(service_info.ports) if service_info.ports else "N/A",
service_info.image or "N/A", service_info.image or "N/A",
digest_map.get(service_info.image or "", "-") digest_map.get(service_info.image or "", "-"),
) )
# Populate images table (unique images as reported by runtime) # Populate images table (unique images as reported by runtime)
if self.images_table: if self.images_table:
@ -181,7 +187,7 @@ class MonitorScreen(Screen):
self._update_controls(list(services.values())) self._update_controls(list(services.values()))
# Update mode indicator # Update mode indicator
self._update_mode_row() self._update_mode_row()
def _get_status_style(self, status: ServiceStatus) -> str: def _get_status_style(self, status: ServiceStatus) -> str:
"""Get the Rich style for a service status.""" """Get the Rich style for a service status."""
status_styles = { status_styles = {
@ -191,20 +197,20 @@ class MonitorScreen(Screen):
ServiceStatus.STOPPING: "bold yellow", ServiceStatus.STOPPING: "bold yellow",
ServiceStatus.ERROR: "bold red", ServiceStatus.ERROR: "bold red",
ServiceStatus.MISSING: "dim", ServiceStatus.MISSING: "dim",
ServiceStatus.UNKNOWN: "dim" ServiceStatus.UNKNOWN: "dim",
} }
return status_styles.get(status, "white") return status_styles.get(status, "white")
async def _auto_refresh(self) -> None: async def _auto_refresh(self) -> None:
"""Auto-refresh services if not in operation.""" """Auto-refresh services if not in operation."""
if not self.operation_in_progress: if not self.operation_in_progress:
await self._refresh_services() await self._refresh_services()
def on_button_pressed(self, event: Button.Pressed) -> None: def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses.""" """Handle button presses."""
button_id = event.button.id or "" button_id = event.button.id or ""
button_label = event.button.label or "" button_label = event.button.label or ""
# Use button ID prefixes to determine action, ignoring any random suffix # Use button ID prefixes to determine action, ignoring any random suffix
if button_id.startswith("start-btn"): if button_id.startswith("start-btn"):
self.run_worker(self._start_services()) self.run_worker(self._start_services())
@ -228,18 +234,18 @@ class MonitorScreen(Screen):
"logs-backend": "openrag-backend", "logs-backend": "openrag-backend",
"logs-frontend": "openrag-frontend", "logs-frontend": "openrag-frontend",
"logs-opensearch": "opensearch", "logs-opensearch": "opensearch",
"logs-langflow": "langflow" "logs-langflow": "langflow",
} }
# Extract the base button ID (without any suffix) # Extract the base button ID (without any suffix)
button_base_id = button_id.split("-")[0] + "-" + button_id.split("-")[1] button_base_id = button_id.split("-")[0] + "-" + button_id.split("-")[1]
service_name = service_mapping.get(button_base_id) service_name = service_mapping.get(button_base_id)
if service_name: if service_name:
# Load recent logs then start following # Load recent logs then start following
self.run_worker(self._show_logs(service_name)) self.run_worker(self._show_logs(service_name))
self._start_follow(service_name) self._start_follow(service_name)
async def _start_services(self, cpu_mode: bool = False) -> None: async def _start_services(self, cpu_mode: bool = False) -> None:
"""Start services with progress updates.""" """Start services with progress updates."""
self.operation_in_progress = True self.operation_in_progress = True
@ -249,12 +255,12 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Starting Services", "Starting Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
self.operation_in_progress = False self.operation_in_progress = False
async def _stop_services(self) -> None: async def _stop_services(self) -> None:
"""Stop services with progress updates.""" """Stop services with progress updates."""
self.operation_in_progress = True self.operation_in_progress = True
@ -264,12 +270,12 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Stopping Services", "Stopping Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
self.operation_in_progress = False self.operation_in_progress = False
async def _restart_services(self) -> None: async def _restart_services(self) -> None:
"""Restart services with progress updates.""" """Restart services with progress updates."""
self.operation_in_progress = True self.operation_in_progress = True
@ -279,12 +285,12 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Restarting Services", "Restarting Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
self.operation_in_progress = False self.operation_in_progress = False
async def _upgrade_services(self) -> None: async def _upgrade_services(self) -> None:
"""Upgrade services with progress updates.""" """Upgrade services with progress updates."""
self.operation_in_progress = True self.operation_in_progress = True
@ -294,12 +300,12 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Upgrading Services", "Upgrading Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
self.operation_in_progress = False self.operation_in_progress = False
async def _reset_services(self) -> None: async def _reset_services(self) -> None:
"""Reset services with progress updates.""" """Reset services with progress updates."""
self.operation_in_progress = True self.operation_in_progress = True
@ -309,17 +315,17 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Resetting Services", "Resetting Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
self.operation_in_progress = False self.operation_in_progress = False
def _strip_ansi_codes(self, text: str) -> str: def _strip_ansi_codes(self, text: str) -> str:
"""Strip ANSI escape sequences from text.""" """Strip ANSI escape sequences from text."""
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
return ansi_escape.sub('', text) return ansi_escape.sub("", text)
async def _show_logs(self, service_name: str) -> None: async def _show_logs(self, service_name: str) -> None:
"""Show logs for a service.""" """Show logs for a service."""
success, logs = await self.container_manager.get_service_logs(service_name) success, logs = await self.container_manager.get_service_logs(service_name)
@ -346,7 +352,7 @@ class MonitorScreen(Screen):
notify_with_diagnostics( notify_with_diagnostics(
self.app, self.app,
f"Failed to get logs for {service_name}: {logs}", f"Failed to get logs for {service_name}: {logs}",
severity="error" severity="error",
) )
def _stop_follow(self) -> None: def _stop_follow(self) -> None:
@ -391,11 +397,9 @@ class MonitorScreen(Screen):
pass pass
except Exception as e: except Exception as e:
notify_with_diagnostics( notify_with_diagnostics(
self.app, self.app, f"Error following logs: {e}", severity="error"
f"Error following logs: {e}",
severity="error"
) )
def action_refresh(self) -> None: def action_refresh(self) -> None:
"""Refresh services manually.""" """Refresh services manually."""
self.run_worker(self._refresh_services()) self.run_worker(self._refresh_services())
@ -431,14 +435,15 @@ class MonitorScreen(Screen):
try: try:
current = getattr(self.container_manager, "use_cpu_compose", True) current = getattr(self.container_manager, "use_cpu_compose", True)
self.container_manager.use_cpu_compose = not current self.container_manager.use_cpu_compose = not current
self.notify("Switched to GPU compose" if not current else "Switched to CPU compose", severity="information") self.notify(
"Switched to GPU compose" if not current else "Switched to CPU compose",
severity="information",
)
self._update_mode_row() self._update_mode_row()
self.action_refresh() self.action_refresh()
except Exception as e: except Exception as e:
notify_with_diagnostics( notify_with_diagnostics(
self.app, self.app, f"Failed to toggle mode: {e}", severity="error"
f"Failed to toggle mode: {e}",
severity="error"
) )
def _update_controls(self, services: list[ServiceInfo]) -> None: def _update_controls(self, services: list[ServiceInfo]) -> None:
@ -446,83 +451,93 @@ class MonitorScreen(Screen):
try: try:
# Get the controls container # Get the controls container
controls = self.query_one("#services-controls", Horizontal) controls = self.query_one("#services-controls", Horizontal)
# Check if any services are running # Check if any services are running
any_running = any(s.status == ServiceStatus.RUNNING for s in services) any_running = any(s.status == ServiceStatus.RUNNING for s in services)
# Clear existing buttons by removing all children # Clear existing buttons by removing all children
controls.remove_children() controls.remove_children()
# Use a single ID for each button type, but make them unique with a suffix # Use a single ID for each button type, but make them unique with a suffix
# This ensures we don't create duplicate IDs across refreshes # This ensures we don't create duplicate IDs across refreshes
import random import random
suffix = f"-{random.randint(10000, 99999)}" suffix = f"-{random.randint(10000, 99999)}"
# Add appropriate buttons based on service state # Add appropriate buttons based on service state
if any_running: if any_running:
# When services are running, show stop and restart # When services are running, show stop and restart
controls.mount(Button("Stop Services", variant="error", id=f"stop-btn{suffix}")) controls.mount(
controls.mount(Button("Restart", variant="primary", id=f"restart-btn{suffix}")) Button("Stop Services", variant="error", id=f"stop-btn{suffix}")
)
controls.mount(
Button("Restart", variant="primary", id=f"restart-btn{suffix}")
)
else: else:
# When services are not running, show start # When services are not running, show start
controls.mount(Button("Start Services", variant="success", id=f"start-btn{suffix}")) controls.mount(
Button("Start Services", variant="success", id=f"start-btn{suffix}")
)
# Always show upgrade and reset buttons # Always show upgrade and reset buttons
controls.mount(Button("Upgrade", variant="warning", id=f"upgrade-btn{suffix}")) controls.mount(
Button("Upgrade", variant="warning", id=f"upgrade-btn{suffix}")
)
controls.mount(Button("Reset", variant="error", id=f"reset-btn{suffix}")) controls.mount(Button("Reset", variant="error", id=f"reset-btn{suffix}"))
except Exception as e: except Exception as e:
notify_with_diagnostics( notify_with_diagnostics(
self.app, self.app, f"Error updating controls: {e}", severity="error"
f"Error updating controls: {e}",
severity="error"
) )
def action_back(self) -> None: def action_back(self) -> None:
"""Go back to previous screen.""" """Go back to previous screen."""
self.app.pop_screen() self.app.pop_screen()
def action_start(self) -> None: def action_start(self) -> None:
"""Start services.""" """Start services."""
self.run_worker(self._start_services()) self.run_worker(self._start_services())
def action_stop(self) -> None: def action_stop(self) -> None:
"""Stop services.""" """Stop services."""
self.run_worker(self._stop_services()) self.run_worker(self._stop_services())
def action_upgrade(self) -> None: def action_upgrade(self) -> None:
"""Upgrade services.""" """Upgrade services."""
self.run_worker(self._upgrade_services()) self.run_worker(self._upgrade_services())
def action_reset(self) -> None: def action_reset(self) -> None:
"""Reset services.""" """Reset services."""
self.run_worker(self._reset_services()) self.run_worker(self._reset_services())
def action_logs(self) -> None: def action_logs(self) -> None:
"""View logs for the selected service.""" """View logs for the selected service."""
try: try:
# Get the currently focused row in the services table # Get the currently focused row in the services table
table = self.query_one("#services-table", DataTable) table = self.query_one("#services-table", DataTable)
if table.cursor_row is not None and table.cursor_row >= 0: if table.cursor_row is not None and table.cursor_row >= 0:
# Get the service name from the first column of the selected row # Get the service name from the first column of the selected row
row_data = table.get_row_at(table.cursor_row) row_data = table.get_row_at(table.cursor_row)
if row_data: if row_data:
service_name = str(row_data[0]) # First column is service name service_name = str(row_data[0]) # First column is service name
# Map display names to actual service names # Map display names to actual service names
service_mapping = { service_mapping = {
"openrag-backend": "openrag-backend", "openrag-backend": "openrag-backend",
"openrag-frontend": "openrag-frontend", "openrag-frontend": "openrag-frontend",
"opensearch": "opensearch", "opensearch": "opensearch",
"langflow": "langflow", "langflow": "langflow",
"dashboards": "dashboards" "dashboards": "dashboards",
} }
actual_service_name = service_mapping.get(service_name, service_name) actual_service_name = service_mapping.get(
service_name, service_name
)
# Push the logs screen with the selected service # Push the logs screen with the selected service
from .logs import LogsScreen from .logs import LogsScreen
logs_screen = LogsScreen(initial_service=actual_service_name) logs_screen = LogsScreen(initial_service=actual_service_name)
self.app.push_screen(logs_screen) self.app.push_screen(logs_screen)
else: else:

View file

@ -16,7 +16,7 @@ from ..managers.env_manager import EnvManager
class WelcomeScreen(Screen): class WelcomeScreen(Screen):
"""Initial welcome screen with setup options.""" """Initial welcome screen with setup options."""
BINDINGS = [ BINDINGS = [
("q", "quit", "Quit"), ("q", "quit", "Quit"),
("enter", "default_action", "Continue"), ("enter", "default_action", "Continue"),
@ -25,7 +25,7 @@ class WelcomeScreen(Screen):
("3", "monitor", "Monitor Services"), ("3", "monitor", "Monitor Services"),
("4", "diagnostics", "Diagnostics"), ("4", "diagnostics", "Diagnostics"),
] ]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.container_manager = ContainerManager() self.container_manager = ContainerManager()
@ -34,19 +34,19 @@ class WelcomeScreen(Screen):
self.has_oauth_config = False self.has_oauth_config = False
self.default_button_id = "basic-setup-btn" self.default_button_id = "basic-setup-btn"
self._state_checked = False self._state_checked = False
# Load .env file if it exists # Load .env file if it exists
load_dotenv() load_dotenv()
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Create the welcome screen layout.""" """Create the welcome screen layout."""
yield Container( yield Container(
Vertical( Vertical(
Static(self._create_welcome_text(), id="welcome-text"), Static(self._create_welcome_text(), id="welcome-text"),
self._create_dynamic_buttons(), self._create_dynamic_buttons(),
id="welcome-container" id="welcome-container",
), ),
id="main-container" id="main-container",
) )
yield Footer() yield Footer()
@ -65,55 +65,67 @@ class WelcomeScreen(Screen):
welcome_text.append("Terminal User Interface for OpenRAG\n\n", style="dim") welcome_text.append("Terminal User Interface for OpenRAG\n\n", style="dim")
if self.services_running: if self.services_running:
welcome_text.append("✓ Services are currently running\n\n", style="bold green") welcome_text.append(
"✓ Services are currently running\n\n", style="bold green"
)
elif self.has_oauth_config: elif self.has_oauth_config:
welcome_text.append("OAuth credentials detected — Advanced Setup recommended\n\n", style="bold green") welcome_text.append(
"OAuth credentials detected — Advanced Setup recommended\n\n",
style="bold green",
)
else: else:
welcome_text.append("Select a setup below to continue\n\n", style="white") welcome_text.append("Select a setup below to continue\n\n", style="white")
return welcome_text return welcome_text
def _create_dynamic_buttons(self) -> Horizontal: def _create_dynamic_buttons(self) -> Horizontal:
"""Create buttons based on current state.""" """Create buttons based on current state."""
# Check OAuth config early to determine which buttons to show # Check OAuth config early to determine which buttons to show
has_oauth = ( has_oauth = bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or bool(
bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID")
bool(os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID"))
) )
buttons = [] buttons = []
if self.services_running: if self.services_running:
# Services running - only show monitor # Services running - only show monitor
buttons.append(Button("Monitor Services", variant="success", id="monitor-btn")) buttons.append(
Button("Monitor Services", variant="success", id="monitor-btn")
)
else: else:
# Services not running - show setup options # Services not running - show setup options
if has_oauth: if has_oauth:
# Only show advanced setup if OAuth is configured # Only show advanced setup if OAuth is configured
buttons.append(Button("Advanced Setup", variant="success", id="advanced-setup-btn")) buttons.append(
Button("Advanced Setup", variant="success", id="advanced-setup-btn")
)
else: else:
# Only show basic setup if no OAuth # Only show basic setup if no OAuth
buttons.append(Button("Basic Setup", variant="success", id="basic-setup-btn")) buttons.append(
Button("Basic Setup", variant="success", id="basic-setup-btn")
)
# Always show monitor option # Always show monitor option
buttons.append(Button("Monitor Services", variant="default", id="monitor-btn")) buttons.append(
Button("Monitor Services", variant="default", id="monitor-btn")
)
return Horizontal(*buttons, classes="button-row") return Horizontal(*buttons, classes="button-row")
async def on_mount(self) -> None: async def on_mount(self) -> None:
"""Initialize screen state when mounted.""" """Initialize screen state when mounted."""
# Check if services are running # Check if services are running
if self.container_manager.is_available(): if self.container_manager.is_available():
services = await self.container_manager.get_service_status() services = await self.container_manager.get_service_status()
running_services = [s.name for s in services.values() if s.status == ServiceStatus.RUNNING] running_services = [
s.name for s in services.values() if s.status == ServiceStatus.RUNNING
]
self.services_running = len(running_services) > 0 self.services_running = len(running_services) > 0
# Check for OAuth configuration # Check for OAuth configuration
self.has_oauth_config = ( self.has_oauth_config = bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or bool(
bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID")
bool(os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID"))
) )
# Set default button focus # Set default button focus
if self.services_running: if self.services_running:
self.default_button_id = "monitor-btn" self.default_button_id = "monitor-btn"
@ -121,12 +133,14 @@ class WelcomeScreen(Screen):
self.default_button_id = "advanced-setup-btn" self.default_button_id = "advanced-setup-btn"
else: else:
self.default_button_id = "basic-setup-btn" self.default_button_id = "basic-setup-btn"
# Update the welcome text and recompose with new state # Update the welcome text and recompose with new state
try: try:
welcome_widget = self.query_one("#welcome-text") welcome_widget = self.query_one("#welcome-text")
welcome_widget.update(self._create_welcome_text()) # This is fine for Static widgets welcome_widget.update(
self._create_welcome_text()
) # This is fine for Static widgets
# Focus the appropriate button # Focus the appropriate button
if self.services_running: if self.services_running:
try: try:
@ -143,10 +157,10 @@ class WelcomeScreen(Screen):
self.query_one("#basic-setup-btn").focus() self.query_one("#basic-setup-btn").focus()
except: except:
pass pass
except: except:
pass # Widgets might not be mounted yet pass # Widgets might not be mounted yet
def on_button_pressed(self, event: Button.Pressed) -> None: def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses.""" """Handle button presses."""
if event.button.id == "basic-setup-btn": if event.button.id == "basic-setup-btn":
@ -157,7 +171,7 @@ class WelcomeScreen(Screen):
self.action_monitor() self.action_monitor()
elif event.button.id == "diagnostics-btn": elif event.button.id == "diagnostics-btn":
self.action_diagnostics() self.action_diagnostics()
def action_default_action(self) -> None: def action_default_action(self) -> None:
"""Handle Enter key - go to default action based on state.""" """Handle Enter key - go to default action based on state."""
if self.services_running: if self.services_running:
@ -166,27 +180,31 @@ class WelcomeScreen(Screen):
self.action_full_setup() self.action_full_setup()
else: else:
self.action_no_auth_setup() self.action_no_auth_setup()
def action_no_auth_setup(self) -> None: def action_no_auth_setup(self) -> None:
"""Switch to basic configuration screen.""" """Switch to basic configuration screen."""
from .config import ConfigScreen from .config import ConfigScreen
self.app.push_screen(ConfigScreen(mode="no_auth")) self.app.push_screen(ConfigScreen(mode="no_auth"))
def action_full_setup(self) -> None: def action_full_setup(self) -> None:
"""Switch to advanced configuration screen.""" """Switch to advanced configuration screen."""
from .config import ConfigScreen from .config import ConfigScreen
self.app.push_screen(ConfigScreen(mode="full")) self.app.push_screen(ConfigScreen(mode="full"))
def action_monitor(self) -> None: def action_monitor(self) -> None:
"""Switch to monitoring screen.""" """Switch to monitoring screen."""
from .monitor import MonitorScreen from .monitor import MonitorScreen
self.app.push_screen(MonitorScreen()) self.app.push_screen(MonitorScreen())
def action_diagnostics(self) -> None: def action_diagnostics(self) -> None:
"""Switch to diagnostics screen.""" """Switch to diagnostics screen."""
from .diagnostics import DiagnosticsScreen from .diagnostics import DiagnosticsScreen
self.app.push_screen(DiagnosticsScreen()) self.app.push_screen(DiagnosticsScreen())
def action_quit(self) -> None: def action_quit(self) -> None:
"""Quit the application.""" """Quit the application."""
self.app.exit() self.app.exit()

View file

@ -1 +1 @@
"""TUI utilities package.""" """TUI utilities package."""

View file

@ -34,40 +34,66 @@ class PlatformDetector:
"""Detect available container runtime and compose capabilities.""" """Detect available container runtime and compose capabilities."""
# First check if we have podman installed # First check if we have podman installed
podman_version = self._get_podman_version() podman_version = self._get_podman_version()
# If we have podman, check if docker is actually podman in disguise # If we have podman, check if docker is actually podman in disguise
if podman_version: if podman_version:
docker_version = self._get_docker_version() docker_version = self._get_docker_version()
if docker_version and podman_version in docker_version: if docker_version and podman_version in docker_version:
# This is podman masquerading as docker # This is podman masquerading as docker
if self._check_command(["docker", "compose", "--help"]): if self._check_command(["docker", "compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["docker", "compose"], ["docker"], podman_version) return RuntimeInfo(
RuntimeType.PODMAN,
["docker", "compose"],
["docker"],
podman_version,
)
if self._check_command(["docker-compose", "--help"]): if self._check_command(["docker-compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["docker-compose"], ["docker"], podman_version) return RuntimeInfo(
RuntimeType.PODMAN,
["docker-compose"],
["docker"],
podman_version,
)
# Check for native podman compose # Check for native podman compose
if self._check_command(["podman", "compose", "--help"]): if self._check_command(["podman", "compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["podman", "compose"], ["podman"], podman_version) return RuntimeInfo(
RuntimeType.PODMAN,
["podman", "compose"],
["podman"],
podman_version,
)
# Check for actual docker # Check for actual docker
if self._check_command(["docker", "compose", "--help"]): if self._check_command(["docker", "compose", "--help"]):
version = self._get_docker_version() version = self._get_docker_version()
return RuntimeInfo(RuntimeType.DOCKER, ["docker", "compose"], ["docker"], version) return RuntimeInfo(
RuntimeType.DOCKER, ["docker", "compose"], ["docker"], version
)
if self._check_command(["docker-compose", "--help"]): if self._check_command(["docker-compose", "--help"]):
version = self._get_docker_version() version = self._get_docker_version()
return RuntimeInfo(RuntimeType.DOCKER_COMPOSE, ["docker-compose"], ["docker"], version) return RuntimeInfo(
RuntimeType.DOCKER_COMPOSE, ["docker-compose"], ["docker"], version
)
return RuntimeInfo(RuntimeType.NONE, [], []) return RuntimeInfo(RuntimeType.NONE, [], [])
def detect_gpu_available(self) -> bool: def detect_gpu_available(self) -> bool:
"""Best-effort detection of NVIDIA GPU availability for containers.""" """Best-effort detection of NVIDIA GPU availability for containers."""
try: try:
res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=5) res = subprocess.run(
if res.returncode == 0 and any("GPU" in ln for ln in res.stdout.splitlines()): ["nvidia-smi", "-L"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0 and any(
"GPU" in ln for ln in res.stdout.splitlines()
):
return True return True
except (subprocess.TimeoutExpired, FileNotFoundError): except (subprocess.TimeoutExpired, FileNotFoundError):
pass pass
for cmd in (["docker", "info", "--format", "{{json .Runtimes}}"], ["podman", "info", "--format", "json"]): for cmd in (
["docker", "info", "--format", "{{json .Runtimes}}"],
["podman", "info", "--format", "json"],
):
try: try:
res = subprocess.run(cmd, capture_output=True, text=True, timeout=5) res = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
if res.returncode == 0 and "nvidia" in res.stdout.lower(): if res.returncode == 0 and "nvidia" in res.stdout.lower():
@ -85,7 +111,9 @@ class PlatformDetector:
def _get_docker_version(self) -> Optional[str]: def _get_docker_version(self) -> Optional[str]:
try: try:
res = subprocess.run(["docker", "--version"], capture_output=True, text=True, timeout=5) res = subprocess.run(
["docker", "--version"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0: if res.returncode == 0:
return res.stdout.strip() return res.stdout.strip()
except (subprocess.TimeoutExpired, FileNotFoundError): except (subprocess.TimeoutExpired, FileNotFoundError):
@ -94,7 +122,9 @@ class PlatformDetector:
def _get_podman_version(self) -> Optional[str]: def _get_podman_version(self) -> Optional[str]:
try: try:
res = subprocess.run(["podman", "--version"], capture_output=True, text=True, timeout=5) res = subprocess.run(
["podman", "--version"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0: if res.returncode == 0:
return res.stdout.strip() return res.stdout.strip()
except (subprocess.TimeoutExpired, FileNotFoundError): except (subprocess.TimeoutExpired, FileNotFoundError):
@ -110,7 +140,12 @@ class PlatformDetector:
if self.platform_system != "Darwin": if self.platform_system != "Darwin":
return True, 0, "Not running on macOS" return True, 0, "Not running on macOS"
try: try:
result = subprocess.run(["podman", "machine", "inspect"], capture_output=True, text=True, timeout=10) result = subprocess.run(
["podman", "machine", "inspect"],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode != 0: if result.returncode != 0:
return False, 0, "Could not inspect Podman machine" return False, 0, "Could not inspect Podman machine"
machines = json.loads(result.stdout) machines = json.loads(result.stdout)
@ -124,7 +159,11 @@ class PlatformDetector:
if not is_sufficient: if not is_sufficient:
status += "\nTo increase: podman machine stop && podman machine rm && podman machine init --memory 8192 && podman machine start" status += "\nTo increase: podman machine stop && podman machine rm && podman machine init --memory 8192 && podman machine start"
return is_sufficient, memory_mb, status return is_sufficient, memory_mb, status
except (subprocess.TimeoutExpired, FileNotFoundError, json.JSONDecodeError) as e: except (
subprocess.TimeoutExpired,
FileNotFoundError,
json.JSONDecodeError,
) as e:
return False, 0, f"Error checking Podman VM memory: {e}" return False, 0, f"Error checking Podman VM memory: {e}"
def get_installation_instructions(self) -> str: def get_installation_instructions(self) -> str:
@ -167,4 +206,4 @@ Or Podman Desktop:
No container runtime found. Please install Docker or Podman for your platform: No container runtime found. Please install Docker or Podman for your platform:
- Docker: https://docs.docker.com/get-docker/ - Docker: https://docs.docker.com/get-docker/
- Podman: https://podman.io/getting-started/installation - Podman: https://podman.io/getting-started/installation
""" """

View file

@ -8,28 +8,31 @@ from typing import Optional
class ValidationError(Exception): class ValidationError(Exception):
"""Validation error exception.""" """Validation error exception."""
pass pass
def validate_env_var_name(name: str) -> bool: def validate_env_var_name(name: str) -> bool:
"""Validate environment variable name format.""" """Validate environment variable name format."""
return bool(re.match(r'^[A-Z][A-Z0-9_]*$', name)) return bool(re.match(r"^[A-Z][A-Z0-9_]*$", name))
def validate_path(path: str, must_exist: bool = False, must_be_dir: bool = False) -> bool: def validate_path(
path: str, must_exist: bool = False, must_be_dir: bool = False
) -> bool:
"""Validate file/directory path.""" """Validate file/directory path."""
if not path: if not path:
return False return False
try: try:
path_obj = Path(path).expanduser().resolve() path_obj = Path(path).expanduser().resolve()
if must_exist and not path_obj.exists(): if must_exist and not path_obj.exists():
return False return False
if must_be_dir and path_obj.exists() and not path_obj.is_dir(): if must_be_dir and path_obj.exists() and not path_obj.is_dir():
return False return False
return True return True
except (OSError, ValueError): except (OSError, ValueError):
return False return False
@ -39,15 +42,17 @@ def validate_url(url: str) -> bool:
"""Validate URL format.""" """Validate URL format."""
if not url: if not url:
return False return False
url_pattern = re.compile( url_pattern = re.compile(
r'^https?://' # http:// or https:// r"^https?://" # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|" # domain
r'localhost|' # localhost r"localhost|" # localhost
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # IP r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP
r'(?::\d+)?' # optional port r"(?::\d+)?" # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE) r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
return bool(url_pattern.match(url)) return bool(url_pattern.match(url))
@ -55,14 +60,14 @@ def validate_openai_api_key(key: str) -> bool:
"""Validate OpenAI API key format.""" """Validate OpenAI API key format."""
if not key: if not key:
return False return False
return key.startswith('sk-') and len(key) > 20 return key.startswith("sk-") and len(key) > 20
def validate_google_oauth_client_id(client_id: str) -> bool: def validate_google_oauth_client_id(client_id: str) -> bool:
"""Validate Google OAuth client ID format.""" """Validate Google OAuth client ID format."""
if not client_id: if not client_id:
return False return False
return client_id.endswith('.apps.googleusercontent.com') return client_id.endswith(".apps.googleusercontent.com")
def validate_non_empty(value: str) -> bool: def validate_non_empty(value: str) -> bool:
@ -74,37 +79,38 @@ def sanitize_env_value(value: str) -> str:
"""Sanitize environment variable value.""" """Sanitize environment variable value."""
# Remove leading/trailing whitespace # Remove leading/trailing whitespace
value = value.strip() value = value.strip()
# Remove quotes if they wrap the entire value # Remove quotes if they wrap the entire value
if len(value) >= 2: if len(value) >= 2:
if (value.startswith('"') and value.endswith('"')) or \ if (value.startswith('"') and value.endswith('"')) or (
(value.startswith("'") and value.endswith("'")): value.startswith("'") and value.endswith("'")
):
value = value[1:-1] value = value[1:-1]
return value return value
def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]: def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
""" """
Validate comma-separated documents paths for volume mounting. Validate comma-separated documents paths for volume mounting.
Returns: Returns:
(is_valid, error_message, validated_paths) (is_valid, error_message, validated_paths)
""" """
if not paths_str: if not paths_str:
return False, "Documents paths cannot be empty", [] return False, "Documents paths cannot be empty", []
paths = [path.strip() for path in paths_str.split(',') if path.strip()] paths = [path.strip() for path in paths_str.split(",") if path.strip()]
if not paths: if not paths:
return False, "No valid paths provided", [] return False, "No valid paths provided", []
validated_paths = [] validated_paths = []
for path in paths: for path in paths:
try: try:
path_obj = Path(path).expanduser().resolve() path_obj = Path(path).expanduser().resolve()
# Check if path exists # Check if path exists
if not path_obj.exists(): if not path_obj.exists():
# Try to create it # Try to create it
@ -112,11 +118,11 @@ def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
path_obj.mkdir(parents=True, exist_ok=True) path_obj.mkdir(parents=True, exist_ok=True)
except (OSError, PermissionError) as e: except (OSError, PermissionError) as e:
return False, f"Cannot create directory '{path}': {e}", [] return False, f"Cannot create directory '{path}': {e}", []
# Check if it's a directory # Check if it's a directory
if not path_obj.is_dir(): if not path_obj.is_dir():
return False, f"Path '{path}' must be a directory", [] return False, f"Path '{path}' must be a directory", []
# Check if we can write to it # Check if we can write to it
try: try:
test_file = path_obj / ".openrag_test" test_file = path_obj / ".openrag_test"
@ -124,10 +130,10 @@ def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
test_file.unlink() test_file.unlink()
except (OSError, PermissionError): except (OSError, PermissionError):
return False, f"Directory '{path}' is not writable", [] return False, f"Directory '{path}' is not writable", []
validated_paths.append(str(path_obj)) validated_paths.append(str(path_obj))
except (OSError, ValueError) as e: except (OSError, ValueError) as e:
return False, f"Invalid path '{path}': {e}", [] return False, f"Invalid path '{path}': {e}", []
return True, "All paths valid", validated_paths return True, "All paths valid", validated_paths

View file

@ -65,13 +65,13 @@ class CommandOutputModal(ModalScreen):
""" """
def __init__( def __init__(
self, self,
title: str, title: str,
command_generator: AsyncIterator[tuple[bool, str]], command_generator: AsyncIterator[tuple[bool, str]],
on_complete: Optional[Callable] = None on_complete: Optional[Callable] = None,
): ):
"""Initialize the modal dialog. """Initialize the modal dialog.
Args: Args:
title: Title of the modal dialog title: Title of the modal dialog
command_generator: Async generator that yields (is_complete, message) tuples command_generator: Async generator that yields (is_complete, message) tuples
@ -104,29 +104,32 @@ class CommandOutputModal(ModalScreen):
async def _run_command(self) -> None: async def _run_command(self) -> None:
"""Run the command and update the output in real-time.""" """Run the command and update the output in real-time."""
output = self.query_one("#command-output", RichLog) output = self.query_one("#command-output", RichLog)
try: try:
async for is_complete, message in self.command_generator: async for is_complete, message in self.command_generator:
# Simple approach: just append each line as it comes # Simple approach: just append each line as it comes
output.write(message + "\n") output.write(message + "\n")
# Scroll to bottom # Scroll to bottom
container = self.query_one("#output-container", ScrollableContainer) container = self.query_one("#output-container", ScrollableContainer)
container.scroll_end(animate=False) container.scroll_end(animate=False)
# If command is complete, update UI # If command is complete, update UI
if is_complete: if is_complete:
output.write("[bold green]Command completed successfully[/bold green]\n") output.write(
"[bold green]Command completed successfully[/bold green]\n"
)
# Call the completion callback if provided # Call the completion callback if provided
if self.on_complete: if self.on_complete:
await asyncio.sleep(0.5) # Small delay for better UX await asyncio.sleep(0.5) # Small delay for better UX
self.on_complete() self.on_complete()
except Exception as e: except Exception as e:
output.write(f"[bold red]Error: {e}[/bold red]\n") output.write(f"[bold red]Error: {e}[/bold red]\n")
# Enable the close button and focus it # Enable the close button and focus it
close_btn = self.query_one("#close-btn", Button) close_btn = self.query_one("#close-btn", Button)
close_btn.disabled = False close_btn.disabled = False
close_btn.focus() close_btn.focus()
# Made with Bob # Made with Bob

View file

@ -9,10 +9,10 @@ def notify_with_diagnostics(
app: App, app: App,
message: str, message: str,
severity: Literal["information", "warning", "error"] = "error", severity: Literal["information", "warning", "error"] = "error",
timeout: float = 10.0 timeout: float = 10.0,
) -> None: ) -> None:
"""Show a notification with a button to open the diagnostics screen. """Show a notification with a button to open the diagnostics screen.
Args: Args:
app: The Textual app app: The Textual app
message: The notification message message: The notification message
@ -21,18 +21,20 @@ def notify_with_diagnostics(
""" """
# First show the notification # First show the notification
app.notify(message, severity=severity, timeout=timeout) app.notify(message, severity=severity, timeout=timeout)
# Then add a button to open diagnostics screen # Then add a button to open diagnostics screen
def open_diagnostics() -> None: def open_diagnostics() -> None:
from ..screens.diagnostics import DiagnosticsScreen from ..screens.diagnostics import DiagnosticsScreen
app.push_screen(DiagnosticsScreen()) app.push_screen(DiagnosticsScreen())
# Add a separate notification with just the button # Add a separate notification with just the button
app.notify( app.notify(
"Click to view diagnostics", "Click to view diagnostics",
severity="information", severity="information",
timeout=timeout, timeout=timeout,
title="Diagnostics" title="Diagnostics",
) )
# Made with Bob
# Made with Bob

View file

@ -9,10 +9,10 @@ def notify_with_diagnostics(
app: App, app: App,
message: str, message: str,
severity: Literal["information", "warning", "error"] = "error", severity: Literal["information", "warning", "error"] = "error",
timeout: float = 10.0 timeout: float = 10.0,
) -> None: ) -> None:
"""Show a notification with a button to open the diagnostics screen. """Show a notification with a button to open the diagnostics screen.
Args: Args:
app: The Textual app app: The Textual app
message: The notification message message: The notification message
@ -21,18 +21,20 @@ def notify_with_diagnostics(
""" """
# First show the notification # First show the notification
app.notify(message, severity=severity, timeout=timeout) app.notify(message, severity=severity, timeout=timeout)
# Then add a button to open diagnostics screen # Then add a button to open diagnostics screen
def open_diagnostics() -> None: def open_diagnostics() -> None:
from ..screens.diagnostics import DiagnosticsScreen from ..screens.diagnostics import DiagnosticsScreen
app.push_screen(DiagnosticsScreen()) app.push_screen(DiagnosticsScreen())
# Add a separate notification with just the button # Add a separate notification with just the button
app.notify( app.notify(
"Click to view diagnostics", "Click to view diagnostics",
severity="information", severity="information",
timeout=timeout, timeout=timeout,
title="Diagnostics" title="Diagnostics",
) )
# Made with Bob # Made with Bob

View file

@ -1,7 +1,12 @@
import hashlib import hashlib
import os import os
import sys
import platform
from collections import defaultdict from collections import defaultdict
from .gpu_detection import detect_gpu_devices from .gpu_detection import detect_gpu_devices
from utils.logging_config import get_logger
logger = get_logger(__name__)
# Global converter cache for worker processes # Global converter cache for worker processes
_worker_converter = None _worker_converter = None
@ -37,11 +42,11 @@ def get_worker_converter():
"1" # Still disable progress bars "1" # Still disable progress bars
) )
print( logger.info(
f"[WORKER {os.getpid()}] Initializing DocumentConverter in worker process" "Initializing DocumentConverter in worker process", worker_pid=os.getpid()
) )
_worker_converter = DocumentConverter() _worker_converter = DocumentConverter()
print(f"[WORKER {os.getpid()}] DocumentConverter ready in worker process") logger.info("DocumentConverter ready in worker process", worker_pid=os.getpid())
return _worker_converter return _worker_converter
@ -118,33 +123,45 @@ def process_document_sync(file_path: str):
start_memory = process.memory_info().rss / 1024 / 1024 # MB start_memory = process.memory_info().rss / 1024 / 1024 # MB
try: try:
print(f"[WORKER {os.getpid()}] Starting document processing: {file_path}") logger.info(
print(f"[WORKER {os.getpid()}] Initial memory usage: {start_memory:.1f} MB") "Starting document processing",
worker_pid=os.getpid(),
file_path=file_path,
initial_memory_mb=f"{start_memory:.1f}",
)
# Check file size # Check file size
try: try:
file_size = os.path.getsize(file_path) / 1024 / 1024 # MB file_size = os.path.getsize(file_path) / 1024 / 1024 # MB
print(f"[WORKER {os.getpid()}] File size: {file_size:.1f} MB") logger.info(
"File size determined",
worker_pid=os.getpid(),
file_size_mb=f"{file_size:.1f}",
)
except OSError as e: except OSError as e:
print(f"[WORKER {os.getpid()}] WARNING: Cannot get file size: {e}") logger.warning("Cannot get file size", worker_pid=os.getpid(), error=str(e))
file_size = 0 file_size = 0
# Get the cached converter for this worker # Get the cached converter for this worker
try: try:
print(f"[WORKER {os.getpid()}] Getting document converter...") logger.info("Getting document converter", worker_pid=os.getpid())
converter = get_worker_converter() converter = get_worker_converter()
memory_after_converter = process.memory_info().rss / 1024 / 1024 memory_after_converter = process.memory_info().rss / 1024 / 1024
print( logger.info(
f"[WORKER {os.getpid()}] Memory after converter init: {memory_after_converter:.1f} MB" "Memory after converter init",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_converter:.1f}",
) )
except Exception as e: except Exception as e:
print(f"[WORKER {os.getpid()}] ERROR: Failed to initialize converter: {e}") logger.error(
"Failed to initialize converter", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc() traceback.print_exc()
raise raise
# Compute file hash # Compute file hash
try: try:
print(f"[WORKER {os.getpid()}] Computing file hash...") logger.info("Computing file hash", worker_pid=os.getpid())
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
while True: while True:
@ -153,50 +170,67 @@ def process_document_sync(file_path: str):
break break
sha256.update(chunk) sha256.update(chunk)
file_hash = sha256.hexdigest() file_hash = sha256.hexdigest()
print(f"[WORKER {os.getpid()}] File hash computed: {file_hash[:12]}...") logger.info(
"File hash computed",
worker_pid=os.getpid(),
file_hash_prefix=file_hash[:12],
)
except Exception as e: except Exception as e:
print(f"[WORKER {os.getpid()}] ERROR: Failed to compute file hash: {e}") logger.error(
"Failed to compute file hash", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc() traceback.print_exc()
raise raise
# Convert with docling # Convert with docling
try: try:
print(f"[WORKER {os.getpid()}] Starting docling conversion...") logger.info("Starting docling conversion", worker_pid=os.getpid())
memory_before_convert = process.memory_info().rss / 1024 / 1024 memory_before_convert = process.memory_info().rss / 1024 / 1024
print( logger.info(
f"[WORKER {os.getpid()}] Memory before conversion: {memory_before_convert:.1f} MB" "Memory before conversion",
worker_pid=os.getpid(),
memory_mb=f"{memory_before_convert:.1f}",
) )
result = converter.convert(file_path) result = converter.convert(file_path)
memory_after_convert = process.memory_info().rss / 1024 / 1024 memory_after_convert = process.memory_info().rss / 1024 / 1024
print( logger.info(
f"[WORKER {os.getpid()}] Memory after conversion: {memory_after_convert:.1f} MB" "Memory after conversion",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_convert:.1f}",
) )
print(f"[WORKER {os.getpid()}] Docling conversion completed") logger.info("Docling conversion completed", worker_pid=os.getpid())
full_doc = result.document.export_to_dict() full_doc = result.document.export_to_dict()
memory_after_export = process.memory_info().rss / 1024 / 1024 memory_after_export = process.memory_info().rss / 1024 / 1024
print( logger.info(
f"[WORKER {os.getpid()}] Memory after export: {memory_after_export:.1f} MB" "Memory after export",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_export:.1f}",
) )
except Exception as e: except Exception as e:
print( current_memory = process.memory_info().rss / 1024 / 1024
f"[WORKER {os.getpid()}] ERROR: Failed during docling conversion: {e}" logger.error(
) "Failed during docling conversion",
print( worker_pid=os.getpid(),
f"[WORKER {os.getpid()}] Current memory usage: {process.memory_info().rss / 1024 / 1024:.1f} MB" error=str(e),
current_memory_mb=f"{current_memory:.1f}",
) )
traceback.print_exc() traceback.print_exc()
raise raise
# Extract relevant content (same logic as extract_relevant) # Extract relevant content (same logic as extract_relevant)
try: try:
print(f"[WORKER {os.getpid()}] Extracting relevant content...") logger.info("Extracting relevant content", worker_pid=os.getpid())
origin = full_doc.get("origin", {}) origin = full_doc.get("origin", {})
texts = full_doc.get("texts", []) texts = full_doc.get("texts", [])
print(f"[WORKER {os.getpid()}] Found {len(texts)} text fragments") logger.info(
"Found text fragments",
worker_pid=os.getpid(),
fragment_count=len(texts),
)
page_texts = defaultdict(list) page_texts = defaultdict(list)
for txt in texts: for txt in texts:
@ -210,22 +244,27 @@ def process_document_sync(file_path: str):
joined = "\n".join(page_texts[page]) joined = "\n".join(page_texts[page])
chunks.append({"page": page, "text": joined}) chunks.append({"page": page, "text": joined})
print( logger.info(
f"[WORKER {os.getpid()}] Created {len(chunks)} chunks from {len(page_texts)} pages" "Created chunks from pages",
worker_pid=os.getpid(),
chunk_count=len(chunks),
page_count=len(page_texts),
) )
except Exception as e: except Exception as e:
print( logger.error(
f"[WORKER {os.getpid()}] ERROR: Failed during content extraction: {e}" "Failed during content extraction", worker_pid=os.getpid(), error=str(e)
) )
traceback.print_exc() traceback.print_exc()
raise raise
final_memory = process.memory_info().rss / 1024 / 1024 final_memory = process.memory_info().rss / 1024 / 1024
memory_delta = final_memory - start_memory memory_delta = final_memory - start_memory
print(f"[WORKER {os.getpid()}] Document processing completed successfully") logger.info(
print( "Document processing completed successfully",
f"[WORKER {os.getpid()}] Final memory: {final_memory:.1f} MB (Delta +{memory_delta:.1f} MB)" worker_pid=os.getpid(),
final_memory_mb=f"{final_memory:.1f}",
memory_delta_mb=f"{memory_delta:.1f}",
) )
return { return {
@ -239,24 +278,29 @@ def process_document_sync(file_path: str):
except Exception as e: except Exception as e:
final_memory = process.memory_info().rss / 1024 / 1024 final_memory = process.memory_info().rss / 1024 / 1024
memory_delta = final_memory - start_memory memory_delta = final_memory - start_memory
print(f"[WORKER {os.getpid()}] FATAL ERROR in process_document_sync") logger.error(
print(f"[WORKER {os.getpid()}] File: {file_path}") "FATAL ERROR in process_document_sync",
print(f"[WORKER {os.getpid()}] Python version: {sys.version}") worker_pid=os.getpid(),
print( file_path=file_path,
f"[WORKER {os.getpid()}] Memory at crash: {final_memory:.1f} MB (Delta +{memory_delta:.1f} MB)" python_version=sys.version,
memory_at_crash_mb=f"{final_memory:.1f}",
memory_delta_mb=f"{memory_delta:.1f}",
error_type=type(e).__name__,
error=str(e),
) )
print(f"[WORKER {os.getpid()}] Error: {type(e).__name__}: {e}") logger.error("Full traceback:", worker_pid=os.getpid())
print(f"[WORKER {os.getpid()}] Full traceback:")
traceback.print_exc() traceback.print_exc()
# Try to get more system info before crashing # Try to get more system info before crashing
try: try:
import platform import platform
print( logger.error(
f"[WORKER {os.getpid()}] System: {platform.system()} {platform.release()}" "System info",
worker_pid=os.getpid(),
system=f"{platform.system()} {platform.release()}",
architecture=platform.machine(),
) )
print(f"[WORKER {os.getpid()}] Architecture: {platform.machine()}")
except: except:
pass pass

View file

@ -1,5 +1,8 @@
import multiprocessing import multiprocessing
import os import os
from utils.logging_config import get_logger
logger = get_logger(__name__)
def detect_gpu_devices(): def detect_gpu_devices():
@ -30,13 +33,15 @@ def get_worker_count():
if has_gpu_devices: if has_gpu_devices:
default_workers = min(4, multiprocessing.cpu_count() // 2) default_workers = min(4, multiprocessing.cpu_count() // 2)
print( logger.info(
f"GPU mode enabled with {gpu_count} GPU(s) - using limited concurrency ({default_workers} workers)" "GPU mode enabled with limited concurrency",
gpu_count=gpu_count,
worker_count=default_workers,
) )
else: else:
default_workers = multiprocessing.cpu_count() default_workers = multiprocessing.cpu_count()
print( logger.info(
f"CPU-only mode enabled - using full concurrency ({default_workers} workers)" "CPU-only mode enabled with full concurrency", worker_count=default_workers
) )
return int(os.getenv("MAX_WORKERS", default_workers)) return int(os.getenv("MAX_WORKERS", default_workers))

View file

@ -9,13 +9,15 @@ def configure_logging(
log_level: str = "INFO", log_level: str = "INFO",
json_logs: bool = False, json_logs: bool = False,
include_timestamps: bool = True, include_timestamps: bool = True,
service_name: str = "openrag" service_name: str = "openrag",
) -> None: ) -> None:
"""Configure structlog for the application.""" """Configure structlog for the application."""
# Convert string log level to actual level # Convert string log level to actual level
level = getattr(structlog.stdlib.logging, log_level.upper(), structlog.stdlib.logging.INFO) level = getattr(
structlog.stdlib.logging, log_level.upper(), structlog.stdlib.logging.INFO
)
# Base processors # Base processors
shared_processors = [ shared_processors = [
structlog.contextvars.merge_contextvars, structlog.contextvars.merge_contextvars,
@ -23,29 +25,65 @@ def configure_logging(
structlog.processors.StackInfoRenderer(), structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info, structlog.dev.set_exc_info,
] ]
if include_timestamps: if include_timestamps:
shared_processors.append(structlog.processors.TimeStamper(fmt="iso")) shared_processors.append(structlog.processors.TimeStamper(fmt="iso"))
# Add service name to all logs # Add service name and file location to all logs
shared_processors.append( shared_processors.append(
structlog.processors.CallsiteParameterAdder( structlog.processors.CallsiteParameterAdder(
parameters=[structlog.processors.CallsiteParameter.FUNC_NAME] parameters=[
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.FILENAME,
structlog.processors.CallsiteParameter.LINENO,
structlog.processors.CallsiteParameter.PATHNAME,
]
) )
) )
# Console output configuration # Console output configuration
if json_logs or os.getenv("LOG_FORMAT", "").lower() == "json": if json_logs or os.getenv("LOG_FORMAT", "").lower() == "json":
# JSON output for production/containers # JSON output for production/containers
shared_processors.append(structlog.processors.JSONRenderer()) shared_processors.append(structlog.processors.JSONRenderer())
console_renderer = structlog.processors.JSONRenderer() console_renderer = structlog.processors.JSONRenderer()
else: else:
# Pretty colored output for development # Custom clean format: timestamp path/file:loc logentry
console_renderer = structlog.dev.ConsoleRenderer( def custom_formatter(logger, log_method, event_dict):
colors=sys.stderr.isatty(), timestamp = event_dict.pop("timestamp", "")
exception_formatter=structlog.dev.plain_traceback, pathname = event_dict.pop("pathname", "")
) filename = event_dict.pop("filename", "")
lineno = event_dict.pop("lineno", "")
level = event_dict.pop("level", "")
# Build file location - prefer pathname for full path, fallback to filename
if pathname and lineno:
location = f"{pathname}:{lineno}"
elif filename and lineno:
location = f"{filename}:{lineno}"
elif pathname:
location = pathname
elif filename:
location = filename
else:
location = "unknown"
# Build the main message
message_parts = []
event = event_dict.pop("event", "")
if event:
message_parts.append(event)
# Add any remaining context
for key, value in event_dict.items():
if key not in ["service", "func_name"]: # Skip internal fields
message_parts.append(f"{key}={value}")
message = " ".join(message_parts)
return f"{timestamp} {location} {message}"
console_renderer = custom_formatter
# Configure structlog # Configure structlog
structlog.configure( structlog.configure(
processors=shared_processors + [console_renderer], processors=shared_processors + [console_renderer],
@ -54,7 +92,7 @@ def configure_logging(
logger_factory=structlog.WriteLoggerFactory(sys.stderr), logger_factory=structlog.WriteLoggerFactory(sys.stderr),
cache_logger_on_first_use=True, cache_logger_on_first_use=True,
) )
# Add global context # Add global context
structlog.contextvars.clear_contextvars() structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(service=service_name) structlog.contextvars.bind_contextvars(service=service_name)
@ -73,9 +111,7 @@ def configure_from_env() -> None:
log_level = os.getenv("LOG_LEVEL", "INFO") log_level = os.getenv("LOG_LEVEL", "INFO")
json_logs = os.getenv("LOG_FORMAT", "").lower() == "json" json_logs = os.getenv("LOG_FORMAT", "").lower() == "json"
service_name = os.getenv("SERVICE_NAME", "openrag") service_name = os.getenv("SERVICE_NAME", "openrag")
configure_logging( configure_logging(
log_level=log_level, log_level=log_level, json_logs=json_logs, service_name=service_name
json_logs=json_logs, )
service_name=service_name
)

View file

@ -1,10 +1,13 @@
import os import os
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from utils.gpu_detection import get_worker_count from utils.gpu_detection import get_worker_count
from utils.logging_config import get_logger
logger = get_logger(__name__)
# Create shared process pool at import time (before CUDA initialization) # Create shared process pool at import time (before CUDA initialization)
# This avoids the "Cannot re-initialize CUDA in forked subprocess" error # This avoids the "Cannot re-initialize CUDA in forked subprocess" error
MAX_WORKERS = get_worker_count() MAX_WORKERS = get_worker_count()
process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS) process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
print(f"Shared process pool initialized with {MAX_WORKERS} workers") logger.info("Shared process pool initialized", max_workers=MAX_WORKERS)

View file

@ -1,13 +1,17 @@
from docling.document_converter import DocumentConverter from docling.document_converter import DocumentConverter
import logging
print("Warming up docling models...") logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Warming up docling models")
try: try:
# Use the sample document to warm up docling # Use the sample document to warm up docling
test_file = "/app/warmup_ocr.pdf" test_file = "/app/warmup_ocr.pdf"
print(f"Using {test_file} to warm up docling...") logger.info(f"Using test file to warm up docling: {test_file}")
DocumentConverter().convert(test_file) DocumentConverter().convert(test_file)
print("Docling models warmed up successfully") logger.info("Docling models warmed up successfully")
except Exception as e: except Exception as e:
print(f"Docling warm-up completed with: {e}") logger.info(f"Docling warm-up completed with exception: {str(e)}")
# This is expected - we just want to trigger the model downloads # This is expected - we just want to trigger the model downloads