Merge branch 'feat-googledrive-enhancements' of github.com:langflow-ai/openrag into feat-googledrive-enhancements

This commit is contained in:
Mike Fortman 2025-09-05 16:13:00 -05:00
commit 86d8c9f5b2
41 changed files with 1730 additions and 967 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

Binary file not shown.

View file

@ -95,7 +95,9 @@ async def async_response_stream(
chunk_count = 0 chunk_count = 0
async for chunk in response: async for chunk in response:
chunk_count += 1 chunk_count += 1
logger.debug("Stream chunk received", chunk_count=chunk_count, chunk=str(chunk)) logger.debug(
"Stream chunk received", chunk_count=chunk_count, chunk=str(chunk)
)
# Yield the raw event as JSON for the UI to process # Yield the raw event as JSON for the UI to process
import json import json
@ -241,7 +243,10 @@ async def async_langflow_stream(
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
log_prefix="langflow", log_prefix="langflow",
): ):
logger.debug("Yielding chunk from langflow stream", chunk_preview=chunk[:100].decode('utf-8', errors='replace')) logger.debug(
"Yielding chunk from langflow stream",
chunk_preview=chunk[:100].decode("utf-8", errors="replace"),
)
yield chunk yield chunk
logger.debug("Langflow stream completed") logger.debug("Langflow stream completed")
except Exception as e: except Exception as e:
@ -260,18 +265,24 @@ async def async_chat(
model: str = "gpt-4.1-mini", model: str = "gpt-4.1-mini",
previous_response_id: str = None, previous_response_id: str = None,
): ):
logger.debug("async_chat called", user_id=user_id, previous_response_id=previous_response_id) logger.debug(
"async_chat called", user_id=user_id, previous_response_id=previous_response_id
)
# Get the specific conversation thread (or create new one) # Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id) conversation_state = get_conversation_thread(user_id, previous_response_id)
logger.debug("Got conversation state", message_count=len(conversation_state['messages'])) logger.debug(
"Got conversation state", message_count=len(conversation_state["messages"])
)
# Add user message to conversation with timestamp # Add user message to conversation with timestamp
from datetime import datetime from datetime import datetime
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()} user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
conversation_state["messages"].append(user_message) conversation_state["messages"].append(user_message)
logger.debug("Added user message", message_count=len(conversation_state['messages'])) logger.debug(
"Added user message", message_count=len(conversation_state["messages"])
)
response_text, response_id = await async_response( response_text, response_id = await async_response(
async_client, async_client,
@ -280,7 +291,9 @@ async def async_chat(
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
log_prefix="agent", log_prefix="agent",
) )
logger.debug("Got response", response_preview=response_text[:50], response_id=response_id) logger.debug(
"Got response", response_preview=response_text[:50], response_id=response_id
)
# Add assistant response to conversation with response_id and timestamp # Add assistant response to conversation with response_id and timestamp
assistant_message = { assistant_message = {
@ -290,17 +303,26 @@ async def async_chat(
"timestamp": datetime.now(), "timestamp": datetime.now(),
} }
conversation_state["messages"].append(assistant_message) conversation_state["messages"].append(assistant_message)
logger.debug("Added assistant message", message_count=len(conversation_state['messages'])) logger.debug(
"Added assistant message", message_count=len(conversation_state["messages"])
)
# Store the conversation thread with its response_id # Store the conversation thread with its response_id
if response_id: if response_id:
conversation_state["last_activity"] = datetime.now() conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state) store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored conversation thread", user_id=user_id, response_id=response_id) logger.debug(
"Stored conversation thread", user_id=user_id, response_id=response_id
)
# Debug: Check what's in user_conversations now # Debug: Check what's in user_conversations now
conversations = get_user_conversations(user_id) conversations = get_user_conversations(user_id)
logger.debug("User conversations updated", user_id=user_id, conversation_count=len(conversations), conversation_ids=list(conversations.keys())) logger.debug(
"User conversations updated",
user_id=user_id,
conversation_count=len(conversations),
conversation_ids=list(conversations.keys()),
)
else: else:
logger.warning("No response_id received, conversation not stored") logger.warning("No response_id received, conversation not stored")
@ -363,7 +385,9 @@ async def async_chat_stream(
if response_id: if response_id:
conversation_state["last_activity"] = datetime.now() conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state) store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored conversation thread", user_id=user_id, response_id=response_id) logger.debug(
"Stored conversation thread", user_id=user_id, response_id=response_id
)
# Async langflow function with conversation storage (non-streaming) # Async langflow function with conversation storage (non-streaming)
@ -375,18 +399,28 @@ async def async_langflow_chat(
extra_headers: dict = None, extra_headers: dict = None,
previous_response_id: str = None, previous_response_id: str = None,
): ):
logger.debug("async_langflow_chat called", user_id=user_id, previous_response_id=previous_response_id) logger.debug(
"async_langflow_chat called",
user_id=user_id,
previous_response_id=previous_response_id,
)
# Get the specific conversation thread (or create new one) # Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id) conversation_state = get_conversation_thread(user_id, previous_response_id)
logger.debug("Got langflow conversation state", message_count=len(conversation_state['messages'])) logger.debug(
"Got langflow conversation state",
message_count=len(conversation_state["messages"]),
)
# Add user message to conversation with timestamp # Add user message to conversation with timestamp
from datetime import datetime from datetime import datetime
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()} user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
conversation_state["messages"].append(user_message) conversation_state["messages"].append(user_message)
logger.debug("Added user message to langflow", message_count=len(conversation_state['messages'])) logger.debug(
"Added user message to langflow",
message_count=len(conversation_state["messages"]),
)
response_text, response_id = await async_response( response_text, response_id = await async_response(
langflow_client, langflow_client,
@ -396,7 +430,11 @@ async def async_langflow_chat(
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
log_prefix="langflow", log_prefix="langflow",
) )
logger.debug("Got langflow response", response_preview=response_text[:50], response_id=response_id) logger.debug(
"Got langflow response",
response_preview=response_text[:50],
response_id=response_id,
)
# Add assistant response to conversation with response_id and timestamp # Add assistant response to conversation with response_id and timestamp
assistant_message = { assistant_message = {
@ -406,17 +444,29 @@ async def async_langflow_chat(
"timestamp": datetime.now(), "timestamp": datetime.now(),
} }
conversation_state["messages"].append(assistant_message) conversation_state["messages"].append(assistant_message)
logger.debug("Added assistant message to langflow", message_count=len(conversation_state['messages'])) logger.debug(
"Added assistant message to langflow",
message_count=len(conversation_state["messages"]),
)
# Store the conversation thread with its response_id # Store the conversation thread with its response_id
if response_id: if response_id:
conversation_state["last_activity"] = datetime.now() conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state) store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id) logger.debug(
"Stored langflow conversation thread",
user_id=user_id,
response_id=response_id,
)
# Debug: Check what's in user_conversations now # Debug: Check what's in user_conversations now
conversations = get_user_conversations(user_id) conversations = get_user_conversations(user_id)
logger.debug("User conversations updated", user_id=user_id, conversation_count=len(conversations), conversation_ids=list(conversations.keys())) logger.debug(
"User conversations updated",
user_id=user_id,
conversation_count=len(conversations),
conversation_ids=list(conversations.keys()),
)
else: else:
logger.warning("No response_id received from langflow, conversation not stored") logger.warning("No response_id received from langflow, conversation not stored")
@ -432,7 +482,11 @@ async def async_langflow_chat_stream(
extra_headers: dict = None, extra_headers: dict = None,
previous_response_id: str = None, previous_response_id: str = None,
): ):
logger.debug("async_langflow_chat_stream called", user_id=user_id, previous_response_id=previous_response_id) logger.debug(
"async_langflow_chat_stream called",
user_id=user_id,
previous_response_id=previous_response_id,
)
# Get the specific conversation thread (or create new one) # Get the specific conversation thread (or create new one)
conversation_state = get_conversation_thread(user_id, previous_response_id) conversation_state = get_conversation_thread(user_id, previous_response_id)
@ -483,4 +537,8 @@ async def async_langflow_chat_stream(
if response_id: if response_id:
conversation_state["last_activity"] = datetime.now() conversation_state["last_activity"] = datetime.now()
store_conversation_thread(user_id, response_id, conversation_state) store_conversation_thread(user_id, response_id, conversation_state)
logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id) logger.debug(
"Stored langflow conversation thread",
user_id=user_id,
response_id=response_id,
)

View file

@ -25,6 +25,11 @@ async def connector_sync(request: Request, connector_service, session_manager):
selected_files = data.get("selected_files") selected_files = data.get("selected_files")
try: try:
logger.debug(
"Starting connector sync",
connector_type=connector_type,
max_files=max_files,
)
user = request.state.user user = request.state.user
jwt_token = request.state.jwt_token jwt_token = request.state.jwt_token
@ -44,6 +49,10 @@ async def connector_sync(request: Request, connector_service, session_manager):
# Start sync tasks for all active connections # Start sync tasks for all active connections
task_ids = [] task_ids = []
for connection in active_connections: for connection in active_connections:
logger.debug(
"About to call sync_connector_files for connection",
connection_id=connection.connection_id,
)
if selected_files: if selected_files:
task_id = await connector_service.sync_specific_files( task_id = await connector_service.sync_specific_files(
connection.connection_id, connection.connection_id,
@ -58,8 +67,6 @@ async def connector_sync(request: Request, connector_service, session_manager):
max_files, max_files,
jwt_token=jwt_token, jwt_token=jwt_token,
) )
task_ids.append(task_id)
return JSONResponse( return JSONResponse(
{ {
"task_ids": task_ids, "task_ids": task_ids,
@ -170,7 +177,9 @@ async def connector_webhook(request: Request, connector_service, session_manager
channel_id = None channel_id = None
if not channel_id: if not channel_id:
logger.warning("No channel ID found in webhook", connector_type=connector_type) logger.warning(
"No channel ID found in webhook", connector_type=connector_type
)
return JSONResponse({"status": "ignored", "reason": "no_channel_id"}) return JSONResponse({"status": "ignored", "reason": "no_channel_id"})
# Find the specific connection for this webhook # Find the specific connection for this webhook
@ -180,7 +189,9 @@ async def connector_webhook(request: Request, connector_service, session_manager
) )
) )
if not connection or not connection.is_active: if not connection or not connection.is_active:
logger.info("Unknown webhook channel, will auto-expire", channel_id=channel_id) logger.info(
"Unknown webhook channel, will auto-expire", channel_id=channel_id
)
return JSONResponse( return JSONResponse(
{"status": "ignored_unknown_channel", "channel_id": channel_id} {"status": "ignored_unknown_channel", "channel_id": channel_id}
) )
@ -190,7 +201,10 @@ async def connector_webhook(request: Request, connector_service, session_manager
# Get the connector instance # Get the connector instance
connector = await connector_service._get_connector(connection.connection_id) connector = await connector_service._get_connector(connection.connection_id)
if not connector: if not connector:
logger.error("Could not get connector for connection", connection_id=connection.connection_id) logger.error(
"Could not get connector for connection",
connection_id=connection.connection_id,
)
return JSONResponse( return JSONResponse(
{"status": "error", "reason": "connector_not_found"} {"status": "error", "reason": "connector_not_found"}
) )
@ -199,7 +213,11 @@ async def connector_webhook(request: Request, connector_service, session_manager
affected_files = await connector.handle_webhook(payload) affected_files = await connector.handle_webhook(payload)
if affected_files: if affected_files:
logger.info("Webhook connection files affected", connection_id=connection.connection_id, affected_count=len(affected_files)) logger.info(
"Webhook connection files affected",
connection_id=connection.connection_id,
affected_count=len(affected_files),
)
# Generate JWT token for the user (needed for OpenSearch authentication) # Generate JWT token for the user (needed for OpenSearch authentication)
user = session_manager.get_user(connection.user_id) user = session_manager.get_user(connection.user_id)
@ -223,7 +241,10 @@ async def connector_webhook(request: Request, connector_service, session_manager
} }
else: else:
# No specific files identified - just log the webhook # No specific files identified - just log the webhook
logger.info("Webhook general change detected, no specific files", connection_id=connection.connection_id) logger.info(
"Webhook general change detected, no specific files",
connection_id=connection.connection_id,
)
result = { result = {
"connection_id": connection.connection_id, "connection_id": connection.connection_id,
@ -241,7 +262,15 @@ async def connector_webhook(request: Request, connector_service, session_manager
) )
except Exception as e: except Exception as e:
logger.error("Failed to process webhook for connection", connection_id=connection.connection_id, error=str(e)) logger.error(
"Failed to process webhook for connection",
connection_id=connection.connection_id,
error=str(e),
)
import traceback
traceback.print_exc()
return JSONResponse( return JSONResponse(
{ {
"status": "error", "status": "error",

View file

@ -395,15 +395,19 @@ async def knowledge_filter_webhook(
# Get the webhook payload # Get the webhook payload
payload = await request.json() payload = await request.json()
logger.info("Knowledge filter webhook received", logger.info(
"Knowledge filter webhook received",
filter_id=filter_id, filter_id=filter_id,
subscription_id=subscription_id, subscription_id=subscription_id,
payload_size=len(str(payload))) payload_size=len(str(payload)),
)
# Extract findings from the payload # Extract findings from the payload
findings = payload.get("findings", []) findings = payload.get("findings", [])
if not findings: if not findings:
logger.info("No findings in webhook payload", subscription_id=subscription_id) logger.info(
"No findings in webhook payload", subscription_id=subscription_id
)
return JSONResponse({"status": "no_findings"}) return JSONResponse({"status": "no_findings"})
# Process the findings - these are the documents that matched the knowledge filter # Process the findings - these are the documents that matched the knowledge filter
@ -420,14 +424,18 @@ async def knowledge_filter_webhook(
) )
# Log the matched documents # Log the matched documents
logger.info("Knowledge filter matched documents", logger.info(
"Knowledge filter matched documents",
filter_id=filter_id, filter_id=filter_id,
matched_count=len(matched_documents)) matched_count=len(matched_documents),
)
for doc in matched_documents: for doc in matched_documents:
logger.debug("Matched document", logger.debug(
document_id=doc['document_id'], "Matched document",
index=doc['index'], document_id=doc["document_id"],
score=doc.get('score')) index=doc["index"],
score=doc.get("score"),
)
# Here you could add additional processing: # Here you could add additional processing:
# - Send notifications to external webhooks # - Send notifications to external webhooks
@ -446,10 +454,12 @@ async def knowledge_filter_webhook(
) )
except Exception as e: except Exception as e:
logger.error("Failed to process knowledge filter webhook", logger.error(
"Failed to process knowledge filter webhook",
filter_id=filter_id, filter_id=filter_id,
subscription_id=subscription_id, subscription_id=subscription_id,
error=str(e)) error=str(e),
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View file

@ -23,14 +23,16 @@ async def search(request: Request, search_service, session_manager):
# Extract JWT token from auth middleware # Extract JWT token from auth middleware
jwt_token = request.state.jwt_token jwt_token = request.state.jwt_token
logger.debug("Search API request", logger.debug(
"Search API request",
user=str(user), user=str(user),
user_id=user.user_id if user else None, user_id=user.user_id if user else None,
has_jwt_token=jwt_token is not None, has_jwt_token=jwt_token is not None,
query=query, query=query,
filters=filters, filters=filters,
limit=limit, limit=limit,
score_threshold=score_threshold) score_threshold=score_threshold,
)
result = await search_service.search( result = await search_service.search(
query, query,

View file

@ -3,6 +3,9 @@ from starlette.responses import JSONResponse
from typing import Optional from typing import Optional
from session_manager import User from session_manager import User
from config.settings import is_no_auth_mode from config.settings import is_no_auth_mode
from utils.logging_config import get_logger
logger = get_logger(__name__)
def get_current_user(request: Request, session_manager) -> Optional[User]: def get_current_user(request: Request, session_manager) -> Optional[User]:
@ -25,22 +28,15 @@ def require_auth(session_manager):
async def wrapper(request: Request): async def wrapper(request: Request):
# In no-auth mode, bypass authentication entirely # In no-auth mode, bypass authentication entirely
if is_no_auth_mode(): if is_no_auth_mode():
print(f"[DEBUG] No-auth mode: Creating anonymous user") logger.debug("No-auth mode: Creating anonymous user")
# Create an anonymous user object so endpoints don't break # Create an anonymous user object so endpoints don't break
from session_manager import User from session_manager import User
from datetime import datetime from datetime import datetime
request.state.user = User( from session_manager import AnonymousUser
user_id="anonymous", request.state.user = AnonymousUser()
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
request.state.jwt_token = None # No JWT in no-auth mode request.state.jwt_token = None # No JWT in no-auth mode
print(f"[DEBUG] Set user_id=anonymous, jwt_token=None") logger.debug("Set user_id=anonymous, jwt_token=None")
return await handler(request) return await handler(request)
user = get_current_user(request, session_manager) user = get_current_user(request, session_manager)
@ -72,15 +68,8 @@ def optional_auth(session_manager):
from session_manager import User from session_manager import User
from datetime import datetime from datetime import datetime
request.state.user = User( from session_manager import AnonymousUser
user_id="anonymous", request.state.user = AnonymousUser()
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
request.state.jwt_token = None # No JWT in no-auth mode request.state.jwt_token = None # No JWT in no-auth mode
else: else:
user = get_current_user(request, session_manager) user = get_current_user(request, session_manager)

View file

@ -36,7 +36,12 @@ GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
def is_no_auth_mode(): def is_no_auth_mode():
"""Check if we're running in no-auth mode (OAuth credentials missing)""" """Check if we're running in no-auth mode (OAuth credentials missing)"""
result = not (GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET) result = not (GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET)
logger.debug("Checking auth mode", no_auth_mode=result, has_client_id=GOOGLE_OAUTH_CLIENT_ID is not None, has_client_secret=GOOGLE_OAUTH_CLIENT_SECRET is not None) logger.debug(
"Checking auth mode",
no_auth_mode=result,
has_client_id=GOOGLE_OAUTH_CLIENT_ID is not None,
has_client_secret=GOOGLE_OAUTH_CLIENT_SECRET is not None,
)
return result return result
@ -99,7 +104,9 @@ async def generate_langflow_api_key():
return LANGFLOW_KEY return LANGFLOW_KEY
if not LANGFLOW_SUPERUSER or not LANGFLOW_SUPERUSER_PASSWORD: if not LANGFLOW_SUPERUSER or not LANGFLOW_SUPERUSER_PASSWORD:
logger.warning("LANGFLOW_SUPERUSER and LANGFLOW_SUPERUSER_PASSWORD not set, skipping API key generation") logger.warning(
"LANGFLOW_SUPERUSER and LANGFLOW_SUPERUSER_PASSWORD not set, skipping API key generation"
)
return None return None
try: try:
@ -141,11 +148,19 @@ async def generate_langflow_api_key():
raise KeyError("api_key") raise KeyError("api_key")
LANGFLOW_KEY = api_key LANGFLOW_KEY = api_key
logger.info("Successfully generated Langflow API key", api_key_preview=api_key[:8]) logger.info(
"Successfully generated Langflow API key",
api_key_preview=api_key[:8],
)
return api_key return api_key
except (requests.exceptions.RequestException, KeyError) as e: except (requests.exceptions.RequestException, KeyError) as e:
last_error = e last_error = e
logger.warning("Attempt to generate Langflow API key failed", attempt=attempt, max_attempts=max_attempts, error=str(e)) logger.warning(
"Attempt to generate Langflow API key failed",
attempt=attempt,
max_attempts=max_attempts,
error=str(e),
)
if attempt < max_attempts: if attempt < max_attempts:
time.sleep(delay_seconds) time.sleep(delay_seconds)
else: else:
@ -195,7 +210,9 @@ class AppClients:
logger.warning("Failed to initialize Langflow client", error=str(e)) logger.warning("Failed to initialize Langflow client", error=str(e))
self.langflow_client = None self.langflow_client = None
if self.langflow_client is None: if self.langflow_client is None:
logger.warning("No Langflow client initialized yet, will attempt later on first use") logger.warning(
"No Langflow client initialized yet, will attempt later on first use"
)
# Initialize patched OpenAI client # Initialize patched OpenAI client
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI()) self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
@ -218,7 +235,9 @@ class AppClients:
) )
logger.info("Langflow client initialized on-demand") logger.info("Langflow client initialized on-demand")
except Exception as e: except Exception as e:
logger.error("Failed to initialize Langflow client on-demand", error=str(e)) logger.error(
"Failed to initialize Langflow client on-demand", error=str(e)
)
self.langflow_client = None self.langflow_client = None
return self.langflow_client return self.langflow_client

View file

@ -321,13 +321,18 @@ class ConnectionManager:
if connection_config.config.get( if connection_config.config.get(
"webhook_channel_id" "webhook_channel_id"
) or connection_config.config.get("subscription_id"): ) or connection_config.config.get("subscription_id"):
logger.info("Webhook subscription already exists", connection_id=connection_id) logger.info(
"Webhook subscription already exists", connection_id=connection_id
)
return return
# Check if webhook URL is configured # Check if webhook URL is configured
webhook_url = connection_config.config.get("webhook_url") webhook_url = connection_config.config.get("webhook_url")
if not webhook_url: if not webhook_url:
logger.info("No webhook URL configured, skipping subscription setup", connection_id=connection_id) logger.info(
"No webhook URL configured, skipping subscription setup",
connection_id=connection_id,
)
return return
try: try:
@ -345,10 +350,18 @@ class ConnectionManager:
# Save updated connection config # Save updated connection config
await self.save_connections() await self.save_connections()
logger.info("Successfully set up webhook subscription", connection_id=connection_id, subscription_id=subscription_id) logger.info(
"Successfully set up webhook subscription",
connection_id=connection_id,
subscription_id=subscription_id,
)
except Exception as e: except Exception as e:
logger.error("Failed to setup webhook subscription", connection_id=connection_id, error=str(e)) logger.error(
"Failed to setup webhook subscription",
connection_id=connection_id,
error=str(e),
)
# Don't fail the entire connection setup if webhook fails # Don't fail the entire connection setup if webhook fails
async def _setup_webhook_for_new_connection( async def _setup_webhook_for_new_connection(
@ -356,12 +369,18 @@ class ConnectionManager:
): ):
"""Setup webhook subscription for a newly authenticated connection""" """Setup webhook subscription for a newly authenticated connection"""
try: try:
logger.info("Setting up subscription for newly authenticated connection", connection_id=connection_id) logger.info(
"Setting up subscription for newly authenticated connection",
connection_id=connection_id,
)
# Create and authenticate connector # Create and authenticate connector
connector = self._create_connector(connection_config) connector = self._create_connector(connection_config)
if not await connector.authenticate(): if not await connector.authenticate():
logger.error("Failed to authenticate connector for webhook setup", connection_id=connection_id) logger.error(
"Failed to authenticate connector for webhook setup",
connection_id=connection_id,
)
return return
# Setup subscription # Setup subscription
@ -376,8 +395,16 @@ class ConnectionManager:
# Save updated connection config # Save updated connection config
await self.save_connections() await self.save_connections()
logger.info("Successfully set up webhook subscription", connection_id=connection_id, subscription_id=subscription_id) logger.info(
"Successfully set up webhook subscription",
connection_id=connection_id,
subscription_id=subscription_id,
)
except Exception as e: except Exception as e:
logger.error("Failed to setup webhook subscription for new connection", connection_id=connection_id, error=str(e)) logger.error(
"Failed to setup webhook subscription for new connection",
connection_id=connection_id,
error=str(e),
)
# Don't fail the connection setup if webhook fails # Don't fail the connection setup if webhook fails

View file

@ -4,6 +4,11 @@ from typing import Dict, Any, List, Optional
from .base import BaseConnector, ConnectorDocument from .base import BaseConnector, ConnectorDocument
from utils.logging_config import get_logger from utils.logging_config import get_logger
logger = get_logger(__name__)
from .google_drive import GoogleDriveConnector
from .sharepoint import SharePointConnector
from .onedrive import OneDriveConnector
from .connection_manager import ConnectionManager from .connection_manager import ConnectionManager
logger = get_logger(__name__) logger = get_logger(__name__)
@ -62,7 +67,7 @@ class ConnectorService:
doc_service = DocumentService(session_manager=self.session_manager) doc_service = DocumentService(session_manager=self.session_manager)
print(f"[DEBUG] Processing connector document with ID: {document.id}") logger.debug("Processing connector document", document_id=document.id)
# Process using the existing pipeline but with connector document metadata # Process using the existing pipeline but with connector document metadata
result = await doc_service.process_file_common( result = await doc_service.process_file_common(
@ -77,7 +82,7 @@ class ConnectorService:
connector_type=connector_type, connector_type=connector_type,
) )
print(f"[DEBUG] Document processing result: {result}") logger.debug("Document processing result", result=result)
# If successfully indexed or already exists, update the indexed documents with connector metadata # If successfully indexed or already exists, update the indexed documents with connector metadata
if result["status"] in ["indexed", "unchanged"]: if result["status"] in ["indexed", "unchanged"]:
@ -104,7 +109,7 @@ class ConnectorService:
jwt_token: str = None, jwt_token: str = None,
): ):
"""Update indexed chunks with connector-specific metadata""" """Update indexed chunks with connector-specific metadata"""
print(f"[DEBUG] Looking for chunks with document_id: {document.id}") logger.debug("Looking for chunks", document_id=document.id)
# Find all chunks for this document # Find all chunks for this document
query = {"query": {"term": {"document_id": document.id}}} query = {"query": {"term": {"document_id": document.id}}}
@ -117,26 +122,34 @@ class ConnectorService:
try: try:
response = await opensearch_client.search(index=self.index_name, body=query) response = await opensearch_client.search(index=self.index_name, body=query)
except Exception as e: except Exception as e:
print( logger.error(
f"[ERROR] OpenSearch search failed for connector metadata update: {e}" "OpenSearch search failed for connector metadata update",
error=str(e),
query=query,
) )
print(f"[ERROR] Search query: {query}")
raise raise
print(f"[DEBUG] Search query: {query}") logger.debug(
print( "Search query executed",
f"[DEBUG] Found {len(response['hits']['hits'])} chunks matching document_id: {document.id}" query=query,
chunks_found=len(response["hits"]["hits"]),
document_id=document.id,
) )
# Update each chunk with connector metadata # Update each chunk with connector metadata
print( logger.debug(
f"[DEBUG] Updating {len(response['hits']['hits'])} chunks with connector_type: {connector_type}" "Updating chunks with connector_type",
chunk_count=len(response["hits"]["hits"]),
connector_type=connector_type,
) )
for hit in response["hits"]["hits"]: for hit in response["hits"]["hits"]:
chunk_id = hit["_id"] chunk_id = hit["_id"]
current_connector_type = hit["_source"].get("connector_type", "unknown") current_connector_type = hit["_source"].get("connector_type", "unknown")
print( logger.debug(
f"[DEBUG] Chunk {chunk_id}: current connector_type = {current_connector_type}, updating to {connector_type}" "Updating chunk connector metadata",
chunk_id=chunk_id,
current_connector_type=current_connector_type,
new_connector_type=connector_type,
) )
update_body = { update_body = {
@ -164,10 +177,14 @@ class ConnectorService:
await opensearch_client.update( await opensearch_client.update(
index=self.index_name, id=chunk_id, body=update_body index=self.index_name, id=chunk_id, body=update_body
) )
print(f"[DEBUG] Updated chunk {chunk_id} with connector metadata") logger.debug("Updated chunk with connector metadata", chunk_id=chunk_id)
except Exception as e: except Exception as e:
print(f"[ERROR] OpenSearch update failed for chunk {chunk_id}: {e}") logger.error(
print(f"[ERROR] Update body: {update_body}") "OpenSearch update failed for chunk",
chunk_id=chunk_id,
error=str(e),
update_body=update_body,
)
raise raise
def _get_file_extension(self, mimetype: str) -> str: def _get_file_extension(self, mimetype: str) -> str:
@ -226,11 +243,11 @@ class ConnectorService:
while True: while True:
# List files from connector with limit # List files from connector with limit
logger.info( logger.debug(
"Calling list_files", page_size=page_size, page_token=page_token "Calling list_files", page_size=page_size, page_token=page_token
) )
file_list = await connector.list_files(page_token, max_files=page_size) file_list = await connector.list_files(page_token, limit=page_size)
logger.info( logger.debug(
"Got files from connector", file_count=len(file_list.get("files", [])) "Got files from connector", file_count=len(file_list.get("files", []))
) )
files = file_list["files"] files = file_list["files"]

View file

@ -3,11 +3,13 @@ import sys
# Check for TUI flag FIRST, before any heavy imports # Check for TUI flag FIRST, before any heavy imports
if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == "--tui": if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == "--tui":
from tui.main import run_tui from tui.main import run_tui
run_tui() run_tui()
sys.exit(0) sys.exit(0)
# Configure structured logging early # Configure structured logging early
from utils.logging_config import configure_from_env, get_logger from utils.logging_config import configure_from_env, get_logger
configure_from_env() configure_from_env()
logger = get_logger(__name__) logger = get_logger(__name__)
@ -25,6 +27,8 @@ import torch
# Configuration and setup # Configuration and setup
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
from config.settings import is_no_auth_mode
from utils.gpu_detection import detect_gpu_devices
# Services # Services
from services.document_service import DocumentService from services.document_service import DocumentService
@ -56,8 +60,11 @@ from api import (
# Set multiprocessing start method to 'spawn' for CUDA compatibility # Set multiprocessing start method to 'spawn' for CUDA compatibility
multiprocessing.set_start_method("spawn", force=True) multiprocessing.set_start_method("spawn", force=True)
logger.info("CUDA available", cuda_available=torch.cuda.is_available()) logger.info(
logger.info("CUDA version PyTorch was built with", cuda_version=torch.version.cuda) "CUDA device information",
cuda_available=torch.cuda.is_available(),
cuda_version=torch.version.cuda,
)
async def wait_for_opensearch(): async def wait_for_opensearch():
@ -71,7 +78,12 @@ async def wait_for_opensearch():
logger.info("OpenSearch is ready") logger.info("OpenSearch is ready")
return return
except Exception as e: except Exception as e:
logger.warning("OpenSearch not ready yet", attempt=attempt + 1, max_retries=max_retries, error=str(e)) logger.warning(
"OpenSearch not ready yet",
attempt=attempt + 1,
max_retries=max_retries,
error=str(e),
)
if attempt < max_retries - 1: if attempt < max_retries - 1:
await asyncio.sleep(retry_delay) await asyncio.sleep(retry_delay)
else: else:
@ -93,7 +105,9 @@ async def configure_alerting_security():
# Use admin client (clients.opensearch uses admin credentials) # Use admin client (clients.opensearch uses admin credentials)
response = await clients.opensearch.cluster.put_settings(body=alerting_settings) response = await clients.opensearch.cluster.put_settings(body=alerting_settings)
logger.info("Alerting security settings configured successfully", response=response) logger.info(
"Alerting security settings configured successfully", response=response
)
except Exception as e: except Exception as e:
logger.warning("Failed to configure alerting security settings", error=str(e)) logger.warning("Failed to configure alerting security settings", error=str(e))
# Don't fail startup if alerting config fails # Don't fail startup if alerting config fails
@ -133,9 +147,14 @@ async def init_index():
await clients.opensearch.indices.create( await clients.opensearch.indices.create(
index=knowledge_filter_index_name, body=knowledge_filter_index_body index=knowledge_filter_index_name, body=knowledge_filter_index_body
) )
logger.info("Created knowledge filters index", index_name=knowledge_filter_index_name) logger.info(
"Created knowledge filters index", index_name=knowledge_filter_index_name
)
else: else:
logger.info("Knowledge filters index already exists, skipping creation", index_name=knowledge_filter_index_name) logger.info(
"Knowledge filters index already exists, skipping creation",
index_name=knowledge_filter_index_name,
)
# Configure alerting plugin security settings # Configure alerting plugin security settings
await configure_alerting_security() await configure_alerting_security()
@ -190,9 +209,59 @@ async def init_index_when_ready():
logger.info("OpenSearch index initialization completed successfully") logger.info("OpenSearch index initialization completed successfully")
except Exception as e: except Exception as e:
logger.error("OpenSearch index initialization failed", error=str(e)) logger.error("OpenSearch index initialization failed", error=str(e))
logger.warning("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready") logger.warning(
"OIDC endpoints will still work, but document operations may fail until OpenSearch is ready"
)
async def ingest_default_documents_when_ready(services):
"""Scan the local documents folder and ingest files like a non-auth upload."""
try:
logger.info("Ingesting default documents when ready")
base_dir = os.path.abspath(os.path.join(os.getcwd(), "documents"))
if not os.path.isdir(base_dir):
logger.info("Default documents directory not found; skipping ingestion", base_dir=base_dir)
return
# Collect files recursively
file_paths = [
os.path.join(root, fn)
for root, _, files in os.walk(base_dir)
for fn in files
]
if not file_paths:
logger.info("No default documents found; nothing to ingest", base_dir=base_dir)
return
# Build a processor that DOES NOT set 'owner' on documents (owner_user_id=None)
from models.processors import DocumentFileProcessor
processor = DocumentFileProcessor(
services["document_service"],
owner_user_id=None,
jwt_token=None,
owner_name=None,
owner_email=None,
)
task_id = await services["task_service"].create_custom_task(
"anonymous", file_paths, processor
)
logger.info(
"Started default documents ingestion task",
task_id=task_id,
file_count=len(file_paths),
)
except Exception as e:
logger.error("Default documents ingestion failed", error=str(e))
async def startup_tasks(services):
"""Startup tasks"""
logger.info("Starting startup tasks")
await init_index()
await ingest_default_documents_when_ready(services)
async def initialize_services(): async def initialize_services():
"""Initialize all services and their dependencies""" """Initialize all services and their dependencies"""
# Generate JWT keys if they don't exist # Generate JWT keys if they don't exist
@ -237,9 +306,14 @@ async def initialize_services():
try: try:
await connector_service.initialize() await connector_service.initialize()
loaded_count = len(connector_service.connection_manager.connections) loaded_count = len(connector_service.connection_manager.connections)
logger.info("Loaded persisted connector connections on startup", loaded_count=loaded_count) logger.info(
"Loaded persisted connector connections on startup",
loaded_count=loaded_count,
)
except Exception as e: except Exception as e:
logger.warning("Failed to load persisted connections on startup", error=str(e)) logger.warning(
"Failed to load persisted connections on startup", error=str(e)
)
else: else:
logger.info("[CONNECTORS] Skipping connection loading in no-auth mode") logger.info("[CONNECTORS] Skipping connection loading in no-auth mode")
@ -639,12 +713,15 @@ async def create_app():
app = Starlette(debug=True, routes=routes) app = Starlette(debug=True, routes=routes)
app.state.services = services # Store services for cleanup app.state.services = services # Store services for cleanup
app.state.background_tasks = set()
# Add startup event handler # Add startup event handler
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
# Start index initialization in background to avoid blocking OIDC endpoints # Start index initialization in background to avoid blocking OIDC endpoints
asyncio.create_task(init_index_when_ready()) t1 = asyncio.create_task(startup_tasks(services))
app.state.background_tasks.add(t1)
t1.add_done_callback(app.state.background_tasks.discard)
# Add shutdown event handler # Add shutdown event handler
@app.on_event("shutdown") @app.on_event("shutdown")
@ -687,18 +764,30 @@ async def cleanup_subscriptions_proper(services):
for connection in active_connections: for connection in active_connections:
try: try:
logger.info("Cancelling subscription for connection", connection_id=connection.connection_id) logger.info(
"Cancelling subscription for connection",
connection_id=connection.connection_id,
)
connector = await connector_service.get_connector( connector = await connector_service.get_connector(
connection.connection_id connection.connection_id
) )
if connector: if connector:
subscription_id = connection.config.get("webhook_channel_id") subscription_id = connection.config.get("webhook_channel_id")
await connector.cleanup_subscription(subscription_id) await connector.cleanup_subscription(subscription_id)
logger.info("Cancelled subscription", subscription_id=subscription_id) logger.info(
"Cancelled subscription", subscription_id=subscription_id
)
except Exception as e: except Exception as e:
logger.error("Failed to cancel subscription", connection_id=connection.connection_id, error=str(e)) logger.error(
"Failed to cancel subscription",
connection_id=connection.connection_id,
error=str(e),
)
logger.info("Finished cancelling subscriptions", subscription_count=len(active_connections)) logger.info(
"Finished cancelling subscriptions",
subscription_count=len(active_connections),
)
except Exception as e: except Exception as e:
logger.error("Failed to cleanup subscriptions", error=str(e)) logger.error("Failed to cleanup subscriptions", error=str(e))

View file

@ -1,6 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict from typing import Any, Dict
from .tasks import UploadTask, FileTask from .tasks import UploadTask, FileTask
from utils.logging_config import get_logger
logger = get_logger(__name__)
class TaskProcessor(ABC): class TaskProcessor(ABC):
@ -225,10 +228,12 @@ class S3FileProcessor(TaskProcessor):
index=INDEX_NAME, id=chunk_id, body=chunk_doc index=INDEX_NAME, id=chunk_id, body=chunk_doc
) )
except Exception as e: except Exception as e:
print( logger.error(
f"[ERROR] OpenSearch indexing failed for S3 chunk {chunk_id}: {e}" "OpenSearch indexing failed for S3 chunk",
chunk_id=chunk_id,
error=str(e),
chunk_doc=chunk_doc,
) )
print(f"[ERROR] Chunk document: {chunk_doc}")
raise raise
result = {"status": "indexed", "id": slim_doc["id"]} result = {"status": "indexed", "id": slim_doc["id"]}

View file

@ -111,7 +111,10 @@ class ChatService:
# Pass the complete filter expression as a single header to Langflow (only if we have something to send) # Pass the complete filter expression as a single header to Langflow (only if we have something to send)
if filter_expression: if filter_expression:
logger.info("Sending OpenRAG query filter to Langflow", filter_expression=filter_expression) logger.info(
"Sending OpenRAG query filter to Langflow",
filter_expression=filter_expression,
)
extra_headers["X-LANGFLOW-GLOBAL-VAR-OPENRAG-QUERY-FILTER"] = json.dumps( extra_headers["X-LANGFLOW-GLOBAL-VAR-OPENRAG-QUERY-FILTER"] = json.dumps(
filter_expression filter_expression
) )
@ -201,7 +204,11 @@ class ChatService:
return {"error": "User ID is required", "conversations": []} return {"error": "User ID is required", "conversations": []}
conversations_dict = get_user_conversations(user_id) conversations_dict = get_user_conversations(user_id)
logger.debug("Getting chat history for user", user_id=user_id, conversation_count=len(conversations_dict)) logger.debug(
"Getting chat history for user",
user_id=user_id,
conversation_count=len(conversations_dict),
)
# Convert conversations dict to list format with metadata # Convert conversations dict to list format with metadata
conversations = [] conversations = []

View file

@ -196,7 +196,11 @@ class DocumentService:
index=INDEX_NAME, id=chunk_id, body=chunk_doc index=INDEX_NAME, id=chunk_id, body=chunk_doc
) )
except Exception as e: except Exception as e:
logger.error("OpenSearch indexing failed for chunk", chunk_id=chunk_id, error=str(e)) logger.error(
"OpenSearch indexing failed for chunk",
chunk_id=chunk_id,
error=str(e),
)
logger.error("Chunk document details", chunk_doc=chunk_doc) logger.error("Chunk document details", chunk_doc=chunk_doc)
raise raise
return {"status": "indexed", "id": file_hash} return {"status": "indexed", "id": file_hash}
@ -232,7 +236,9 @@ class DocumentService:
try: try:
exists = await opensearch_client.exists(index=INDEX_NAME, id=file_hash) exists = await opensearch_client.exists(index=INDEX_NAME, id=file_hash)
except Exception as e: except Exception as e:
logger.error("OpenSearch exists check failed", file_hash=file_hash, error=str(e)) logger.error(
"OpenSearch exists check failed", file_hash=file_hash, error=str(e)
)
raise raise
if exists: if exists:
return {"status": "unchanged", "id": file_hash} return {"status": "unchanged", "id": file_hash}
@ -372,7 +378,11 @@ class DocumentService:
index=INDEX_NAME, id=chunk_id, body=chunk_doc index=INDEX_NAME, id=chunk_id, body=chunk_doc
) )
except Exception as e: except Exception as e:
logger.error("OpenSearch indexing failed for batch chunk", chunk_id=chunk_id, error=str(e)) logger.error(
"OpenSearch indexing failed for batch chunk",
chunk_id=chunk_id,
error=str(e),
)
logger.error("Chunk document details", chunk_doc=chunk_doc) logger.error("Chunk document details", chunk_doc=chunk_doc)
raise raise
@ -388,9 +398,13 @@ class DocumentService:
from concurrent.futures import BrokenExecutor from concurrent.futures import BrokenExecutor
if isinstance(e, BrokenExecutor): if isinstance(e, BrokenExecutor):
logger.error("Process pool broken while processing file", file_path=file_path) logger.error(
"Process pool broken while processing file", file_path=file_path
)
logger.info("Worker process likely crashed") logger.info("Worker process likely crashed")
logger.info("You should see detailed crash logs above from the worker process") logger.info(
"You should see detailed crash logs above from the worker process"
)
# Mark pool as broken for potential recreation # Mark pool as broken for potential recreation
self._process_pool_broken = True self._process_pool_broken = True
@ -399,11 +413,15 @@ class DocumentService:
if self._recreate_process_pool(): if self._recreate_process_pool():
logger.info("Process pool successfully recreated") logger.info("Process pool successfully recreated")
else: else:
logger.warning("Failed to recreate process pool - future operations may fail") logger.warning(
"Failed to recreate process pool - future operations may fail"
)
file_task.error = f"Worker process crashed: {str(e)}" file_task.error = f"Worker process crashed: {str(e)}"
else: else:
logger.error("Failed to process file", file_path=file_path, error=str(e)) logger.error(
"Failed to process file", file_path=file_path, error=str(e)
)
file_task.error = str(e) file_task.error = str(e)
logger.error("Full traceback available") logger.error("Full traceback available")

View file

@ -195,7 +195,9 @@ class MonitorService:
return monitors return monitors
except Exception as e: except Exception as e:
logger.error("Error listing monitors for user", user_id=user_id, error=str(e)) logger.error(
"Error listing monitors for user", user_id=user_id, error=str(e)
)
return [] return []
async def list_monitors_for_filter( async def list_monitors_for_filter(
@ -236,7 +238,9 @@ class MonitorService:
return monitors return monitors
except Exception as e: except Exception as e:
logger.error("Error listing monitors for filter", filter_id=filter_id, error=str(e)) logger.error(
"Error listing monitors for filter", filter_id=filter_id, error=str(e)
)
return [] return []
async def _get_or_create_webhook_destination( async def _get_or_create_webhook_destination(

View file

@ -138,7 +138,11 @@ class SearchService:
search_body["min_score"] = score_threshold search_body["min_score"] = score_threshold
# Authentication required - DLS will handle document filtering automatically # Authentication required - DLS will handle document filtering automatically
logger.debug("search_service authentication info", user_id=user_id, has_jwt_token=jwt_token is not None) logger.debug(
"search_service authentication info",
user_id=user_id,
has_jwt_token=jwt_token is not None,
)
if not user_id: if not user_id:
logger.debug("search_service: user_id is None/empty, returning auth error") logger.debug("search_service: user_id is None/empty, returning auth error")
return {"results": [], "error": "Authentication required"} return {"results": [], "error": "Authentication required"}
@ -151,7 +155,9 @@ class SearchService:
try: try:
results = await opensearch_client.search(index=INDEX_NAME, body=search_body) results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
except Exception as e: except Exception as e:
logger.error("OpenSearch query failed", error=str(e), search_body=search_body) logger.error(
"OpenSearch query failed", error=str(e), search_body=search_body
)
# Re-raise the exception so the API returns the error to frontend # Re-raise the exception so the API returns the error to frontend
raise raise

View file

@ -2,11 +2,14 @@ import asyncio
import uuid import uuid
import time import time
import random import random
from typing import Dict from typing import Dict, Optional
from models.tasks import TaskStatus, UploadTask, FileTask from models.tasks import TaskStatus, UploadTask, FileTask
from session_manager import AnonymousUser
from src.utils.gpu_detection import get_worker_count from src.utils.gpu_detection import get_worker_count
from utils.logging_config import get_logger
logger = get_logger(__name__)
class TaskService: class TaskService:
@ -104,7 +107,9 @@ class TaskService:
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e: except Exception as e:
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}") logger.error(
"Background upload processor failed", task_id=task_id, error=str(e)
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@ -136,7 +141,9 @@ class TaskService:
try: try:
await processor.process_item(upload_task, item, file_task) await processor.process_item(upload_task, item, file_task)
except Exception as e: except Exception as e:
print(f"[ERROR] Failed to process item {item}: {e}") logger.error(
"Failed to process item", item=str(item), error=str(e)
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@ -157,13 +164,15 @@ class TaskService:
upload_task.updated_at = time.time() upload_task.updated_at = time.time()
except asyncio.CancelledError: except asyncio.CancelledError:
print(f"[INFO] Background processor for task {task_id} was cancelled") logger.info("Background processor cancelled", task_id=task_id)
if user_id in self.task_store and task_id in self.task_store[user_id]: if user_id in self.task_store and task_id in self.task_store[user_id]:
# Task status and pending files already handled by cancel_task() # Task status and pending files already handled by cancel_task()
pass pass
raise # Re-raise to properly handle cancellation raise # Re-raise to properly handle cancellation
except Exception as e: except Exception as e:
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}") logger.error(
"Background custom processor failed", task_id=task_id, error=str(e)
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@ -171,16 +180,29 @@ class TaskService:
self.task_store[user_id][task_id].status = TaskStatus.FAILED self.task_store[user_id][task_id].status = TaskStatus.FAILED
self.task_store[user_id][task_id].updated_at = time.time() self.task_store[user_id][task_id].updated_at = time.time()
def get_task_status(self, user_id: str, task_id: str) -> dict: def get_task_status(self, user_id: str, task_id: str) -> Optional[dict]:
"""Get the status of a specific upload task""" """Get the status of a specific upload task
if (
not task_id Includes fallback to shared tasks stored under the "anonymous" user key
or user_id not in self.task_store so default system tasks are visible to all users.
or task_id not in self.task_store[user_id] """
): if not task_id:
return None return None
upload_task = self.task_store[user_id][task_id] # Prefer the caller's user_id; otherwise check shared/anonymous tasks
candidate_user_ids = [user_id, AnonymousUser().user_id]
upload_task = None
for candidate_user_id in candidate_user_ids:
if (
candidate_user_id in self.task_store
and task_id in self.task_store[candidate_user_id]
):
upload_task = self.task_store[candidate_user_id][task_id]
break
if upload_task is None:
return None
file_statuses = {} file_statuses = {}
for file_path, file_task in upload_task.file_tasks.items(): for file_path, file_task in upload_task.file_tasks.items():
@ -206,14 +228,21 @@ class TaskService:
} }
def get_all_tasks(self, user_id: str) -> list: def get_all_tasks(self, user_id: str) -> list:
"""Get all tasks for a user""" """Get all tasks for a user
if user_id not in self.task_store:
return []
tasks = [] Returns the union of the user's own tasks and shared default tasks stored
for task_id, upload_task in self.task_store[user_id].items(): under the "anonymous" user key. User-owned tasks take precedence
tasks.append( if a task_id overlaps.
{ """
tasks_by_id = {}
def add_tasks_from_store(store_user_id):
if store_user_id not in self.task_store:
return
for task_id, upload_task in self.task_store[store_user_id].items():
if task_id in tasks_by_id:
continue
tasks_by_id[task_id] = {
"task_id": upload_task.task_id, "task_id": upload_task.task_id,
"status": upload_task.status.value, "status": upload_task.status.value,
"total_files": upload_task.total_files, "total_files": upload_task.total_files,
@ -223,18 +252,36 @@ class TaskService:
"created_at": upload_task.created_at, "created_at": upload_task.created_at,
"updated_at": upload_task.updated_at, "updated_at": upload_task.updated_at,
} }
)
# Sort by creation time, most recent first # First, add user-owned tasks; then shared anonymous;
add_tasks_from_store(user_id)
add_tasks_from_store(AnonymousUser().user_id)
tasks = list(tasks_by_id.values())
tasks.sort(key=lambda x: x["created_at"], reverse=True) tasks.sort(key=lambda x: x["created_at"], reverse=True)
return tasks return tasks
def cancel_task(self, user_id: str, task_id: str) -> bool: def cancel_task(self, user_id: str, task_id: str) -> bool:
"""Cancel a task if it exists and is not already completed""" """Cancel a task if it exists and is not already completed.
if user_id not in self.task_store or task_id not in self.task_store[user_id]:
Supports cancellation of shared default tasks stored under the anonymous user.
"""
# Check candidate user IDs first, then anonymous to find which user ID the task is mapped to
candidate_user_ids = [user_id, AnonymousUser().user_id]
store_user_id = None
for candidate_user_id in candidate_user_ids:
if (
candidate_user_id in self.task_store
and task_id in self.task_store[candidate_user_id]
):
store_user_id = candidate_user_id
break
if store_user_id is None:
return False return False
upload_task = self.task_store[user_id][task_id] upload_task = self.task_store[store_user_id][task_id]
# Can only cancel pending or running tasks # Can only cancel pending or running tasks
if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]: if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:

View file

@ -6,8 +6,13 @@ from typing import Dict, Optional, Any
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
import os import os
from utils.logging_config import get_logger
logger = get_logger(__name__)
from utils.logging_config import get_logger
logger = get_logger(__name__)
@dataclass @dataclass
class User: class User:
"""User information from OAuth provider""" """User information from OAuth provider"""
@ -26,6 +31,19 @@ class User:
if self.last_login is None: if self.last_login is None:
self.last_login = datetime.now() self.last_login = datetime.now()
class AnonymousUser(User):
"""Anonymous user"""
def __init__(self):
super().__init__(
user_id="anonymous",
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
)
class SessionManager: class SessionManager:
"""Manages user sessions and JWT tokens""" """Manages user sessions and JWT tokens"""
@ -80,13 +98,15 @@ class SessionManager:
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
else: else:
print( logger.error(
f"Failed to get user info: {response.status_code} {response.text}" "Failed to get user info",
status_code=response.status_code,
response_text=response.text,
) )
return None return None
except Exception as e: except Exception as e:
print(f"Error getting user info: {e}") logger.error("Error getting user info", error=str(e))
return None return None
async def create_user_session( async def create_user_session(
@ -173,19 +193,24 @@ class SessionManager:
"""Get or create OpenSearch client for user with their JWT""" """Get or create OpenSearch client for user with their JWT"""
from config.settings import is_no_auth_mode from config.settings import is_no_auth_mode
print( logger.debug(
f"[DEBUG] get_user_opensearch_client: user_id={user_id}, jwt_token={'None' if jwt_token is None else 'present'}, no_auth_mode={is_no_auth_mode()}" "get_user_opensearch_client",
user_id=user_id,
jwt_token_present=(jwt_token is not None),
no_auth_mode=is_no_auth_mode(),
) )
# In no-auth mode, create anonymous JWT for OpenSearch DLS # In no-auth mode, create anonymous JWT for OpenSearch DLS
if is_no_auth_mode() and jwt_token is None: if jwt_token is None and (is_no_auth_mode() or user_id in (None, AnonymousUser().user_id)):
if not hasattr(self, "_anonymous_jwt"): if not hasattr(self, "_anonymous_jwt"):
# Create anonymous JWT token for OpenSearch OIDC # Create anonymous JWT token for OpenSearch OIDC
print(f"[DEBUG] Creating anonymous JWT...") logger.debug("Creating anonymous JWT")
self._anonymous_jwt = self._create_anonymous_jwt() self._anonymous_jwt = self._create_anonymous_jwt()
print(f"[DEBUG] Anonymous JWT created: {self._anonymous_jwt[:50]}...") logger.debug(
"Anonymous JWT created", jwt_prefix=self._anonymous_jwt[:50]
)
jwt_token = self._anonymous_jwt jwt_token = self._anonymous_jwt
print(f"[DEBUG] Using anonymous JWT for OpenSearch") logger.debug("Using anonymous JWT for OpenSearch")
# Check if we have a cached client for this user # Check if we have a cached client for this user
if user_id not in self.user_opensearch_clients: if user_id not in self.user_opensearch_clients:
@ -199,14 +224,5 @@ class SessionManager:
def _create_anonymous_jwt(self) -> str: def _create_anonymous_jwt(self) -> str:
"""Create JWT token for anonymous user in no-auth mode""" """Create JWT token for anonymous user in no-auth mode"""
anonymous_user = User( anonymous_user = AnonymousUser()
user_id="anonymous",
email="anonymous@localhost",
name="Anonymous User",
picture=None,
provider="none",
created_at=datetime.now(),
last_login=datetime.now(),
)
return self.create_jwt_token(anonymous_user) return self.create_jwt_token(anonymous_user)

View file

@ -3,6 +3,9 @@
import sys import sys
from pathlib import Path from pathlib import Path
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from utils.logging_config import get_logger
logger = get_logger(__name__)
from .screens.welcome import WelcomeScreen from .screens.welcome import WelcomeScreen
from .screens.config import ConfigScreen from .screens.config import ConfigScreen
@ -187,7 +190,7 @@ class OpenRAGTUI(App):
self, self,
"No container runtime found. Please install Docker or Podman.", "No container runtime found. Please install Docker or Podman.",
severity="warning", severity="warning",
timeout=10 timeout=10,
) )
# Load existing config if available # Load existing config if available
@ -208,7 +211,9 @@ class OpenRAGTUI(App):
# Check Podman macOS memory if applicable # Check Podman macOS memory if applicable
runtime_info = self.container_manager.get_runtime_info() runtime_info = self.container_manager.get_runtime_info()
if runtime_info.runtime_type.value == "podman": if runtime_info.runtime_type.value == "podman":
is_sufficient, _, message = self.platform_detector.check_podman_macos_memory() is_sufficient, _, message = (
self.platform_detector.check_podman_macos_memory()
)
if not is_sufficient: if not is_sufficient:
return False, f"Podman VM memory insufficient:\n{message}" return False, f"Podman VM memory insufficient:\n{message}"
@ -221,10 +226,10 @@ def run_tui():
app = OpenRAGTUI() app = OpenRAGTUI()
app.run() app.run()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nOpenRAG TUI interrupted by user") logger.info("OpenRAG TUI interrupted by user")
sys.exit(0) sys.exit(0)
except Exception as e: except Exception as e:
print(f"Error running OpenRAG TUI: {e}") logger.error("Error running OpenRAG TUI", error=str(e))
sys.exit(1) sys.exit(1)

View file

@ -8,6 +8,9 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, AsyncIterator from typing import Dict, List, Optional, AsyncIterator
from utils.logging_config import get_logger
logger = get_logger(__name__)
from ..utils.platform import PlatformDetector, RuntimeInfo, RuntimeType from ..utils.platform import PlatformDetector, RuntimeInfo, RuntimeType
from utils.gpu_detection import detect_gpu_devices from utils.gpu_detection import detect_gpu_devices
@ -15,6 +18,7 @@ from utils.gpu_detection import detect_gpu_devices
class ServiceStatus(Enum): class ServiceStatus(Enum):
"""Container service status.""" """Container service status."""
UNKNOWN = "unknown" UNKNOWN = "unknown"
RUNNING = "running" RUNNING = "running"
STOPPED = "stopped" STOPPED = "stopped"
@ -27,6 +31,7 @@ class ServiceStatus(Enum):
@dataclass @dataclass
class ServiceInfo: class ServiceInfo:
"""Container service information.""" """Container service information."""
name: str name: str
status: ServiceStatus status: ServiceStatus
health: Optional[str] = None health: Optional[str] = None
@ -63,7 +68,7 @@ class ContainerManager:
"openrag-frontend", "openrag-frontend",
"opensearch", "opensearch",
"dashboards", "dashboards",
"langflow" "langflow",
] ]
# Map container names to service names # Map container names to service names
@ -72,7 +77,7 @@ class ContainerManager:
"openrag-frontend": "openrag-frontend", "openrag-frontend": "openrag-frontend",
"os": "opensearch", "os": "opensearch",
"osdash": "dashboards", "osdash": "dashboards",
"langflow": "langflow" "langflow": "langflow",
} }
def is_available(self) -> bool: def is_available(self) -> bool:
@ -87,7 +92,9 @@ class ContainerManager:
"""Get installation instructions if runtime is not available.""" """Get installation instructions if runtime is not available."""
return self.platform_detector.get_installation_instructions() return self.platform_detector.get_installation_instructions()
async def _run_compose_command(self, args: List[str], cpu_mode: Optional[bool] = None) -> tuple[bool, str, str]: async def _run_compose_command(
self, args: List[str], cpu_mode: Optional[bool] = None
) -> tuple[bool, str, str]:
"""Run a compose command and return (success, stdout, stderr).""" """Run a compose command and return (success, stdout, stderr)."""
if not self.is_available(): if not self.is_available():
return False, "", "No container runtime available" return False, "", "No container runtime available"
@ -102,7 +109,7 @@ class ContainerManager:
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=Path.cwd() cwd=Path.cwd(),
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
@ -115,7 +122,9 @@ class ContainerManager:
except Exception as e: except Exception as e:
return False, "", f"Command execution failed: {e}" return False, "", f"Command execution failed: {e}"
async def _run_compose_command_streaming(self, args: List[str], cpu_mode: Optional[bool] = None) -> AsyncIterator[str]: async def _run_compose_command_streaming(
self, args: List[str], cpu_mode: Optional[bool] = None
) -> AsyncIterator[str]:
"""Run a compose command and yield output lines in real-time.""" """Run a compose command and yield output lines in real-time."""
if not self.is_available(): if not self.is_available():
yield "No container runtime available" yield "No container runtime available"
@ -131,7 +140,7 @@ class ContainerManager:
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT, # Combine stderr with stdout for unified output stderr=asyncio.subprocess.STDOUT, # Combine stderr with stdout for unified output
cwd=Path.cwd() cwd=Path.cwd(),
) )
# Simple approach: read line by line and yield each one # Simple approach: read line by line and yield each one
@ -159,9 +168,7 @@ class ContainerManager:
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
@ -174,10 +181,12 @@ class ContainerManager:
except Exception as e: except Exception as e:
return False, "", f"Command execution failed: {e}" return False, "", f"Command execution failed: {e}"
def _process_service_json(self, service: Dict, services: Dict[str, ServiceInfo]) -> None: def _process_service_json(
self, service: Dict, services: Dict[str, ServiceInfo]
) -> None:
"""Process a service JSON object and add it to the services dict.""" """Process a service JSON object and add it to the services dict."""
# Debug print to see the actual service data # Debug print to see the actual service data
print(f"DEBUG: Processing service data: {json.dumps(service, indent=2)}") logger.debug("Processing service data", service_data=service)
container_name = service.get("Name", "") container_name = service.get("Name", "")
@ -203,7 +212,9 @@ class ContainerManager:
# Extract ports # Extract ports
ports_str = service.get("Ports", "") ports_str = service.get("Ports", "")
ports = [p.strip() for p in ports_str.split(",") if p.strip()] if ports_str else [] ports = (
[p.strip() for p in ports_str.split(",") if p.strip()] if ports_str else []
)
# Extract image # Extract image
image = service.get("Image", "N/A") image = service.get("Image", "N/A")
@ -216,7 +227,9 @@ class ContainerManager:
image=image, image=image,
) )
async def get_service_status(self, force_refresh: bool = False) -> Dict[str, ServiceInfo]: async def get_service_status(
self, force_refresh: bool = False
) -> Dict[str, ServiceInfo]:
"""Get current status of all services.""" """Get current status of all services."""
current_time = time.time() current_time = time.time()
@ -280,26 +293,30 @@ class ContainerManager:
pass pass
else: else:
# For Docker, use compose ps command # For Docker, use compose ps command
success, stdout, stderr = await self._run_compose_command(["ps", "--format", "json"]) success, stdout, stderr = await self._run_compose_command(
["ps", "--format", "json"]
)
if success and stdout.strip(): if success and stdout.strip():
try: try:
# Handle both single JSON object (Podman) and multiple JSON objects (Docker) # Handle both single JSON object (Podman) and multiple JSON objects (Docker)
if stdout.strip().startswith('[') and stdout.strip().endswith(']'): if stdout.strip().startswith("[") and stdout.strip().endswith("]"):
# JSON array format # JSON array format
service_list = json.loads(stdout.strip()) service_list = json.loads(stdout.strip())
for service in service_list: for service in service_list:
self._process_service_json(service, services) self._process_service_json(service, services)
else: else:
# Line-by-line JSON format # Line-by-line JSON format
for line in stdout.strip().split('\n'): for line in stdout.strip().split("\n"):
if line.strip() and line.startswith('{'): if line.strip() and line.startswith("{"):
service = json.loads(line) service = json.loads(line)
self._process_service_json(service, services) self._process_service_json(service, services)
except json.JSONDecodeError: except json.JSONDecodeError:
# Fallback to parsing text output # Fallback to parsing text output
lines = stdout.strip().split('\n') lines = stdout.strip().split("\n")
if len(lines) > 1: # Make sure we have at least a header and one line if (
len(lines) > 1
): # Make sure we have at least a header and one line
for line in lines[1:]: # Skip header for line in lines[1:]: # Skip header
if line.strip(): if line.strip():
parts = line.split() parts = line.split()
@ -319,12 +336,16 @@ class ContainerManager:
else: else:
status = ServiceStatus.UNKNOWN status = ServiceStatus.UNKNOWN
services[name] = ServiceInfo(name=name, status=status) services[name] = ServiceInfo(
name=name, status=status
)
# Add expected services that weren't found # Add expected services that weren't found
for expected in self.expected_services: for expected in self.expected_services:
if expected not in services: if expected not in services:
services[expected] = ServiceInfo(name=expected, status=ServiceStatus.MISSING) services[expected] = ServiceInfo(
name=expected, status=ServiceStatus.MISSING
)
self.services_cache = services self.services_cache = services
self.last_status_update = current_time self.last_status_update = current_time
@ -337,9 +358,9 @@ class ContainerManager:
for image in images: for image in images:
if not image or image in digests: if not image or image in digests:
continue continue
success, stdout, _ = await self._run_runtime_command([ success, stdout, _ = await self._run_runtime_command(
"image", "inspect", image, "--format", "{{.Id}}" ["image", "inspect", image, "--format", "{{.Id}}"]
]) )
if success and stdout.strip(): if success and stdout.strip():
digests[image] = stdout.strip().splitlines()[0] digests[image] = stdout.strip().splitlines()[0]
return digests return digests
@ -353,13 +374,15 @@ class ContainerManager:
continue continue
for line in compose.read_text().splitlines(): for line in compose.read_text().splitlines():
line = line.strip() line = line.strip()
if not line or line.startswith('#'): if not line or line.startswith("#"):
continue continue
if line.startswith('image:'): if line.startswith("image:"):
# image: repo/name:tag # image: repo/name:tag
val = line.split(':', 1)[1].strip() val = line.split(":", 1)[1].strip()
# Remove quotes if present # Remove quotes if present
if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")): if (val.startswith('"') and val.endswith('"')) or (
val.startswith("'") and val.endswith("'")
):
val = val[1:-1] val = val[1:-1]
images.add(val) images.add(val)
except Exception: except Exception:
@ -374,21 +397,25 @@ class ContainerManager:
expected = self._parse_compose_images() expected = self._parse_compose_images()
results: list[tuple[str, str]] = [] results: list[tuple[str, str]] = []
for image in expected: for image in expected:
digest = '-' digest = "-"
success, stdout, _ = await self._run_runtime_command([ success, stdout, _ = await self._run_runtime_command(
'image', 'inspect', image, '--format', '{{.Id}}' ["image", "inspect", image, "--format", "{{.Id}}"]
]) )
if success and stdout.strip(): if success and stdout.strip():
digest = stdout.strip().splitlines()[0] digest = stdout.strip().splitlines()[0]
results.append((image, digest)) results.append((image, digest))
results.sort(key=lambda x: x[0]) results.sort(key=lambda x: x[0])
return results return results
async def start_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]: async def start_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Start all services and yield progress updates.""" """Start all services and yield progress updates."""
yield False, "Starting OpenRAG services..." yield False, "Starting OpenRAG services..."
success, stdout, stderr = await self._run_compose_command(["up", "-d"], cpu_mode) success, stdout, stderr = await self._run_compose_command(
["up", "-d"], cpu_mode
)
if success: if success:
yield True, "Services started successfully" yield True, "Services started successfully"
@ -406,7 +433,9 @@ class ContainerManager:
else: else:
yield False, f"Failed to stop services: {stderr}" yield False, f"Failed to stop services: {stderr}"
async def restart_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]: async def restart_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Restart all services and yield progress updates.""" """Restart all services and yield progress updates."""
yield False, "Restarting OpenRAG services..." yield False, "Restarting OpenRAG services..."
@ -417,7 +446,9 @@ class ContainerManager:
else: else:
yield False, f"Failed to restart services: {stderr}" yield False, f"Failed to restart services: {stderr}"
async def upgrade_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]: async def upgrade_services(
self, cpu_mode: bool = False
) -> AsyncIterator[tuple[bool, str]]:
"""Upgrade services (pull latest images and restart) and yield progress updates.""" """Upgrade services (pull latest images and restart) and yield progress updates."""
yield False, "Pulling latest images..." yield False, "Pulling latest images..."
@ -436,7 +467,9 @@ class ContainerManager:
# Restart with new images using streaming output # Restart with new images using streaming output
restart_success = True restart_success = True
async for line in self._run_compose_command_streaming(["up", "-d", "--force-recreate"], cpu_mode): async for line in self._run_compose_command_streaming(
["up", "-d", "--force-recreate"], cpu_mode
):
yield False, line yield False, line
# Check for error patterns in the output # Check for error patterns in the output
if "error" in line.lower() or "failed" in line.lower(): if "error" in line.lower() or "failed" in line.lower():
@ -452,12 +485,9 @@ class ContainerManager:
yield False, "Stopping all services..." yield False, "Stopping all services..."
# Stop and remove everything # Stop and remove everything
success, stdout, stderr = await self._run_compose_command([ success, stdout, stderr = await self._run_compose_command(
"down", ["down", "--volumes", "--remove-orphans", "--rmi", "local"]
"--volumes", )
"--remove-orphans",
"--rmi", "local"
])
if not success: if not success:
yield False, f"Failed to stop services: {stderr}" yield False, f"Failed to stop services: {stderr}"
@ -469,11 +499,18 @@ class ContainerManager:
# This is more thorough than just compose down # This is more thorough than just compose down
await self._run_runtime_command(["system", "prune", "-f"]) await self._run_runtime_command(["system", "prune", "-f"])
yield True, "System reset completed - all containers, volumes, and local images removed" yield (
True,
"System reset completed - all containers, volumes, and local images removed",
)
async def get_service_logs(self, service_name: str, lines: int = 100) -> tuple[bool, str]: async def get_service_logs(
self, service_name: str, lines: int = 100
) -> tuple[bool, str]:
"""Get logs for a specific service.""" """Get logs for a specific service."""
success, stdout, stderr = await self._run_compose_command(["logs", "--tail", str(lines), service_name]) success, stdout, stderr = await self._run_compose_command(
["logs", "--tail", str(lines), service_name]
)
if success: if success:
return True, stdout return True, stdout
@ -486,15 +523,23 @@ class ContainerManager:
yield "No container runtime available" yield "No container runtime available"
return return
compose_file = self.cpu_compose_file if self.use_cpu_compose else self.compose_file compose_file = (
cmd = self.runtime_info.compose_command + ["-f", str(compose_file), "logs", "-f", service_name] self.cpu_compose_file if self.use_cpu_compose else self.compose_file
)
cmd = self.runtime_info.compose_command + [
"-f",
str(compose_file),
"logs",
"-f",
service_name,
]
try: try:
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT, stderr=asyncio.subprocess.STDOUT,
cwd=Path.cwd() cwd=Path.cwd(),
) )
if process.stdout: if process.stdout:
@ -515,11 +560,13 @@ class ContainerManager:
stats = {} stats = {}
# Get container stats # Get container stats
success, stdout, stderr = await self._run_runtime_command(["stats", "--no-stream", "--format", "json"]) success, stdout, stderr = await self._run_runtime_command(
["stats", "--no-stream", "--format", "json"]
)
if success and stdout.strip(): if success and stdout.strip():
try: try:
for line in stdout.strip().split('\n'): for line in stdout.strip().split("\n"):
if line.strip(): if line.strip():
data = json.loads(line) data = json.loads(line)
name = data.get("Name", data.get("Container", "")) name = data.get("Name", data.get("Container", ""))
@ -548,7 +595,7 @@ class ContainerManager:
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=Path.cwd() cwd=Path.cwd(),
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
@ -582,5 +629,7 @@ class ContainerManager:
if self.runtime_info.runtime_type != RuntimeType.PODMAN: if self.runtime_info.runtime_type != RuntimeType.PODMAN:
return True, "Not using Podman" return True, "Not using Podman"
is_sufficient, memory_mb, message = self.platform_detector.check_podman_macos_memory() is_sufficient, memory_mb, message = (
self.platform_detector.check_podman_macos_memory()
)
return is_sufficient, message return is_sufficient, message

View file

@ -7,6 +7,9 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, List from typing import Dict, Optional, List
from dataclasses import dataclass, field from dataclasses import dataclass, field
from utils.logging_config import get_logger
logger = get_logger(__name__)
from ..utils.validation import ( from ..utils.validation import (
validate_openai_api_key, validate_openai_api_key,
@ -14,13 +17,14 @@ from ..utils.validation import (
validate_non_empty, validate_non_empty,
validate_url, validate_url,
validate_documents_paths, validate_documents_paths,
sanitize_env_value sanitize_env_value,
) )
@dataclass @dataclass
class EnvConfig: class EnvConfig:
"""Environment configuration data.""" """Environment configuration data."""
# Core settings # Core settings
openai_api_key: str = "" openai_api_key: str = ""
opensearch_password: str = "" opensearch_password: str = ""
@ -64,7 +68,7 @@ class EnvManager:
"""Generate a secure password for OpenSearch.""" """Generate a secure password for OpenSearch."""
# Generate a 16-character password with letters, digits, and symbols # Generate a 16-character password with letters, digits, and symbols
alphabet = string.ascii_letters + string.digits + "!@#$%^&*" alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
return ''.join(secrets.choice(alphabet) for _ in range(16)) return "".join(secrets.choice(alphabet) for _ in range(16))
def generate_langflow_secret_key(self) -> str: def generate_langflow_secret_key(self) -> str:
"""Generate a secure secret key for Langflow.""" """Generate a secure secret key for Langflow."""
@ -76,37 +80,37 @@ class EnvManager:
return False return False
try: try:
with open(self.env_file, 'r') as f: with open(self.env_file, "r") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line or line.startswith('#'): if not line or line.startswith("#"):
continue continue
if '=' in line: if "=" in line:
key, value = line.split('=', 1) key, value = line.split("=", 1)
key = key.strip() key = key.strip()
value = sanitize_env_value(value) value = sanitize_env_value(value)
# Map env vars to config attributes # Map env vars to config attributes
attr_map = { attr_map = {
'OPENAI_API_KEY': 'openai_api_key', "OPENAI_API_KEY": "openai_api_key",
'OPENSEARCH_PASSWORD': 'opensearch_password', "OPENSEARCH_PASSWORD": "opensearch_password",
'LANGFLOW_SECRET_KEY': 'langflow_secret_key', "LANGFLOW_SECRET_KEY": "langflow_secret_key",
'LANGFLOW_SUPERUSER': 'langflow_superuser', "LANGFLOW_SUPERUSER": "langflow_superuser",
'LANGFLOW_SUPERUSER_PASSWORD': 'langflow_superuser_password', "LANGFLOW_SUPERUSER_PASSWORD": "langflow_superuser_password",
'FLOW_ID': 'flow_id', "FLOW_ID": "flow_id",
'GOOGLE_OAUTH_CLIENT_ID': 'google_oauth_client_id', "GOOGLE_OAUTH_CLIENT_ID": "google_oauth_client_id",
'GOOGLE_OAUTH_CLIENT_SECRET': 'google_oauth_client_secret', "GOOGLE_OAUTH_CLIENT_SECRET": "google_oauth_client_secret",
'MICROSOFT_GRAPH_OAUTH_CLIENT_ID': 'microsoft_graph_oauth_client_id', "MICROSOFT_GRAPH_OAUTH_CLIENT_ID": "microsoft_graph_oauth_client_id",
'MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET': 'microsoft_graph_oauth_client_secret', "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET": "microsoft_graph_oauth_client_secret",
'WEBHOOK_BASE_URL': 'webhook_base_url', "WEBHOOK_BASE_URL": "webhook_base_url",
'AWS_ACCESS_KEY_ID': 'aws_access_key_id', "AWS_ACCESS_KEY_ID": "aws_access_key_id",
'AWS_SECRET_ACCESS_KEY': 'aws_secret_access_key', "AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
'LANGFLOW_PUBLIC_URL': 'langflow_public_url', "LANGFLOW_PUBLIC_URL": "langflow_public_url",
'OPENRAG_DOCUMENTS_PATHS': 'openrag_documents_paths', "OPENRAG_DOCUMENTS_PATHS": "openrag_documents_paths",
'LANGFLOW_AUTO_LOGIN': 'langflow_auto_login', "LANGFLOW_AUTO_LOGIN": "langflow_auto_login",
'LANGFLOW_NEW_USER_IS_ACTIVE': 'langflow_new_user_is_active', "LANGFLOW_NEW_USER_IS_ACTIVE": "langflow_new_user_is_active",
'LANGFLOW_ENABLE_SUPERUSER_CLI': 'langflow_enable_superuser_cli', "LANGFLOW_ENABLE_SUPERUSER_CLI": "langflow_enable_superuser_cli",
} }
if key in attr_map: if key in attr_map:
@ -115,7 +119,7 @@ class EnvManager:
return True return True
except Exception as e: except Exception as e:
print(f"Error loading .env file: {e}") logger.error("Error loading .env file", error=str(e))
return False return False
def setup_secure_defaults(self) -> None: def setup_secure_defaults(self) -> None:
@ -140,40 +144,71 @@ class EnvManager:
# Always validate OpenAI API key # Always validate OpenAI API key
if not validate_openai_api_key(self.config.openai_api_key): if not validate_openai_api_key(self.config.openai_api_key):
self.config.validation_errors['openai_api_key'] = "Invalid OpenAI API key format (should start with sk-)" self.config.validation_errors["openai_api_key"] = (
"Invalid OpenAI API key format (should start with sk-)"
)
# Validate documents paths only if provided (optional) # Validate documents paths only if provided (optional)
if self.config.openrag_documents_paths: if self.config.openrag_documents_paths:
is_valid, error_msg, _ = validate_documents_paths(self.config.openrag_documents_paths) is_valid, error_msg, _ = validate_documents_paths(
self.config.openrag_documents_paths
)
if not is_valid: if not is_valid:
self.config.validation_errors['openrag_documents_paths'] = error_msg self.config.validation_errors["openrag_documents_paths"] = error_msg
# Validate required fields # Validate required fields
if not validate_non_empty(self.config.opensearch_password): if not validate_non_empty(self.config.opensearch_password):
self.config.validation_errors['opensearch_password'] = "OpenSearch password is required" self.config.validation_errors["opensearch_password"] = (
"OpenSearch password is required"
)
# Langflow secret key is auto-generated; no user input required # Langflow secret key is auto-generated; no user input required
if not validate_non_empty(self.config.langflow_superuser_password): if not validate_non_empty(self.config.langflow_superuser_password):
self.config.validation_errors['langflow_superuser_password'] = "Langflow superuser password is required" self.config.validation_errors["langflow_superuser_password"] = (
"Langflow superuser password is required"
)
if mode == "full": if mode == "full":
# Validate OAuth settings if provided # Validate OAuth settings if provided
if self.config.google_oauth_client_id and not validate_google_oauth_client_id(self.config.google_oauth_client_id): if (
self.config.validation_errors['google_oauth_client_id'] = "Invalid Google OAuth client ID format" self.config.google_oauth_client_id
and not validate_google_oauth_client_id(
self.config.google_oauth_client_id
)
):
self.config.validation_errors["google_oauth_client_id"] = (
"Invalid Google OAuth client ID format"
)
if self.config.google_oauth_client_id and not validate_non_empty(self.config.google_oauth_client_secret): if self.config.google_oauth_client_id and not validate_non_empty(
self.config.validation_errors['google_oauth_client_secret'] = "Google OAuth client secret required when client ID is provided" self.config.google_oauth_client_secret
):
self.config.validation_errors["google_oauth_client_secret"] = (
"Google OAuth client secret required when client ID is provided"
)
if self.config.microsoft_graph_oauth_client_id and not validate_non_empty(self.config.microsoft_graph_oauth_client_secret): if self.config.microsoft_graph_oauth_client_id and not validate_non_empty(
self.config.validation_errors['microsoft_graph_oauth_client_secret'] = "Microsoft Graph client secret required when client ID is provided" self.config.microsoft_graph_oauth_client_secret
):
self.config.validation_errors["microsoft_graph_oauth_client_secret"] = (
"Microsoft Graph client secret required when client ID is provided"
)
# Validate optional URLs if provided # Validate optional URLs if provided
if self.config.webhook_base_url and not validate_url(self.config.webhook_base_url): if self.config.webhook_base_url and not validate_url(
self.config.validation_errors['webhook_base_url'] = "Invalid webhook URL format" self.config.webhook_base_url
):
self.config.validation_errors["webhook_base_url"] = (
"Invalid webhook URL format"
)
if self.config.langflow_public_url and not validate_url(self.config.langflow_public_url): if self.config.langflow_public_url and not validate_url(
self.config.validation_errors['langflow_public_url'] = "Invalid Langflow public URL format" self.config.langflow_public_url
):
self.config.validation_errors["langflow_public_url"] = (
"Invalid Langflow public URL format"
)
return len(self.config.validation_errors) == 0 return len(self.config.validation_errors) == 0
@ -184,11 +219,11 @@ class EnvManager:
self.setup_secure_defaults() self.setup_secure_defaults()
# Create timestamped backup if file exists # Create timestamped backup if file exists
if self.env_file.exists(): if self.env_file.exists():
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = self.env_file.with_suffix(f'.env.backup.{timestamp}') backup_file = self.env_file.with_suffix(f".env.backup.{timestamp}")
self.env_file.rename(backup_file) self.env_file.rename(backup_file)
with open(self.env_file, 'w') as f: with open(self.env_file, "w") as f:
f.write("# OpenRAG Environment Configuration\n") f.write("# OpenRAG Environment Configuration\n")
f.write("# Generated by OpenRAG TUI\n\n") f.write("# Generated by OpenRAG TUI\n\n")
@ -196,31 +231,53 @@ class EnvManager:
f.write("# Core settings\n") f.write("# Core settings\n")
f.write(f"LANGFLOW_SECRET_KEY={self.config.langflow_secret_key}\n") f.write(f"LANGFLOW_SECRET_KEY={self.config.langflow_secret_key}\n")
f.write(f"LANGFLOW_SUPERUSER={self.config.langflow_superuser}\n") f.write(f"LANGFLOW_SUPERUSER={self.config.langflow_superuser}\n")
f.write(f"LANGFLOW_SUPERUSER_PASSWORD={self.config.langflow_superuser_password}\n") f.write(
f"LANGFLOW_SUPERUSER_PASSWORD={self.config.langflow_superuser_password}\n"
)
f.write(f"FLOW_ID={self.config.flow_id}\n") f.write(f"FLOW_ID={self.config.flow_id}\n")
f.write(f"OPENSEARCH_PASSWORD={self.config.opensearch_password}\n") f.write(f"OPENSEARCH_PASSWORD={self.config.opensearch_password}\n")
f.write(f"OPENAI_API_KEY={self.config.openai_api_key}\n") f.write(f"OPENAI_API_KEY={self.config.openai_api_key}\n")
f.write(f"OPENRAG_DOCUMENTS_PATHS={self.config.openrag_documents_paths}\n") f.write(
f"OPENRAG_DOCUMENTS_PATHS={self.config.openrag_documents_paths}\n"
)
f.write("\n") f.write("\n")
# Langflow auth settings # Langflow auth settings
f.write("# Langflow auth settings\n") f.write("# Langflow auth settings\n")
f.write(f"LANGFLOW_AUTO_LOGIN={self.config.langflow_auto_login}\n") f.write(f"LANGFLOW_AUTO_LOGIN={self.config.langflow_auto_login}\n")
f.write(f"LANGFLOW_NEW_USER_IS_ACTIVE={self.config.langflow_new_user_is_active}\n") f.write(
f.write(f"LANGFLOW_ENABLE_SUPERUSER_CLI={self.config.langflow_enable_superuser_cli}\n") f"LANGFLOW_NEW_USER_IS_ACTIVE={self.config.langflow_new_user_is_active}\n"
)
f.write(
f"LANGFLOW_ENABLE_SUPERUSER_CLI={self.config.langflow_enable_superuser_cli}\n"
)
f.write("\n") f.write("\n")
# OAuth settings # OAuth settings
if self.config.google_oauth_client_id or self.config.google_oauth_client_secret: if (
self.config.google_oauth_client_id
or self.config.google_oauth_client_secret
):
f.write("# Google OAuth settings\n") f.write("# Google OAuth settings\n")
f.write(f"GOOGLE_OAUTH_CLIENT_ID={self.config.google_oauth_client_id}\n") f.write(
f.write(f"GOOGLE_OAUTH_CLIENT_SECRET={self.config.google_oauth_client_secret}\n") f"GOOGLE_OAUTH_CLIENT_ID={self.config.google_oauth_client_id}\n"
)
f.write(
f"GOOGLE_OAUTH_CLIENT_SECRET={self.config.google_oauth_client_secret}\n"
)
f.write("\n") f.write("\n")
if self.config.microsoft_graph_oauth_client_id or self.config.microsoft_graph_oauth_client_secret: if (
self.config.microsoft_graph_oauth_client_id
or self.config.microsoft_graph_oauth_client_secret
):
f.write("# Microsoft Graph OAuth settings\n") f.write("# Microsoft Graph OAuth settings\n")
f.write(f"MICROSOFT_GRAPH_OAUTH_CLIENT_ID={self.config.microsoft_graph_oauth_client_id}\n") f.write(
f.write(f"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET={self.config.microsoft_graph_oauth_client_secret}\n") f"MICROSOFT_GRAPH_OAUTH_CLIENT_ID={self.config.microsoft_graph_oauth_client_id}\n"
)
f.write(
f"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET={self.config.microsoft_graph_oauth_client_secret}\n"
)
f.write("\n") f.write("\n")
# Optional settings # Optional settings
@ -245,16 +302,31 @@ class EnvManager:
return True return True
except Exception as e: except Exception as e:
print(f"Error saving .env file: {e}") logger.error("Error saving .env file", error=str(e))
return False return False
def get_no_auth_setup_fields(self) -> List[tuple[str, str, str, bool]]: def get_no_auth_setup_fields(self) -> List[tuple[str, str, str, bool]]:
"""Get fields required for no-auth setup mode. Returns (field_name, display_name, placeholder, can_generate).""" """Get fields required for no-auth setup mode. Returns (field_name, display_name, placeholder, can_generate)."""
return [ return [
("openai_api_key", "OpenAI API Key", "sk-...", False), ("openai_api_key", "OpenAI API Key", "sk-...", False),
("opensearch_password", "OpenSearch Password", "Will be auto-generated if empty", True), (
("langflow_superuser_password", "Langflow Superuser Password", "Will be auto-generated if empty", True), "opensearch_password",
("openrag_documents_paths", "Documents Paths", "./documents,/path/to/more/docs", False), "OpenSearch Password",
"Will be auto-generated if empty",
True,
),
(
"langflow_superuser_password",
"Langflow Superuser Password",
"Will be auto-generated if empty",
True,
),
(
"openrag_documents_paths",
"Documents Paths",
"./documents,/path/to/more/docs",
False,
),
] ]
def get_full_setup_fields(self) -> List[tuple[str, str, str, bool]]: def get_full_setup_fields(self) -> List[tuple[str, str, str, bool]]:
@ -262,24 +334,46 @@ class EnvManager:
base_fields = self.get_no_auth_setup_fields() base_fields = self.get_no_auth_setup_fields()
oauth_fields = [ oauth_fields = [
("google_oauth_client_id", "Google OAuth Client ID", "xxx.apps.googleusercontent.com", False), (
"google_oauth_client_id",
"Google OAuth Client ID",
"xxx.apps.googleusercontent.com",
False,
),
("google_oauth_client_secret", "Google OAuth Client Secret", "", False), ("google_oauth_client_secret", "Google OAuth Client Secret", "", False),
("microsoft_graph_oauth_client_id", "Microsoft Graph Client ID", "", False), ("microsoft_graph_oauth_client_id", "Microsoft Graph Client ID", "", False),
("microsoft_graph_oauth_client_secret", "Microsoft Graph Client Secret", "", False), (
"microsoft_graph_oauth_client_secret",
"Microsoft Graph Client Secret",
"",
False,
),
] ]
optional_fields = [ optional_fields = [
("webhook_base_url", "Webhook Base URL (optional)", "https://your-domain.com", False), (
"webhook_base_url",
"Webhook Base URL (optional)",
"https://your-domain.com",
False,
),
("aws_access_key_id", "AWS Access Key ID (optional)", "", False), ("aws_access_key_id", "AWS Access Key ID (optional)", "", False),
("aws_secret_access_key", "AWS Secret Access Key (optional)", "", False), ("aws_secret_access_key", "AWS Secret Access Key (optional)", "", False),
("langflow_public_url", "Langflow Public URL (optional)", "http://localhost:7860", False), (
"langflow_public_url",
"Langflow Public URL (optional)",
"http://localhost:7860",
False,
),
] ]
return base_fields + oauth_fields + optional_fields return base_fields + oauth_fields + optional_fields
def generate_compose_volume_mounts(self) -> List[str]: def generate_compose_volume_mounts(self) -> List[str]:
"""Generate Docker Compose volume mount strings from documents paths.""" """Generate Docker Compose volume mount strings from documents paths."""
is_valid, _, validated_paths = validate_documents_paths(self.config.openrag_documents_paths) is_valid, _, validated_paths = validate_documents_paths(
self.config.openrag_documents_paths
)
if not is_valid: if not is_valid:
return ["./documents:/app/documents:Z"] # fallback return ["./documents:/app/documents:Z"] # fallback
@ -291,6 +385,6 @@ class EnvManager:
volume_mounts.append(f"{path}:/app/documents:Z") volume_mounts.append(f"{path}:/app/documents:Z")
else: else:
# Additional paths map to numbered directories # Additional paths map to numbered directories
volume_mounts.append(f"{path}:/app/documents{i+1}:Z") volume_mounts.append(f"{path}:/app/documents{i + 1}:Z")
return volume_mounts return volume_mounts

View file

@ -3,7 +3,16 @@
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.containers import Container, Vertical, Horizontal, ScrollableContainer from textual.containers import Container, Vertical, Horizontal, ScrollableContainer
from textual.screen import Screen from textual.screen import Screen
from textual.widgets import Header, Footer, Static, Button, Input, Label, TabbedContent, TabPane from textual.widgets import (
Header,
Footer,
Static,
Button,
Input,
Label,
TabbedContent,
TabPane,
)
from textual.validation import ValidationResult, Validator from textual.validation import ValidationResult, Validator
from rich.text import Text from rich.text import Text
from pathlib import Path from pathlib import Path
@ -70,7 +79,7 @@ class ConfigScreen(Screen):
Button("Generate Passwords", variant="default", id="generate-btn"), Button("Generate Passwords", variant="default", id="generate-btn"),
Button("Save Configuration", variant="success", id="save-btn"), Button("Save Configuration", variant="success", id="save-btn"),
Button("Back", variant="default", id="back-btn"), Button("Back", variant="default", id="back-btn"),
classes="button-row" classes="button-row",
) )
yield Footer() yield Footer()
@ -80,10 +89,14 @@ class ConfigScreen(Screen):
if self.mode == "no_auth": if self.mode == "no_auth":
header_text.append("Quick Setup - No Authentication\n", style="bold green") header_text.append("Quick Setup - No Authentication\n", style="bold green")
header_text.append("Configure OpenRAG for local document processing only.\n\n", style="dim") header_text.append(
"Configure OpenRAG for local document processing only.\n\n", style="dim"
)
else: else:
header_text.append("Full Setup - OAuth Integration\n", style="bold cyan") header_text.append("Full Setup - OAuth Integration\n", style="bold cyan")
header_text.append("Configure OpenRAG with cloud service integrations.\n\n", style="dim") header_text.append(
"Configure OpenRAG with cloud service integrations.\n\n", style="dim"
)
header_text.append("Required fields are marked with *\n", style="yellow") header_text.append("Required fields are marked with *\n", style="yellow")
header_text.append("Use Ctrl+G to generate admin passwords\n", style="dim") header_text.append("Use Ctrl+G to generate admin passwords\n", style="dim")
@ -104,7 +117,7 @@ class ConfigScreen(Screen):
placeholder="Auto-generated secure password", placeholder="Auto-generated secure password",
value=current_value, value=current_value,
password=True, password=True,
id="input-opensearch_password" id="input-opensearch_password",
) )
yield input_widget yield input_widget
self.inputs["opensearch_password"] = input_widget self.inputs["opensearch_password"] = input_widget
@ -114,9 +127,7 @@ class ConfigScreen(Screen):
yield Label("Langflow Admin Username *") yield Label("Langflow Admin Username *")
current_value = getattr(self.env_manager.config, "langflow_superuser", "") current_value = getattr(self.env_manager.config, "langflow_superuser", "")
input_widget = Input( input_widget = Input(
placeholder="admin", placeholder="admin", value=current_value, id="input-langflow_superuser"
value=current_value,
id="input-langflow_superuser"
) )
yield input_widget yield input_widget
self.inputs["langflow_superuser"] = input_widget self.inputs["langflow_superuser"] = input_widget
@ -124,12 +135,14 @@ class ConfigScreen(Screen):
# Langflow Admin Password # Langflow Admin Password
yield Label("Langflow Admin Password *") yield Label("Langflow Admin Password *")
current_value = getattr(self.env_manager.config, "langflow_superuser_password", "") current_value = getattr(
self.env_manager.config, "langflow_superuser_password", ""
)
input_widget = Input( input_widget = Input(
placeholder="Auto-generated secure password", placeholder="Auto-generated secure password",
value=current_value, value=current_value,
password=True, password=True,
id="input-langflow_superuser_password" id="input-langflow_superuser_password",
) )
yield input_widget yield input_widget
self.inputs["langflow_superuser_password"] = input_widget self.inputs["langflow_superuser_password"] = input_widget
@ -143,14 +156,17 @@ class ConfigScreen(Screen):
# OpenAI API Key # OpenAI API Key
yield Label("OpenAI API Key *") yield Label("OpenAI API Key *")
# Where to create OpenAI keys (helper above the box) # Where to create OpenAI keys (helper above the box)
yield Static(Text("Get a key: https://platform.openai.com/api-keys", style="dim"), classes="helper-text") yield Static(
Text("Get a key: https://platform.openai.com/api-keys", style="dim"),
classes="helper-text",
)
current_value = getattr(self.env_manager.config, "openai_api_key", "") current_value = getattr(self.env_manager.config, "openai_api_key", "")
input_widget = Input( input_widget = Input(
placeholder="sk-...", placeholder="sk-...",
value=current_value, value=current_value,
password=True, password=True,
validators=[OpenAIKeyValidator()], validators=[OpenAIKeyValidator()],
id="input-openai_api_key" id="input-openai_api_key",
) )
yield input_widget yield input_widget
self.inputs["openai_api_key"] = input_widget self.inputs["openai_api_key"] = input_widget
@ -161,7 +177,13 @@ class ConfigScreen(Screen):
# Google OAuth Client ID # Google OAuth Client ID
yield Label("Google OAuth Client ID") yield Label("Google OAuth Client ID")
# Where to create Google OAuth credentials (helper above the box) # Where to create Google OAuth credentials (helper above the box)
yield Static(Text("Create credentials: https://console.cloud.google.com/apis/credentials", style="dim"), classes="helper-text") yield Static(
Text(
"Create credentials: https://console.cloud.google.com/apis/credentials",
style="dim",
),
classes="helper-text",
)
# Callback URL guidance for Google OAuth # Callback URL guidance for Google OAuth
yield Static( yield Static(
Text( Text(
@ -169,15 +191,17 @@ class ConfigScreen(Screen):
" - Local: http://localhost:3000/auth/callback\n" " - Local: http://localhost:3000/auth/callback\n"
" - Prod: https://your-domain.com/auth/callback\n" " - Prod: https://your-domain.com/auth/callback\n"
"If you use separate apps for login and connectors, add this URL to BOTH.", "If you use separate apps for login and connectors, add this URL to BOTH.",
style="dim" style="dim",
), ),
classes="helper-text" classes="helper-text",
)
current_value = getattr(
self.env_manager.config, "google_oauth_client_id", ""
) )
current_value = getattr(self.env_manager.config, "google_oauth_client_id", "")
input_widget = Input( input_widget = Input(
placeholder="xxx.apps.googleusercontent.com", placeholder="xxx.apps.googleusercontent.com",
value=current_value, value=current_value,
id="input-google_oauth_client_id" id="input-google_oauth_client_id",
) )
yield input_widget yield input_widget
self.inputs["google_oauth_client_id"] = input_widget self.inputs["google_oauth_client_id"] = input_widget
@ -185,12 +209,14 @@ class ConfigScreen(Screen):
# Google OAuth Client Secret # Google OAuth Client Secret
yield Label("Google OAuth Client Secret") yield Label("Google OAuth Client Secret")
current_value = getattr(self.env_manager.config, "google_oauth_client_secret", "") current_value = getattr(
self.env_manager.config, "google_oauth_client_secret", ""
)
input_widget = Input( input_widget = Input(
placeholder="", placeholder="",
value=current_value, value=current_value,
password=True, password=True,
id="input-google_oauth_client_secret" id="input-google_oauth_client_secret",
) )
yield input_widget yield input_widget
self.inputs["google_oauth_client_secret"] = input_widget self.inputs["google_oauth_client_secret"] = input_widget
@ -199,7 +225,13 @@ class ConfigScreen(Screen):
# Microsoft Graph Client ID # Microsoft Graph Client ID
yield Label("Microsoft Graph Client ID") yield Label("Microsoft Graph Client ID")
# Where to create Microsoft app registrations (helper above the box) # Where to create Microsoft app registrations (helper above the box)
yield Static(Text("Create app: https://portal.azure.com/#view/Microsoft_AAD_RegisteredApps/ApplicationsListBlade", style="dim"), classes="helper-text") yield Static(
Text(
"Create app: https://portal.azure.com/#view/Microsoft_AAD_RegisteredApps/ApplicationsListBlade",
style="dim",
),
classes="helper-text",
)
# Callback URL guidance for Microsoft OAuth # Callback URL guidance for Microsoft OAuth
yield Static( yield Static(
Text( Text(
@ -207,15 +239,17 @@ class ConfigScreen(Screen):
" - Local: http://localhost:3000/auth/callback\n" " - Local: http://localhost:3000/auth/callback\n"
" - Prod: https://your-domain.com/auth/callback\n" " - Prod: https://your-domain.com/auth/callback\n"
"If you use separate apps for login and connectors, add this URI to BOTH.", "If you use separate apps for login and connectors, add this URI to BOTH.",
style="dim" style="dim",
), ),
classes="helper-text" classes="helper-text",
)
current_value = getattr(
self.env_manager.config, "microsoft_graph_oauth_client_id", ""
) )
current_value = getattr(self.env_manager.config, "microsoft_graph_oauth_client_id", "")
input_widget = Input( input_widget = Input(
placeholder="", placeholder="",
value=current_value, value=current_value,
id="input-microsoft_graph_oauth_client_id" id="input-microsoft_graph_oauth_client_id",
) )
yield input_widget yield input_widget
self.inputs["microsoft_graph_oauth_client_id"] = input_widget self.inputs["microsoft_graph_oauth_client_id"] = input_widget
@ -223,12 +257,14 @@ class ConfigScreen(Screen):
# Microsoft Graph Client Secret # Microsoft Graph Client Secret
yield Label("Microsoft Graph Client Secret") yield Label("Microsoft Graph Client Secret")
current_value = getattr(self.env_manager.config, "microsoft_graph_oauth_client_secret", "") current_value = getattr(
self.env_manager.config, "microsoft_graph_oauth_client_secret", ""
)
input_widget = Input( input_widget = Input(
placeholder="", placeholder="",
value=current_value, value=current_value,
password=True, password=True,
id="input-microsoft_graph_oauth_client_secret" id="input-microsoft_graph_oauth_client_secret",
) )
yield input_widget yield input_widget
self.inputs["microsoft_graph_oauth_client_secret"] = input_widget self.inputs["microsoft_graph_oauth_client_secret"] = input_widget
@ -237,12 +273,16 @@ class ConfigScreen(Screen):
# AWS Access Key ID # AWS Access Key ID
yield Label("AWS Access Key ID") yield Label("AWS Access Key ID")
# Where to create AWS keys (helper above the box) # Where to create AWS keys (helper above the box)
yield Static(Text("Create keys: https://console.aws.amazon.com/iam/home#/security_credentials", style="dim"), classes="helper-text") yield Static(
Text(
"Create keys: https://console.aws.amazon.com/iam/home#/security_credentials",
style="dim",
),
classes="helper-text",
)
current_value = getattr(self.env_manager.config, "aws_access_key_id", "") current_value = getattr(self.env_manager.config, "aws_access_key_id", "")
input_widget = Input( input_widget = Input(
placeholder="", placeholder="", value=current_value, id="input-aws_access_key_id"
value=current_value,
id="input-aws_access_key_id"
) )
yield input_widget yield input_widget
self.inputs["aws_access_key_id"] = input_widget self.inputs["aws_access_key_id"] = input_widget
@ -250,12 +290,14 @@ class ConfigScreen(Screen):
# AWS Secret Access Key # AWS Secret Access Key
yield Label("AWS Secret Access Key") yield Label("AWS Secret Access Key")
current_value = getattr(self.env_manager.config, "aws_secret_access_key", "") current_value = getattr(
self.env_manager.config, "aws_secret_access_key", ""
)
input_widget = Input( input_widget = Input(
placeholder="", placeholder="",
value=current_value, value=current_value,
password=True, password=True,
id="input-aws_secret_access_key" id="input-aws_secret_access_key",
) )
yield input_widget yield input_widget
self.inputs["aws_secret_access_key"] = input_widget self.inputs["aws_secret_access_key"] = input_widget
@ -274,11 +316,15 @@ class ConfigScreen(Screen):
placeholder="./documents,/path/to/more/docs", placeholder="./documents,/path/to/more/docs",
value=current_value, value=current_value,
validators=[DocumentsPathValidator()], validators=[DocumentsPathValidator()],
id="input-openrag_documents_paths" id="input-openrag_documents_paths",
) )
yield input_widget yield input_widget
# Actions row with pick button # Actions row with pick button
yield Horizontal(Button("Pick…", id="pick-docs-btn"), id="docs-path-actions", classes="controls-row") yield Horizontal(
Button("Pick…", id="pick-docs-btn"),
id="docs-path-actions",
classes="controls-row",
)
self.inputs["openrag_documents_paths"] = input_widget self.inputs["openrag_documents_paths"] = input_widget
yield Static(" ") yield Static(" ")
@ -290,9 +336,7 @@ class ConfigScreen(Screen):
yield Label("Langflow Auto Login") yield Label("Langflow Auto Login")
current_value = getattr(self.env_manager.config, "langflow_auto_login", "False") current_value = getattr(self.env_manager.config, "langflow_auto_login", "False")
input_widget = Input( input_widget = Input(
placeholder="False", placeholder="False", value=current_value, id="input-langflow_auto_login"
value=current_value,
id="input-langflow_auto_login"
) )
yield input_widget yield input_widget
self.inputs["langflow_auto_login"] = input_widget self.inputs["langflow_auto_login"] = input_widget
@ -300,11 +344,13 @@ class ConfigScreen(Screen):
# Langflow New User Is Active # Langflow New User Is Active
yield Label("Langflow New User Is Active") yield Label("Langflow New User Is Active")
current_value = getattr(self.env_manager.config, "langflow_new_user_is_active", "False") current_value = getattr(
self.env_manager.config, "langflow_new_user_is_active", "False"
)
input_widget = Input( input_widget = Input(
placeholder="False", placeholder="False",
value=current_value, value=current_value,
id="input-langflow_new_user_is_active" id="input-langflow_new_user_is_active",
) )
yield input_widget yield input_widget
self.inputs["langflow_new_user_is_active"] = input_widget self.inputs["langflow_new_user_is_active"] = input_widget
@ -312,11 +358,13 @@ class ConfigScreen(Screen):
# Langflow Enable Superuser CLI # Langflow Enable Superuser CLI
yield Label("Langflow Enable Superuser CLI") yield Label("Langflow Enable Superuser CLI")
current_value = getattr(self.env_manager.config, "langflow_enable_superuser_cli", "False") current_value = getattr(
self.env_manager.config, "langflow_enable_superuser_cli", "False"
)
input_widget = Input( input_widget = Input(
placeholder="False", placeholder="False",
value=current_value, value=current_value,
id="input-langflow_enable_superuser_cli" id="input-langflow_enable_superuser_cli",
) )
yield input_widget yield input_widget
self.inputs["langflow_enable_superuser_cli"] = input_widget self.inputs["langflow_enable_superuser_cli"] = input_widget
@ -333,7 +381,7 @@ class ConfigScreen(Screen):
input_widget = Input( input_widget = Input(
placeholder="https://your-domain.com", placeholder="https://your-domain.com",
value=current_value, value=current_value,
id="input-webhook_base_url" id="input-webhook_base_url",
) )
yield input_widget yield input_widget
self.inputs["webhook_base_url"] = input_widget self.inputs["webhook_base_url"] = input_widget
@ -345,13 +393,20 @@ class ConfigScreen(Screen):
input_widget = Input( input_widget = Input(
placeholder="http://localhost:7860", placeholder="http://localhost:7860",
value=current_value, value=current_value,
id="input-langflow_public_url" id="input-langflow_public_url",
) )
yield input_widget yield input_widget
self.inputs["langflow_public_url"] = input_widget self.inputs["langflow_public_url"] = input_widget
yield Static(" ") yield Static(" ")
def _create_field(self, field_name: str, display_name: str, placeholder: str, can_generate: bool, required: bool = False) -> ComposeResult: def _create_field(
self,
field_name: str,
display_name: str,
placeholder: str,
can_generate: bool,
required: bool = False,
) -> ComposeResult:
"""Create a single form field.""" """Create a single form field."""
# Create label # Create label
label_text = f"{display_name}" label_text = f"{display_name}"
@ -370,27 +425,25 @@ class ConfigScreen(Screen):
value=current_value, value=current_value,
password=True, password=True,
validators=[OpenAIKeyValidator()], validators=[OpenAIKeyValidator()],
id=f"input-{field_name}" id=f"input-{field_name}",
) )
elif field_name == "openrag_documents_paths": elif field_name == "openrag_documents_paths":
input_widget = Input( input_widget = Input(
placeholder=placeholder, placeholder=placeholder,
value=current_value, value=current_value,
validators=[DocumentsPathValidator()], validators=[DocumentsPathValidator()],
id=f"input-{field_name}" id=f"input-{field_name}",
) )
elif "password" in field_name or "secret" in field_name: elif "password" in field_name or "secret" in field_name:
input_widget = Input( input_widget = Input(
placeholder=placeholder, placeholder=placeholder,
value=current_value, value=current_value,
password=True, password=True,
id=f"input-{field_name}" id=f"input-{field_name}",
) )
else: else:
input_widget = Input( input_widget = Input(
placeholder=placeholder, placeholder=placeholder, value=current_value, id=f"input-{field_name}"
value=current_value,
id=f"input-{field_name}"
) )
yield input_widget yield input_widget
@ -445,7 +498,10 @@ class ConfigScreen(Screen):
for field, error in self.env_manager.config.validation_errors.items(): for field, error in self.env_manager.config.validation_errors.items():
error_messages.append(f"{field}: {error}") error_messages.append(f"{field}: {error}")
self.notify(f"Validation failed:\n" + "\n".join(error_messages[:3]), severity="error") self.notify(
f"Validation failed:\n" + "\n".join(error_messages[:3]),
severity="error",
)
return return
# Save to file # Save to file
@ -453,6 +509,7 @@ class ConfigScreen(Screen):
self.notify("Configuration saved successfully!", severity="information") self.notify("Configuration saved successfully!", severity="information")
# Switch to monitor screen # Switch to monitor screen
from .monitor import MonitorScreen from .monitor import MonitorScreen
self.app.push_screen(MonitorScreen()) self.app.push_screen(MonitorScreen())
else: else:
self.notify("Failed to save configuration", severity="error") self.notify("Failed to save configuration", severity="error")
@ -465,6 +522,7 @@ class ConfigScreen(Screen):
"""Open textual-fspicker to select a path and append it to the input.""" """Open textual-fspicker to select a path and append it to the input."""
try: try:
import importlib import importlib
fsp = importlib.import_module("textual_fspicker") fsp = importlib.import_module("textual_fspicker")
except Exception: except Exception:
self.notify("textual-fspicker not available", severity="warning") self.notify("textual-fspicker not available", severity="warning")
@ -479,9 +537,13 @@ class ConfigScreen(Screen):
start = Path(first).expanduser() start = Path(first).expanduser()
# Prefer SelectDirectory for directories; fallback to FileOpen # Prefer SelectDirectory for directories; fallback to FileOpen
PickerClass = getattr(fsp, "SelectDirectory", None) or getattr(fsp, "FileOpen", None) PickerClass = getattr(fsp, "SelectDirectory", None) or getattr(
fsp, "FileOpen", None
)
if PickerClass is None: if PickerClass is None:
self.notify("No compatible picker found in textual-fspicker", severity="warning") self.notify(
"No compatible picker found in textual-fspicker", severity="warning"
)
return return
try: try:
picker = PickerClass(location=start) picker = PickerClass(location=start)

View file

@ -117,6 +117,7 @@ class DiagnosticsScreen(Screen):
# Try to use pyperclip if available # Try to use pyperclip if available
try: try:
import pyperclip import pyperclip
pyperclip.copy(content) pyperclip.copy(content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard") status.update("✓ Content copied to clipboard")
@ -131,16 +132,12 @@ class DiagnosticsScreen(Screen):
system = platform.system() system = platform.system()
if system == "Darwin": # macOS if system == "Darwin": # macOS
process = subprocess.Popen( process = subprocess.Popen(["pbcopy"], stdin=subprocess.PIPE, text=True)
["pbcopy"], stdin=subprocess.PIPE, text=True
)
process.communicate(input=content) process.communicate(input=content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard") status.update("✓ Content copied to clipboard")
elif system == "Windows": elif system == "Windows":
process = subprocess.Popen( process = subprocess.Popen(["clip"], stdin=subprocess.PIPE, text=True)
["clip"], stdin=subprocess.PIPE, text=True
)
process.communicate(input=content) process.communicate(input=content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard") status.update("✓ Content copied to clipboard")
@ -150,7 +147,7 @@ class DiagnosticsScreen(Screen):
process = subprocess.Popen( process = subprocess.Popen(
["xclip", "-selection", "clipboard"], ["xclip", "-selection", "clipboard"],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
text=True text=True,
) )
process.communicate(input=content) process.communicate(input=content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
@ -160,16 +157,23 @@ class DiagnosticsScreen(Screen):
process = subprocess.Popen( process = subprocess.Popen(
["xsel", "--clipboard", "--input"], ["xsel", "--clipboard", "--input"],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
text=True text=True,
) )
process.communicate(input=content) process.communicate(input=content)
self.notify("Copied to clipboard", severity="information") self.notify("Copied to clipboard", severity="information")
status.update("✓ Content copied to clipboard") status.update("✓ Content copied to clipboard")
except FileNotFoundError: except FileNotFoundError:
self.notify("Clipboard utilities not found. Install xclip or xsel.", severity="error") self.notify(
status.update("❌ Clipboard utilities not found. Install xclip or xsel.") "Clipboard utilities not found. Install xclip or xsel.",
severity="error",
)
status.update(
"❌ Clipboard utilities not found. Install xclip or xsel."
)
else: else:
self.notify("Clipboard not supported on this platform", severity="error") self.notify(
"Clipboard not supported on this platform", severity="error"
)
status.update("❌ Clipboard not supported on this platform") status.update("❌ Clipboard not supported on this platform")
self._hide_status_after_delay(status) self._hide_status_after_delay(status)
@ -179,16 +183,22 @@ class DiagnosticsScreen(Screen):
status.update(f"❌ Failed to copy: {e}") status.update(f"❌ Failed to copy: {e}")
self._hide_status_after_delay(status) self._hide_status_after_delay(status)
def _hide_status_after_delay(self, status_widget: Static, delay: float = 3.0) -> None: def _hide_status_after_delay(
self, status_widget: Static, delay: float = 3.0
) -> None:
"""Hide the status message after a delay.""" """Hide the status message after a delay."""
# Cancel any existing timer # Cancel any existing timer
if self._status_timer: if self._status_timer:
self._status_timer.cancel() self._status_timer.cancel()
# Create and run the timer task # Create and run the timer task
self._status_timer = asyncio.create_task(self._clear_status_after_delay(status_widget, delay)) self._status_timer = asyncio.create_task(
self._clear_status_after_delay(status_widget, delay)
)
async def _clear_status_after_delay(self, status_widget: Static, delay: float) -> None: async def _clear_status_after_delay(
self, status_widget: Static, delay: float
) -> None:
"""Clear the status message after a delay.""" """Clear the status message after a delay."""
await asyncio.sleep(delay) await asyncio.sleep(delay)
status_widget.update("") status_widget.update("")
@ -274,7 +284,9 @@ class DiagnosticsScreen(Screen):
services = await self.container_manager.get_service_status(force_refresh=True) services = await self.container_manager.get_service_status(force_refresh=True)
for name, info in services.items(): for name, info in services.items():
status_color = "green" if info.status == "running" else "red" status_color = "green" if info.status == "running" else "red"
log.write(f"[bold]{name}[/bold]: [{status_color}]{info.status.value}[/{status_color}]") log.write(
f"[bold]{name}[/bold]: [{status_color}]{info.status.value}[/{status_color}]"
)
if info.health: if info.health:
log.write(f" Health: {info.health}") log.write(f" Health: {info.health}")
if info.ports: if info.ports:
@ -300,22 +312,20 @@ class DiagnosticsScreen(Screen):
# Check Podman version # Check Podman version
cmd = ["podman", "--version"] cmd = ["podman", "--version"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
log.write(f"Podman version: {stdout.decode().strip()}") log.write(f"Podman version: {stdout.decode().strip()}")
else: else:
log.write(f"[red]Failed to get Podman version: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to get Podman version: {stderr.decode().strip()}[/red]"
)
# Check Podman containers # Check Podman containers
cmd = ["podman", "ps", "--all"] cmd = ["podman", "ps", "--all"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
@ -323,7 +333,9 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"): for line in stdout.decode().strip().split("\n"):
log.write(f" {line}") log.write(f" {line}")
else: else:
log.write(f"[red]Failed to list Podman containers: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to list Podman containers: {stderr.decode().strip()}[/red]"
)
# Check Podman compose # Check Podman compose
cmd = ["podman", "compose", "ps"] cmd = ["podman", "compose", "ps"]
@ -331,7 +343,7 @@ class DiagnosticsScreen(Screen):
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=self.container_manager.compose_file.parent cwd=self.container_manager.compose_file.parent,
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
@ -339,7 +351,9 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"): for line in stdout.decode().strip().split("\n"):
log.write(f" {line}") log.write(f" {line}")
else: else:
log.write(f"[red]Failed to list Podman compose services: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to list Podman compose services: {stderr.decode().strip()}[/red]"
)
log.write("") log.write("")
@ -356,22 +370,20 @@ class DiagnosticsScreen(Screen):
# Check Docker version # Check Docker version
cmd = ["docker", "--version"] cmd = ["docker", "--version"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
log.write(f"Docker version: {stdout.decode().strip()}") log.write(f"Docker version: {stdout.decode().strip()}")
else: else:
log.write(f"[red]Failed to get Docker version: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to get Docker version: {stderr.decode().strip()}[/red]"
)
# Check Docker containers # Check Docker containers
cmd = ["docker", "ps", "--all"] cmd = ["docker", "ps", "--all"]
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
@ -379,7 +391,9 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"): for line in stdout.decode().strip().split("\n"):
log.write(f" {line}") log.write(f" {line}")
else: else:
log.write(f"[red]Failed to list Docker containers: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to list Docker containers: {stderr.decode().strip()}[/red]"
)
# Check Docker compose # Check Docker compose
cmd = ["docker", "compose", "ps"] cmd = ["docker", "compose", "ps"]
@ -387,7 +401,7 @@ class DiagnosticsScreen(Screen):
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=self.container_manager.compose_file.parent cwd=self.container_manager.compose_file.parent,
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode == 0: if process.returncode == 0:
@ -395,8 +409,11 @@ class DiagnosticsScreen(Screen):
for line in stdout.decode().strip().split("\n"): for line in stdout.decode().strip().split("\n"):
log.write(f" {line}") log.write(f" {line}")
else: else:
log.write(f"[red]Failed to list Docker compose services: {stderr.decode().strip()}[/red]") log.write(
f"[red]Failed to list Docker compose services: {stderr.decode().strip()}[/red]"
)
log.write("") log.write("")
# Made with Bob # Made with Bob

View file

@ -33,7 +33,13 @@ class LogsScreen(Screen):
self.container_manager = ContainerManager() self.container_manager = ContainerManager()
# Validate the initial service against available options # Validate the initial service against available options
valid_services = ["openrag-backend", "openrag-frontend", "opensearch", "langflow", "dashboards"] valid_services = [
"openrag-backend",
"openrag-frontend",
"opensearch",
"langflow",
"dashboards",
]
if initial_service not in valid_services: if initial_service not in valid_services:
initial_service = "openrag-backend" # fallback initial_service = "openrag-backend" # fallback
@ -49,9 +55,9 @@ class LogsScreen(Screen):
Vertical( Vertical(
Static(f"Service Logs: {self.current_service}", id="logs-title"), Static(f"Service Logs: {self.current_service}", id="logs-title"),
self._create_logs_area(), self._create_logs_area(),
id="logs-content" id="logs-content",
), ),
id="main-container" id="main-container",
) )
yield Footer() yield Footer()
@ -61,7 +67,7 @@ class LogsScreen(Screen):
text="Loading logs...", text="Loading logs...",
read_only=True, read_only=True,
show_line_numbers=False, show_line_numbers=False,
id="logs-area" id="logs-area",
) )
return self.logs_area return self.logs_area
@ -72,7 +78,13 @@ class LogsScreen(Screen):
select = self.query_one("#service-select") select = self.query_one("#service-select")
# Set a default first, then set the desired value # Set a default first, then set the desired value
select.value = "openrag-backend" select.value = "openrag-backend"
if self.current_service in ["openrag-backend", "openrag-frontend", "opensearch", "langflow", "dashboards"]: if self.current_service in [
"openrag-backend",
"openrag-frontend",
"opensearch",
"langflow",
"dashboards",
]:
select.value = self.current_service select.value = self.current_service
except Exception as e: except Exception as e:
# If setting the service fails, just use the default # If setting the service fails, just use the default
@ -90,15 +102,15 @@ class LogsScreen(Screen):
"""Clean up when unmounting.""" """Clean up when unmounting."""
self._stop_following() self._stop_following()
async def _load_logs(self, lines: int = 200) -> None: async def _load_logs(self, lines: int = 200) -> None:
"""Load recent logs for the current service.""" """Load recent logs for the current service."""
if not self.container_manager.is_available(): if not self.container_manager.is_available():
self.logs_area.text = "No container runtime available" self.logs_area.text = "No container runtime available"
return return
success, logs = await self.container_manager.get_service_logs(self.current_service, lines) success, logs = await self.container_manager.get_service_logs(
self.current_service, lines
)
if success: if success:
self.logs_area.text = logs self.logs_area.text = logs
@ -122,7 +134,9 @@ class LogsScreen(Screen):
return return
try: try:
async for log_line in self.container_manager.follow_service_logs(self.current_service): async for log_line in self.container_manager.follow_service_logs(
self.current_service
):
if not self.following: if not self.following:
break break
@ -131,10 +145,10 @@ class LogsScreen(Screen):
new_text = current_text + "\n" + log_line new_text = current_text + "\n" + log_line
# Keep only last 1000 lines to prevent memory issues # Keep only last 1000 lines to prevent memory issues
lines = new_text.split('\n') lines = new_text.split("\n")
if len(lines) > 1000: if len(lines) > 1000:
lines = lines[-1000:] lines = lines[-1000:]
new_text = '\n'.join(lines) new_text = "\n".join(lines)
self.logs_area.text = new_text self.logs_area.text = new_text
# Scroll to bottom if auto scroll is enabled # Scroll to bottom if auto scroll is enabled
@ -144,7 +158,9 @@ class LogsScreen(Screen):
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
if self.following: # Only show error if we're still supposed to be following if (
self.following
): # Only show error if we're still supposed to be following
self.notify(f"Error following logs: {e}", severity="error") self.notify(f"Error following logs: {e}", severity="error")
finally: finally:
self.following = False self.following = False

View file

@ -75,17 +75,22 @@ class MonitorScreen(Screen):
yield Horizontal(id="services-controls", classes="button-row") yield Horizontal(id="services-controls", classes="button-row")
# Create services table with image + digest info # Create services table with image + digest info
self.services_table = DataTable(id="services-table") self.services_table = DataTable(id="services-table")
self.services_table.add_columns("Service", "Status", "Health", "Ports", "Image", "Digest") self.services_table.add_columns(
"Service", "Status", "Health", "Ports", "Image", "Digest"
)
yield self.services_table yield self.services_table
def _get_runtime_status(self) -> Text: def _get_runtime_status(self) -> Text:
"""Get container runtime status text.""" """Get container runtime status text."""
status_text = Text() status_text = Text()
if not self.container_manager.is_available(): if not self.container_manager.is_available():
status_text.append("WARNING: No container runtime available\n", style="bold red") status_text.append(
status_text.append("Please install Docker or Podman to continue.\n", style="dim") "WARNING: No container runtime available\n", style="bold red"
)
status_text.append(
"Please install Docker or Podman to continue.\n", style="dim"
)
return status_text return status_text
runtime_info = self.container_manager.get_runtime_info() runtime_info = self.container_manager.get_runtime_info()
@ -108,7 +113,6 @@ class MonitorScreen(Screen):
return status_text return status_text
async def on_mount(self) -> None: async def on_mount(self) -> None:
"""Initialize the screen when mounted.""" """Initialize the screen when mounted."""
await self._refresh_services() await self._refresh_services()
@ -147,7 +151,9 @@ class MonitorScreen(Screen):
images_set.add(img) images_set.add(img)
# Ensure compose-declared images are also shown (e.g., langflow when stopped) # Ensure compose-declared images are also shown (e.g., langflow when stopped)
try: try:
for img in self.container_manager._parse_compose_images(): # best-effort, no YAML dep for img in (
self.container_manager._parse_compose_images()
): # best-effort, no YAML dep
if img: if img:
images_set.add(img) images_set.add(img)
except Exception: except Exception:
@ -171,7 +177,7 @@ class MonitorScreen(Screen):
service_info.health or "N/A", service_info.health or "N/A",
", ".join(service_info.ports) if service_info.ports else "N/A", ", ".join(service_info.ports) if service_info.ports else "N/A",
service_info.image or "N/A", service_info.image or "N/A",
digest_map.get(service_info.image or "", "-") digest_map.get(service_info.image or "", "-"),
) )
# Populate images table (unique images as reported by runtime) # Populate images table (unique images as reported by runtime)
if self.images_table: if self.images_table:
@ -191,7 +197,7 @@ class MonitorScreen(Screen):
ServiceStatus.STOPPING: "bold yellow", ServiceStatus.STOPPING: "bold yellow",
ServiceStatus.ERROR: "bold red", ServiceStatus.ERROR: "bold red",
ServiceStatus.MISSING: "dim", ServiceStatus.MISSING: "dim",
ServiceStatus.UNKNOWN: "dim" ServiceStatus.UNKNOWN: "dim",
} }
return status_styles.get(status, "white") return status_styles.get(status, "white")
@ -228,7 +234,7 @@ class MonitorScreen(Screen):
"logs-backend": "openrag-backend", "logs-backend": "openrag-backend",
"logs-frontend": "openrag-frontend", "logs-frontend": "openrag-frontend",
"logs-opensearch": "opensearch", "logs-opensearch": "opensearch",
"logs-langflow": "langflow" "logs-langflow": "langflow",
} }
# Extract the base button ID (without any suffix) # Extract the base button ID (without any suffix)
@ -249,7 +255,7 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Starting Services", "Starting Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
@ -264,7 +270,7 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Stopping Services", "Stopping Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
@ -279,7 +285,7 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Restarting Services", "Restarting Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
@ -294,7 +300,7 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Upgrading Services", "Upgrading Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
@ -309,7 +315,7 @@ class MonitorScreen(Screen):
modal = CommandOutputModal( modal = CommandOutputModal(
"Resetting Services", "Resetting Services",
command_generator, command_generator,
on_complete=None # We'll refresh in on_screen_resume instead on_complete=None, # We'll refresh in on_screen_resume instead
) )
self.app.push_screen(modal) self.app.push_screen(modal)
finally: finally:
@ -317,8 +323,8 @@ class MonitorScreen(Screen):
def _strip_ansi_codes(self, text: str) -> str: def _strip_ansi_codes(self, text: str) -> str:
"""Strip ANSI escape sequences from text.""" """Strip ANSI escape sequences from text."""
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
return ansi_escape.sub('', text) return ansi_escape.sub("", text)
async def _show_logs(self, service_name: str) -> None: async def _show_logs(self, service_name: str) -> None:
"""Show logs for a service.""" """Show logs for a service."""
@ -346,7 +352,7 @@ class MonitorScreen(Screen):
notify_with_diagnostics( notify_with_diagnostics(
self.app, self.app,
f"Failed to get logs for {service_name}: {logs}", f"Failed to get logs for {service_name}: {logs}",
severity="error" severity="error",
) )
def _stop_follow(self) -> None: def _stop_follow(self) -> None:
@ -391,9 +397,7 @@ class MonitorScreen(Screen):
pass pass
except Exception as e: except Exception as e:
notify_with_diagnostics( notify_with_diagnostics(
self.app, self.app, f"Error following logs: {e}", severity="error"
f"Error following logs: {e}",
severity="error"
) )
def action_refresh(self) -> None: def action_refresh(self) -> None:
@ -431,14 +435,15 @@ class MonitorScreen(Screen):
try: try:
current = getattr(self.container_manager, "use_cpu_compose", True) current = getattr(self.container_manager, "use_cpu_compose", True)
self.container_manager.use_cpu_compose = not current self.container_manager.use_cpu_compose = not current
self.notify("Switched to GPU compose" if not current else "Switched to CPU compose", severity="information") self.notify(
"Switched to GPU compose" if not current else "Switched to CPU compose",
severity="information",
)
self._update_mode_row() self._update_mode_row()
self.action_refresh() self.action_refresh()
except Exception as e: except Exception as e:
notify_with_diagnostics( notify_with_diagnostics(
self.app, self.app, f"Failed to toggle mode: {e}", severity="error"
f"Failed to toggle mode: {e}",
severity="error"
) )
def _update_controls(self, services: list[ServiceInfo]) -> None: def _update_controls(self, services: list[ServiceInfo]) -> None:
@ -456,26 +461,33 @@ class MonitorScreen(Screen):
# Use a single ID for each button type, but make them unique with a suffix # Use a single ID for each button type, but make them unique with a suffix
# This ensures we don't create duplicate IDs across refreshes # This ensures we don't create duplicate IDs across refreshes
import random import random
suffix = f"-{random.randint(10000, 99999)}" suffix = f"-{random.randint(10000, 99999)}"
# Add appropriate buttons based on service state # Add appropriate buttons based on service state
if any_running: if any_running:
# When services are running, show stop and restart # When services are running, show stop and restart
controls.mount(Button("Stop Services", variant="error", id=f"stop-btn{suffix}")) controls.mount(
controls.mount(Button("Restart", variant="primary", id=f"restart-btn{suffix}")) Button("Stop Services", variant="error", id=f"stop-btn{suffix}")
)
controls.mount(
Button("Restart", variant="primary", id=f"restart-btn{suffix}")
)
else: else:
# When services are not running, show start # When services are not running, show start
controls.mount(Button("Start Services", variant="success", id=f"start-btn{suffix}")) controls.mount(
Button("Start Services", variant="success", id=f"start-btn{suffix}")
)
# Always show upgrade and reset buttons # Always show upgrade and reset buttons
controls.mount(Button("Upgrade", variant="warning", id=f"upgrade-btn{suffix}")) controls.mount(
Button("Upgrade", variant="warning", id=f"upgrade-btn{suffix}")
)
controls.mount(Button("Reset", variant="error", id=f"reset-btn{suffix}")) controls.mount(Button("Reset", variant="error", id=f"reset-btn{suffix}"))
except Exception as e: except Exception as e:
notify_with_diagnostics( notify_with_diagnostics(
self.app, self.app, f"Error updating controls: {e}", severity="error"
f"Error updating controls: {e}",
severity="error"
) )
def action_back(self) -> None: def action_back(self) -> None:
@ -516,13 +528,16 @@ class MonitorScreen(Screen):
"openrag-frontend": "openrag-frontend", "openrag-frontend": "openrag-frontend",
"opensearch": "opensearch", "opensearch": "opensearch",
"langflow": "langflow", "langflow": "langflow",
"dashboards": "dashboards" "dashboards": "dashboards",
} }
actual_service_name = service_mapping.get(service_name, service_name) actual_service_name = service_mapping.get(
service_name, service_name
)
# Push the logs screen with the selected service # Push the logs screen with the selected service
from .logs import LogsScreen from .logs import LogsScreen
logs_screen = LogsScreen(initial_service=actual_service_name) logs_screen = LogsScreen(initial_service=actual_service_name)
self.app.push_screen(logs_screen) self.app.push_screen(logs_screen)
else: else:

View file

@ -44,9 +44,9 @@ class WelcomeScreen(Screen):
Vertical( Vertical(
Static(self._create_welcome_text(), id="welcome-text"), Static(self._create_welcome_text(), id="welcome-text"),
self._create_dynamic_buttons(), self._create_dynamic_buttons(),
id="welcome-container" id="welcome-container",
), ),
id="main-container" id="main-container",
) )
yield Footer() yield Footer()
@ -65,9 +65,14 @@ class WelcomeScreen(Screen):
welcome_text.append("Terminal User Interface for OpenRAG\n\n", style="dim") welcome_text.append("Terminal User Interface for OpenRAG\n\n", style="dim")
if self.services_running: if self.services_running:
welcome_text.append("✓ Services are currently running\n\n", style="bold green") welcome_text.append(
"✓ Services are currently running\n\n", style="bold green"
)
elif self.has_oauth_config: elif self.has_oauth_config:
welcome_text.append("OAuth credentials detected — Advanced Setup recommended\n\n", style="bold green") welcome_text.append(
"OAuth credentials detected — Advanced Setup recommended\n\n",
style="bold green",
)
else: else:
welcome_text.append("Select a setup below to continue\n\n", style="white") welcome_text.append("Select a setup below to continue\n\n", style="white")
return welcome_text return welcome_text
@ -75,27 +80,34 @@ class WelcomeScreen(Screen):
def _create_dynamic_buttons(self) -> Horizontal: def _create_dynamic_buttons(self) -> Horizontal:
"""Create buttons based on current state.""" """Create buttons based on current state."""
# Check OAuth config early to determine which buttons to show # Check OAuth config early to determine which buttons to show
has_oauth = ( has_oauth = bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or bool(
bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID")
bool(os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID"))
) )
buttons = [] buttons = []
if self.services_running: if self.services_running:
# Services running - only show monitor # Services running - only show monitor
buttons.append(Button("Monitor Services", variant="success", id="monitor-btn")) buttons.append(
Button("Monitor Services", variant="success", id="monitor-btn")
)
else: else:
# Services not running - show setup options # Services not running - show setup options
if has_oauth: if has_oauth:
# Only show advanced setup if OAuth is configured # Only show advanced setup if OAuth is configured
buttons.append(Button("Advanced Setup", variant="success", id="advanced-setup-btn")) buttons.append(
Button("Advanced Setup", variant="success", id="advanced-setup-btn")
)
else: else:
# Only show basic setup if no OAuth # Only show basic setup if no OAuth
buttons.append(Button("Basic Setup", variant="success", id="basic-setup-btn")) buttons.append(
Button("Basic Setup", variant="success", id="basic-setup-btn")
)
# Always show monitor option # Always show monitor option
buttons.append(Button("Monitor Services", variant="default", id="monitor-btn")) buttons.append(
Button("Monitor Services", variant="default", id="monitor-btn")
)
return Horizontal(*buttons, classes="button-row") return Horizontal(*buttons, classes="button-row")
@ -104,14 +116,14 @@ class WelcomeScreen(Screen):
# Check if services are running # Check if services are running
if self.container_manager.is_available(): if self.container_manager.is_available():
services = await self.container_manager.get_service_status() services = await self.container_manager.get_service_status()
running_services = [s.name for s in services.values() if s.status == ServiceStatus.RUNNING] running_services = [
s.name for s in services.values() if s.status == ServiceStatus.RUNNING
]
self.services_running = len(running_services) > 0 self.services_running = len(running_services) > 0
# Check for OAuth configuration # Check for OAuth configuration
self.has_oauth_config = ( self.has_oauth_config = bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or bool(
bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID")
bool(os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID"))
) )
# Set default button focus # Set default button focus
@ -125,7 +137,9 @@ class WelcomeScreen(Screen):
# Update the welcome text and recompose with new state # Update the welcome text and recompose with new state
try: try:
welcome_widget = self.query_one("#welcome-text") welcome_widget = self.query_one("#welcome-text")
welcome_widget.update(self._create_welcome_text()) # This is fine for Static widgets welcome_widget.update(
self._create_welcome_text()
) # This is fine for Static widgets
# Focus the appropriate button # Focus the appropriate button
if self.services_running: if self.services_running:
@ -170,21 +184,25 @@ class WelcomeScreen(Screen):
def action_no_auth_setup(self) -> None: def action_no_auth_setup(self) -> None:
"""Switch to basic configuration screen.""" """Switch to basic configuration screen."""
from .config import ConfigScreen from .config import ConfigScreen
self.app.push_screen(ConfigScreen(mode="no_auth")) self.app.push_screen(ConfigScreen(mode="no_auth"))
def action_full_setup(self) -> None: def action_full_setup(self) -> None:
"""Switch to advanced configuration screen.""" """Switch to advanced configuration screen."""
from .config import ConfigScreen from .config import ConfigScreen
self.app.push_screen(ConfigScreen(mode="full")) self.app.push_screen(ConfigScreen(mode="full"))
def action_monitor(self) -> None: def action_monitor(self) -> None:
"""Switch to monitoring screen.""" """Switch to monitoring screen."""
from .monitor import MonitorScreen from .monitor import MonitorScreen
self.app.push_screen(MonitorScreen()) self.app.push_screen(MonitorScreen())
def action_diagnostics(self) -> None: def action_diagnostics(self) -> None:
"""Switch to diagnostics screen.""" """Switch to diagnostics screen."""
from .diagnostics import DiagnosticsScreen from .diagnostics import DiagnosticsScreen
self.app.push_screen(DiagnosticsScreen()) self.app.push_screen(DiagnosticsScreen())
def action_quit(self) -> None: def action_quit(self) -> None:

View file

@ -41,33 +41,59 @@ class PlatformDetector:
if docker_version and podman_version in docker_version: if docker_version and podman_version in docker_version:
# This is podman masquerading as docker # This is podman masquerading as docker
if self._check_command(["docker", "compose", "--help"]): if self._check_command(["docker", "compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["docker", "compose"], ["docker"], podman_version) return RuntimeInfo(
RuntimeType.PODMAN,
["docker", "compose"],
["docker"],
podman_version,
)
if self._check_command(["docker-compose", "--help"]): if self._check_command(["docker-compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["docker-compose"], ["docker"], podman_version) return RuntimeInfo(
RuntimeType.PODMAN,
["docker-compose"],
["docker"],
podman_version,
)
# Check for native podman compose # Check for native podman compose
if self._check_command(["podman", "compose", "--help"]): if self._check_command(["podman", "compose", "--help"]):
return RuntimeInfo(RuntimeType.PODMAN, ["podman", "compose"], ["podman"], podman_version) return RuntimeInfo(
RuntimeType.PODMAN,
["podman", "compose"],
["podman"],
podman_version,
)
# Check for actual docker # Check for actual docker
if self._check_command(["docker", "compose", "--help"]): if self._check_command(["docker", "compose", "--help"]):
version = self._get_docker_version() version = self._get_docker_version()
return RuntimeInfo(RuntimeType.DOCKER, ["docker", "compose"], ["docker"], version) return RuntimeInfo(
RuntimeType.DOCKER, ["docker", "compose"], ["docker"], version
)
if self._check_command(["docker-compose", "--help"]): if self._check_command(["docker-compose", "--help"]):
version = self._get_docker_version() version = self._get_docker_version()
return RuntimeInfo(RuntimeType.DOCKER_COMPOSE, ["docker-compose"], ["docker"], version) return RuntimeInfo(
RuntimeType.DOCKER_COMPOSE, ["docker-compose"], ["docker"], version
)
return RuntimeInfo(RuntimeType.NONE, [], []) return RuntimeInfo(RuntimeType.NONE, [], [])
def detect_gpu_available(self) -> bool: def detect_gpu_available(self) -> bool:
"""Best-effort detection of NVIDIA GPU availability for containers.""" """Best-effort detection of NVIDIA GPU availability for containers."""
try: try:
res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=5) res = subprocess.run(
if res.returncode == 0 and any("GPU" in ln for ln in res.stdout.splitlines()): ["nvidia-smi", "-L"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0 and any(
"GPU" in ln for ln in res.stdout.splitlines()
):
return True return True
except (subprocess.TimeoutExpired, FileNotFoundError): except (subprocess.TimeoutExpired, FileNotFoundError):
pass pass
for cmd in (["docker", "info", "--format", "{{json .Runtimes}}"], ["podman", "info", "--format", "json"]): for cmd in (
["docker", "info", "--format", "{{json .Runtimes}}"],
["podman", "info", "--format", "json"],
):
try: try:
res = subprocess.run(cmd, capture_output=True, text=True, timeout=5) res = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
if res.returncode == 0 and "nvidia" in res.stdout.lower(): if res.returncode == 0 and "nvidia" in res.stdout.lower():
@ -85,7 +111,9 @@ class PlatformDetector:
def _get_docker_version(self) -> Optional[str]: def _get_docker_version(self) -> Optional[str]:
try: try:
res = subprocess.run(["docker", "--version"], capture_output=True, text=True, timeout=5) res = subprocess.run(
["docker", "--version"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0: if res.returncode == 0:
return res.stdout.strip() return res.stdout.strip()
except (subprocess.TimeoutExpired, FileNotFoundError): except (subprocess.TimeoutExpired, FileNotFoundError):
@ -94,7 +122,9 @@ class PlatformDetector:
def _get_podman_version(self) -> Optional[str]: def _get_podman_version(self) -> Optional[str]:
try: try:
res = subprocess.run(["podman", "--version"], capture_output=True, text=True, timeout=5) res = subprocess.run(
["podman", "--version"], capture_output=True, text=True, timeout=5
)
if res.returncode == 0: if res.returncode == 0:
return res.stdout.strip() return res.stdout.strip()
except (subprocess.TimeoutExpired, FileNotFoundError): except (subprocess.TimeoutExpired, FileNotFoundError):
@ -110,7 +140,12 @@ class PlatformDetector:
if self.platform_system != "Darwin": if self.platform_system != "Darwin":
return True, 0, "Not running on macOS" return True, 0, "Not running on macOS"
try: try:
result = subprocess.run(["podman", "machine", "inspect"], capture_output=True, text=True, timeout=10) result = subprocess.run(
["podman", "machine", "inspect"],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode != 0: if result.returncode != 0:
return False, 0, "Could not inspect Podman machine" return False, 0, "Could not inspect Podman machine"
machines = json.loads(result.stdout) machines = json.loads(result.stdout)
@ -124,7 +159,11 @@ class PlatformDetector:
if not is_sufficient: if not is_sufficient:
status += "\nTo increase: podman machine stop && podman machine rm && podman machine init --memory 8192 && podman machine start" status += "\nTo increase: podman machine stop && podman machine rm && podman machine init --memory 8192 && podman machine start"
return is_sufficient, memory_mb, status return is_sufficient, memory_mb, status
except (subprocess.TimeoutExpired, FileNotFoundError, json.JSONDecodeError) as e: except (
subprocess.TimeoutExpired,
FileNotFoundError,
json.JSONDecodeError,
) as e:
return False, 0, f"Error checking Podman VM memory: {e}" return False, 0, f"Error checking Podman VM memory: {e}"
def get_installation_instructions(self) -> str: def get_installation_instructions(self) -> str:

View file

@ -8,15 +8,18 @@ from typing import Optional
class ValidationError(Exception): class ValidationError(Exception):
"""Validation error exception.""" """Validation error exception."""
pass pass
def validate_env_var_name(name: str) -> bool: def validate_env_var_name(name: str) -> bool:
"""Validate environment variable name format.""" """Validate environment variable name format."""
return bool(re.match(r'^[A-Z][A-Z0-9_]*$', name)) return bool(re.match(r"^[A-Z][A-Z0-9_]*$", name))
def validate_path(path: str, must_exist: bool = False, must_be_dir: bool = False) -> bool: def validate_path(
path: str, must_exist: bool = False, must_be_dir: bool = False
) -> bool:
"""Validate file/directory path.""" """Validate file/directory path."""
if not path: if not path:
return False return False
@ -41,12 +44,14 @@ def validate_url(url: str) -> bool:
return False return False
url_pattern = re.compile( url_pattern = re.compile(
r'^https?://' # http:// or https:// r"^https?://" # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|" # domain
r'localhost|' # localhost r"localhost|" # localhost
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # IP r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP
r'(?::\d+)?' # optional port r"(?::\d+)?" # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE) r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
return bool(url_pattern.match(url)) return bool(url_pattern.match(url))
@ -55,14 +60,14 @@ def validate_openai_api_key(key: str) -> bool:
"""Validate OpenAI API key format.""" """Validate OpenAI API key format."""
if not key: if not key:
return False return False
return key.startswith('sk-') and len(key) > 20 return key.startswith("sk-") and len(key) > 20
def validate_google_oauth_client_id(client_id: str) -> bool: def validate_google_oauth_client_id(client_id: str) -> bool:
"""Validate Google OAuth client ID format.""" """Validate Google OAuth client ID format."""
if not client_id: if not client_id:
return False return False
return client_id.endswith('.apps.googleusercontent.com') return client_id.endswith(".apps.googleusercontent.com")
def validate_non_empty(value: str) -> bool: def validate_non_empty(value: str) -> bool:
@ -77,8 +82,9 @@ def sanitize_env_value(value: str) -> str:
# Remove quotes if they wrap the entire value # Remove quotes if they wrap the entire value
if len(value) >= 2: if len(value) >= 2:
if (value.startswith('"') and value.endswith('"')) or \ if (value.startswith('"') and value.endswith('"')) or (
(value.startswith("'") and value.endswith("'")): value.startswith("'") and value.endswith("'")
):
value = value[1:-1] value = value[1:-1]
return value return value
@ -94,7 +100,7 @@ def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
if not paths_str: if not paths_str:
return False, "Documents paths cannot be empty", [] return False, "Documents paths cannot be empty", []
paths = [path.strip() for path in paths_str.split(',') if path.strip()] paths = [path.strip() for path in paths_str.split(",") if path.strip()]
if not paths: if not paths:
return False, "No valid paths provided", [] return False, "No valid paths provided", []

View file

@ -68,7 +68,7 @@ class CommandOutputModal(ModalScreen):
self, self,
title: str, title: str,
command_generator: AsyncIterator[tuple[bool, str]], command_generator: AsyncIterator[tuple[bool, str]],
on_complete: Optional[Callable] = None on_complete: Optional[Callable] = None,
): ):
"""Initialize the modal dialog. """Initialize the modal dialog.
@ -116,7 +116,9 @@ class CommandOutputModal(ModalScreen):
# If command is complete, update UI # If command is complete, update UI
if is_complete: if is_complete:
output.write("[bold green]Command completed successfully[/bold green]\n") output.write(
"[bold green]Command completed successfully[/bold green]\n"
)
# Call the completion callback if provided # Call the completion callback if provided
if self.on_complete: if self.on_complete:
await asyncio.sleep(0.5) # Small delay for better UX await asyncio.sleep(0.5) # Small delay for better UX
@ -129,4 +131,5 @@ class CommandOutputModal(ModalScreen):
close_btn.disabled = False close_btn.disabled = False
close_btn.focus() close_btn.focus()
# Made with Bob # Made with Bob

View file

@ -9,7 +9,7 @@ def notify_with_diagnostics(
app: App, app: App,
message: str, message: str,
severity: Literal["information", "warning", "error"] = "error", severity: Literal["information", "warning", "error"] = "error",
timeout: float = 10.0 timeout: float = 10.0,
) -> None: ) -> None:
"""Show a notification with a button to open the diagnostics screen. """Show a notification with a button to open the diagnostics screen.
@ -25,6 +25,7 @@ def notify_with_diagnostics(
# Then add a button to open diagnostics screen # Then add a button to open diagnostics screen
def open_diagnostics() -> None: def open_diagnostics() -> None:
from ..screens.diagnostics import DiagnosticsScreen from ..screens.diagnostics import DiagnosticsScreen
app.push_screen(DiagnosticsScreen()) app.push_screen(DiagnosticsScreen())
# Add a separate notification with just the button # Add a separate notification with just the button
@ -32,7 +33,8 @@ def notify_with_diagnostics(
"Click to view diagnostics", "Click to view diagnostics",
severity="information", severity="information",
timeout=timeout, timeout=timeout,
title="Diagnostics" title="Diagnostics",
) )
# Made with Bob # Made with Bob

View file

@ -9,7 +9,7 @@ def notify_with_diagnostics(
app: App, app: App,
message: str, message: str,
severity: Literal["information", "warning", "error"] = "error", severity: Literal["information", "warning", "error"] = "error",
timeout: float = 10.0 timeout: float = 10.0,
) -> None: ) -> None:
"""Show a notification with a button to open the diagnostics screen. """Show a notification with a button to open the diagnostics screen.
@ -25,6 +25,7 @@ def notify_with_diagnostics(
# Then add a button to open diagnostics screen # Then add a button to open diagnostics screen
def open_diagnostics() -> None: def open_diagnostics() -> None:
from ..screens.diagnostics import DiagnosticsScreen from ..screens.diagnostics import DiagnosticsScreen
app.push_screen(DiagnosticsScreen()) app.push_screen(DiagnosticsScreen())
# Add a separate notification with just the button # Add a separate notification with just the button
@ -32,7 +33,8 @@ def notify_with_diagnostics(
"Click to view diagnostics", "Click to view diagnostics",
severity="information", severity="information",
timeout=timeout, timeout=timeout,
title="Diagnostics" title="Diagnostics",
) )
# Made with Bob # Made with Bob

View file

@ -1,7 +1,12 @@
import hashlib import hashlib
import os import os
import sys
import platform
from collections import defaultdict from collections import defaultdict
from .gpu_detection import detect_gpu_devices from .gpu_detection import detect_gpu_devices
from utils.logging_config import get_logger
logger = get_logger(__name__)
# Global converter cache for worker processes # Global converter cache for worker processes
_worker_converter = None _worker_converter = None
@ -37,11 +42,11 @@ def get_worker_converter():
"1" # Still disable progress bars "1" # Still disable progress bars
) )
print( logger.info(
f"[WORKER {os.getpid()}] Initializing DocumentConverter in worker process" "Initializing DocumentConverter in worker process", worker_pid=os.getpid()
) )
_worker_converter = DocumentConverter() _worker_converter = DocumentConverter()
print(f"[WORKER {os.getpid()}] DocumentConverter ready in worker process") logger.info("DocumentConverter ready in worker process", worker_pid=os.getpid())
return _worker_converter return _worker_converter
@ -118,33 +123,45 @@ def process_document_sync(file_path: str):
start_memory = process.memory_info().rss / 1024 / 1024 # MB start_memory = process.memory_info().rss / 1024 / 1024 # MB
try: try:
print(f"[WORKER {os.getpid()}] Starting document processing: {file_path}") logger.info(
print(f"[WORKER {os.getpid()}] Initial memory usage: {start_memory:.1f} MB") "Starting document processing",
worker_pid=os.getpid(),
file_path=file_path,
initial_memory_mb=f"{start_memory:.1f}",
)
# Check file size # Check file size
try: try:
file_size = os.path.getsize(file_path) / 1024 / 1024 # MB file_size = os.path.getsize(file_path) / 1024 / 1024 # MB
print(f"[WORKER {os.getpid()}] File size: {file_size:.1f} MB") logger.info(
"File size determined",
worker_pid=os.getpid(),
file_size_mb=f"{file_size:.1f}",
)
except OSError as e: except OSError as e:
print(f"[WORKER {os.getpid()}] WARNING: Cannot get file size: {e}") logger.warning("Cannot get file size", worker_pid=os.getpid(), error=str(e))
file_size = 0 file_size = 0
# Get the cached converter for this worker # Get the cached converter for this worker
try: try:
print(f"[WORKER {os.getpid()}] Getting document converter...") logger.info("Getting document converter", worker_pid=os.getpid())
converter = get_worker_converter() converter = get_worker_converter()
memory_after_converter = process.memory_info().rss / 1024 / 1024 memory_after_converter = process.memory_info().rss / 1024 / 1024
print( logger.info(
f"[WORKER {os.getpid()}] Memory after converter init: {memory_after_converter:.1f} MB" "Memory after converter init",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_converter:.1f}",
) )
except Exception as e: except Exception as e:
print(f"[WORKER {os.getpid()}] ERROR: Failed to initialize converter: {e}") logger.error(
"Failed to initialize converter", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc() traceback.print_exc()
raise raise
# Compute file hash # Compute file hash
try: try:
print(f"[WORKER {os.getpid()}] Computing file hash...") logger.info("Computing file hash", worker_pid=os.getpid())
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
while True: while True:
@ -153,50 +170,67 @@ def process_document_sync(file_path: str):
break break
sha256.update(chunk) sha256.update(chunk)
file_hash = sha256.hexdigest() file_hash = sha256.hexdigest()
print(f"[WORKER {os.getpid()}] File hash computed: {file_hash[:12]}...") logger.info(
"File hash computed",
worker_pid=os.getpid(),
file_hash_prefix=file_hash[:12],
)
except Exception as e: except Exception as e:
print(f"[WORKER {os.getpid()}] ERROR: Failed to compute file hash: {e}") logger.error(
"Failed to compute file hash", worker_pid=os.getpid(), error=str(e)
)
traceback.print_exc() traceback.print_exc()
raise raise
# Convert with docling # Convert with docling
try: try:
print(f"[WORKER {os.getpid()}] Starting docling conversion...") logger.info("Starting docling conversion", worker_pid=os.getpid())
memory_before_convert = process.memory_info().rss / 1024 / 1024 memory_before_convert = process.memory_info().rss / 1024 / 1024
print( logger.info(
f"[WORKER {os.getpid()}] Memory before conversion: {memory_before_convert:.1f} MB" "Memory before conversion",
worker_pid=os.getpid(),
memory_mb=f"{memory_before_convert:.1f}",
) )
result = converter.convert(file_path) result = converter.convert(file_path)
memory_after_convert = process.memory_info().rss / 1024 / 1024 memory_after_convert = process.memory_info().rss / 1024 / 1024
print( logger.info(
f"[WORKER {os.getpid()}] Memory after conversion: {memory_after_convert:.1f} MB" "Memory after conversion",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_convert:.1f}",
) )
print(f"[WORKER {os.getpid()}] Docling conversion completed") logger.info("Docling conversion completed", worker_pid=os.getpid())
full_doc = result.document.export_to_dict() full_doc = result.document.export_to_dict()
memory_after_export = process.memory_info().rss / 1024 / 1024 memory_after_export = process.memory_info().rss / 1024 / 1024
print( logger.info(
f"[WORKER {os.getpid()}] Memory after export: {memory_after_export:.1f} MB" "Memory after export",
worker_pid=os.getpid(),
memory_mb=f"{memory_after_export:.1f}",
) )
except Exception as e: except Exception as e:
print( current_memory = process.memory_info().rss / 1024 / 1024
f"[WORKER {os.getpid()}] ERROR: Failed during docling conversion: {e}" logger.error(
) "Failed during docling conversion",
print( worker_pid=os.getpid(),
f"[WORKER {os.getpid()}] Current memory usage: {process.memory_info().rss / 1024 / 1024:.1f} MB" error=str(e),
current_memory_mb=f"{current_memory:.1f}",
) )
traceback.print_exc() traceback.print_exc()
raise raise
# Extract relevant content (same logic as extract_relevant) # Extract relevant content (same logic as extract_relevant)
try: try:
print(f"[WORKER {os.getpid()}] Extracting relevant content...") logger.info("Extracting relevant content", worker_pid=os.getpid())
origin = full_doc.get("origin", {}) origin = full_doc.get("origin", {})
texts = full_doc.get("texts", []) texts = full_doc.get("texts", [])
print(f"[WORKER {os.getpid()}] Found {len(texts)} text fragments") logger.info(
"Found text fragments",
worker_pid=os.getpid(),
fragment_count=len(texts),
)
page_texts = defaultdict(list) page_texts = defaultdict(list)
for txt in texts: for txt in texts:
@ -210,22 +244,27 @@ def process_document_sync(file_path: str):
joined = "\n".join(page_texts[page]) joined = "\n".join(page_texts[page])
chunks.append({"page": page, "text": joined}) chunks.append({"page": page, "text": joined})
print( logger.info(
f"[WORKER {os.getpid()}] Created {len(chunks)} chunks from {len(page_texts)} pages" "Created chunks from pages",
worker_pid=os.getpid(),
chunk_count=len(chunks),
page_count=len(page_texts),
) )
except Exception as e: except Exception as e:
print( logger.error(
f"[WORKER {os.getpid()}] ERROR: Failed during content extraction: {e}" "Failed during content extraction", worker_pid=os.getpid(), error=str(e)
) )
traceback.print_exc() traceback.print_exc()
raise raise
final_memory = process.memory_info().rss / 1024 / 1024 final_memory = process.memory_info().rss / 1024 / 1024
memory_delta = final_memory - start_memory memory_delta = final_memory - start_memory
print(f"[WORKER {os.getpid()}] Document processing completed successfully") logger.info(
print( "Document processing completed successfully",
f"[WORKER {os.getpid()}] Final memory: {final_memory:.1f} MB (Delta +{memory_delta:.1f} MB)" worker_pid=os.getpid(),
final_memory_mb=f"{final_memory:.1f}",
memory_delta_mb=f"{memory_delta:.1f}",
) )
return { return {
@ -239,24 +278,29 @@ def process_document_sync(file_path: str):
except Exception as e: except Exception as e:
final_memory = process.memory_info().rss / 1024 / 1024 final_memory = process.memory_info().rss / 1024 / 1024
memory_delta = final_memory - start_memory memory_delta = final_memory - start_memory
print(f"[WORKER {os.getpid()}] FATAL ERROR in process_document_sync") logger.error(
print(f"[WORKER {os.getpid()}] File: {file_path}") "FATAL ERROR in process_document_sync",
print(f"[WORKER {os.getpid()}] Python version: {sys.version}") worker_pid=os.getpid(),
print( file_path=file_path,
f"[WORKER {os.getpid()}] Memory at crash: {final_memory:.1f} MB (Delta +{memory_delta:.1f} MB)" python_version=sys.version,
memory_at_crash_mb=f"{final_memory:.1f}",
memory_delta_mb=f"{memory_delta:.1f}",
error_type=type(e).__name__,
error=str(e),
) )
print(f"[WORKER {os.getpid()}] Error: {type(e).__name__}: {e}") logger.error("Full traceback:", worker_pid=os.getpid())
print(f"[WORKER {os.getpid()}] Full traceback:")
traceback.print_exc() traceback.print_exc()
# Try to get more system info before crashing # Try to get more system info before crashing
try: try:
import platform import platform
print( logger.error(
f"[WORKER {os.getpid()}] System: {platform.system()} {platform.release()}" "System info",
worker_pid=os.getpid(),
system=f"{platform.system()} {platform.release()}",
architecture=platform.machine(),
) )
print(f"[WORKER {os.getpid()}] Architecture: {platform.machine()}")
except: except:
pass pass

View file

@ -1,5 +1,8 @@
import multiprocessing import multiprocessing
import os import os
from utils.logging_config import get_logger
logger = get_logger(__name__)
def detect_gpu_devices(): def detect_gpu_devices():
@ -30,13 +33,15 @@ def get_worker_count():
if has_gpu_devices: if has_gpu_devices:
default_workers = min(4, multiprocessing.cpu_count() // 2) default_workers = min(4, multiprocessing.cpu_count() // 2)
print( logger.info(
f"GPU mode enabled with {gpu_count} GPU(s) - using limited concurrency ({default_workers} workers)" "GPU mode enabled with limited concurrency",
gpu_count=gpu_count,
worker_count=default_workers,
) )
else: else:
default_workers = multiprocessing.cpu_count() default_workers = multiprocessing.cpu_count()
print( logger.info(
f"CPU-only mode enabled - using full concurrency ({default_workers} workers)" "CPU-only mode enabled with full concurrency", worker_count=default_workers
) )
return int(os.getenv("MAX_WORKERS", default_workers)) return int(os.getenv("MAX_WORKERS", default_workers))

View file

@ -9,12 +9,14 @@ def configure_logging(
log_level: str = "INFO", log_level: str = "INFO",
json_logs: bool = False, json_logs: bool = False,
include_timestamps: bool = True, include_timestamps: bool = True,
service_name: str = "openrag" service_name: str = "openrag",
) -> None: ) -> None:
"""Configure structlog for the application.""" """Configure structlog for the application."""
# Convert string log level to actual level # Convert string log level to actual level
level = getattr(structlog.stdlib.logging, log_level.upper(), structlog.stdlib.logging.INFO) level = getattr(
structlog.stdlib.logging, log_level.upper(), structlog.stdlib.logging.INFO
)
# Base processors # Base processors
shared_processors = [ shared_processors = [
@ -27,10 +29,15 @@ def configure_logging(
if include_timestamps: if include_timestamps:
shared_processors.append(structlog.processors.TimeStamper(fmt="iso")) shared_processors.append(structlog.processors.TimeStamper(fmt="iso"))
# Add service name to all logs # Add service name and file location to all logs
shared_processors.append( shared_processors.append(
structlog.processors.CallsiteParameterAdder( structlog.processors.CallsiteParameterAdder(
parameters=[structlog.processors.CallsiteParameter.FUNC_NAME] parameters=[
structlog.processors.CallsiteParameter.FUNC_NAME,
structlog.processors.CallsiteParameter.FILENAME,
structlog.processors.CallsiteParameter.LINENO,
structlog.processors.CallsiteParameter.PATHNAME,
]
) )
) )
@ -40,11 +47,42 @@ def configure_logging(
shared_processors.append(structlog.processors.JSONRenderer()) shared_processors.append(structlog.processors.JSONRenderer())
console_renderer = structlog.processors.JSONRenderer() console_renderer = structlog.processors.JSONRenderer()
else: else:
# Pretty colored output for development # Custom clean format: timestamp path/file:loc logentry
console_renderer = structlog.dev.ConsoleRenderer( def custom_formatter(logger, log_method, event_dict):
colors=sys.stderr.isatty(), timestamp = event_dict.pop("timestamp", "")
exception_formatter=structlog.dev.plain_traceback, pathname = event_dict.pop("pathname", "")
) filename = event_dict.pop("filename", "")
lineno = event_dict.pop("lineno", "")
level = event_dict.pop("level", "")
# Build file location - prefer pathname for full path, fallback to filename
if pathname and lineno:
location = f"{pathname}:{lineno}"
elif filename and lineno:
location = f"{filename}:{lineno}"
elif pathname:
location = pathname
elif filename:
location = filename
else:
location = "unknown"
# Build the main message
message_parts = []
event = event_dict.pop("event", "")
if event:
message_parts.append(event)
# Add any remaining context
for key, value in event_dict.items():
if key not in ["service", "func_name"]: # Skip internal fields
message_parts.append(f"{key}={value}")
message = " ".join(message_parts)
return f"{timestamp} {location} {message}"
console_renderer = custom_formatter
# Configure structlog # Configure structlog
structlog.configure( structlog.configure(
@ -75,7 +113,5 @@ def configure_from_env() -> None:
service_name = os.getenv("SERVICE_NAME", "openrag") service_name = os.getenv("SERVICE_NAME", "openrag")
configure_logging( configure_logging(
log_level=log_level, log_level=log_level, json_logs=json_logs, service_name=service_name
json_logs=json_logs,
service_name=service_name
) )

View file

@ -1,10 +1,13 @@
import os import os
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from utils.gpu_detection import get_worker_count from utils.gpu_detection import get_worker_count
from utils.logging_config import get_logger
logger = get_logger(__name__)
# Create shared process pool at import time (before CUDA initialization) # Create shared process pool at import time (before CUDA initialization)
# This avoids the "Cannot re-initialize CUDA in forked subprocess" error # This avoids the "Cannot re-initialize CUDA in forked subprocess" error
MAX_WORKERS = get_worker_count() MAX_WORKERS = get_worker_count()
process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS) process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
print(f"Shared process pool initialized with {MAX_WORKERS} workers") logger.info("Shared process pool initialized", max_workers=MAX_WORKERS)

View file

@ -1,13 +1,17 @@
from docling.document_converter import DocumentConverter from docling.document_converter import DocumentConverter
import logging
print("Warming up docling models...") logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Warming up docling models")
try: try:
# Use the sample document to warm up docling # Use the sample document to warm up docling
test_file = "/app/warmup_ocr.pdf" test_file = "/app/warmup_ocr.pdf"
print(f"Using {test_file} to warm up docling...") logger.info(f"Using test file to warm up docling: {test_file}")
DocumentConverter().convert(test_file) DocumentConverter().convert(test_file)
print("Docling models warmed up successfully") logger.info("Docling models warmed up successfully")
except Exception as e: except Exception as e:
print(f"Docling warm-up completed with: {e}") logger.info(f"Docling warm-up completed with exception: {str(e)}")
# This is expected - we just want to trigger the model downloads # This is expected - we just want to trigger the model downloads