Merge branch 'feat-googledrive-enhancements' of github.com:langflow-ai/openrag into feat-googledrive-enhancements

This commit is contained in:
Mike Fortman 2025-09-05 16:13:00 -05:00
commit 86d8c9f5b2
41 changed files with 1730 additions and 967 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

Binary file not shown.

View file

@ -95,7 +95,9 @@ async def async_response_stream(
chunk_count = 0
async for chunk in response:
chunk_count += 1
logger.debug("Stream chunk received", chunk_count=chunk_count, chunk=str(chunk))
logger.debug(
"Stream chunk received", chunk_count=chunk_count, chunk=str(chunk)
)
# Yield the raw event as JSON for the UI to process
import json
@ -241,7 +243,10 @@ async def async_langflow_stream(
previous_response_id=previous_response_id,
log_prefix="langflow",
):
logger.debug("Yielding chunk from langflow stream", chunk_preview=chunk[:100].decode('utf-8', errors='replace'))
logger.debug(
"Yielding chunk from langflow stream",
chunk_preview=chunk[:100].decode("utf-8", errors="replace"),
)
yield chunk
logger.debug("Langflow stream completed")
except Exception as e:
@ -260,18 +265,24 @@ async def async_chat(
model: str = "gpt-4.1-mini",
previous_response_id: str = None,
):
logger.debug("async_chat called", user_id=user_id, previous_response_id=previous_response_id)
logger.debug(
"async_chat called", user_id=user_id, previous_response_id=previous_response_id
)
# Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id)
logger.debug("Got conversation state", message_count=len(conversation_state['messages']))
logger.debug(
"Got conversation state", message_count=len(conversation_state["messages"])
)
# Add user message to conversation with timestamp
from datetime import datetime
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
conversation_state["messages"].append(user_message)
logger.debug("Added user message", message_count=len(conversation_state['messages']))
logger.debug(
"Added user message", message_count=len(conversation_state["messages"])
)
response_text, response_id = await async_response(
async_client,
@ -280,7 +291,9 @@ async def async_chat(
previous_response_id=previous_response_id,
log_prefix="agent",
)
logger.debug("Got response", response_preview=response_text[:50], response_id=response_id)
logger.debug(
"Got response", response_preview=response_text[:50], response_id=response_id
)
# Add assistant response to conversation with response_id and timestamp
assistant_message = {
@ -290,17 +303,26 @@ async def async_chat(
"timestamp": datetime.now(),
}
conversation_state["messages"].append(assistant_message)
logger.debug("Added assistant message", message_count=len(conversation_state['messages']))
logger.debug(
"Added assistant message", message_count=len(conversation_state["messages"])
)
# Store the conversation thread with its response_id
if response_id:
conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored conversation thread", user_id=user_id, response_id=response_id)
logger.debug(
"Stored conversation thread", user_id=user_id, response_id=response_id
)
# Debug: Check what's in user_conversations now
conversations = get_user_conversations(user_id)
logger.debug("User conversations updated", user_id=user_id, conversation_count=len(conversations), conversation_ids=list(conversations.keys()))
logger.debug(
"User conversations updated",
user_id=user_id,
conversation_count=len(conversations),
conversation_ids=list(conversations.keys()),
)
else:
logger.warning("No response_id received, conversation not stored")
@ -363,7 +385,9 @@ async def async_chat_stream(
if response_id:
conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored conversation thread", user_id=user_id, response_id=response_id)
logger.debug(
"Stored conversation thread", user_id=user_id, response_id=response_id
)
# Async langflow function with conversation storage (non-streaming)
@ -375,18 +399,28 @@ async def async_langflow_chat(
extra_headers: dict = None,
previous_response_id: str = None,
):
logger.debug("async_langflow_chat called", user_id=user_id, previous_response_id=previous_response_id)
logger.debug(
"async_langflow_chat called",
user_id=user_id,
previous_response_id=previous_response_id,
)
# Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id)
logger.debug("Got langflow conversation state", message_count=len(conversation_state['messages']))
logger.debug(
"Got langflow conversation state",
message_count=len(conversation_state["messages"]),
)
# Add user message to conversation with timestamp
from datetime import datetime
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
conversation_state["messages"].append(user_message)
logger.debug("Added user message to langflow", message_count=len(conversation_state['messages']))
logger.debug(
"Added user message to langflow",
message_count=len(conversation_state["messages"]),
)
response_text, response_id = await async_response(
langflow_client,
@ -396,7 +430,11 @@ async def async_langflow_chat(
previous_response_id=previous_response_id,
log_prefix="langflow",
)
logger.debug("Got langflow response", response_preview=response_text[:50], response_id=response_id)
logger.debug(
"Got langflow response",
response_preview=response_text[:50],
response_id=response_id,
)
# Add assistant response to conversation with response_id and timestamp
assistant_message = {
@ -406,17 +444,29 @@ async def async_langflow_chat(
"timestamp": datetime.now(),
}
conversation_state["messages"].append(assistant_message)
logger.debug("Added assistant message to langflow", message_count=len(conversation_state['messages']))
logger.debug(
"Added assistant message to langflow",
message_count=len(conversation_state["messages"]),
)
# Store the conversation thread with its response_id
if response_id:
conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id)
logger.debug(
"Stored langflow conversation thread",
user_id=user_id,
response_id=response_id,
)
# Debug: Check what's in user_conversations now
conversations = get_user_conversations(user_id)
logger.debug("User conversations updated", user_id=user_id, conversation_count=len(conversations), conversation_ids=list(conversations.keys()))
logger.debug(
"User conversations updated",
user_id=user_id,
conversation_count=len(conversations),
conversation_ids=list(conversations.keys()),
)
else:
logger.warning("No response_id received from langflow, conversation not stored")
@ -432,7 +482,11 @@ async def async_langflow_chat_stream(
extra_headers: dict = None,
previous_response_id: str = None,
):
logger.debug("async_langflow_chat_stream called", user_id=user_id, previous_response_id=previous_response_id)
logger.debug(
"async_langflow_chat_stream called",
user_id=user_id,
previous_response_id=previous_response_id,
)
# Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id)
@ -483,4 +537,8 @@ async def async_langflow_chat_stream(
if response_id:
conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id)
logger.debug(
"Stored langflow conversation thread",
user_id=user_id,
response_id=response_id,
)

View file

@ -25,6 +25,11 @@ async def connector_sync(request: Request, connector_service, session_manager):
selected_files = data.get("selected_files")
try:
logger.debug(
"Starting connector sync",
connector_type=connector_type,
max_files=max_files,
)
user = request.state.user
jwt_token = request.state.jwt_token
@ -44,6 +49,10 @@ async def connector_sync(request: Request, connector_service, session_manager):
# Start sync tasks for all active connections
task_ids = []
for connection in active_connections:
logger.debug(
"About to call sync_connector_files for connection",
connection_id=connection.connection_id,
)
if selected_files:
task_id = await connector_service.sync_specific_files(
connection.connection_id,
@ -58,8 +67,6 @@ async def connector_sync(request: Request, connector_service, session_manager):
max_files,
jwt_token=jwt_token,
)
task_ids.append(task_id)
return JSONResponse(
{
"task_ids": task_ids,
@ -170,7 +177,9 @@ async def connector_webhook(request: Request, connector_service, session_manager
channel_id = None
if not channel_id:
logger.warning("No channel ID found in webhook", connector_type=connector_type)
logger.warning(
"No channel ID found in webhook", connector_type=connector_type
)
return JSONResponse({"status": "ignored", "reason": "no_channel_id"})
# Find the specific connection for this webhook
@ -180,7 +189,9 @@ async def connector_webhook(request: Request, connector_service, session_manager
)
)
if not connection or not connection.is_active:
logger.info("Unknown webhook channel, will auto-expire", channel_id=channel_id)
logger.info(
"Unknown webhook channel, will auto-expire", channel_id=channel_id
)
return JSONResponse(
{"status": "ignored_unknown_channel", "channel_id": channel_id}
)
@ -190,7 +201,10 @@ async def connector_webhook(request: Request, connector_service, session_manager
# Get the connector instance
connector = await connector_service._get_connector(connection.connection_id)
if not connector:
logger.error("Could not get connector for connection", connection_id=connection.connection_id)
logger.error(
"Could not get connector for connection",
connection_id=connection.connection_id,
)
return JSONResponse(
{"status": "error", "reason": "connector_not_found"}
)
@ -199,7 +213,11 @@ async def connector_webhook(request: Request, connector_service, session_manager
affected_files = await connector.handle_webhook(payload)
if affected_files:
logger.info("Webhook connection files affected", connection_id=connection.connection_id, affected_count=len(affected_files))
logger.info(
"Webhook connection files affected",
connection_id=connection.connection_id,
affected_count=len(affected_files),
)
# Generate JWT token for the user (needed for OpenSearch authentication)
user = session_manager.get_user(connection.user_id)
@ -223,7 +241,10 @@ async def connector_webhook(request: Request, connector_service, session_manager
}
else:
# No specific files identified - just log the webhook
logger.info("Webhook general change detected, no specific files", connection_id=connection.connection_id)
logger.info(
"Webhook general change detected, no specific files",
connection_id=connection.connection_id,
)
result = {
"connection_id": connection.connection_id,
@ -241,7 +262,15 @@ async def connector_webhook(request: Request, connector_service, session_manager
)
except Exception as e:
logger.error("Failed to process webhook for connection", connection_id=connection.connection_id, error=str(e))
logger.error(
"Failed to process webhook for connection",
connection_id=connection.connection_id,
error=str(e),
)
import traceback
traceback.print_exc()
return JSONResponse(
{
"status": "error",

View file

@ -395,15 +395,19 @@ async def knowledge_filter_webhook(
# Get the webhook payload
payload = await request.json()
logger.info("Knowledge filter webhook received",
filter_id=filter_id,
subscription_id=subscription_id,
payload_size=len(str(payload)))
logger.info(
"Knowledge filter webhook received",
filter_id=filter_id,
subscription_id=subscription_id,
payload_size=len(str(payload)),
)
# Extract findings from the payload
findings = payload.get("findings", [])
if not findings:
logger.info("No findings in webhook payload", subscription_id=subscription_id)
logger.info(
"No findings in webhook payload", subscription_id=subscription_id
)
return JSONResponse({"status": "no_findings"})
# Process the findings - these are the documents that matched the knowledge filter
@ -420,14 +424,18 @@ async def knowledge_filter_webhook(
)
# Log the matched documents
logger.info("Knowledge filter matched documents",
filter_id=filter_id,
matched_count=len(matched_documents))
logger.info(
"Knowledge filter matched documents",
filter_id=filter_id,
matched_count=len(matched_documents),
)
for doc in matched_documents:
logger.debug("Matched document",
document_id=doc['document_id'],
index=doc['index'],
score=doc.get('score'))
logger.debug(
"Matched document",
document_id=doc["document_id"],
index=doc["index"],
score=doc.get("score"),
)
# Here you could add additional processing:
# - Send notifications to external webhooks
@ -446,10 +454,12 @@ async def knowledge_filter_webhook(
)
except Exception as e:
logger.error("Failed to process knowledge filter webhook",
filter_id=filter_id,
subscription_id=subscription_id,
error=str(e))
logger.error(
"Failed to process knowledge filter webhook",
filter_id=filter_id,
subscription_id=subscription_id,
error=str(e),
)
import traceback
traceback.print_exc()

View file

@ -23,14 +23,16 @@ async def search(request: Request, search_service, session_manager):
# Extract JWT token from auth middleware
jwt_token = request.state.jwt_token
logger.debug("Search API request",
user=str(user),
user_id=user.user_id if user else None,
has_jwt_token=jwt_token is not None,
query=query,
filters=filters,
limit=limit,
score_threshold=score_threshold)
logger.debug(
"Search API request",
user=str(user),
user_id=user.user_id if user else None,
has_jwt_token=jwt_token is not None,
query=query,
filters=filters,
limit=limit,
score_threshold=score_threshold,
)
result = await search_service.search(
query,

View file

@ -14,7 +14,7 @@ async def upload(request: Request, document_service, session_manager):
jwt_token = request.state.jwt_token
from config.settings import is_no_auth_mode
# In no-auth mode, pass None for owner fields so documents have no owner
# This allows all users to see them when switching to auth mode
if is_no_auth_mode():
@ -25,7 +25,7 @@ async def upload(request: Request, document_service, session_manager):
owner_user_id = user.user_id
owner_name = user.name
owner_email = user.email
result = await document_service.process_upload_file(
upload_file,
owner_user_id=owner_user_id,
@ -61,9 +61,9 @@ async def upload_path(request: Request, task_service, session_manager):
user = request.state.user
jwt_token = request.state.jwt_token
from config.settings import is_no_auth_mode
# In no-auth mode, pass None for owner fields so documents have no owner
if is_no_auth_mode():
owner_user_id = None
@ -73,7 +73,7 @@ async def upload_path(request: Request, task_service, session_manager):
owner_user_id = user.user_id
owner_name = user.name
owner_email = user.email
task_id = await task_service.create_upload_task(
owner_user_id,
file_paths,

View file

@ -3,6 +3,9 @@ from starlette.responses import JSONResponse
from typing import Optional
from session_manager import User
from config.settings import is_no_auth_mode
from utils.logging_config import get_logger
logger = get_logger(__name__)
def get_current_user(request: Request, session_manager) -> Optional[User]:
@ -25,22 +28,15 @@ def require_auth(session_manager):
async def wrapper(request: Request):
# In no-auth mode, bypass authentication entirely
if is_no_auth_mode():
print(f"[DEBUG] No-auth mode: Creating anonymous user")
logger.debug("No-auth mode: Creating anonymous user")
# Create an anonymous user object so endpoints don't break
from session_manager import User
from datetime import datetime
request.state.user = User(
user_id="anonymous",
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
from session_manager import AnonymousUser
request.state.user = AnonymousUser()
request.state.jwt_token = None # No JWT in no-auth mode
print(f"[DEBUG] Set user_id=anonymous, jwt_token=None")
logger.debug("Set user_id=anonymous, jwt_token=None")
return await handler(request)
user = get_current_user(request, session_manager)
@ -72,15 +68,8 @@ def optional_auth(session_manager):
from session_manager import User
from datetime import datetime
request.state.user = User(
user_id="anonymous",
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
from session_manager import AnonymousUser
request.state.user = AnonymousUser()
request.state.jwt_token = None # No JWT in no-auth mode
else:
user = get_current_user(request, session_manager)

View file

@ -36,7 +36,12 @@ GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
def is_no_auth_mode():
"""Check if we're running in no-auth mode (OAuth credentials missing)"""
result = not (GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET)
logger.debug("Checking auth mode", no_auth_mode=result, has_client_id=GOOGLE_OAUTH_CLIENT_ID is not None, has_client_secret=GOOGLE_OAUTH_CLIENT_SECRET is not None)
logger.debug(
"Checking auth mode",
no_auth_mode=result,
has_client_id=GOOGLE_OAUTH_CLIENT_ID is not None,
has_client_secret=GOOGLE_OAUTH_CLIENT_SECRET is not None,
)
return result
@ -99,7 +104,9 @@ async def generate_langflow_api_key():
return LANGFLOW_KEY
if not LANGFLOW_SUPERUSER or not LANGFLOW_SUPERUSER_PASSWORD:
logger.warning("LANGFLOW_SUPERUSER and LANGFLOW_SUPERUSER_PASSWORD not set, skipping API key generation")
logger.warning(
"LANGFLOW_SUPERUSER and LANGFLOW_SUPERUSER_PASSWORD not set, skipping API key generation"
)
return None
try:
@ -141,11 +148,19 @@ async def generate_langflow_api_key():
raise KeyError("api_key")
LANGFLOW_KEY = api_key
logger.info("Successfully generated Langflow API key", api_key_preview=api_key[:8])
logger.info(
"Successfully generated Langflow API key",
api_key_preview=api_key[:8],
)
return api_key
except (requests.exceptions.RequestException, KeyError) as e:
last_error = e
logger.warning("Attempt to generate Langflow API key failed", attempt=attempt, max_attempts=max_attempts, error=str(e))
logger.warning(
"Attempt to generate Langflow API key failed",
attempt=attempt,
max_attempts=max_attempts,
error=str(e),
)
if attempt < max_attempts:
time.sleep(delay_seconds)
else:
@ -195,7 +210,9 @@ class AppClients:
logger.warning("Failed to initialize Langflow client", error=str(e))
self.langflow_client = None
if self.langflow_client is None:
logger.warning("No Langflow client initialized yet, will attempt later on first use")
logger.warning(
"No Langflow client initialized yet, will attempt later on first use"
)
# Initialize patched OpenAI client
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
@ -218,7 +235,9 @@ class AppClients:
)
logger.info("Langflow client initialized on-demand")
except Exception as e:
logger.error("Failed to initialize Langflow client on-demand", error=str(e))
logger.error(
"Failed to initialize Langflow client on-demand", error=str(e)
)
self.langflow_client = None
return self.langflow_client

View file

@ -321,13 +321,18 @@ class ConnectionManager:
if connection_config.config.get(
"webhook_channel_id"
) or connection_config.config.get("subscription_id"):
logger.info("Webhook subscription already exists", connection_id=connection_id)
logger.info(
"Webhook subscription already exists", connection_id=connection_id
)
return
# Check if webhook URL is configured
webhook_url = connection_config.config.get("webhook_url")
if not webhook_url:
logger.info("No webhook URL configured, skipping subscription setup", connection_id=connection_id)
logger.info(
"No webhook URL configured, skipping subscription setup",
connection_id=connection_id,
)
return
try:
@ -345,10 +350,18 @@ class ConnectionManager:
# Save updated connection config
await self.save_connections()
logger.info("Successfully set up webhook subscription", connection_id=connection_id, subscription_id=subscription_id)
logger.info(
"Successfully set up webhook subscription",
connection_id=connection_id,
subscription_id=subscription_id,
)
except Exception as e:
logger.error("Failed to setup webhook subscription", connection_id=connection_id, error=str(e))
logger.error(
"Failed to setup webhook subscription",
connection_id=connection_id,
error=str(e),
)
# Don't fail the entire connection setup if webhook fails
async def _setup_webhook_for_new_connection(
@ -356,12 +369,18 @@ class ConnectionManager:
):
"""Setup webhook subscription for a newly authenticated connection"""
try:
logger.info("Setting up subscription for newly authenticated connection", connection_id=connection_id)
logger.info(
"Setting up subscription for newly authenticated connection",
connection_id=connection_id,
)
# Create and authenticate connector
connector = self._create_connector(connection_config)
if not await connector.authenticate():
logger.error("Failed to authenticate connector for webhook setup", connection_id=connection_id)
logger.error(
"Failed to authenticate connector for webhook setup",
connection_id=connection_id,
)
return
# Setup subscription
@ -376,8 +395,16 @@ class ConnectionManager:
# Save updated connection config
await self.save_connections()
logger.info("Successfully set up webhook subscription", connection_id=connection_id, subscription_id=subscription_id)
logger.info(
"Successfully set up webhook subscription",
connection_id=connection_id,
subscription_id=subscription_id,
)
except Exception as e:
logger.error("Failed to setup webhook subscription for new connection", connection_id=connection_id, error=str(e))
logger.error(
"Failed to setup webhook subscription for new connection",
connection_id=connection_id,
error=str(e),
)
# Don't fail the connection setup if webhook fails

View file

@ -4,6 +4,11 @@ from typing import Dict, Any, List, Optional
from .base import BaseConnector, ConnectorDocument
from utils.logging_config import get_logger
logger = get_logger(__name__)
from .google_drive import GoogleDriveConnector
from .sharepoint import SharePointConnector
from .onedrive import OneDriveConnector
from .connection_manager import ConnectionManager
logger = get_logger(__name__)
@ -62,7 +67,7 @@ class ConnectorService:
doc_service = DocumentService(session_manager=self.session_manager)
print(f"[DEBUG] Processing connector document with ID: {document.id}")
logger.debug("Processing connector document", document_id=document.id)
# Process using the existing pipeline but with connector document metadata
result = await doc_service.process_file_common(
@ -77,7 +82,7 @@ class ConnectorService:
connector_type=connector_type,
)
print(f"[DEBUG] Document processing result: {result}")
logger.debug("Document processing result", result=result)
# If successfully indexed or already exists, update the indexed documents with connector metadata
if result["status"] in ["indexed", "unchanged"]:
@ -104,7 +109,7 @@ class ConnectorService:
jwt_token: str = None,
):
"""Update indexed chunks with connector-specific metadata"""
print(f"[DEBUG] Looking for chunks with document_id: {document.id}")
logger.debug("Looking for chunks", document_id=document.id)
# Find all chunks for this document
query = {"query": {"term": {"document_id": document.id}}}
@ -117,26 +122,34 @@ class ConnectorService:
try:
response = await opensearch_client.search(index=self.index_name, body=query)
except Exception as e:
print(
f"[ERROR] OpenSearch search failed for connector metadata update: {e}"
logger.error(
"OpenSearch search failed for connector metadata update",
error=str(e),
query=query,
)
print(f"[ERROR] Search query: {query}")
raise
print(f"[DEBUG] Search query: {query}")
print(
f"[DEBUG] Found {len(response['hits']['hits'])} chunks matching document_id: {document.id}"
logger.debug(
"Search query executed",
query=query,
chunks_found=len(response["hits"]["hits"]),
document_id=document.id,
)
# Update each chunk with connector metadata
print(
f"[DEBUG] Updating {len(response['hits']['hits'])} chunks with connector_type: {connector_type}"
logger.debug(
"Updating chunks with connector_type",
chunk_count=len(response["hits"]["hits"]),
connector_type=connector_type,
)
for hit in response["hits"]["hits"]:
chunk_id = hit["_id"]
current_connector_type = hit["_source"].get("connector_type", "unknown")
print(
f"[DEBUG] Chunk {chunk_id}: current connector_type = {current_connector_type}, updating to {connector_type}"
logger.debug(
"Updating chunk connector metadata",
chunk_id=chunk_id,
current_connector_type=current_connector_type,
new_connector_type=connector_type,
)
update_body = {
@ -164,10 +177,14 @@ class ConnectorService:
await opensearch_client.update(
index=self.index_name, id=chunk_id, body=update_body
)
print(f"[DEBUG] Updated chunk {chunk_id} with connector metadata")
logger.debug("Updated chunk with connector metadata", chunk_id=chunk_id)
except Exception as e:
print(f"[ERROR] OpenSearch update failed for chunk {chunk_id}: {e}")
print(f"[ERROR] Update body: {update_body}")
logger.error(
"OpenSearch update failed for chunk",
chunk_id=chunk_id,
error=str(e),
update_body=update_body,
)
raise
def _get_file_extension(self, mimetype: str) -> str:
@ -226,11 +243,11 @@ class ConnectorService:
while True:
# List files from connector with limit
logger.info(
logger.debug(
"Calling list_files", page_size=page_size, page_token=page_token
)
file_list = await connector.list_files(page_token, max_files=page_size)
logger.info(
file_list = await connector.list_files(page_token, limit=page_size)
logger.debug(
"Got files from connector", file_count=len(file_list.get("files", []))
)
files = file_list["files"]

View file

@ -3,11 +3,13 @@ import sys
# Check for TUI flag FIRST, before any heavy imports
if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == "--tui":
from tui.main import run_tui
run_tui()
sys.exit(0)
# Configure structured logging early
from utils.logging_config import configure_from_env, get_logger
configure_from_env()
logger = get_logger(__name__)
@ -25,6 +27,8 @@ import torch
# Configuration and setup
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
from config.settings import is_no_auth_mode
from utils.gpu_detection import detect_gpu_devices
# Services
from services.document_service import DocumentService
@ -56,8 +60,11 @@ from api import (
# Set multiprocessing start method to 'spawn' for CUDA compatibility
multiprocessing.set_start_method("spawn", force=True)
logger.info("CUDA available", cuda_available=torch.cuda.is_available())
logger.info("CUDA version PyTorch was built with", cuda_version=torch.version.cuda)
logger.info(
"CUDA device information",
cuda_available=torch.cuda.is_available(),
cuda_version=torch.version.cuda,
)
async def wait_for_opensearch():
@ -71,7 +78,12 @@ async def wait_for_opensearch():
logger.info("OpenSearch is ready")
return
except Exception as e:
logger.warning("OpenSearch not ready yet", attempt=attempt + 1, max_retries=max_retries, error=str(e))
logger.warning(
"OpenSearch not ready yet",
attempt=attempt + 1,
max_retries=max_retries,
error=str(e),
)
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
else:
@ -93,7 +105,9 @@ async def configure_alerting_security():
# Use admin client (clients.opensearch uses admin credentials)
response = await clients.opensearch.cluster.put_settings(body=alerting_settings)
logger.info("Alerting security settings configured successfully", response=response)
logger.info(
"Alerting security settings configured successfully", response=response
)
except Exception as e:
logger.warning("Failed to configure alerting security settings", error=str(e))
# Don't fail startup if alerting config fails
@ -133,9 +147,14 @@ async def init_index():
await clients.opensearch.indices.create(
index=knowledge_filter_index_name, body=knowledge_filter_index_body
)
logger.info("Created knowledge filters index", index_name=knowledge_filter_index_name)
logger.info(
"Created knowledge filters index", index_name=knowledge_filter_index_name
)
else:
logger.info("Knowledge filters index already exists, skipping creation", index_name=knowledge_filter_index_name)
logger.info(
"Knowledge filters index already exists, skipping creation",
index_name=knowledge_filter_index_name,
)
# Configure alerting plugin security settings
await configure_alerting_security()
@ -190,9 +209,59 @@ async def init_index_when_ready():
logger.info("OpenSearch index initialization completed successfully")
except Exception as e:
logger.error("OpenSearch index initialization failed", error=str(e))
logger.warning("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready")
logger.warning(
"OIDC endpoints will still work, but document operations may fail until OpenSearch is ready"
)
async def ingest_default_documents_when_ready(services):
"""Scan the local documents folder and ingest files like a non-auth upload."""
try:
logger.info("Ingesting default documents when ready")
base_dir = os.path.abspath(os.path.join(os.getcwd(), "documents"))
if not os.path.isdir(base_dir):
logger.info("Default documents directory not found; skipping ingestion", base_dir=base_dir)
return
# Collect files recursively
file_paths = [
os.path.join(root, fn)
for root, _, files in os.walk(base_dir)
for fn in files
]
if not file_paths:
logger.info("No default documents found; nothing to ingest", base_dir=base_dir)
return
# Build a processor that DOES NOT set 'owner' on documents (owner_user_id=None)
from models.processors import DocumentFileProcessor
processor = DocumentFileProcessor(
services["document_service"],
owner_user_id=None,
jwt_token=None,
owner_name=None,
owner_email=None,
)
task_id = await services["task_service"].create_custom_task(
"anonymous", file_paths, processor
)
logger.info(
"Started default documents ingestion task",
task_id=task_id,
file_count=len(file_paths),
)
except Exception as e:
logger.error("Default documents ingestion failed", error=str(e))
async def startup_tasks(services):
"""Startup tasks"""
logger.info("Starting startup tasks")
await init_index()
await ingest_default_documents_when_ready(services)
async def initialize_services():
"""Initialize all services and their dependencies"""
# Generate JWT keys if they don't exist
@ -237,9 +306,14 @@ async def initialize_services():
try:
await connector_service.initialize()
loaded_count = len(connector_service.connection_manager.connections)
logger.info("Loaded persisted connector connections on startup", loaded_count=loaded_count)
logger.info(
"Loaded persisted connector connections on startup",
loaded_count=loaded_count,
)
except Exception as e:
logger.warning("Failed to load persisted connections on startup", error=str(e))
logger.warning(
"Failed to load persisted connections on startup", error=str(e)
)
else:
logger.info("[CONNECTORS] Skipping connection loading in no-auth mode")
@ -639,12 +713,15 @@ async def create_app():
app = Starlette(debug=True, routes=routes)
app.state.services = services # Store services for cleanup
app.state.background_tasks = set()
# Add startup event handler
@app.on_event("startup")
async def startup_event():
# Start index initialization in background to avoid blocking OIDC endpoints
asyncio.create_task(init_index_when_ready())
t1 = asyncio.create_task(startup_tasks(services))
app.state.background_tasks.add(t1)
t1.add_done_callback(app.state.background_tasks.discard)
# Add shutdown event handler
@app.on_event("shutdown")
@ -687,18 +764,30 @@ async def cleanup_subscriptions_proper(services):
for connection in active_connections:
try:
logger.info("Cancelling subscription for connection", connection_id=connection.connection_id)
logger.info(
"Cancelling subscription for connection",
connection_id=connection.connection_id,
)
connector = await connector_service.get_connector(
connection.connection_id
)
if connector:
subscription_id = connection.config.get("webhook_channel_id")
await connector.cleanup_subscription(subscription_id)
logger.info("Cancelled subscription", subscription_id=subscription_id)
logger.info(
"Cancelled subscription", subscription_id=subscription_id
)
except Exception as e:
logger.error("Failed to cancel subscription", connection_id=connection.connection_id, error=str(e))
logger.error(
"Failed to cancel subscription",
connection_id=connection.connection_id,
error=str(e),
)
logger.info("Finished cancelling subscriptions", subscription_count=len(active_connections))
logger.info(
"Finished cancelling subscriptions",
subscription_count=len(active_connections),
)
except Exception as e:
logger.error("Failed to cleanup subscriptions", error=str(e))

View file

@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
from .tasks import UploadTask, FileTask
from utils.logging_config import get_logger
logger = get_logger(__name__)
class TaskProcessor(ABC):
@ -211,7 +214,7 @@ class S3FileProcessor(TaskProcessor):
"connector_type": "s3", # S3 uploads
"indexed_time": datetime.datetime.now().isoformat(),
}
# Only set owner fields if owner_user_id is provided (for no-auth mode support)
if self.owner_user_id is not None:
chunk_doc["owner"] = self.owner_user_id
@ -225,10 +228,12 @@ class S3FileProcessor(TaskProcessor):
index=INDEX_NAME, id=chunk_id, body=chunk_doc
)
except Exception as e:
print(
f"[ERROR] OpenSearch indexing failed for S3 chunk {chunk_id}: {e}"
logger.error(
"OpenSearch indexing failed for S3 chunk",
chunk_id=chunk_id,
error=str(e),
chunk_doc=chunk_doc,
)
print(f"[ERROR] Chunk document: {chunk_doc}")
raise
result = {"status": "indexed", "id": slim_doc["id"]}

View file

@ -111,7 +111,10 @@ class ChatService:
# Pass the complete filter expression as a single header to Langflow (only if we have something to send)
if filter_expression:
logger.info("Sending OpenRAG query filter to Langflow", filter_expression=filter_expression)
logger.info(
"Sending OpenRAG query filter to Langflow",
filter_expression=filter_expression,
)
extra_headers["X-LANGFLOW-GLOBAL-VAR-OPENRAG-QUERY-FILTER"] = json.dumps(
filter_expression
)
@ -201,7 +204,11 @@ class ChatService:
return {"error": "User ID is required", "conversations": []}
conversations_dict = get_user_conversations(user_id)
logger.debug("Getting chat history for user", user_id=user_id, conversation_count=len(conversations_dict))
logger.debug(
"Getting chat history for user",
user_id=user_id,
conversation_count=len(conversations_dict),
)
# Convert conversations dict to list format with metadata
conversations = []

View file

@ -196,7 +196,11 @@ class DocumentService:
index=INDEX_NAME, id=chunk_id, body=chunk_doc
)
except Exception as e:
logger.error("OpenSearch indexing failed for chunk", chunk_id=chunk_id, error=str(e))
logger.error(
"OpenSearch indexing failed for chunk",
chunk_id=chunk_id,
error=str(e),
)
logger.error("Chunk document details", chunk_doc=chunk_doc)
raise
return {"status": "indexed", "id": file_hash}
@ -232,7 +236,9 @@ class DocumentService:
try:
exists = await opensearch_client.exists(index=INDEX_NAME, id=file_hash)
except Exception as e:
logger.error("OpenSearch exists check failed", file_hash=file_hash, error=str(e))
logger.error(
"OpenSearch exists check failed", file_hash=file_hash, error=str(e)
)
raise
if exists:
return {"status": "unchanged", "id": file_hash}
@ -372,7 +378,11 @@ class DocumentService:
index=INDEX_NAME, id=chunk_id, body=chunk_doc
)
except Exception as e:
logger.error("OpenSearch indexing failed for batch chunk", chunk_id=chunk_id, error=str(e))
logger.error(
"OpenSearch indexing failed for batch chunk",
chunk_id=chunk_id,
error=str(e),
)
logger.error("Chunk document details", chunk_doc=chunk_doc)
raise
@ -388,9 +398,13 @@ class DocumentService:
from concurrent.futures import BrokenExecutor
if isinstance(e, BrokenExecutor):
logger.error("Process pool broken while processing file", file_path=file_path)
logger.error(
"Process pool broken while processing file", file_path=file_path
)
logger.info("Worker process likely crashed")
logger.info("You should see detailed crash logs above from the worker process")
logger.info(
"You should see detailed crash logs above from the worker process"
)
# Mark pool as broken for potential recreation
self._process_pool_broken = True
@ -399,11 +413,15 @@ class DocumentService:
if self._recreate_process_pool():
logger.info("Process pool successfully recreated")
else:
logger.warning("Failed to recreate process pool - future operations may fail")
logger.warning(
"Failed to recreate process pool - future operations may fail"
)
file_task.error = f"Worker process crashed: {str(e)}"
else:
logger.error("Failed to process file", file_path=file_path, error=str(e))
logger.error(
"Failed to process file", file_path=file_path, error=str(e)
)
file_task.error = str(e)
logger.error("Full traceback available")

View file

@ -195,7 +195,9 @@ class MonitorService:
return monitors
except Exception as e:
logger.error("Error listing monitors for user", user_id=user_id, error=str(e))
logger.error(
"Error listing monitors for user", user_id=user_id, error=str(e)
)
return []
async def list_monitors_for_filter(
@ -236,7 +238,9 @@ class MonitorService:
return monitors
except Exception as e:
logger.error("Error listing monitors for filter", filter_id=filter_id, error=str(e))
logger.error(
"Error listing monitors for filter", filter_id=filter_id, error=str(e)
)
return []
async def _get_or_create_webhook_destination(

View file

@ -138,7 +138,11 @@ class SearchService:
search_body["min_score"] = score_threshold
# Authentication required - DLS will handle document filtering automatically
logger.debug("search_service authentication info", user_id=user_id, has_jwt_token=jwt_token is not None)
logger.debug(
"search_service authentication info",
user_id=user_id,
has_jwt_token=jwt_token is not None,
)
if not user_id:
logger.debug("search_service: user_id is None/empty, returning auth error")
return {"results": [], "error": "Authentication required"}
@ -151,7 +155,9 @@ class SearchService:
try:
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
except Exception as e:
logger.error("OpenSearch query failed", error=str(e), search_body=search_body)
logger.error(
"OpenSearch query failed", error=str(e), search_body=search_body
)
# Re-raise the exception so the API returns the error to frontend
raise

View file

@ -2,11 +2,14 @@ import asyncio
import uuid
import time
import random
from typing import Dict
from typing import Dict, Optional
from models.tasks import TaskStatus, UploadTask, FileTask
from session_manager import AnonymousUser
from src.utils.gpu_detection import get_worker_count
from utils.logging_config import get_logger
logger = get_logger(__name__)
class TaskService:
@ -104,7 +107,9 @@ class TaskService:
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
logger.error(
"Background upload processor failed", task_id=task_id, error=str(e)
)
import traceback
traceback.print_exc()
@ -136,7 +141,9 @@ class TaskService:
try:
await processor.process_item(upload_task, item, file_task)
except Exception as e:
print(f"[ERROR] Failed to process item {item}: {e}")
logger.error(
"Failed to process item", item=str(item), error=str(e)
)
import traceback
traceback.print_exc()
@ -157,13 +164,15 @@ class TaskService:
upload_task.updated_at = time.time()
except asyncio.CancelledError:
print(f"[INFO] Background processor for task {task_id} was cancelled")
logger.info("Background processor cancelled", task_id=task_id)
if user_id in self.task_store and task_id in self.task_store[user_id]:
# Task status and pending files already handled by cancel_task()
pass
raise # Re-raise to properly handle cancellation
except Exception as e:
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}")
logger.error(
"Background custom processor failed", task_id=task_id, error=str(e)
)
import traceback
traceback.print_exc()
@ -171,16 +180,29 @@ class TaskService:
self.task_store[user_id][task_id].status = TaskStatus.FAILED
self.task_store[user_id][task_id].updated_at = time.time()
def get_task_status(self, user_id: str, task_id: str) -> dict:
"""Get the status of a specific upload task"""
if (
not task_id
or user_id not in self.task_store
or task_id not in self.task_store[user_id]
):
def get_task_status(self, user_id: str, task_id: str) -> Optional[dict]:
"""Get the status of a specific upload task
Includes fallback to shared tasks stored under the "anonymous" user key
so default system tasks are visible to all users.
"""
if not task_id:
return None
upload_task = self.task_store[user_id][task_id]
# Prefer the caller's user_id; otherwise check shared/anonymous tasks
candidate_user_ids = [user_id, AnonymousUser().user_id]
upload_task = None
for candidate_user_id in candidate_user_ids:
if (
candidate_user_id in self.task_store
and task_id in self.task_store[candidate_user_id]
):
upload_task = self.task_store[candidate_user_id][task_id]
break
if upload_task is None:
return None
file_statuses = {}
for file_path, file_task in upload_task.file_tasks.items():
@ -206,14 +228,21 @@ class TaskService:
}
def get_all_tasks(self, user_id: str) -> list:
"""Get all tasks for a user"""
if user_id not in self.task_store:
return []
"""Get all tasks for a user
tasks = []
for task_id, upload_task in self.task_store[user_id].items():
tasks.append(
{
Returns the union of the user's own tasks and shared default tasks stored
under the "anonymous" user key. User-owned tasks take precedence
if a task_id overlaps.
"""
tasks_by_id = {}
def add_tasks_from_store(store_user_id):
if store_user_id not in self.task_store:
return
for task_id, upload_task in self.task_store[store_user_id].items():
if task_id in tasks_by_id:
continue
tasks_by_id[task_id] = {
"task_id": upload_task.task_id,
"status": upload_task.status.value,
"total_files": upload_task.total_files,
@ -223,18 +252,36 @@ class TaskService:
"created_at": upload_task.created_at,
"updated_at": upload_task.updated_at,
}
)
# Sort by creation time, most recent first
# First, add user-owned tasks; then shared anonymous;
add_tasks_from_store(user_id)
add_tasks_from_store(AnonymousUser().user_id)
tasks = list(tasks_by_id.values())
tasks.sort(key=lambda x: x["created_at"], reverse=True)
return tasks
def cancel_task(self, user_id: str, task_id: str) -> bool:
"""Cancel a task if it exists and is not already completed"""
if user_id not in self.task_store or task_id not in self.task_store[user_id]:
"""Cancel a task if it exists and is not already completed.
Supports cancellation of shared default tasks stored under the anonymous user.
"""
# Check candidate user IDs first, then anonymous to find which user ID the task is mapped to
candidate_user_ids = [user_id, AnonymousUser().user_id]
store_user_id = None
for candidate_user_id in candidate_user_ids:
if (
candidate_user_id in self.task_store
and task_id in self.task_store[candidate_user_id]
):
store_user_id = candidate_user_id
break
if store_user_id is None:
return False
upload_task = self.task_store[user_id][task_id]
upload_task = self.task_store[store_user_id][task_id]
# Can only cancel pending or running tasks
if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:

View file

@ -6,8 +6,13 @@ from typing import Dict, Optional, Any
from dataclasses import dataclass, asdict
from cryptography.hazmat.primitives import serialization
import os
from utils.logging_config import get_logger
logger = get_logger(__name__)
from utils.logging_config import get_logger
logger = get_logger(__name__)
@dataclass
class User:
"""User information from OAuth provider"""
@ -26,6 +31,19 @@ class User:
if self.last_login is None:
self.last_login = datetime.now()
class AnonymousUser(User):
"""Anonymous user"""
def __init__(self):
super().__init__(
user_id="anonymous",
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
)
class SessionManager:
"""Manages user sessions and JWT tokens"""
@ -80,13 +98,15 @@ class SessionManager:
if response.status_code == 200:
return response.json()
else:
print(
f"Failed to get user info: {response.status_code} {response.text}"
logger.error(
"Failed to get user info",
status_code=response.status_code,
response_text=response.text,
)
return None
except Exception as e:
print(f"Error getting user info: {e}")
logger.error("Error getting user info", error=str(e))
return None
async def create_user_session(
@ -173,19 +193,24 @@ class SessionManager:
"""Get or create OpenSearch client for user with their JWT"""
from config.settings import is_no_auth_mode
print(
f"[DEBUG] get_user_opensearch_client: user_id={user_id}, jwt_token={'None' if jwt_token is None else 'present'}, no_auth_mode={is_no_auth_mode()}"
logger.debug(
"get_user_opensearch_client",
user_id=user_id,
jwt_token_present=(jwt_token is not None),
no_auth_mode=is_no_auth_mode(),
)
# In no-auth mode, create anonymous JWT for OpenSearch DLS
if is_no_auth_mode() and jwt_token is None:
if jwt_token is None and (is_no_auth_mode() or user_id in (None, AnonymousUser().user_id)):
if not hasattr(self, "_anonymous_jwt"):
# Create anonymous JWT token for OpenSearch OIDC
print(f"[DEBUG] Creating anonymous JWT...")
logger.debug("Creating anonymous JWT")
self._anonymous_jwt = self._create_anonymous_jwt()
print(f"[DEBUG] Anonymous JWT created: {self._anonymous_jwt[:50]}...")
logger.debug(
"Anonymous JWT created", jwt_prefix=self._anonymous_jwt[:50]
)
jwt_token = self._anonymous_jwt
print(f"[DEBUG] Using anonymous JWT for OpenSearch")
logger.debug("Using anonymous JWT for OpenSearch")
# Check if we have a cached client for this user
if user_id not in self.user_opensearch_clients:
@ -199,14 +224,5 @@ class SessionManager:
def _create_anonymous_jwt(self) -> str:
"""Create JWT token for anonymous user in no-auth mode"""
anonymous_user = User(
user_id="anonymous",
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
anonymous_user = AnonymousUser()
return self.create_jwt_token(anonymous_user)

View file

@ -1 +1 @@
"""OpenRAG Terminal User Interface package."""
"""OpenRAG Terminal User Interface package."""

View file

@ -3,6 +3,9 @@
import sys
from pathlib import Path
from textual.app import App, ComposeResult
from utils.logging_config import get_logger
logger = get_logger(__name__)
from .screens.welcome import WelcomeScreen
from .screens.config import ConfigScreen
@ -17,10 +20,10 @@ from .widgets.diagnostics_notification import notify_with_diagnostics
class OpenRAGTUI(App):
"""OpenRAG Terminal User Interface application."""
TITLE = "OpenRAG TUI"
SUB_TITLE = "Container Management & Configuration"
CSS = """
Screen {
background: $background;
@ -172,13 +175,13 @@ class OpenRAGTUI(App):
padding: 1;
}
"""
def __init__(self):
super().__init__()
self.platform_detector = PlatformDetector()
self.container_manager = ContainerManager()
self.env_manager = EnvManager()
def on_mount(self) -> None:
"""Initialize the application."""
# Check for runtime availability and show appropriate screen
@ -187,31 +190,33 @@ class OpenRAGTUI(App):
self,
"No container runtime found. Please install Docker or Podman.",
severity="warning",
timeout=10
timeout=10,
)
# Load existing config if available
config_exists = self.env_manager.load_existing_env()
# Start with welcome screen
self.push_screen(WelcomeScreen())
async def action_quit(self) -> None:
"""Quit the application."""
self.exit()
def check_runtime_requirements(self) -> tuple[bool, str]:
"""Check if runtime requirements are met."""
if not self.container_manager.is_available():
return False, self.platform_detector.get_installation_instructions()
# Check Podman macOS memory if applicable
runtime_info = self.container_manager.get_runtime_info()
if runtime_info.runtime_type.value == "podman":
is_sufficient, _, message = self.platform_detector.check_podman_macos_memory()
is_sufficient, _, message = (
self.platform_detector.check_podman_macos_memory()
)
if not is_sufficient:
return False, f"Podman VM memory insufficient:\n{message}"
return True, "Runtime requirements satisfied"
@ -221,10 +226,10 @@ def run_tui():
app = OpenRAGTUI()
app.run()
except KeyboardInterrupt:
print("\nOpenRAG TUI interrupted by user")
logger.info("OpenRAG TUI interrupted by user")
sys.exit(0)
except Exception as e:
print(f"Error running OpenRAG TUI: {e}")
logger.error("Error running OpenRAG TUI", error=str(e))
sys.exit(1)

View file

@ -1 +1 @@
"""TUI managers package."""
"""TUI managers package."""

View file

@ -8,6 +8,9 @@ from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, AsyncIterator
from utils.logging_config import get_logger
logger = get_logger(__name__)
from ..utils.platform import PlatformDetector, RuntimeInfo, RuntimeType
from utils.gpu_detection import detect_gpu_devices
@ -15,6 +18,7 @@ from utils.gpu_detection import detect_gpu_devices
class ServiceStatus(Enum):
"""Container service status."""
UNKNOWN = "unknown"
RUNNING = "running"
STOPPED = "stopped"
@ -27,6 +31,7 @@ class ServiceStatus(Enum):
@dataclass
class ServiceInfo:
"""Container service information."""
name: str
status: ServiceStatus
health: Optional[str] = None
@ -34,7 +39,7 @@ class ServiceInfo:
image: Optional[str] = None
image_digest: Optional[str] = None
created: Optional[str] = None
def __post_init__(self):
if self.ports is None:
self.ports = []
@ -42,7 +47,7 @@ class ServiceInfo:
class ContainerManager:
"""Manages Docker/Podman container lifecycle for OpenRAG."""
def __init__(self, compose_file: Optional[Path] = None):
self.platform_detector = PlatformDetector()
self.runtime_info = self.platform_detector.detect_runtime()
@ -56,138 +61,142 @@ class ContainerManager:
self.use_cpu_compose = not has_gpu
except Exception:
self.use_cpu_compose = True
# Expected services based on compose files
self.expected_services = [
"openrag-backend",
"openrag-frontend",
"openrag-frontend",
"opensearch",
"dashboards",
"langflow"
"langflow",
]
# Map container names to service names
self.container_name_map = {
"openrag-backend": "openrag-backend",
"openrag-frontend": "openrag-frontend",
"os": "opensearch",
"os": "opensearch",
"osdash": "dashboards",
"langflow": "langflow"
"langflow": "langflow",
}
def is_available(self) -> bool:
"""Check if container runtime is available."""
return self.runtime_info.runtime_type != RuntimeType.NONE
def get_runtime_info(self) -> RuntimeInfo:
"""Get container runtime information."""
return self.runtime_info
def get_installation_help(self) -> str:
"""Get installation instructions if runtime is not available."""
return self.platform_detector.get_installation_instructions()
async def _run_compose_command(self, args: List[str], cpu_mode: Optional[bool] = None) -> tuple[bool, str, str]:
async def _run_compose_command(
self, args: List[str], cpu_mode: Optional[bool] = None
) -> tuple[bool, str, str]:
"""Run a compose command and return (success, stdout, stderr)."""
if not self.is_available():
return False, "", "No container runtime available"
if cpu_mode is None:
cpu_mode = self.use_cpu_compose
compose_file = self.cpu_compose_file if cpu_mode else self.compose_file
cmd = self.runtime_info.compose_command + ["-f", str(compose_file)] + args
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=Path.cwd()
cwd=Path.cwd(),
)
stdout, stderr = await process.communicate()
stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else ""
success = process.returncode == 0
return success, stdout_text, stderr_text
except Exception as e:
return False, "", f"Command execution failed: {e}"
async def _run_compose_command_streaming(self, args: List[str], cpu_mode: Optional[bool] = None) -> AsyncIterator[str]:
async def _run_compose_command_streaming(
self, args: List[str], cpu_mode: Optional[bool] = None
) -> AsyncIterator[str]:
"""Run a compose command and yield output lines in real-time."""
if not self.is_available():
yield "No container runtime available"
return
if cpu_mode is None:
cpu_mode = self.use_cpu_compose
compose_file = self.cpu_compose_file if cpu_mode else self.compose_file
cmd = self.runtime_info.compose_command + ["-f", str(compose_file)] + args
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT, # Combine stderr with stdout for unified output
cwd=Path.cwd()
cwd=Path.cwd(),
)
# Simple approach: read line by line and yield each one
while True:
line = await process.stdout.readline()
if not line:
break
line_text = line.decode().rstrip()
if line_text:
yield line_text
# Wait for process to complete
await process.wait()
except Exception as e:
yield f"Command execution failed: {e}"
async def _run_runtime_command(self, args: List[str]) -> tuple[bool, str, str]:
"""Run a runtime command (docker/podman) and return (success, stdout, stderr)."""
if not self.is_available():
return False, "", "No container runtime available"
cmd = self.runtime_info.runtime_command + args
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else ""
success = process.returncode == 0
return success, stdout_text, stderr_text
except Exception as e:
return False, "", f"Command execution failed: {e}"
def _process_service_json(self, service: Dict, services: Dict[str, ServiceInfo]) -> None:
def _process_service_json(
self, service: Dict, services: Dict[str, ServiceInfo]
) -> None:
"""Process a service JSON object and add it to the services dict."""
# Debug print to see the actual service data
print(f"DEBUG: Processing service data: {json.dumps(service, indent=2)}")
logger.debug("Processing service data", service_data=service)
container_name = service.get("Name", "")
# Map container name to service name
service_name = self.container_name_map.get(container_name)
if not service_name:
return
state = service.get("State", "").lower()
# Map compose states to our status enum
if "running" in state:
status = ServiceStatus.RUNNING
@ -197,17 +206,19 @@ class ContainerManager:
status = ServiceStatus.STARTING
else:
status = ServiceStatus.UNKNOWN
# Extract health - use Status if Health is empty
health = service.get("Health", "") or service.get("Status", "N/A")
# Extract ports
ports_str = service.get("Ports", "")
ports = [p.strip() for p in ports_str.split(",") if p.strip()] if ports_str else []
ports = (
[p.strip() for p in ports_str.split(",") if p.strip()] if ports_str else []
)
# Extract image
image = service.get("Image", "N/A")
services[service_name] = ServiceInfo(
name=service_name,
status=status,
@ -215,23 +226,25 @@ class ContainerManager:
ports=ports,
image=image,
)
async def get_service_status(self, force_refresh: bool = False) -> Dict[str, ServiceInfo]:
async def get_service_status(
self, force_refresh: bool = False
) -> Dict[str, ServiceInfo]:
"""Get current status of all services."""
current_time = time.time()
# Use cache if recent and not forcing refresh
if not force_refresh and current_time - self.last_status_update < 5:
return self.services_cache
services = {}
# Different approach for Podman vs Docker
if self.runtime_info.runtime_type == RuntimeType.PODMAN:
# For Podman, use direct podman ps command instead of compose
cmd = ["ps", "--all", "--format", "json"]
success, stdout, stderr = await self._run_runtime_command(cmd)
if success and stdout.strip():
try:
containers = json.loads(stdout.strip())
@ -240,12 +253,12 @@ class ContainerManager:
names = container.get("Names", [])
if not names:
continue
container_name = names[0]
service_name = self.container_name_map.get(container_name)
if not service_name:
continue
# Get container state
state = container.get("State", "").lower()
if "running" in state:
@ -256,7 +269,7 @@ class ContainerManager:
status = ServiceStatus.STARTING
else:
status = ServiceStatus.UNKNOWN
# Get other container info
image = container.get("Image", "N/A")
ports = []
@ -268,7 +281,7 @@ class ContainerManager:
container_port = port.get("container_port")
if host_port and container_port:
ports.append(f"{host_port}:{container_port}")
services[service_name] = ServiceInfo(
name=service_name,
status=status,
@ -280,55 +293,63 @@ class ContainerManager:
pass
else:
# For Docker, use compose ps command
success, stdout, stderr = await self._run_compose_command(["ps", "--format", "json"])
success, stdout, stderr = await self._run_compose_command(
["ps", "--format", "json"]
)
if success and stdout.strip():
try:
# Handle both single JSON object (Podman) and multiple JSON objects (Docker)
if stdout.strip().startswith('[') and stdout.strip().endswith(']'):
if stdout.strip().startswith("[") and stdout.strip().endswith("]"):
# JSON array format
service_list = json.loads(stdout.strip())
for service in service_list:
self._process_service_json(service, services)
else:
# Line-by-line JSON format
for line in stdout.strip().split('\n'):
if line.strip() and line.startswith('{'):
for line in stdout.strip().split("\n"):
if line.strip() and line.startswith("{"):
service = json.loads(line)
self._process_service_json(service, services)
except json.JSONDecodeError:
# Fallback to parsing text output
lines = stdout.strip().split('\n')
if len(lines) > 1: # Make sure we have at least a header and one line
lines = stdout.strip().split("\n")
if (
len(lines) > 1
): # Make sure we have at least a header and one line
for line in lines[1:]: # Skip header
if line.strip():
parts = line.split()
if len(parts) >= 3:
name = parts[0]
# Only include our expected services
if name not in self.expected_services:
continue
state = parts[2].lower()
if "up" in state:
status = ServiceStatus.RUNNING
elif "exit" in state:
status = ServiceStatus.STOPPED
else:
status = ServiceStatus.UNKNOWN
services[name] = ServiceInfo(name=name, status=status)
services[name] = ServiceInfo(
name=name, status=status
)
# Add expected services that weren't found
for expected in self.expected_services:
if expected not in services:
services[expected] = ServiceInfo(name=expected, status=ServiceStatus.MISSING)
services[expected] = ServiceInfo(
name=expected, status=ServiceStatus.MISSING
)
self.services_cache = services
self.last_status_update = current_time
return services
async def get_images_digests(self, images: List[str]) -> Dict[str, str]:
@ -337,9 +358,9 @@ class ContainerManager:
for image in images:
if not image or image in digests:
continue
success, stdout, _ = await self._run_runtime_command([
"image", "inspect", image, "--format", "{{.Id}}"
])
success, stdout, _ = await self._run_runtime_command(
["image", "inspect", image, "--format", "{{.Id}}"]
)
if success and stdout.strip():
digests[image] = stdout.strip().splitlines()[0]
return digests
@ -353,13 +374,15 @@ class ContainerManager:
continue
for line in compose.read_text().splitlines():
line = line.strip()
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
if line.startswith('image:'):
if line.startswith("image:"):
# image: repo/name:tag
val = line.split(':', 1)[1].strip()
val = line.split(":", 1)[1].strip()
# Remove quotes if present
if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")):
if (val.startswith('"') and val.endswith('"')) or (
val.startswith("'") and val.endswith("'")
):
val = val[1:-1]
images.add(val)
except Exception:
@ -374,53 +397,61 @@ class ContainerManager:
expected = self._parse_compose_images()
results: list[tuple[str, str]] = []
for image in expected:
digest = '-'
success, stdout, _ = await self._run_runtime_command([
'image', 'inspect', image, '--format', '{{.Id}}'
])
digest = "-"
success, stdout, _ = await self._run_runtime_command(
["image", "inspect", image, "--format", "{{.Id}}"]
)
if success and stdout.strip():
digest = stdout.strip().splitlines()[0]
results.append((image, digest))
results.sort(key=lambda x: x[0])
return results
async def start_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]:
async def start_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Start all services and yield progress updates."""
yield False, "Starting OpenRAG services..."
success, stdout, stderr = await self._run_compose_command(["up", "-d"], cpu_mode)
success, stdout, stderr = await self._run_compose_command(
["up", "-d"], cpu_mode
)
if success:
yield True, "Services started successfully"
else:
yield False, f"Failed to start services: {stderr}"
async def stop_services(self) -> AsyncIterator[tuple[bool, str]]:
"""Stop all services and yield progress updates."""
yield False, "Stopping OpenRAG services..."
success, stdout, stderr = await self._run_compose_command(["down"])
if success:
yield True, "Services stopped successfully"
else:
yield False, f"Failed to stop services: {stderr}"
async def restart_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]:
async def restart_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Restart all services and yield progress updates."""
yield False, "Restarting OpenRAG services..."
success, stdout, stderr = await self._run_compose_command(["restart"], cpu_mode)
if success:
yield True, "Services restarted successfully"
else:
yield False, f"Failed to restart services: {stderr}"
async def upgrade_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]:
async def upgrade_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Upgrade services (pull latest images and restart) and yield progress updates."""
yield False, "Pulling latest images..."
# Pull latest images with streaming output
pull_success = True
async for line in self._run_compose_command_streaming(["pull"], cpu_mode):
@ -428,75 +459,89 @@ class ContainerManager:
# Check for error patterns in the output
if "error" in line.lower() or "failed" in line.lower():
pull_success = False
if not pull_success:
yield False, "Failed to pull some images, but continuing with restart..."
yield False, "Images updated, restarting services..."
# Restart with new images using streaming output
restart_success = True
async for line in self._run_compose_command_streaming(["up", "-d", "--force-recreate"], cpu_mode):
async for line in self._run_compose_command_streaming(
["up", "-d", "--force-recreate"], cpu_mode
):
yield False, line
# Check for error patterns in the output
if "error" in line.lower() or "failed" in line.lower():
restart_success = False
if restart_success:
yield True, "Services upgraded and restarted successfully"
else:
yield False, "Some errors occurred during service restart"
async def reset_services(self) -> AsyncIterator[tuple[bool, str]]:
"""Reset all services (stop, remove containers/volumes, clear data) and yield progress updates."""
yield False, "Stopping all services..."
# Stop and remove everything
success, stdout, stderr = await self._run_compose_command([
"down",
"--volumes",
"--remove-orphans",
"--rmi", "local"
])
success, stdout, stderr = await self._run_compose_command(
["down", "--volumes", "--remove-orphans", "--rmi", "local"]
)
if not success:
yield False, f"Failed to stop services: {stderr}"
return
yield False, "Cleaning up container data..."
# Additional cleanup - remove any remaining containers/volumes
# This is more thorough than just compose down
await self._run_runtime_command(["system", "prune", "-f"])
yield True, "System reset completed - all containers, volumes, and local images removed"
async def get_service_logs(self, service_name: str, lines: int = 100) -> tuple[bool, str]:
yield (
True,
"System reset completed - all containers, volumes, and local images removed",
)
async def get_service_logs(
self, service_name: str, lines: int = 100
) -> tuple[bool, str]:
"""Get logs for a specific service."""
success, stdout, stderr = await self._run_compose_command(["logs", "--tail", str(lines), service_name])
success, stdout, stderr = await self._run_compose_command(
["logs", "--tail", str(lines), service_name]
)
if success:
return True, stdout
else:
return False, f"Failed to get logs: {stderr}"
async def follow_service_logs(self, service_name: str) -> AsyncIterator[str]:
"""Follow logs for a specific service."""
if not self.is_available():
yield "No container runtime available"
return
compose_file = self.cpu_compose_file if self.use_cpu_compose else self.compose_file
cmd = self.runtime_info.compose_command + ["-f", str(compose_file), "logs", "-f", service_name]
compose_file = (
self.cpu_compose_file if self.use_cpu_compose else self.compose_file
)
cmd = self.runtime_info.compose_command + [
"-f",
str(compose_file),
"logs",
"-f",
service_name,
]
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
cwd=Path.cwd()
cwd=Path.cwd(),
)
if process.stdout:
while True:
line = await process.stdout.readline()
@ -506,20 +551,22 @@ class ContainerManager:
break
else:
yield "Error: Unable to read process output"
except Exception as e:
yield f"Error following logs: {e}"
async def get_system_stats(self) -> Dict[str, Dict[str, str]]:
"""Get system resource usage statistics."""
stats = {}
# Get container stats
success, stdout, stderr = await self._run_runtime_command(["stats", "--no-stream", "--format", "json"])
success, stdout, stderr = await self._run_runtime_command(
["stats", "--no-stream", "--format", "json"]
)
if success and stdout.strip():
try:
for line in stdout.strip().split('\n'):
for line in stdout.strip().split("\n"):
if line.strip():
data = json.loads(line)
name = data.get("Name", data.get("Container", ""))
@ -533,14 +580,14 @@ class ContainerManager:
}
except json.JSONDecodeError:
pass
return stats
async def debug_podman_services(self) -> str:
"""Run a direct Podman command to check services status for debugging."""
if self.runtime_info.runtime_type != RuntimeType.PODMAN:
return "Not using Podman"
# Try direct podman command
cmd = ["podman", "ps", "--all", "--format", "json"]
try:
@ -548,18 +595,18 @@ class ContainerManager:
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=Path.cwd()
cwd=Path.cwd(),
)
stdout, stderr = await process.communicate()
stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else ""
result = f"Command: {' '.join(cmd)}\n"
result += f"Return code: {process.returncode}\n"
result += f"Stdout: {stdout_text}\n"
result += f"Stderr: {stderr_text}\n"
# Try to parse the output
if stdout_text.strip():
try:
@ -571,16 +618,18 @@ class ContainerManager:
result += f" - {name}: {state}\n"
except json.JSONDecodeError as e:
result += f"\nFailed to parse JSON: {e}\n"
return result
except Exception as e:
return f"Error executing command: {e}"
def check_podman_macos_memory(self) -> tuple[bool, str]:
"""Check if Podman VM has sufficient memory on macOS."""
if self.runtime_info.runtime_type != RuntimeType.PODMAN:
return True, "Not using Podman"
is_sufficient, memory_mb, message = self.platform_detector.check_podman_macos_memory()
is_sufficient, memory_mb, message = (
self.platform_detector.check_podman_macos_memory()
)
return is_sufficient, message

View file

@ -7,6 +7,9 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, List
from dataclasses import dataclass, field
from utils.logging_config import get_logger
logger = get_logger(__name__)
from ..utils.validation import (
validate_openai_api_key,
@ -14,13 +17,14 @@ from ..utils.validation import (
validate_non_empty,
validate_url,
validate_documents_paths,
sanitize_env_value
sanitize_env_value,
)
@dataclass
class EnvConfig:
"""Environment configuration data."""
# Core settings
openai_api_key: str = ""
opensearch_password: str = ""
@ -28,155 +32,186 @@ class EnvConfig:
langflow_superuser: str = "admin"
langflow_superuser_password: str = ""
flow_id: str = "1098eea1-6649-4e1d-aed1-b77249fb8dd0"
# OAuth settings
google_oauth_client_id: str = ""
google_oauth_client_secret: str = ""
microsoft_graph_oauth_client_id: str = ""
microsoft_graph_oauth_client_secret: str = ""
# Optional settings
webhook_base_url: str = ""
aws_access_key_id: str = ""
aws_secret_access_key: str = ""
langflow_public_url: str = ""
# Langflow auth settings
langflow_auto_login: str = "False"
langflow_new_user_is_active: str = "False"
langflow_enable_superuser_cli: str = "False"
# Document paths (comma-separated)
openrag_documents_paths: str = "./documents"
# Validation errors
validation_errors: Dict[str, str] = field(default_factory=dict)
class EnvManager:
"""Manages environment configuration for OpenRAG."""
def __init__(self, env_file: Optional[Path] = None):
self.env_file = env_file or Path(".env")
self.config = EnvConfig()
def generate_secure_password(self) -> str:
"""Generate a secure password for OpenSearch."""
# Generate a 16-character password with letters, digits, and symbols
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
return ''.join(secrets.choice(alphabet) for _ in range(16))
return "".join(secrets.choice(alphabet) for _ in range(16))
def generate_langflow_secret_key(self) -> str:
"""Generate a secure secret key for Langflow."""
return secrets.token_urlsafe(32)
def load_existing_env(self) -> bool:
"""Load existing .env file if it exists."""
if not self.env_file.exists():
return False
try:
with open(self.env_file, 'r') as f:
with open(self.env_file, "r") as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
if '=' in line:
key, value = line.split('=', 1)
if "=" in line:
key, value = line.split("=", 1)
key = key.strip()
value = sanitize_env_value(value)
# Map env vars to config attributes
attr_map = {
'OPENAI_API_KEY': 'openai_api_key',
'OPENSEARCH_PASSWORD': 'opensearch_password',
'LANGFLOW_SECRET_KEY': 'langflow_secret_key',
'LANGFLOW_SUPERUSER': 'langflow_superuser',
'LANGFLOW_SUPERUSER_PASSWORD': 'langflow_superuser_password',
'FLOW_ID': 'flow_id',
'GOOGLE_OAUTH_CLIENT_ID': 'google_oauth_client_id',
'GOOGLE_OAUTH_CLIENT_SECRET': 'google_oauth_client_secret',
'MICROSOFT_GRAPH_OAUTH_CLIENT_ID': 'microsoft_graph_oauth_client_id',
'MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET': 'microsoft_graph_oauth_client_secret',
'WEBHOOK_BASE_URL': 'webhook_base_url',
'AWS_ACCESS_KEY_ID': 'aws_access_key_id',
'AWS_SECRET_ACCESS_KEY': 'aws_secret_access_key',
'LANGFLOW_PUBLIC_URL': 'langflow_public_url',
'OPENRAG_DOCUMENTS_PATHS': 'openrag_documents_paths',
'LANGFLOW_AUTO_LOGIN': 'langflow_auto_login',
'LANGFLOW_NEW_USER_IS_ACTIVE': 'langflow_new_user_is_active',
'LANGFLOW_ENABLE_SUPERUSER_CLI': 'langflow_enable_superuser_cli',
"OPENAI_API_KEY": "openai_api_key",
"OPENSEARCH_PASSWORD": "opensearch_password",
"LANGFLOW_SECRET_KEY": "langflow_secret_key",
"LANGFLOW_SUPERUSER": "langflow_superuser",
"LANGFLOW_SUPERUSER_PASSWORD": "langflow_superuser_password",
"FLOW_ID": "flow_id",
"GOOGLE_OAUTH_CLIENT_ID": "google_oauth_client_id",
"GOOGLE_OAUTH_CLIENT_SECRET": "google_oauth_client_secret",
"MICROSOFT_GRAPH_OAUTH_CLIENT_ID": "microsoft_graph_oauth_client_id",
"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET": "microsoft_graph_oauth_client_secret",
"WEBHOOK_BASE_URL": "webhook_base_url",
"AWS_ACCESS_KEY_ID": "aws_access_key_id",
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
"LANGFLOW_PUBLIC_URL": "langflow_public_url",
"OPENRAG_DOCUMENTS_PATHS": "openrag_documents_paths",
"LANGFLOW_AUTO_LOGIN": "langflow_auto_login",
"LANGFLOW_NEW_USER_IS_ACTIVE": "langflow_new_user_is_active",
"LANGFLOW_ENABLE_SUPERUSER_CLI": "langflow_enable_superuser_cli",
}
if key in attr_map:
setattr(self.config, attr_map[key], value)
return True
except Exception as e:
print(f"Error loading .env file: {e}")
logger.error("Error loading .env file", error=str(e))
return False
def setup_secure_defaults(self) -> None:
"""Set up secure default values for passwords and keys."""
if not self.config.opensearch_password:
self.config.opensearch_password = self.generate_secure_password()
if not self.config.langflow_secret_key:
self.config.langflow_secret_key = self.generate_langflow_secret_key()
if not self.config.langflow_superuser_password:
self.config.langflow_superuser_password = self.generate_secure_password()
def validate_config(self, mode: str = "full") -> bool:
"""
Validate the current configuration.
Args:
mode: "no_auth" for minimal validation, "full" for complete validation
"""
self.config.validation_errors.clear()
# Always validate OpenAI API key
if not validate_openai_api_key(self.config.openai_api_key):
self.config.validation_errors['openai_api_key'] = "Invalid OpenAI API key format (should start with sk-)"
self.config.validation_errors["openai_api_key"] = (
"Invalid OpenAI API key format (should start with sk-)"
)
# Validate documents paths only if provided (optional)
if self.config.openrag_documents_paths:
is_valid, error_msg, _ = validate_documents_paths(self.config.openrag_documents_paths)
is_valid, error_msg, _ = validate_documents_paths(
self.config.openrag_documents_paths
)
if not is_valid:
self.config.validation_errors['openrag_documents_paths'] = error_msg
self.config.validation_errors["openrag_documents_paths"] = error_msg
# Validate required fields
if not validate_non_empty(self.config.opensearch_password):
self.config.validation_errors['opensearch_password'] = "OpenSearch password is required"
self.config.validation_errors["opensearch_password"] = (
"OpenSearch password is required"
)
# Langflow secret key is auto-generated; no user input required
if not validate_non_empty(self.config.langflow_superuser_password):
self.config.validation_errors['langflow_superuser_password'] = "Langflow superuser password is required"
self.config.validation_errors["langflow_superuser_password"] = (
"Langflow superuser password is required"
)
if mode == "full":
# Validate OAuth settings if provided
if self.config.google_oauth_client_id and not validate_google_oauth_client_id(self.config.google_oauth_client_id):
self.config.validation_errors['google_oauth_client_id'] = "Invalid Google OAuth client ID format"
if self.config.google_oauth_client_id and not validate_non_empty(self.config.google_oauth_client_secret):
self.config.validation_errors['google_oauth_client_secret'] = "Google OAuth client secret required when client ID is provided"
if self.config.microsoft_graph_oauth_client_id and not validate_non_empty(self.config.microsoft_graph_oauth_client_secret):
self.config.validation_errors['microsoft_graph_oauth_client_secret'] = "Microsoft Graph client secret required when client ID is provided"
if (
self.config.google_oauth_client_id
and not validate_google_oauth_client_id(
self.config.google_oauth_client_id
)
):
self.config.validation_errors["google_oauth_client_id"] = (
"Invalid Google OAuth client ID format"
)
if self.config.google_oauth_client_id and not validate_non_empty(
self.config.google_oauth_client_secret
):
self.config.validation_errors["google_oauth_client_secret"] = (
"Google OAuth client secret required when client ID is provided"
)
if self.config.microsoft_graph_oauth_client_id and not validate_non_empty(
self.config.microsoft_graph_oauth_client_secret
):
self.config.validation_errors["microsoft_graph_oauth_client_secret"] = (
"Microsoft Graph client secret required when client ID is provided"
)
# Validate optional URLs if provided
if self.config.webhook_base_url and not validate_url(self.config.webhook_base_url):
self.config.validation_errors['webhook_base_url'] = "Invalid webhook URL format"
if self.config.langflow_public_url and not validate_url(self.config.langflow_public_url):
self.config.validation_errors['langflow_public_url'] = "Invalid Langflow public URL format"
if self.config.webhook_base_url and not validate_url(
self.config.webhook_base_url
):
self.config.validation_errors["webhook_base_url"] = (
"Invalid webhook URL format"
)
if self.config.langflow_public_url and not validate_url(
self.config.langflow_public_url
):
self.config.validation_errors["langflow_public_url"] = (
"Invalid Langflow public URL format"
)
return len(self.config.validation_errors) == 0
def save_env_file(self) -> bool:
"""Save current configuration to .env file."""
try:
@ -184,45 +219,67 @@ class EnvManager:
self.setup_secure_defaults()
# Create timestamped backup if file exists
if self.env_file.exists():
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
backup_file = self.env_file.with_suffix(f'.env.backup.{timestamp}')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = self.env_file.with_suffix(f".env.backup.{timestamp}")
self.env_file.rename(backup_file)
with open(self.env_file, 'w') as f:
with open(self.env_file, "w") as f:
f.write("# OpenRAG Environment Configuration\n")
f.write("# Generated by OpenRAG TUI\n\n")
# Core settings
f.write("# Core settings\n")
f.write(f"LANGFLOW_SECRET_KEY={self.config.langflow_secret_key}\n")
f.write(f"LANGFLOW_SUPERUSER={self.config.langflow_superuser}\n")
f.write(f"LANGFLOW_SUPERUSER_PASSWORD={self.config.langflow_superuser_password}\n")
f.write(
f"LANGFLOW_SUPERUSER_PASSWORD={self.config.langflow_superuser_password}\n"
)
f.write(f"FLOW_ID={self.config.flow_id}\n")
f.write(f"OPENSEARCH_PASSWORD={self.config.opensearch_password}\n")
f.write(f"OPENAI_API_KEY={self.config.openai_api_key}\n")
f.write(f"OPENRAG_DOCUMENTS_PATHS={self.config.openrag_documents_paths}\n")
f.write(
f"OPENRAG_DOCUMENTS_PATHS={self.config.openrag_documents_paths}\n"
)
f.write("\n")
# Langflow auth settings
f.write("# Langflow auth settings\n")
f.write(f"LANGFLOW_AUTO_LOGIN={self.config.langflow_auto_login}\n")
f.write(f"LANGFLOW_NEW_USER_IS_ACTIVE={self.config.langflow_new_user_is_active}\n")
f.write(f"LANGFLOW_ENABLE_SUPERUSER_CLI={self.config.langflow_enable_superuser_cli}\n")
f.write(
f"LANGFLOW_NEW_USER_IS_ACTIVE={self.config.langflow_new_user_is_active}\n"
)
f.write(
f"LANGFLOW_ENABLE_SUPERUSER_CLI={self.config.langflow_enable_superuser_cli}\n"
)
f.write("\n")
# OAuth settings
if self.config.google_oauth_client_id or self.config.google_oauth_client_secret:
if (
self.config.google_oauth_client_id
or self.config.google_oauth_client_secret
):
f.write("# Google OAuth settings\n")
f.write(f"GOOGLE_OAUTH_CLIENT_ID={self.config.google_oauth_client_id}\n")
f.write(f"GOOGLE_OAUTH_CLIENT_SECRET={self.config.google_oauth_client_secret}\n")
f.write(
f"GOOGLE_OAUTH_CLIENT_ID={self.config.google_oauth_client_id}\n"
)
f.write(
f"GOOGLE_OAUTH_CLIENT_SECRET={self.config.google_oauth_client_secret}\n"
)
f.write("\n")
if self.config.microsoft_graph_oauth_client_id or self.config.microsoft_graph_oauth_client_secret:
if (
self.config.microsoft_graph_oauth_client_id
or self.config.microsoft_graph_oauth_client_secret
):
f.write("# Microsoft Graph OAuth settings\n")
f.write(f"MICROSOFT_GRAPH_OAUTH_CLIENT_ID={self.config.microsoft_graph_oauth_client_id}\n")
f.write(f"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET={self.config.microsoft_graph_oauth_client_secret}\n")
f.write(
f"MICROSOFT_GRAPH_OAUTH_CLIENT_ID={self.config.microsoft_graph_oauth_client_id}\n"
)
f.write(
f"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET={self.config.microsoft_graph_oauth_client_secret}\n"
)
f.write("\n")
# Optional settings
optional_vars = [
("WEBHOOK_BASE_URL", self.config.webhook_base_url),
@ -230,7 +287,7 @@ class EnvManager:
("AWS_SECRET_ACCESS_KEY", self.config.aws_secret_access_key),
("LANGFLOW_PUBLIC_URL", self.config.langflow_public_url),
]
optional_written = False
for var_name, var_value in optional_vars:
if var_value:
@ -238,52 +295,89 @@ class EnvManager:
f.write("# Optional settings\n")
optional_written = True
f.write(f"{var_name}={var_value}\n")
if optional_written:
f.write("\n")
return True
except Exception as e:
print(f"Error saving .env file: {e}")
logger.error("Error saving .env file", error=str(e))
return False
def get_no_auth_setup_fields(self) -> List[tuple[str, str, str, bool]]:
"""Get fields required for no-auth setup mode. Returns (field_name, display_name, placeholder, can_generate)."""
return [
("openai_api_key", "OpenAI API Key", "sk-...", False),
("opensearch_password", "OpenSearch Password", "Will be auto-generated if empty", True),
("langflow_superuser_password", "Langflow Superuser Password", "Will be auto-generated if empty", True),
("openrag_documents_paths", "Documents Paths", "./documents,/path/to/more/docs", False),
(
"opensearch_password",
"OpenSearch Password",
"Will be auto-generated if empty",
True,
),
(
"langflow_superuser_password",
"Langflow Superuser Password",
"Will be auto-generated if empty",
True,
),
(
"openrag_documents_paths",
"Documents Paths",
"./documents,/path/to/more/docs",
False,
),
]
def get_full_setup_fields(self) -> List[tuple[str, str, str, bool]]:
"""Get all fields for full setup mode."""
base_fields = self.get_no_auth_setup_fields()
oauth_fields = [
("google_oauth_client_id", "Google OAuth Client ID", "xxx.apps.googleusercontent.com", False),
(
"google_oauth_client_id",
"Google OAuth Client ID",
"xxx.apps.googleusercontent.com",
False,
),
("google_oauth_client_secret", "Google OAuth Client Secret", "", False),
("microsoft_graph_oauth_client_id", "Microsoft Graph Client ID", "", False),
("microsoft_graph_oauth_client_secret", "Microsoft Graph Client Secret", "", False),
(
"microsoft_graph_oauth_client_secret",
"Microsoft Graph Client Secret",
"",
False,
),
]
optional_fields = [
("webhook_base_url", "Webhook Base URL (optional)", "https://your-domain.com", False),
(
"webhook_base_url",
"Webhook Base URL (optional)",
"https://your-domain.com",
False,
),
("aws_access_key_id", "AWS Access Key ID (optional)", "", False),
("aws_secret_access_key", "AWS Secret Access Key (optional)", "", False),
("langflow_public_url", "Langflow Public URL (optional)", "http://localhost:7860", False),
(
"langflow_public_url",
"Langflow Public URL (optional)",
"http://localhost:7860",
False,
),
]
return base_fields + oauth_fields + optional_fields
def generate_compose_volume_mounts(self) -> List[str]:
"""Generate Docker Compose volume mount strings from documents paths."""
is_valid, _, validated_paths = validate_documents_paths(self.config.openrag_documents_paths)
is_valid, _, validated_paths = validate_documents_paths(
self.config.openrag_documents_paths
)
if not is_valid:
return ["./documents:/app/documents:Z"] # fallback
volume_mounts = []
for i, path in enumerate(validated_paths):
if i == 0:
@ -291,6 +385,6 @@ class EnvManager:
volume_mounts.append(f"{path}:/app/documents:Z")
else:
# Additional paths map to numbered directories
volume_mounts.append(f"{path}:/app/documents{i+1}:Z")
volume_mounts.append(f"{path}:/app/documents{i + 1}:Z")
return volume_mounts

View file

@ -1 +1 @@
"""TUI screens package."""
"""TUI screens package."""

View file

@ -3,7 +3,16 @@
from textual.app import ComposeResult
from textual.containers import Container, Vertical, Horizontal, ScrollableContainer
from textual.screen import Screen
from textual.widgets import Header, Footer, Static, Button, Input, Label, TabbedContent, TabPane
from textual.widgets import (
Header,
Footer,
Static,
Button,
Input,
Label,
TabbedContent,
TabPane,
)
from textual.validation import ValidationResult, Validator
from rich.text import Text
from pathlib import Path
@ -15,11 +24,11 @@ from pathlib import Path
class OpenAIKeyValidator(Validator):
"""Validator for OpenAI API keys."""
def validate(self, value: str) -> ValidationResult:
if not value:
return self.success()
if validate_openai_api_key(value):
return self.success()
else:
@ -28,12 +37,12 @@ class OpenAIKeyValidator(Validator):
class DocumentsPathValidator(Validator):
"""Validator for documents paths."""
def validate(self, value: str) -> ValidationResult:
# Optional: allow empty value
if not value:
return self.success()
is_valid, error_msg, _ = validate_documents_paths(value)
if is_valid:
return self.success()
@ -43,22 +52,22 @@ class DocumentsPathValidator(Validator):
class ConfigScreen(Screen):
"""Configuration screen for environment setup."""
BINDINGS = [
("escape", "back", "Back"),
("ctrl+s", "save", "Save"),
("ctrl+g", "generate", "Generate Passwords"),
]
def __init__(self, mode: str = "full"):
super().__init__()
self.mode = mode # "no_auth" or "full"
self.env_manager = EnvManager()
self.inputs = {}
# Load existing config if available
self.env_manager.load_existing_env()
def compose(self) -> ComposeResult:
"""Create the configuration screen layout."""
# Removed top header bar and header text
@ -70,33 +79,37 @@ class ConfigScreen(Screen):
Button("Generate Passwords", variant="default", id="generate-btn"),
Button("Save Configuration", variant="success", id="save-btn"),
Button("Back", variant="default", id="back-btn"),
classes="button-row"
classes="button-row",
)
yield Footer()
def _create_header_text(self) -> Text:
"""Create the configuration header text."""
header_text = Text()
if self.mode == "no_auth":
header_text.append("Quick Setup - No Authentication\n", style="bold green")
header_text.append("Configure OpenRAG for local document processing only.\n\n", style="dim")
header_text.append(
"Configure OpenRAG for local document processing only.\n\n", style="dim"
)
else:
header_text.append("Full Setup - OAuth Integration\n", style="bold cyan")
header_text.append("Configure OpenRAG with cloud service integrations.\n\n", style="dim")
header_text.append(
"Configure OpenRAG with cloud service integrations.\n\n", style="dim"
)
header_text.append("Required fields are marked with *\n", style="yellow")
header_text.append("Use Ctrl+G to generate admin passwords\n", style="dim")
return header_text
def _create_all_fields(self) -> ComposeResult:
"""Create all configuration fields in a single scrollable layout."""
# Admin Credentials Section
yield Static("Admin Credentials", classes="tab-header")
yield Static(" ")
# OpenSearch Admin Password
yield Label("OpenSearch Admin Password *")
current_value = getattr(self.env_manager.config, "opensearch_password", "")
@ -104,64 +117,73 @@ class ConfigScreen(Screen):
placeholder="Auto-generated secure password",
value=current_value,
password=True,
id="input-opensearch_password"
id="input-opensearch_password",
)
yield input_widget
self.inputs["opensearch_password"] = input_widget
yield Static(" ")
# Langflow Admin Username
yield Label("Langflow Admin Username *")
current_value = getattr(self.env_manager.config, "langflow_superuser", "")
input_widget = Input(
placeholder="admin",
value=current_value,
id="input-langflow_superuser"
placeholder="admin", value=current_value, id="input-langflow_superuser"
)
yield input_widget
self.inputs["langflow_superuser"] = input_widget
yield Static(" ")
# Langflow Admin Password
yield Label("Langflow Admin Password *")
current_value = getattr(self.env_manager.config, "langflow_superuser_password", "")
current_value = getattr(
self.env_manager.config, "langflow_superuser_password", ""
)
input_widget = Input(
placeholder="Auto-generated secure password",
value=current_value,
password=True,
id="input-langflow_superuser_password"
id="input-langflow_superuser_password",
)
yield input_widget
self.inputs["langflow_superuser_password"] = input_widget
yield Static(" ")
yield Static(" ")
# API Keys Section
yield Static("API Keys", classes="tab-header")
yield Static(" ")
# OpenAI API Key
yield Label("OpenAI API Key *")
# Where to create OpenAI keys (helper above the box)
yield Static(Text("Get a key: https://platform.openai.com/api-keys", style="dim"), classes="helper-text")
yield Static(
Text("Get a key: https://platform.openai.com/api-keys", style="dim"),
classes="helper-text",
)
current_value = getattr(self.env_manager.config, "openai_api_key", "")
input_widget = Input(
placeholder="sk-...",
value=current_value,
password=True,
validators=[OpenAIKeyValidator()],
id="input-openai_api_key"
id="input-openai_api_key",
)
yield input_widget
self.inputs["openai_api_key"] = input_widget
yield Static(" ")
# Add OAuth fields only in full mode
if self.mode == "full":
# Google OAuth Client ID
yield Label("Google OAuth Client ID")
# Where to create Google OAuth credentials (helper above the box)
yield Static(Text("Create credentials: https://console.cloud.google.com/apis/credentials", style="dim"), classes="helper-text")
yield Static(
Text(
"Create credentials: https://console.cloud.google.com/apis/credentials",
style="dim",
),
classes="helper-text",
)
# Callback URL guidance for Google OAuth
yield Static(
Text(
@ -169,37 +191,47 @@ class ConfigScreen(Screen):
" - Local: http://localhost:3000/auth/callback\n"
" - Prod: https://your-domain.com/auth/callback\n"
"If you use separate apps for login and connectors, add this URL to BOTH.",
style="dim"
style="dim",
),
classes="helper-text"
classes="helper-text",
)
current_value = getattr(
self.env_manager.config, "google_oauth_client_id", ""
)
current_value = getattr(self.env_manager.config, "google_oauth_client_id", "")
input_widget = Input(
placeholder="xxx.apps.googleusercontent.com",
value=current_value,
id="input-google_oauth_client_id"
id="input-google_oauth_client_id",
)
yield input_widget
self.inputs["google_oauth_client_id"] = input_widget
yield Static(" ")
# Google OAuth Client Secret
yield Label("Google OAuth Client Secret")
current_value = getattr(self.env_manager.config, "google_oauth_client_secret", "")
current_value = getattr(
self.env_manager.config, "google_oauth_client_secret", ""
)
input_widget = Input(
placeholder="",
value=current_value,
password=True,
id="input-google_oauth_client_secret"
id="input-google_oauth_client_secret",
)
yield input_widget
self.inputs["google_oauth_client_secret"] = input_widget
yield Static(" ")
# Microsoft Graph Client ID
yield Label("Microsoft Graph Client ID")
# Where to create Microsoft app registrations (helper above the box)
yield Static(Text("Create app: https://portal.azure.com/#view/Microsoft_AAD_RegisteredApps/ApplicationsListBlade", style="dim"), classes="helper-text")
yield Static(
Text(
"Create app: https://portal.azure.com/#view/Microsoft_AAD_RegisteredApps/ApplicationsListBlade",
style="dim",
),
classes="helper-text",
)
# Callback URL guidance for Microsoft OAuth
yield Static(
Text(
@ -207,66 +239,76 @@ class ConfigScreen(Screen):
" - Local: http://localhost:3000/auth/callback\n"
" - Prod: https://your-domain.com/auth/callback\n"
"If you use separate apps for login and connectors, add this URI to BOTH.",
style="dim"
style="dim",
),
classes="helper-text"
classes="helper-text",
)
current_value = getattr(
self.env_manager.config, "microsoft_graph_oauth_client_id", ""
)
current_value = getattr(self.env_manager.config, "microsoft_graph_oauth_client_id", "")
input_widget = Input(
placeholder="",
value=current_value,
id="input-microsoft_graph_oauth_client_id"
id="input-microsoft_graph_oauth_client_id",
)
yield input_widget
self.inputs["microsoft_graph_oauth_client_id"] = input_widget
yield Static(" ")
# Microsoft Graph Client Secret
yield Label("Microsoft Graph Client Secret")
current_value = getattr(self.env_manager.config, "microsoft_graph_oauth_client_secret", "")
current_value = getattr(
self.env_manager.config, "microsoft_graph_oauth_client_secret", ""
)
input_widget = Input(
placeholder="",
value=current_value,
password=True,
id="input-microsoft_graph_oauth_client_secret"
id="input-microsoft_graph_oauth_client_secret",
)
yield input_widget
self.inputs["microsoft_graph_oauth_client_secret"] = input_widget
yield Static(" ")
# AWS Access Key ID
yield Label("AWS Access Key ID")
# Where to create AWS keys (helper above the box)
yield Static(Text("Create keys: https://console.aws.amazon.com/iam/home#/security_credentials", style="dim"), classes="helper-text")
yield Static(
Text(
"Create keys: https://console.aws.amazon.com/iam/home#/security_credentials",
style="dim",
),
classes="helper-text",
)
current_value = getattr(self.env_manager.config, "aws_access_key_id", "")
input_widget = Input(
placeholder="",
value=current_value,
id="input-aws_access_key_id"
placeholder="", value=current_value, id="input-aws_access_key_id"
)
yield input_widget
self.inputs["aws_access_key_id"] = input_widget
yield Static(" ")
# AWS Secret Access Key
yield Label("AWS Secret Access Key")
current_value = getattr(self.env_manager.config, "aws_secret_access_key", "")
current_value = getattr(
self.env_manager.config, "aws_secret_access_key", ""
)
input_widget = Input(
placeholder="",
value=current_value,
password=True,
id="input-aws_secret_access_key"
id="input-aws_secret_access_key",
)
yield input_widget
self.inputs["aws_secret_access_key"] = input_widget
yield Static(" ")
yield Static(" ")
# Other Settings Section
yield Static("Others", classes="tab-header")
yield Static(" ")
# Documents Paths (optional) + picker action button on next line
yield Label("Documents Paths")
current_value = getattr(self.env_manager.config, "openrag_documents_paths", "")
@ -274,57 +316,63 @@ class ConfigScreen(Screen):
placeholder="./documents,/path/to/more/docs",
value=current_value,
validators=[DocumentsPathValidator()],
id="input-openrag_documents_paths"
id="input-openrag_documents_paths",
)
yield input_widget
# Actions row with pick button
yield Horizontal(Button("Pick…", id="pick-docs-btn"), id="docs-path-actions", classes="controls-row")
yield Horizontal(
Button("Pick…", id="pick-docs-btn"),
id="docs-path-actions",
classes="controls-row",
)
self.inputs["openrag_documents_paths"] = input_widget
yield Static(" ")
# Langflow Auth Settings
yield Static("Langflow Auth Settings", classes="tab-header")
yield Static(" ")
# Langflow Auto Login
yield Label("Langflow Auto Login")
current_value = getattr(self.env_manager.config, "langflow_auto_login", "False")
input_widget = Input(
placeholder="False",
value=current_value,
id="input-langflow_auto_login"
placeholder="False", value=current_value, id="input-langflow_auto_login"
)
yield input_widget
self.inputs["langflow_auto_login"] = input_widget
yield Static(" ")
# Langflow New User Is Active
yield Label("Langflow New User Is Active")
current_value = getattr(self.env_manager.config, "langflow_new_user_is_active", "False")
current_value = getattr(
self.env_manager.config, "langflow_new_user_is_active", "False"
)
input_widget = Input(
placeholder="False",
value=current_value,
id="input-langflow_new_user_is_active"
id="input-langflow_new_user_is_active",
)
yield input_widget
self.inputs["langflow_new_user_is_active"] = input_widget
yield Static(" ")
# Langflow Enable Superuser CLI
yield Label("Langflow Enable Superuser CLI")
current_value = getattr(self.env_manager.config, "langflow_enable_superuser_cli", "False")
current_value = getattr(
self.env_manager.config, "langflow_enable_superuser_cli", "False"
)
input_widget = Input(
placeholder="False",
value=current_value,
id="input-langflow_enable_superuser_cli"
id="input-langflow_enable_superuser_cli",
)
yield input_widget
self.inputs["langflow_enable_superuser_cli"] = input_widget
yield Static(" ")
yield Static(" ")
# Langflow Secret Key removed from UI; generated automatically on save
# Add optional fields only in full mode
if self.mode == "full":
# Webhook Base URL
@ -333,36 +381,43 @@ class ConfigScreen(Screen):
input_widget = Input(
placeholder="https://your-domain.com",
value=current_value,
id="input-webhook_base_url"
id="input-webhook_base_url",
)
yield input_widget
self.inputs["webhook_base_url"] = input_widget
yield Static(" ")
# Langflow Public URL
yield Label("Langflow Public URL")
current_value = getattr(self.env_manager.config, "langflow_public_url", "")
input_widget = Input(
placeholder="http://localhost:7860",
value=current_value,
id="input-langflow_public_url"
id="input-langflow_public_url",
)
yield input_widget
self.inputs["langflow_public_url"] = input_widget
yield Static(" ")
def _create_field(self, field_name: str, display_name: str, placeholder: str, can_generate: bool, required: bool = False) -> ComposeResult:
def _create_field(
self,
field_name: str,
display_name: str,
placeholder: str,
can_generate: bool,
required: bool = False,
) -> ComposeResult:
"""Create a single form field."""
# Create label
label_text = f"{display_name}"
if required:
label_text += " *"
yield Label(label_text)
# Get current value
current_value = getattr(self.env_manager.config, field_name, "")
# Create input with appropriate validator
if field_name == "openai_api_key":
input_widget = Input(
@ -370,35 +425,33 @@ class ConfigScreen(Screen):
value=current_value,
password=True,
validators=[OpenAIKeyValidator()],
id=f"input-{field_name}"
id=f"input-{field_name}",
)
elif field_name == "openrag_documents_paths":
input_widget = Input(
placeholder=placeholder,
value=current_value,
validators=[DocumentsPathValidator()],
id=f"input-{field_name}"
id=f"input-{field_name}",
)
elif "password" in field_name or "secret" in field_name:
input_widget = Input(
placeholder=placeholder,
value=current_value,
password=True,
id=f"input-{field_name}"
id=f"input-{field_name}",
)
else:
input_widget = Input(
placeholder=placeholder,
value=current_value,
id=f"input-{field_name}"
placeholder=placeholder, value=current_value, id=f"input-{field_name}"
)
yield input_widget
self.inputs[field_name] = input_widget
# Add spacing
yield Static(" ")
def on_mount(self) -> None:
"""Initialize the screen when mounted."""
# Focus the first input field
@ -409,7 +462,7 @@ class ConfigScreen(Screen):
inputs[0].focus()
except Exception:
pass
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses."""
if event.button.id == "generate-btn":
@ -420,43 +473,47 @@ class ConfigScreen(Screen):
self.action_back()
elif event.button.id == "pick-docs-btn":
self.action_pick_documents_path()
def action_generate(self) -> None:
"""Generate secure passwords for admin accounts."""
self.env_manager.setup_secure_defaults()
# Update input fields with generated values
for field_name, input_widget in self.inputs.items():
if field_name in ["opensearch_password", "langflow_superuser_password"]:
new_value = getattr(self.env_manager.config, field_name)
input_widget.value = new_value
self.notify("Generated secure passwords", severity="information")
def action_save(self) -> None:
"""Save the configuration."""
# Update config from input fields
for field_name, input_widget in self.inputs.items():
setattr(self.env_manager.config, field_name, input_widget.value)
# Validate the configuration
if not self.env_manager.validate_config(self.mode):
error_messages = []
for field, error in self.env_manager.config.validation_errors.items():
error_messages.append(f"{field}: {error}")
self.notify(f"Validation failed:\n" + "\n".join(error_messages[:3]), severity="error")
self.notify(
f"Validation failed:\n" + "\n".join(error_messages[:3]),
severity="error",
)
return
# Save to file
if self.env_manager.save_env_file():
self.notify("Configuration saved successfully!", severity="information")
# Switch to monitor screen
from .monitor import MonitorScreen
self.app.push_screen(MonitorScreen())
else:
self.notify("Failed to save configuration", severity="error")
def action_back(self) -> None:
"""Go back to welcome screen."""
self.app.pop_screen()
@ -465,6 +522,7 @@ class ConfigScreen(Screen):
"""Open textual-fspicker to select a path and append it to the input."""
try:
import importlib
fsp = importlib.import_module("textual_fspicker")
except Exception:
self.notify("textual-fspicker not available", severity="warning")
@ -479,9 +537,13 @@ class ConfigScreen(Screen):
start = Path(first).expanduser()
# Prefer SelectDirectory for directories; fallback to FileOpen
PickerClass = getattr(fsp, "SelectDirectory", None) or getattr(fsp, "FileOpen", None)
PickerClass = getattr(fsp, "SelectDirectory", None) or getattr(
fsp, "FileOpen", None
)
if PickerClass is None:
self.notify("No compatible picker found in textual-fspicker", severity="warning")
self.notify(
"No compatible picker found in textual-fspicker", severity="warning"
)
return
try:
picker = PickerClass(location=start)
@ -523,7 +585,7 @@ class ConfigScreen(Screen):
pass
except Exception:
pass
def on_input_changed(self, event: Input.Changed) -> None:
"""Handle input changes for real-time validation feedback."""
# This will trigger validation display in real-time

View file

@ -18,7 +18,7 @@ from ..managers.container_manager import ContainerManager
class DiagnosticsScreen(Screen):
"""Diagnostics screen for debugging OpenRAG."""
CSS = """
#diagnostics-log {
border: solid $accent;
@ -40,20 +40,20 @@ class DiagnosticsScreen(Screen):
text-align: center;
}
"""
BINDINGS = [
("escape", "back", "Back"),
("r", "refresh", "Refresh"),
("ctrl+c", "copy", "Copy to Clipboard"),
("ctrl+s", "save", "Save to File"),
]
def __init__(self):
super().__init__()
self.container_manager = ContainerManager()
self._logger = logging.getLogger("openrag.diagnostics")
self._status_timer = None
def compose(self) -> ComposeResult:
"""Create the diagnostics screen layout."""
yield Header()
@ -66,24 +66,24 @@ class DiagnosticsScreen(Screen):
yield Button("Copy to Clipboard", variant="default", id="copy-btn")
yield Button("Save to File", variant="default", id="save-btn")
yield Button("Back", variant="default", id="back-btn")
# Status indicator for copy/save operations
yield Static("", id="copy-status", classes="copy-indicator")
with ScrollableContainer(id="diagnostics-scroll"):
yield Log(id="diagnostics-log", highlight=True)
yield Footer()
def on_mount(self) -> None:
"""Initialize the screen."""
self.run_diagnostics()
# Focus the first button (refresh-btn)
try:
self.query_one("#refresh-btn").focus()
except Exception:
pass
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses."""
if event.button.id == "refresh-btn":
@ -98,25 +98,26 @@ class DiagnosticsScreen(Screen):
self.save_to_file()
elif event.button.id == "back-btn":
self.action_back()
def action_refresh(self) -> None:
"""Refresh diagnostics."""
self.run_diagnostics()
def action_copy(self) -> None:
"""Copy log content to clipboard (keyboard shortcut)."""
self.copy_to_clipboard()
def copy_to_clipboard(self) -> None:
"""Copy log content to clipboard."""
try:
log = self.query_one("#diagnostics-log", Log)
content = "\n".join(str(line) for line in log.lines)
status = self.query_one("#copy-status", Static)
# Try to use pyperclip if available
try:
import pyperclip
pyperclip.copy(content)
self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard")
@ -124,23 +125,19 @@ class DiagnosticsScreen(Screen):
return
except ImportError:
pass
# Fallback to platform-specific clipboard commands
import subprocess
import platform
system = platform.system()
if system == "Darwin": # macOS
process = subprocess.Popen(
["pbcopy"], stdin=subprocess.PIPE, text=True
)
process = subprocess.Popen(["pbcopy"], stdin=subprocess.PIPE, text=True)
process.communicate(input=content)
self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard")
elif system == "Windows":
process = subprocess.Popen(
["clip"], stdin=subprocess.PIPE, text=True
)
process = subprocess.Popen(["clip"], stdin=subprocess.PIPE, text=True)
process.communicate(input=content)
self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard")
@ -150,7 +147,7 @@ class DiagnosticsScreen(Screen):
process = subprocess.Popen(
["xclip", "-selection", "clipboard"],
stdin=subprocess.PIPE,
text=True
text=True,
)
process.communicate(input=content)
self.notify("Copied to clipboard", severity="information")
@ -160,65 +157,78 @@ class DiagnosticsScreen(Screen):
process = subprocess.Popen(
["xsel", "--clipboard", "--input"],
stdin=subprocess.PIPE,
text=True
text=True,
)
process.communicate(input=content)
self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard")
except FileNotFoundError:
self.notify("Clipboard utilities not found. Install xclip or xsel.", severity="error")
status.update("❌ Clipboard utilities not found. Install xclip or xsel.")
self.notify(
"Clipboard utilities not found. Install xclip or xsel.",
severity="error",
)
status.update(
"❌ Clipboard utilities not found. Install xclip or xsel."
)
else:
self.notify("Clipboard not supported on this platform", severity="error")
self.notify(
"Clipboard not supported on this platform", severity="error"
)
status.update("❌ Clipboard not supported on this platform")
self._hide_status_after_delay(status)
except Exception as e:
self.notify(f"Failed to copy to clipboard: {e}", severity="error")
status = self.query_one("#copy-status", Static)
status.update(f"❌ Failed to copy: {e}")
self._hide_status_after_delay(status)
def _hide_status_after_delay(self, status_widget: Static, delay: float = 3.0) -> None:
def _hide_status_after_delay(
self, status_widget: Static, delay: float = 3.0
) -> None:
"""Hide the status message after a delay."""
# Cancel any existing timer
if self._status_timer:
self._status_timer.cancel()
# Create and run the timer task
self._status_timer = asyncio.create_task(self._clear_status_after_delay(status_widget, delay))
async def _clear_status_after_delay(self, status_widget: Static, delay: float) -> None:
self._status_timer = asyncio.create_task(
self._clear_status_after_delay(status_widget, delay)
)
async def _clear_status_after_delay(
self, status_widget: Static, delay: float
) -> None:
"""Clear the status message after a delay."""
await asyncio.sleep(delay)
status_widget.update("")
def action_save(self) -> None:
"""Save log content to file (keyboard shortcut)."""
self.save_to_file()
def save_to_file(self) -> None:
"""Save log content to a file."""
try:
log = self.query_one("#diagnostics-log", Log)
content = "\n".join(str(line) for line in log.lines)
status = self.query_one("#copy-status", Static)
# Create logs directory if it doesn't exist
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
# Create a timestamped filename
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = logs_dir / f"openrag_diagnostics_{timestamp}.txt"
# Save to file
with open(filename, "w") as f:
f.write(content)
self.notify(f"Saved to {filename}", severity="information")
status.update(f"✓ Saved to {filename}")
# Log the save operation
self._logger.info(f"Diagnostics saved to {filename}")
self._hide_status_after_delay(status)
@ -226,55 +236,57 @@ class DiagnosticsScreen(Screen):
error_msg = f"Failed to save file: {e}"
self.notify(error_msg, severity="error")
self._logger.error(error_msg)
status = self.query_one("#copy-status", Static)
status.update(f"{error_msg}")
self._hide_status_after_delay(status)
def action_back(self) -> None:
"""Go back to previous screen."""
self.app.pop_screen()
def _get_system_info(self) -> Text:
"""Get system information text."""
info_text = Text()
runtime_info = self.container_manager.get_runtime_info()
info_text.append("Container Runtime Information\n", style="bold")
info_text.append("=" * 30 + "\n")
info_text.append(f"Type: {runtime_info.runtime_type.value}\n")
info_text.append(f"Compose Command: {' '.join(runtime_info.compose_command)}\n")
info_text.append(f"Runtime Command: {' '.join(runtime_info.runtime_command)}\n")
if runtime_info.version:
info_text.append(f"Version: {runtime_info.version}\n")
return info_text
def run_diagnostics(self) -> None:
"""Run all diagnostics."""
log = self.query_one("#diagnostics-log", Log)
log.clear()
# System information
system_info = self._get_system_info()
log.write(str(system_info))
log.write("")
# Run async diagnostics
asyncio.create_task(self._run_async_diagnostics())
async def _run_async_diagnostics(self) -> None:
"""Run asynchronous diagnostics."""
log = self.query_one("#diagnostics-log", Log)
# Check services
log.write("[bold green]Service Status[/bold green]")
services = await self.container_manager.get_service_status(force_refresh=True)
for name, info in services.items():
status_color = "green" if info.status == "running" else "red"
log.write(f"[bold]{name}[/bold]: [{status_color}]{info.status.value}[/{status_color}]")
log.write(
f"[bold]{name}[/bold]: [{status_color}]{info.status.value}[/{status_color}]"
)
if info.health:
log.write(f" Health: {info.health}")
if info.ports:
@ -282,40 +294,38 @@ class DiagnosticsScreen(Screen):
if info.image:
log.write(f" Image: {info.image}")
log.write("")
# Check for Podman-specific issues
if self.container_manager.runtime_info.runtime_type.name == "PODMAN":
await self.check_podman()
async def check_podman(self) -> None:
"""Run Podman-specific diagnostics."""
log = self.query_one("#diagnostics-log", Log)
log.write("[bold green]Podman Diagnostics[/bold green]")
# Check if using Podman
if self.container_manager.runtime_info.runtime_type.name != "PODMAN":
log.write("[yellow]Not using Podman[/yellow]")
return
# Check Podman version
cmd = ["podman", "--version"]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
log.write(f"Podman version: {stdout.decode().strip()}")
else:
log.write(f"[red]Failed to get Podman version: {stderr.decode().strip()}[/red]")
log.write(
f"[red]Failed to get Podman version: {stderr.decode().strip()}[/red]"
)
# Check Podman containers
cmd = ["podman", "ps", "--all"]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
@ -323,15 +333,17 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"):
log.write(f" {line}")
else:
log.write(f"[red]Failed to list Podman containers: {stderr.decode().strip()}[/red]")
log.write(
f"[red]Failed to list Podman containers: {stderr.decode().strip()}[/red]"
)
# Check Podman compose
cmd = ["podman", "compose", "ps"]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self.container_manager.compose_file.parent
cwd=self.container_manager.compose_file.parent,
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
@ -339,39 +351,39 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"):
log.write(f" {line}")
else:
log.write(f"[red]Failed to list Podman compose services: {stderr.decode().strip()}[/red]")
log.write(
f"[red]Failed to list Podman compose services: {stderr.decode().strip()}[/red]"
)
log.write("")
async def check_docker(self) -> None:
"""Run Docker-specific diagnostics."""
log = self.query_one("#diagnostics-log", Log)
log.write("[bold green]Docker Diagnostics[/bold green]")
# Check if using Docker
if "DOCKER" not in self.container_manager.runtime_info.runtime_type.name:
log.write("[yellow]Not using Docker[/yellow]")
return
# Check Docker version
cmd = ["docker", "--version"]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
log.write(f"Docker version: {stdout.decode().strip()}")
else:
log.write(f"[red]Failed to get Docker version: {stderr.decode().strip()}[/red]")
log.write(
f"[red]Failed to get Docker version: {stderr.decode().strip()}[/red]"
)
# Check Docker containers
cmd = ["docker", "ps", "--all"]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
@ -379,15 +391,17 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"):
log.write(f" {line}")
else:
log.write(f"[red]Failed to list Docker containers: {stderr.decode().strip()}[/red]")
log.write(
f"[red]Failed to list Docker containers: {stderr.decode().strip()}[/red]"
)
# Check Docker compose
cmd = ["docker", "compose", "ps"]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self.container_manager.compose_file.parent
cwd=self.container_manager.compose_file.parent,
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
@ -395,8 +409,11 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"):
log.write(f" {line}")
else:
log.write(f"[red]Failed to list Docker compose services: {stderr.decode().strip()}[/red]")
log.write(
f"[red]Failed to list Docker compose services: {stderr.decode().strip()}[/red]"
)
log.write("")
# Made with Bob

View file

@ -13,7 +13,7 @@ from ..managers.container_manager import ContainerManager
class LogsScreen(Screen):
"""Logs viewing and monitoring screen."""
BINDINGS = [
("escape", "back", "Back"),
("f", "follow", "Follow Logs"),
@ -27,44 +27,50 @@ class LogsScreen(Screen):
("ctrl+u", "scroll_page_up", "Page Up"),
("ctrl+f", "scroll_page_down", "Page Down"),
]
def __init__(self, initial_service: str = "openrag-backend"):
super().__init__()
self.container_manager = ContainerManager()
# Validate the initial service against available options
valid_services = ["openrag-backend", "openrag-frontend", "opensearch", "langflow", "dashboards"]
valid_services = [
"openrag-backend",
"openrag-frontend",
"opensearch",
"langflow",
"dashboards",
]
if initial_service not in valid_services:
initial_service = "openrag-backend" # fallback
self.current_service = initial_service
self.logs_area = None
self.following = False
self.follow_task = None
self.auto_scroll = True
def compose(self) -> ComposeResult:
"""Create the logs screen layout."""
yield Container(
Vertical(
Static(f"Service Logs: {self.current_service}", id="logs-title"),
self._create_logs_area(),
id="logs-content"
id="logs-content",
),
id="main-container"
id="main-container",
)
yield Footer()
def _create_logs_area(self) -> TextArea:
"""Create the logs text area."""
self.logs_area = TextArea(
text="Loading logs...",
read_only=True,
show_line_numbers=False,
id="logs-area"
id="logs-area",
)
return self.logs_area
async def on_mount(self) -> None:
"""Initialize the screen when mounted."""
# Set the correct service in the select widget after a brief delay
@ -72,34 +78,40 @@ class LogsScreen(Screen):
select = self.query_one("#service-select")
# Set a default first, then set the desired value
select.value = "openrag-backend"
if self.current_service in ["openrag-backend", "openrag-frontend", "opensearch", "langflow", "dashboards"]:
if self.current_service in [
"openrag-backend",
"openrag-frontend",
"opensearch",
"langflow",
"dashboards",
]:
select.value = self.current_service
except Exception as e:
# If setting the service fails, just use the default
pass
await self._load_logs()
# Focus the logs area since there are no buttons
try:
self.logs_area.focus()
except Exception:
pass
def on_unmount(self) -> None:
"""Clean up when unmounting."""
self._stop_following()
async def _load_logs(self, lines: int = 200) -> None:
"""Load recent logs for the current service."""
if not self.container_manager.is_available():
self.logs_area.text = "No container runtime available"
return
success, logs = await self.container_manager.get_service_logs(self.current_service, lines)
success, logs = await self.container_manager.get_service_logs(
self.current_service, lines
)
if success:
self.logs_area.text = logs
# Scroll to bottom if auto scroll is enabled
@ -107,67 +119,71 @@ class LogsScreen(Screen):
self.logs_area.scroll_end()
else:
self.logs_area.text = f"Failed to load logs: {logs}"
def _stop_following(self) -> None:
"""Stop following logs."""
self.following = False
if self.follow_task and not self.follow_task.is_finished:
self.follow_task.cancel()
# No button to update since we removed it
async def _follow_logs(self) -> None:
"""Follow logs in real-time."""
if not self.container_manager.is_available():
return
try:
async for log_line in self.container_manager.follow_service_logs(self.current_service):
async for log_line in self.container_manager.follow_service_logs(
self.current_service
):
if not self.following:
break
# Append new line to logs area
current_text = self.logs_area.text
new_text = current_text + "\n" + log_line
# Keep only last 1000 lines to prevent memory issues
lines = new_text.split('\n')
lines = new_text.split("\n")
if len(lines) > 1000:
lines = lines[-1000:]
new_text = '\n'.join(lines)
new_text = "\n".join(lines)
self.logs_area.text = new_text
# Scroll to bottom if auto scroll is enabled
if self.auto_scroll:
self.logs_area.scroll_end()
except asyncio.CancelledError:
pass
except Exception as e:
if self.following: # Only show error if we're still supposed to be following
if (
self.following
): # Only show error if we're still supposed to be following
self.notify(f"Error following logs: {e}", severity="error")
finally:
self.following = False
def action_refresh(self) -> None:
"""Refresh logs."""
self._stop_following()
self.run_worker(self._load_logs())
def action_follow(self) -> None:
"""Toggle log following."""
if self.following:
self._stop_following()
else:
self.following = True
# Start following
self.follow_task = self.run_worker(self._follow_logs(), exclusive=False)
def action_clear(self) -> None:
"""Clear the logs area."""
self.logs_area.text = ""
def action_toggle_auto_scroll(self) -> None:
"""Toggle auto scroll on/off."""
self.auto_scroll = not self.auto_scroll
@ -201,13 +217,13 @@ class LogsScreen(Screen):
def on_key(self, event) -> None:
"""Handle key presses that might be intercepted by TextArea."""
key = event.key
# Handle keys that TextArea might intercept
if key == "ctrl+u":
self.action_scroll_page_up()
event.prevent_default()
elif key == "ctrl+f":
self.action_scroll_page_down()
self.action_scroll_page_down()
event.prevent_default()
elif key.upper() == "G":
self.action_scroll_bottom()
@ -216,4 +232,4 @@ class LogsScreen(Screen):
def action_back(self) -> None:
"""Go back to previous screen."""
self._stop_following()
self.app.pop_screen()
self.app.pop_screen()

View file

@ -23,7 +23,7 @@ from ..widgets.diagnostics_notification import notify_with_diagnostics
class MonitorScreen(Screen):
"""Service monitoring and control screen."""
BINDINGS = [
("escape", "back", "Back"),
("r", "refresh", "Refresh"),
@ -35,7 +35,7 @@ class MonitorScreen(Screen):
("j", "cursor_down", "Move Down"),
("k", "cursor_up", "Move Up"),
]
def __init__(self):
super().__init__()
self.container_manager = ContainerManager()
@ -47,14 +47,14 @@ class MonitorScreen(Screen):
self._follow_task = None
self._follow_service = None
self._logs_buffer = []
def compose(self) -> ComposeResult:
"""Create the monitoring screen layout."""
# Just show the services content directly (no header, no tabs)
yield from self._create_services_tab()
yield Footer()
def _create_services_tab(self) -> ComposeResult:
"""Create the services monitoring tab."""
# Current mode indicator + toggle
@ -75,69 +75,73 @@ class MonitorScreen(Screen):
yield Horizontal(id="services-controls", classes="button-row")
# Create services table with image + digest info
self.services_table = DataTable(id="services-table")
self.services_table.add_columns("Service", "Status", "Health", "Ports", "Image", "Digest")
self.services_table.add_columns(
"Service", "Status", "Health", "Ports", "Image", "Digest"
)
yield self.services_table
def _get_runtime_status(self) -> Text:
"""Get container runtime status text."""
status_text = Text()
if not self.container_manager.is_available():
status_text.append("WARNING: No container runtime available\n", style="bold red")
status_text.append("Please install Docker or Podman to continue.\n", style="dim")
status_text.append(
"WARNING: No container runtime available\n", style="bold red"
)
status_text.append(
"Please install Docker or Podman to continue.\n", style="dim"
)
return status_text
runtime_info = self.container_manager.get_runtime_info()
if runtime_info.runtime_type == RuntimeType.DOCKER:
status_text.append("Docker Runtime\n", style="bold blue")
elif runtime_info.runtime_type == RuntimeType.PODMAN:
status_text.append("Podman Runtime\n", style="bold purple")
else:
status_text.append("Container Runtime\n", style="bold green")
if runtime_info.version:
status_text.append(f"Version: {runtime_info.version}\n", style="dim")
# Check Podman macOS memory if applicable
if runtime_info.runtime_type == RuntimeType.PODMAN:
is_sufficient, message = self.container_manager.check_podman_macos_memory()
if not is_sufficient:
status_text.append(f"WARNING: {message}\n", style="bold yellow")
return status_text
async def on_mount(self) -> None:
"""Initialize the screen when mounted."""
await self._refresh_services()
# Set up auto-refresh every 5 seconds
self.refresh_timer = self.set_interval(5.0, self._auto_refresh)
# Focus the services table
try:
self.services_table.focus()
except Exception:
pass
def on_unmount(self) -> None:
"""Clean up when unmounting."""
if self.refresh_timer:
self.refresh_timer.stop()
# Stop following logs if running
self._stop_follow()
async def on_screen_resume(self) -> None:
"""Called when the screen is resumed (e.g., after a modal is closed)."""
# Refresh services when returning from a modal
await self._refresh_services()
async def _refresh_services(self) -> None:
"""Refresh the services table."""
if not self.container_manager.is_available():
return
services = await self.container_manager.get_service_status(force_refresh=True)
# Collect images actually reported by running/stopped containers so names match runtime
images_set = set()
@ -147,7 +151,9 @@ class MonitorScreen(Screen):
images_set.add(img)
# Ensure compose-declared images are also shown (e.g., langflow when stopped)
try:
for img in self.container_manager._parse_compose_images(): # best-effort, no YAML dep
for img in (
self.container_manager._parse_compose_images()
): # best-effort, no YAML dep
if img:
images_set.add(img)
except Exception:
@ -155,23 +161,23 @@ class MonitorScreen(Screen):
images = list(images_set)
# Lookup digests/IDs for these image names
digest_map = await self.container_manager.get_images_digests(images)
# Clear existing rows
self.services_table.clear()
if self.images_table:
self.images_table.clear()
# Add service rows
for service_name, service_info in services.items():
status_style = self._get_status_style(service_info.status)
self.services_table.add_row(
service_info.name,
Text(service_info.status.value, style=status_style),
service_info.health or "N/A",
", ".join(service_info.ports) if service_info.ports else "N/A",
service_info.image or "N/A",
digest_map.get(service_info.image or "", "-")
digest_map.get(service_info.image or "", "-"),
)
# Populate images table (unique images as reported by runtime)
if self.images_table:
@ -181,7 +187,7 @@ class MonitorScreen(Screen):
self._update_controls(list(services.values()))
# Update mode indicator
self._update_mode_row()
def _get_status_style(self, status: ServiceStatus) -> str:
"""Get the Rich style for a service status."""
status_styles = {
@ -191,20 +197,20 @@ class MonitorScreen(Screen):
ServiceStatus.STOPPING: "bold yellow",
ServiceStatus.ERROR: "bold red",
ServiceStatus.MISSING: "dim",
ServiceStatus.UNKNOWN: "dim"
ServiceStatus.UNKNOWN: "dim",
}
return status_styles.get(status, "white")
async def _auto_refresh(self) -> None:
"""Auto-refresh services if not in operation."""
if not self.operation_in_progress:
await self._refresh_services()
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses."""
button_id = event.button.id or ""
button_label = event.button.label or ""
# Use button ID prefixes to determine action, ignoring any random suffix
if button_id.startswith("start-btn"):
self.run_worker(self._start_services())
@ -228,18 +234,18 @@ class MonitorScreen(Screen):
"logs-backend": "openrag-backend",
"logs-frontend": "openrag-frontend",
"logs-opensearch": "opensearch",
"logs-langflow": "langflow"
"logs-langflow": "langflow",
}
# Extract the base button ID (without any suffix)
button_base_id = button_id.split("-")[0] + "-" + button_id.split("-")[1]
service_name = service_mapping.get(button_base_id)
if service_name:
# Load recent logs then start following
self.run_worker(self._show_logs(service_name))
self._start_follow(service_name)
async def _start_services(self, cpu_mode: bool = False) -> None:
"""Start services with progress updates."""
self.operation_in_progress = True
@ -249,12 +255,12 @@ class MonitorScreen(Screen):
modal = CommandOutputModal(
"Starting Services",
command_generator,
on_complete=None # We'll refresh in on_screen_resume instead
on_complete=None, # We'll refresh in on_screen_resume instead
)
self.app.push_screen(modal)
finally:
self.operation_in_progress = False
async def _stop_services(self) -> None:
"""Stop services with progress updates."""
self.operation_in_progress = True
@ -264,12 +270,12 @@ class MonitorScreen(Screen):
modal = CommandOutputModal(
"Stopping Services",
command_generator,
on_complete=None # We'll refresh in on_screen_resume instead
on_complete=None, # We'll refresh in on_screen_resume instead
)
self.app.push_screen(modal)
finally:
self.operation_in_progress = False
async def _restart_services(self) -> None:
"""Restart services with progress updates."""
self.operation_in_progress = True
@ -279,12 +285,12 @@ class MonitorScreen(Screen):
modal = CommandOutputModal(
"Restarting Services",
command_generator,
on_complete=None # We'll refresh in on_screen_resume instead
on_complete=None, # We'll refresh in on_screen_resume instead
)
self.app.push_screen(modal)
finally:
self.operation_in_progress = False
async def _upgrade_services(self) -> None:
"""Upgrade services with progress updates."""
self.operation_in_progress = True
@ -294,12 +300,12 @@ class MonitorScreen(Screen):
modal = CommandOutputModal(
"Upgrading Services",
command_generator,
on_complete=None # We'll refresh in on_screen_resume instead
on_complete=None, # We'll refresh in on_screen_resume instead
)
self.app.push_screen(modal)
finally:
self.operation_in_progress = False
async def _reset_services(self) -> None:
"""Reset services with progress updates."""
self.operation_in_progress = True
@ -309,17 +315,17 @@ class MonitorScreen(Screen):
modal = CommandOutputModal(
"Resetting Services",
command_generator,
on_complete=None # We'll refresh in on_screen_resume instead
on_complete=None, # We'll refresh in on_screen_resume instead
)
self.app.push_screen(modal)
finally:
self.operation_in_progress = False
def _strip_ansi_codes(self, text: str) -> str:
"""Strip ANSI escape sequences from text."""
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
return ansi_escape.sub('', text)
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
return ansi_escape.sub("", text)
async def _show_logs(self, service_name: str) -> None:
"""Show logs for a service."""
success, logs = await self.container_manager.get_service_logs(service_name)
@ -346,7 +352,7 @@ class MonitorScreen(Screen):
notify_with_diagnostics(
self.app,
f"Failed to get logs for {service_name}: {logs}",
severity="error"
severity="error",
)
def _stop_follow(self) -> None:
@ -391,11 +397,9 @@ class MonitorScreen(Screen):
pass
except Exception as e:
notify_with_diagnostics(
self.app,
f"Error following logs: {e}",
severity="error"
self.app, f"Error following logs: {e}", severity="error"
)
def action_refresh(self) -> None:
"""Refresh services manually."""
self.run_worker(self._refresh_services())
@ -431,14 +435,15 @@ class MonitorScreen(Screen):
try:
current = getattr(self.container_manager, "use_cpu_compose", True)
self.container_manager.use_cpu_compose = not current
self.notify("Switched to GPU compose" if not current else "Switched to CPU compose", severity="information")
self.notify(
"Switched to GPU compose" if not current else "Switched to CPU compose",
severity="information",
)
self._update_mode_row()
self.action_refresh()
except Exception as e:
notify_with_diagnostics(
self.app,
f"Failed to toggle mode: {e}",
severity="error"
self.app, f"Failed to toggle mode: {e}", severity="error"
)
def _update_controls(self, services: list[ServiceInfo]) -> None:
@ -446,83 +451,93 @@ class MonitorScreen(Screen):
try:
# Get the controls container
controls = self.query_one("#services-controls", Horizontal)
# Check if any services are running
any_running = any(s.status == ServiceStatus.RUNNING for s in services)
# Clear existing buttons by removing all children
controls.remove_children()
# Use a single ID for each button type, but make them unique with a suffix
# This ensures we don't create duplicate IDs across refreshes
import random
suffix = f"-{random.randint(10000, 99999)}"
# Add appropriate buttons based on service state
if any_running:
# When services are running, show stop and restart
controls.mount(Button("Stop Services", variant="error", id=f"stop-btn{suffix}"))
controls.mount(Button("Restart", variant="primary", id=f"restart-btn{suffix}"))
controls.mount(
Button("Stop Services", variant="error", id=f"stop-btn{suffix}")
)
controls.mount(
Button("Restart", variant="primary", id=f"restart-btn{suffix}")
)
else:
# When services are not running, show start
controls.mount(Button("Start Services", variant="success", id=f"start-btn{suffix}"))
controls.mount(
Button("Start Services", variant="success", id=f"start-btn{suffix}")
)
# Always show upgrade and reset buttons
controls.mount(Button("Upgrade", variant="warning", id=f"upgrade-btn{suffix}"))
controls.mount(
Button("Upgrade", variant="warning", id=f"upgrade-btn{suffix}")
)
controls.mount(Button("Reset", variant="error", id=f"reset-btn{suffix}"))
except Exception as e:
notify_with_diagnostics(
self.app,
f"Error updating controls: {e}",
severity="error"
self.app, f"Error updating controls: {e}", severity="error"
)
def action_back(self) -> None:
"""Go back to previous screen."""
self.app.pop_screen()
def action_start(self) -> None:
"""Start services."""
self.run_worker(self._start_services())
def action_stop(self) -> None:
"""Stop services."""
self.run_worker(self._stop_services())
def action_upgrade(self) -> None:
"""Upgrade services."""
self.run_worker(self._upgrade_services())
def action_reset(self) -> None:
"""Reset services."""
self.run_worker(self._reset_services())
def action_logs(self) -> None:
"""View logs for the selected service."""
try:
# Get the currently focused row in the services table
table = self.query_one("#services-table", DataTable)
if table.cursor_row is not None and table.cursor_row >= 0:
# Get the service name from the first column of the selected row
row_data = table.get_row_at(table.cursor_row)
if row_data:
service_name = str(row_data[0]) # First column is service name
# Map display names to actual service names
service_mapping = {
"openrag-backend": "openrag-backend",
"openrag-frontend": "openrag-frontend",
"openrag-frontend": "openrag-frontend",
"opensearch": "opensearch",
"langflow": "langflow",
"dashboards": "dashboards"
"dashboards": "dashboards",
}
actual_service_name = service_mapping.get(service_name, service_name)
actual_service_name = service_mapping.get(
service_name, service_name
)
# Push the logs screen with the selected service
from .logs import LogsScreen
logs_screen = LogsScreen(initial_service=actual_service_name)
self.app.push_screen(logs_screen)
else:

View file

@ -16,7 +16,7 @@ from ..managers.env_manager import EnvManager
class WelcomeScreen(Screen):
"""Initial welcome screen with setup options."""
BINDINGS = [
("q", "quit", "Quit"),
("enter", "default_action", "Continue"),
@ -25,7 +25,7 @@ class WelcomeScreen(Screen):
("3", "monitor", "Monitor Services"),
("4", "diagnostics", "Diagnostics"),
]
def __init__(self):
super().__init__()
self.container_manager = ContainerManager()
@ -34,19 +34,19 @@ class WelcomeScreen(Screen):
self.has_oauth_config = False
self.default_button_id = "basic-setup-btn"
self._state_checked = False
# Load .env file if it exists
load_dotenv()
def compose(self) -> ComposeResult:
"""Create the welcome screen layout."""
yield Container(
Vertical(
Static(self._create_welcome_text(), id="welcome-text"),
self._create_dynamic_buttons(),
id="welcome-container"
id="welcome-container",
),
id="main-container"
id="main-container",
)
yield Footer()
@ -65,55 +65,67 @@ class WelcomeScreen(Screen):
welcome_text.append("Terminal User Interface for OpenRAG\n\n", style="dim")
if self.services_running:
welcome_text.append("✓ Services are currently running\n\n", style="bold green")
welcome_text.append(
"✓ Services are currently running\n\n", style="bold green"
)
elif self.has_oauth_config:
welcome_text.append("OAuth credentials detected — Advanced Setup recommended\n\n", style="bold green")
welcome_text.append(
"OAuth credentials detected — Advanced Setup recommended\n\n",
style="bold green",
)
else:
welcome_text.append("Select a setup below to continue\n\n", style="white")
return welcome_text
def _create_dynamic_buttons(self) -> Horizontal:
"""Create buttons based on current state."""
# Check OAuth config early to determine which buttons to show
has_oauth = (
bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or
bool(os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID"))
has_oauth = bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or bool(
os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID")
)
buttons = []
if self.services_running:
# Services running - only show monitor
buttons.append(Button("Monitor Services", variant="success", id="monitor-btn"))
buttons.append(
Button("Monitor Services", variant="success", id="monitor-btn")
)
else:
# Services not running - show setup options
if has_oauth:
# Only show advanced setup if OAuth is configured
buttons.append(Button("Advanced Setup", variant="success", id="advanced-setup-btn"))
buttons.append(
Button("Advanced Setup", variant="success", id="advanced-setup-btn")
)
else:
# Only show basic setup if no OAuth
buttons.append(Button("Basic Setup", variant="success", id="basic-setup-btn"))
buttons.append(
Button("Basic Setup", variant="success", id="basic-setup-btn")
)
# Always show monitor option
buttons.append(Button("Monitor Services", variant="default", id="monitor-btn"))
buttons.append(
Button("Monitor Services", variant="default", id="monitor-btn")
)
return Horizontal(*buttons, classes="button-row")
async def on_mount(self) -> None:
"""Initialize screen state when mounted."""
# Check if services are running
if self.container_manager.is_available():
services = await self.container_manager.get_service_status()
running_services = [s.name for s in services.values() if s.status == ServiceStatus.RUNNING]
running_services = [
s.name for s in services.values() if s.status == ServiceStatus.RUNNING
]
self.services_running = len(running_services) > 0
# Check for OAuth configuration
self.has_oauth_config = (
bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or
bool(os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID"))
self.has_oauth_config = bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or bool(
os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID")
)
# Set default button focus
if self.services_running:
self.default_button_id = "monitor-btn"
@ -121,12 +133,14 @@ class WelcomeScreen(Screen):
self.default_button_id = "advanced-setup-btn"
else:
self.default_button_id = "basic-setup-btn"
# Update the welcome text and recompose with new state
try:
welcome_widget = self.query_one("#welcome-text")
welcome_widget.update(self._create_welcome_text()) # This is fine for Static widgets
welcome_widget.update(
self._create_welcome_text()
) # This is fine for Static widgets
# Focus the appropriate button
if self.services_running:
try:
@ -143,10 +157,10 @@ class WelcomeScreen(Screen):
self.query_one("#basic-setup-btn").focus()
except:
pass
except:
pass # Widgets might not be mounted yet
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button presses."""
if event.button.id == "basic-setup-btn":
@ -157,7 +171,7 @@ class WelcomeScreen(Screen):
self.action_monitor()
elif event.button.id == "diagnostics-btn":
self.action_diagnostics()
def action_default_action(self) -> None:
"""Handle Enter key - go to default action based on state."""
if self.services_running:
@ -166,27 +180,31 @@ class WelcomeScreen(Screen):
self.action_full_setup()
else:
self.action_no_auth_setup()
def action_no_auth_setup(self) -> None:
"""Switch to basic configuration screen."""
from .config import ConfigScreen
self.app.push_screen(ConfigScreen(mode="no_auth"))
def action_full_setup(self) -> None:
"""Switch to advanced configuration screen."""
from .config import ConfigScreen
self.app.push_screen(ConfigScreen(mode="full"))
def action_monitor(self) -> None:
"""Switch to monitoring screen."""
from .monitor import MonitorScreen
self.app.push_screen(MonitorScreen())
def action_diagnostics(self) -> None:
"""Switch to diagnostics screen."""
from .diagnostics import DiagnosticsScreen
self.app.push_screen(DiagnosticsScreen())
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()
self.app.exit()

View file

@ -1 +1 @@
"""TUI utilities package."""
"""TUI utilities package."""

View file

@ -34,40 +34,66 @@ class PlatformDetector:
"""Detect available container runtime and compose capabilities."""
# First check if we have podman installed
podman_version = self._get_podman_version()
# If we have podman, check if docker is actually podman in disguise
if podman_version:
docker_version = self._get_docker_version()
if docker_version and podman_version in docker_version:
# This is podman masquerading as docker
if self._check_command(["docker", "compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["docker", "compose"], ["docker"], podman_version)
return RuntimeInfo(
RuntimeType.PODMAN,
["docker", "compose"],
["docker"],
podman_version,
)
if self._check_command(["docker-compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["docker-compose"], ["docker"], podman_version)
return RuntimeInfo(
RuntimeType.PODMAN,
["docker-compose"],
["docker"],
podman_version,
)
# Check for native podman compose
if self._check_command(["podman", "compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["podman", "compose"], ["podman"], podman_version)
return RuntimeInfo(
RuntimeType.PODMAN,
["podman", "compose"],
["podman"],
podman_version,
)
# Check for actual docker
if self._check_command(["docker", "compose", "--help"]):
version = self._get_docker_version()
return RuntimeInfo(RuntimeType.DOCKER, ["docker", "compose"], ["docker"], version)
return RuntimeInfo(
RuntimeType.DOCKER, ["docker", "compose"], ["docker"], version
)
if self._check_command(["docker-compose", "--help"]):
version = self._get_docker_version()
return RuntimeInfo(RuntimeType.DOCKER_COMPOSE, ["docker-compose"], ["docker"], version)
return RuntimeInfo(
RuntimeType.DOCKER_COMPOSE, ["docker-compose"], ["docker"], version
)
return RuntimeInfo(RuntimeType.NONE, [], [])
def detect_gpu_available(self) -> bool:
"""Best-effort detection of NVIDIA GPU availability for containers."""
try:
res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=5)
if res.returncode == 0 and any("GPU" in ln for ln in res.stdout.splitlines()):
res = subprocess.run(
["nvidia-smi", "-L"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0 and any(
"GPU" in ln for ln in res.stdout.splitlines()
):
return True
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
for cmd in (["docker", "info", "--format", "{{json .Runtimes}}"], ["podman", "info", "--format", "json"]):
for cmd in (
["docker", "info", "--format", "{{json .Runtimes}}"],
["podman", "info", "--format", "json"],
):
try:
res = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
if res.returncode == 0 and "nvidia" in res.stdout.lower():
@ -85,7 +111,9 @@ class PlatformDetector:
def _get_docker_version(self) -> Optional[str]:
try:
res = subprocess.run(["docker", "--version"], capture_output=True, text=True, timeout=5)
res = subprocess.run(
["docker", "--version"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0:
return res.stdout.strip()
except (subprocess.TimeoutExpired, FileNotFoundError):
@ -94,7 +122,9 @@ class PlatformDetector:
def _get_podman_version(self) -> Optional[str]:
try:
res = subprocess.run(["podman", "--version"], capture_output=True, text=True, timeout=5)
res = subprocess.run(
["podman", "--version"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0:
return res.stdout.strip()
except (subprocess.TimeoutExpired, FileNotFoundError):
@ -110,7 +140,12 @@ class PlatformDetector:
if self.platform_system != "Darwin":
return True, 0, "Not running on macOS"
try:
result = subprocess.run(["podman", "machine", "inspect"], capture_output=True, text=True, timeout=10)
result = subprocess.run(
["podman", "machine", "inspect"],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode != 0:
return False, 0, "Could not inspect Podman machine"
machines = json.loads(result.stdout)
@ -124,7 +159,11 @@ class PlatformDetector:
if not is_sufficient:
status += "\nTo increase: podman machine stop && podman machine rm && podman machine init --memory 8192 && podman machine start"
return is_sufficient, memory_mb, status
except (subprocess.TimeoutExpired, FileNotFoundError, json.JSONDecodeError) as e:
except (
subprocess.TimeoutExpired,
FileNotFoundError,
json.JSONDecodeError,
) as e:
return False, 0, f"Error checking Podman VM memory: {e}"
def get_installation_instructions(self) -> str:
@ -167,4 +206,4 @@ Or Podman Desktop:
No container runtime found. Please install Docker or Podman for your platform:
- Docker: https://docs.docker.com/get-docker/
- Podman: https://podman.io/getting-started/installation
"""
"""

View file

@ -8,28 +8,31 @@ from typing import Optional
class ValidationError(Exception):
"""Validation error exception."""
pass
def validate_env_var_name(name: str) -> bool:
"""Validate environment variable name format."""
return bool(re.match(r'^[A-Z][A-Z0-9_]*$', name))
return bool(re.match(r"^[A-Z][A-Z0-9_]*$", name))
def validate_path(path: str, must_exist: bool = False, must_be_dir: bool = False) -> bool:
def validate_path(
path: str, must_exist: bool = False, must_be_dir: bool = False
) -> bool:
"""Validate file/directory path."""
if not path:
return False
try:
path_obj = Path(path).expanduser().resolve()
if must_exist and not path_obj.exists():
return False
if must_be_dir and path_obj.exists() and not path_obj.is_dir():
return False
return True
except (OSError, ValueError):
return False
@ -39,15 +42,17 @@ def validate_url(url: str) -> bool:
"""Validate URL format."""
if not url:
return False
url_pattern = re.compile(
r'^https?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain
r'localhost|' # localhost
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # IP
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
r"^https?://" # http:// or https://
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|" # domain
r"localhost|" # localhost
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP
r"(?::\d+)?" # optional port
r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
return bool(url_pattern.match(url))
@ -55,14 +60,14 @@ def validate_openai_api_key(key: str) -> bool:
"""Validate OpenAI API key format."""
if not key:
return False
return key.startswith('sk-') and len(key) > 20
return key.startswith("sk-") and len(key) > 20
def validate_google_oauth_client_id(client_id: str) -> bool:
"""Validate Google OAuth client ID format."""
if not client_id:
return False
return client_id.endswith('.apps.googleusercontent.com')
return client_id.endswith(".apps.googleusercontent.com")
def validate_non_empty(value: str) -> bool:
@ -74,37 +79,38 @@ def sanitize_env_value(value: str) -> str:
"""Sanitize environment variable value."""
# Remove leading/trailing whitespace
value = value.strip()
# Remove quotes if they wrap the entire value
if len(value) >= 2:
if (value.startswith('"') and value.endswith('"')) or \
(value.startswith("'") and value.endswith("'")):
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
value = value[1:-1]
return value
def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
"""
Validate comma-separated documents paths for volume mounting.
Returns:
(is_valid, error_message, validated_paths)
"""
if not paths_str:
return False, "Documents paths cannot be empty", []
paths = [path.strip() for path in paths_str.split(',') if path.strip()]
paths = [path.strip() for path in paths_str.split(",") if path.strip()]
if not paths:
return False, "No valid paths provided", []
validated_paths = []
for path in paths:
try:
path_obj = Path(path).expanduser().resolve()
# Check if path exists
if not path_obj.exists():
# Try to create it
@ -112,11 +118,11 @@ def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
path_obj.mkdir(parents=True, exist_ok=True)
except (OSError, PermissionError) as e:
return False, f"Cannot create directory '{path}': {e}", []
# Check if it's a directory
if not path_obj.is_dir():
return False, f"Path '{path}' must be a directory", []
# Check if we can write to it
try:
test_file = path_obj / ".openrag_test"
@ -124,10 +130,10 @@ def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
test_file.unlink()
except (OSError, PermissionError):
return False, f"Directory '{path}' is not writable", []
validated_paths.append(str(path_obj))
except (OSError, ValueError) as e:
return False, f"Invalid path '{path}': {e}", []
return True, "All paths valid", validated_paths
return True, "All paths valid", validated_paths

View file

@ -65,13 +65,13 @@ class CommandOutputModal(ModalScreen):
"""
def __init__(
self,
title: str,
self,
title: str,
command_generator: AsyncIterator[tuple[bool, str]],
on_complete: Optional[Callable] = None
on_complete: Optional[Callable] = None,
):
"""Initialize the modal dialog.
Args:
title: Title of the modal dialog
command_generator: Async generator that yields (is_complete, message) tuples
@ -104,29 +104,32 @@ class CommandOutputModal(ModalScreen):
async def _run_command(self) -> None:
"""Run the command and update the output in real-time."""
output = self.query_one("#command-output", RichLog)
try:
async for is_complete, message in self.command_generator:
# Simple approach: just append each line as it comes
output.write(message + "\n")
# Scroll to bottom
container = self.query_one("#output-container", ScrollableContainer)
container.scroll_end(animate=False)
# If command is complete, update UI
if is_complete:
output.write("[bold green]Command completed successfully[/bold green]\n")
output.write(
"[bold green]Command completed successfully[/bold green]\n"
)
# Call the completion callback if provided
if self.on_complete:
await asyncio.sleep(0.5) # Small delay for better UX
self.on_complete()
except Exception as e:
output.write(f"[bold red]Error: {e}[/bold red]\n")
# Enable the close button and focus it
close_btn = self.query_one("#close-btn", Button)
close_btn.disabled = False
close_btn.focus()
# Made with Bob

View file

@ -9,10 +9,10 @@ def notify_with_diagnostics(
app: App,
message: str,
severity: Literal["information", "warning", "error"] = "error",
timeout: float = 10.0
timeout: float = 10.0,
) -> None:
"""Show a notification with a button to open the diagnostics screen.
Args:
app: The Textual app
message: The notification message
@ -21,18 +21,20 @@ def notify_with_diagnostics(
"""
# First show the notification
app.notify(message, severity=severity, timeout=timeout)
# Then add a button to open diagnostics screen
def open_diagnostics() -> None:
from ..screens.diagnostics import DiagnosticsScreen
app.push_screen(DiagnosticsScreen())
# Add a separate notification with just the button
app.notify(
"Click to view diagnostics",
severity="information",
timeout=timeout,
title="Diagnostics"
title="Diagnostics",
)
# Made with Bob
# Made with Bob

View file

@ -9,10 +9,10 @@ def notify_with_diagnostics(
app: App,
message: str,
severity: Literal["information", "warning", "error"] = "error",
timeout: float = 10.0
timeout: float = 10.0,
) -> None:
"""Show a notification with a button to open the diagnostics screen.
Args:
app: The Textual app
message: The notification message
@ -21,18 +21,20 @@ def notify_with_diagnostics(
"""
# First show the notification
app.notify(message, severity=severity, timeout=timeout)
# Then add a button to open diagnostics screen
def open_diagnostics() -> None:
from ..screens.diagnostics import DiagnosticsScreen
app.push_screen(DiagnosticsScreen())
# Add a separate notification with just the button
app.notify(
"Click to view diagnostics",
severity="information",
timeout=timeout,
title="Diagnostics"
title="Diagnostics",
)
# Made with Bob

View file

@ -1,7 +1,12 @@
import hashlib
import os
import sys
import platform
from collections import defaultdict
from .gpu_detection import detect_gpu_devices
from utils.logging_config import get_logger
logger = get_logger(__name__)
# Global converter cache for worker processes
_worker_converter = None
@ -37,11 +42,11 @@ def get_worker_converter():
"1" # Still disable progress bars
)
print(
f"[WORKER {os.getpid()}] Initializing DocumentConverter in worker process"
logger.info(
"Initializing DocumentConverter in worker process", worker_pid=os.getpid()
)
_worker_converter = DocumentConverter()
print(f"[WORKER {os.getpid()}] DocumentConverter ready in worker process")
logger.info("DocumentConverter ready in worker process", worker_pid=os.getpid())
return _worker_converter
@ -118,33 +123,45 @@ def process_document_sync(file_path: str):
start_memory = process.memory_info().rss / 1024 / 1024 # MB
try:
print(f"[WORKER {os.getpid()}] Starting document processing: {file_path}")
print(f"[WORKER {os.getpid()}] Initial memory usage: {start_memory:.1f} MB")
logger.info(
"Starting document processing",
worker_pid=os.getpid(),
file_path=file_path,
initial_memory_mb=f"{start_memory:.1f}",
)
# Check file size
try:
file_size = os.path.getsize(file_path) / 1024 / 1024 # MB
print(f"[WORKER {os.getpid()}] File size: {file_size:.1f} MB")
logger.info(
"File size determined",
worker_pid=os.getpid(),
file_size_mb=f"{file_size:.1f}",
)
except OSError as e:
print(f"[WORKER {os.getpid()}] WARNING: Cannot get file size: {e}")
logger.warning("Cannot get file size", worker_pid=os.getpid(), error=str(e))
file_size = 0
# Get the cached converter for this worker
try:
print(f"[WORKER {os.getpid()}] Getting document converter...")
logger.info("Getting document converter", worker_pid=os.getpid())
converter = get_worker_converter()
memory_after_converter = process.memory_info().rss / 1024 / 1024
print(
f"[WORKER {os.getpid()}] Memory after converter init: {memory_after_converter:.1f} MB"
logger.info(
"Memory after converter init",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_converter:.1f}",
)
except Exception as e:
print(f"[WORKER {os.getpid()}] ERROR: Failed to initialize converter: {e}")
logger.error(
"Failed to initialize converter", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc()
raise
# Compute file hash
try:
print(f"[WORKER {os.getpid()}] Computing file hash...")
logger.info("Computing file hash", worker_pid=os.getpid())
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while True:
@ -153,50 +170,67 @@ def process_document_sync(file_path: str):
break
sha256.update(chunk)
file_hash = sha256.hexdigest()
print(f"[WORKER {os.getpid()}] File hash computed: {file_hash[:12]}...")
logger.info(
"File hash computed",
worker_pid=os.getpid(),
file_hash_prefix=file_hash[:12],
)
except Exception as e:
print(f"[WORKER {os.getpid()}] ERROR: Failed to compute file hash: {e}")
logger.error(
"Failed to compute file hash", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc()
raise
# Convert with docling
try:
print(f"[WORKER {os.getpid()}] Starting docling conversion...")
logger.info("Starting docling conversion", worker_pid=os.getpid())
memory_before_convert = process.memory_info().rss / 1024 / 1024
print(
f"[WORKER {os.getpid()}] Memory before conversion: {memory_before_convert:.1f} MB"
logger.info(
"Memory before conversion",
worker_pid=os.getpid(),
memory_mb=f"{memory_before_convert:.1f}",
)
result = converter.convert(file_path)
memory_after_convert = process.memory_info().rss / 1024 / 1024
print(
f"[WORKER {os.getpid()}] Memory after conversion: {memory_after_convert:.1f} MB"
logger.info(
"Memory after conversion",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_convert:.1f}",
)
print(f"[WORKER {os.getpid()}] Docling conversion completed")
logger.info("Docling conversion completed", worker_pid=os.getpid())
full_doc = result.document.export_to_dict()
memory_after_export = process.memory_info().rss / 1024 / 1024
print(
f"[WORKER {os.getpid()}] Memory after export: {memory_after_export:.1f} MB"
logger.info(
"Memory after export",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_export:.1f}",
)
except Exception as e:
print(
f"[WORKER {os.getpid()}] ERROR: Failed during docling conversion: {e}"
)
print(
f"[WORKER {os.getpid()}] Current memory usage: {process.memory_info().rss / 1024 / 1024:.1f} MB"
current_memory = process.memory_info().rss / 1024 / 1024
logger.error(
"Failed during docling conversion",
worker_pid=os.getpid(),
error=str(e),
current_memory_mb=f"{current_memory:.1f}",
)
traceback.print_exc()
raise
# Extract relevant content (same logic as extract_relevant)
try:
print(f"[WORKER {os.getpid()}] Extracting relevant content...")
logger.info("Extracting relevant content", worker_pid=os.getpid())
origin = full_doc.get("origin", {})
texts = full_doc.get("texts", [])
print(f"[WORKER {os.getpid()}] Found {len(texts)} text fragments")
logger.info(
"Found text fragments",
worker_pid=os.getpid(),
fragment_count=len(texts),
)
page_texts = defaultdict(list)
for txt in texts:
@ -210,22 +244,27 @@ def process_document_sync(file_path: str):
joined = "\n".join(page_texts[page])
chunks.append({"page": page, "text": joined})
print(
f"[WORKER {os.getpid()}] Created {len(chunks)} chunks from {len(page_texts)} pages"
logger.info(
"Created chunks from pages",
worker_pid=os.getpid(),
chunk_count=len(chunks),
page_count=len(page_texts),
)
except Exception as e:
print(
f"[WORKER {os.getpid()}] ERROR: Failed during content extraction: {e}"
logger.error(
"Failed during content extraction", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc()
raise
final_memory = process.memory_info().rss / 1024 / 1024
memory_delta = final_memory - start_memory
print(f"[WORKER {os.getpid()}] Document processing completed successfully")
print(
f"[WORKER {os.getpid()}] Final memory: {final_memory:.1f} MB (Delta +{memory_delta:.1f} MB)"
logger.info(
"Document processing completed successfully",
worker_pid=os.getpid(),
final_memory_mb=f"{final_memory:.1f}",
memory_delta_mb=f"{memory_delta:.1f}",
)
return {
@ -239,24 +278,29 @@ def process_document_sync(file_path: str):
except Exception as e:
final_memory = process.memory_info().rss / 1024 / 1024
memory_delta = final_memory - start_memory
print(f"[WORKER {os.getpid()}] FATAL ERROR in process_document_sync")
print(f"[WORKER {os.getpid()}] File: {file_path}")
print(f"[WORKER {os.getpid()}] Python version: {sys.version}")
print(
f"[WORKER {os.getpid()}] Memory at crash: {final_memory:.1f} MB (Delta +{memory_delta:.1f} MB)"
logger.error(
"FATAL ERROR in process_document_sync",
worker_pid=os.getpid(),
file_path=file_path,
python_version=sys.version,
memory_at_crash_mb=f"{final_memory:.1f}",
memory_delta_mb=f"{memory_delta:.1f}",
error_type=type(e).__name__,
error=str(e),
)
print(f"[WORKER {os.getpid()}] Error: {type(e).__name__}: {e}")
print(f"[WORKER {os.getpid()}] Full traceback:")
logger.error("Full traceback:", worker_pid=os.getpid())
traceback.print_exc()
# Try to get more system info before crashing
try:
import platform
print(
f"[WORKER {os.getpid()}] System: {platform.system()} {platform.release()}"
logger.error(
"System info",
worker_pid=os.getpid(),
system=f"{platform.system()} {platform.release()}",
architecture=platform.machine(),
)
print(f"[WORKER {os.getpid()}] Architecture: {platform.machine()}")
except:
pass

View file

@ -1,5 +1,8 @@
import multiprocessing
import os
from utils.logging_config import get_logger
logger = get_logger(__name__)
def detect_gpu_devices():
@ -30,13 +33,15 @@ def get_worker_count():
if has_gpu_devices:
default_workers = min(4, multiprocessing.cpu_count() // 2)
print(
f"GPU mode enabled with {gpu_count} GPU(s) - using limited concurrency ({default_workers} workers)"
logger.info(
"GPU mode enabled with limited concurrency",
gpu_count=gpu_count,
worker_count=default_workers,
)
else:
default_workers = multiprocessing.cpu_count()
print(
f"CPU-only mode enabled - using full concurrency ({default_workers} workers)"
logger.info(
"CPU-only mode enabled with full concurrency", worker_count=default_workers
)
return int(os.getenv("MAX_WORKERS", default_workers))

View file

@ -9,13 +9,15 @@ def configure_logging(
log_level: str = "INFO",
json_logs: bool = False,
include_timestamps: bool = True,
service_name: str = "openrag"
service_name: str = "openrag",
) -> None:
"""Configure structlog for the application."""
# Convert string log level to actual level
level = getattr(structlog.stdlib.logging, log_level.upper(), structlog.stdlib.logging.INFO)
level = getattr(
structlog.stdlib.logging, log_level.upper(), structlog.stdlib.logging.INFO
)
# Base processors
shared_processors = [
structlog.contextvars.merge_contextvars,
@ -23,29 +25,65 @@ def configure_logging(
structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info,
]
if include_timestamps:
shared_processors.append(structlog.processors.TimeStamper(fmt="iso"))
# Add service name to all logs
# Add service name and file location to all logs
shared_processors.append(
structlog.processors.CallsiteParameterAdder(
parameters=[structlog.processors.CallsiteParameter.FUNC_NAME]
parameters=[
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.FILENAME,
structlog.processors.CallsiteParameter.LINENO,
structlog.processors.CallsiteParameter.PATHNAME,
]
)
)
# Console output configuration
if json_logs or os.getenv("LOG_FORMAT", "").lower() == "json":
# JSON output for production/containers
shared_processors.append(structlog.processors.JSONRenderer())
console_renderer = structlog.processors.JSONRenderer()
else:
# Pretty colored output for development
console_renderer = structlog.dev.ConsoleRenderer(
colors=sys.stderr.isatty(),
exception_formatter=structlog.dev.plain_traceback,
)
# Custom clean format: timestamp path/file:loc logentry
def custom_formatter(logger, log_method, event_dict):
timestamp = event_dict.pop("timestamp", "")
pathname = event_dict.pop("pathname", "")
filename = event_dict.pop("filename", "")
lineno = event_dict.pop("lineno", "")
level = event_dict.pop("level", "")
# Build file location - prefer pathname for full path, fallback to filename
if pathname and lineno:
location = f"{pathname}:{lineno}"
elif filename and lineno:
location = f"{filename}:{lineno}"
elif pathname:
location = pathname
elif filename:
location = filename
else:
location = "unknown"
# Build the main message
message_parts = []
event = event_dict.pop("event", "")
if event:
message_parts.append(event)
# Add any remaining context
for key, value in event_dict.items():
if key not in ["service", "func_name"]: # Skip internal fields
message_parts.append(f"{key}={value}")
message = " ".join(message_parts)
return f"{timestamp} {location} {message}"
console_renderer = custom_formatter
# Configure structlog
structlog.configure(
processors=shared_processors + [console_renderer],
@ -54,7 +92,7 @@ def configure_logging(
logger_factory=structlog.WriteLoggerFactory(sys.stderr),
cache_logger_on_first_use=True,
)
# Add global context
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(service=service_name)
@ -73,9 +111,7 @@ def configure_from_env() -> None:
log_level = os.getenv("LOG_LEVEL", "INFO")
json_logs = os.getenv("LOG_FORMAT", "").lower() == "json"
service_name = os.getenv("SERVICE_NAME", "openrag")
configure_logging(
log_level=log_level,
json_logs=json_logs,
service_name=service_name
)
log_level=log_level, json_logs=json_logs, service_name=service_name
)

View file

@ -1,10 +1,13 @@
import os
from concurrent.futures import ProcessPoolExecutor
from utils.gpu_detection import get_worker_count
from utils.logging_config import get_logger
logger = get_logger(__name__)
# Create shared process pool at import time (before CUDA initialization)
# This avoids the "Cannot re-initialize CUDA in forked subprocess" error
MAX_WORKERS = get_worker_count()
process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
print(f"Shared process pool initialized with {MAX_WORKERS} workers")
logger.info("Shared process pool initialized", max_workers=MAX_WORKERS)

View file

@ -1,13 +1,17 @@
from docling.document_converter import DocumentConverter
import logging
print("Warming up docling models...")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Warming up docling models")
try:
# Use the sample document to warm up docling
test_file = "/app/warmup_ocr.pdf"
print(f"Using {test_file} to warm up docling...")
logger.info(f"Using test file to warm up docling: {test_file}")
DocumentConverter().convert(test_file)
print("Docling models warmed up successfully")
logger.info("Docling models warmed up successfully")
except Exception as e:
print(f"Docling warm-up completed with: {e}")
logger.info(f"Docling warm-up completed with exception: {str(e)}")
# This is expected - we just want to trigger the model downloads