Merge branch 'feat-googledrive-enhancements' of github.com:langflow-ai/openrag into feat-googledrive-enhancements
This commit is contained in:
commit
86d8c9f5b2
41 changed files with 1730 additions and 967 deletions
BIN
.DS_Store
vendored
Normal file
BIN
.DS_Store
vendored
Normal file
Binary file not shown.
BIN
documents/ai-human-resources.pdf
Normal file
BIN
documents/ai-human-resources.pdf
Normal file
Binary file not shown.
96
src/agent.py
96
src/agent.py
|
|
@ -95,7 +95,9 @@ async def async_response_stream(
|
|||
chunk_count = 0
|
||||
async for chunk in response:
|
||||
chunk_count += 1
|
||||
logger.debug("Stream chunk received", chunk_count=chunk_count, chunk=str(chunk))
|
||||
logger.debug(
|
||||
"Stream chunk received", chunk_count=chunk_count, chunk=str(chunk)
|
||||
)
|
||||
|
||||
# Yield the raw event as JSON for the UI to process
|
||||
import json
|
||||
|
|
@ -241,7 +243,10 @@ async def async_langflow_stream(
|
|||
previous_response_id=previous_response_id,
|
||||
log_prefix="langflow",
|
||||
):
|
||||
logger.debug("Yielding chunk from langflow stream", chunk_preview=chunk[:100].decode('utf-8', errors='replace'))
|
||||
logger.debug(
|
||||
"Yielding chunk from langflow stream",
|
||||
chunk_preview=chunk[:100].decode("utf-8", errors="replace"),
|
||||
)
|
||||
yield chunk
|
||||
logger.debug("Langflow stream completed")
|
||||
except Exception as e:
|
||||
|
|
@ -260,18 +265,24 @@ async def async_chat(
|
|||
model: str = "gpt-4.1-mini",
|
||||
previous_response_id: str = None,
|
||||
):
|
||||
logger.debug("async_chat called", user_id=user_id, previous_response_id=previous_response_id)
|
||||
logger.debug(
|
||||
"async_chat called", user_id=user_id, previous_response_id=previous_response_id
|
||||
)
|
||||
|
||||
# Get the specific conversation thread (or create new one)
|
||||
conversation_state = get_conversation_thread(user_id, previous_response_id)
|
||||
logger.debug("Got conversation state", message_count=len(conversation_state['messages']))
|
||||
logger.debug(
|
||||
"Got conversation state", message_count=len(conversation_state["messages"])
|
||||
)
|
||||
|
||||
# Add user message to conversation with timestamp
|
||||
from datetime import datetime
|
||||
|
||||
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
|
||||
conversation_state["messages"].append(user_message)
|
||||
logger.debug("Added user message", message_count=len(conversation_state['messages']))
|
||||
logger.debug(
|
||||
"Added user message", message_count=len(conversation_state["messages"])
|
||||
)
|
||||
|
||||
response_text, response_id = await async_response(
|
||||
async_client,
|
||||
|
|
@ -280,7 +291,9 @@ async def async_chat(
|
|||
previous_response_id=previous_response_id,
|
||||
log_prefix="agent",
|
||||
)
|
||||
logger.debug("Got response", response_preview=response_text[:50], response_id=response_id)
|
||||
logger.debug(
|
||||
"Got response", response_preview=response_text[:50], response_id=response_id
|
||||
)
|
||||
|
||||
# Add assistant response to conversation with response_id and timestamp
|
||||
assistant_message = {
|
||||
|
|
@ -290,17 +303,26 @@ async def async_chat(
|
|||
"timestamp": datetime.now(),
|
||||
}
|
||||
conversation_state["messages"].append(assistant_message)
|
||||
logger.debug("Added assistant message", message_count=len(conversation_state['messages']))
|
||||
logger.debug(
|
||||
"Added assistant message", message_count=len(conversation_state["messages"])
|
||||
)
|
||||
|
||||
# Store the conversation thread with its response_id
|
||||
if response_id:
|
||||
conversation_state["last_activity"] = datetime.now()
|
||||
store_conversation_thread(user_id, response_id, conversation_state)
|
||||
logger.debug("Stored conversation thread", user_id=user_id, response_id=response_id)
|
||||
logger.debug(
|
||||
"Stored conversation thread", user_id=user_id, response_id=response_id
|
||||
)
|
||||
|
||||
# Debug: Check what's in user_conversations now
|
||||
conversations = get_user_conversations(user_id)
|
||||
logger.debug("User conversations updated", user_id=user_id, conversation_count=len(conversations), conversation_ids=list(conversations.keys()))
|
||||
logger.debug(
|
||||
"User conversations updated",
|
||||
user_id=user_id,
|
||||
conversation_count=len(conversations),
|
||||
conversation_ids=list(conversations.keys()),
|
||||
)
|
||||
else:
|
||||
logger.warning("No response_id received, conversation not stored")
|
||||
|
||||
|
|
@ -363,7 +385,9 @@ async def async_chat_stream(
|
|||
if response_id:
|
||||
conversation_state["last_activity"] = datetime.now()
|
||||
store_conversation_thread(user_id, response_id, conversation_state)
|
||||
logger.debug("Stored conversation thread", user_id=user_id, response_id=response_id)
|
||||
logger.debug(
|
||||
"Stored conversation thread", user_id=user_id, response_id=response_id
|
||||
)
|
||||
|
||||
|
||||
# Async langflow function with conversation storage (non-streaming)
|
||||
|
|
@ -375,18 +399,28 @@ async def async_langflow_chat(
|
|||
extra_headers: dict = None,
|
||||
previous_response_id: str = None,
|
||||
):
|
||||
logger.debug("async_langflow_chat called", user_id=user_id, previous_response_id=previous_response_id)
|
||||
logger.debug(
|
||||
"async_langflow_chat called",
|
||||
user_id=user_id,
|
||||
previous_response_id=previous_response_id,
|
||||
)
|
||||
|
||||
# Get the specific conversation thread (or create new one)
|
||||
conversation_state = get_conversation_thread(user_id, previous_response_id)
|
||||
logger.debug("Got langflow conversation state", message_count=len(conversation_state['messages']))
|
||||
logger.debug(
|
||||
"Got langflow conversation state",
|
||||
message_count=len(conversation_state["messages"]),
|
||||
)
|
||||
|
||||
# Add user message to conversation with timestamp
|
||||
from datetime import datetime
|
||||
|
||||
user_message = {"role": "user", "content": prompt, "timestamp": datetime.now()}
|
||||
conversation_state["messages"].append(user_message)
|
||||
logger.debug("Added user message to langflow", message_count=len(conversation_state['messages']))
|
||||
logger.debug(
|
||||
"Added user message to langflow",
|
||||
message_count=len(conversation_state["messages"]),
|
||||
)
|
||||
|
||||
response_text, response_id = await async_response(
|
||||
langflow_client,
|
||||
|
|
@ -396,7 +430,11 @@ async def async_langflow_chat(
|
|||
previous_response_id=previous_response_id,
|
||||
log_prefix="langflow",
|
||||
)
|
||||
logger.debug("Got langflow response", response_preview=response_text[:50], response_id=response_id)
|
||||
logger.debug(
|
||||
"Got langflow response",
|
||||
response_preview=response_text[:50],
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
# Add assistant response to conversation with response_id and timestamp
|
||||
assistant_message = {
|
||||
|
|
@ -406,17 +444,29 @@ async def async_langflow_chat(
|
|||
"timestamp": datetime.now(),
|
||||
}
|
||||
conversation_state["messages"].append(assistant_message)
|
||||
logger.debug("Added assistant message to langflow", message_count=len(conversation_state['messages']))
|
||||
logger.debug(
|
||||
"Added assistant message to langflow",
|
||||
message_count=len(conversation_state["messages"]),
|
||||
)
|
||||
|
||||
# Store the conversation thread with its response_id
|
||||
if response_id:
|
||||
conversation_state["last_activity"] = datetime.now()
|
||||
store_conversation_thread(user_id, response_id, conversation_state)
|
||||
logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id)
|
||||
logger.debug(
|
||||
"Stored langflow conversation thread",
|
||||
user_id=user_id,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
# Debug: Check what's in user_conversations now
|
||||
conversations = get_user_conversations(user_id)
|
||||
logger.debug("User conversations updated", user_id=user_id, conversation_count=len(conversations), conversation_ids=list(conversations.keys()))
|
||||
logger.debug(
|
||||
"User conversations updated",
|
||||
user_id=user_id,
|
||||
conversation_count=len(conversations),
|
||||
conversation_ids=list(conversations.keys()),
|
||||
)
|
||||
else:
|
||||
logger.warning("No response_id received from langflow, conversation not stored")
|
||||
|
||||
|
|
@ -432,7 +482,11 @@ async def async_langflow_chat_stream(
|
|||
extra_headers: dict = None,
|
||||
previous_response_id: str = None,
|
||||
):
|
||||
logger.debug("async_langflow_chat_stream called", user_id=user_id, previous_response_id=previous_response_id)
|
||||
logger.debug(
|
||||
"async_langflow_chat_stream called",
|
||||
user_id=user_id,
|
||||
previous_response_id=previous_response_id,
|
||||
)
|
||||
|
||||
# Get the specific conversation thread (or create new one)
|
||||
conversation_state = get_conversation_thread(user_id, previous_response_id)
|
||||
|
|
@ -483,4 +537,8 @@ async def async_langflow_chat_stream(
|
|||
if response_id:
|
||||
conversation_state["last_activity"] = datetime.now()
|
||||
store_conversation_thread(user_id, response_id, conversation_state)
|
||||
logger.debug("Stored langflow conversation thread", user_id=user_id, response_id=response_id)
|
||||
logger.debug(
|
||||
"Stored langflow conversation thread",
|
||||
user_id=user_id,
|
||||
response_id=response_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,11 @@ async def connector_sync(request: Request, connector_service, session_manager):
|
|||
selected_files = data.get("selected_files")
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
"Starting connector sync",
|
||||
connector_type=connector_type,
|
||||
max_files=max_files,
|
||||
)
|
||||
user = request.state.user
|
||||
jwt_token = request.state.jwt_token
|
||||
|
||||
|
|
@ -44,6 +49,10 @@ async def connector_sync(request: Request, connector_service, session_manager):
|
|||
# Start sync tasks for all active connections
|
||||
task_ids = []
|
||||
for connection in active_connections:
|
||||
logger.debug(
|
||||
"About to call sync_connector_files for connection",
|
||||
connection_id=connection.connection_id,
|
||||
)
|
||||
if selected_files:
|
||||
task_id = await connector_service.sync_specific_files(
|
||||
connection.connection_id,
|
||||
|
|
@ -58,8 +67,6 @@ async def connector_sync(request: Request, connector_service, session_manager):
|
|||
max_files,
|
||||
jwt_token=jwt_token,
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"task_ids": task_ids,
|
||||
|
|
@ -170,7 +177,9 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
channel_id = None
|
||||
|
||||
if not channel_id:
|
||||
logger.warning("No channel ID found in webhook", connector_type=connector_type)
|
||||
logger.warning(
|
||||
"No channel ID found in webhook", connector_type=connector_type
|
||||
)
|
||||
return JSONResponse({"status": "ignored", "reason": "no_channel_id"})
|
||||
|
||||
# Find the specific connection for this webhook
|
||||
|
|
@ -180,7 +189,9 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
)
|
||||
)
|
||||
if not connection or not connection.is_active:
|
||||
logger.info("Unknown webhook channel, will auto-expire", channel_id=channel_id)
|
||||
logger.info(
|
||||
"Unknown webhook channel, will auto-expire", channel_id=channel_id
|
||||
)
|
||||
return JSONResponse(
|
||||
{"status": "ignored_unknown_channel", "channel_id": channel_id}
|
||||
)
|
||||
|
|
@ -190,7 +201,10 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
# Get the connector instance
|
||||
connector = await connector_service._get_connector(connection.connection_id)
|
||||
if not connector:
|
||||
logger.error("Could not get connector for connection", connection_id=connection.connection_id)
|
||||
logger.error(
|
||||
"Could not get connector for connection",
|
||||
connection_id=connection.connection_id,
|
||||
)
|
||||
return JSONResponse(
|
||||
{"status": "error", "reason": "connector_not_found"}
|
||||
)
|
||||
|
|
@ -199,7 +213,11 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
affected_files = await connector.handle_webhook(payload)
|
||||
|
||||
if affected_files:
|
||||
logger.info("Webhook connection files affected", connection_id=connection.connection_id, affected_count=len(affected_files))
|
||||
logger.info(
|
||||
"Webhook connection files affected",
|
||||
connection_id=connection.connection_id,
|
||||
affected_count=len(affected_files),
|
||||
)
|
||||
|
||||
# Generate JWT token for the user (needed for OpenSearch authentication)
|
||||
user = session_manager.get_user(connection.user_id)
|
||||
|
|
@ -223,7 +241,10 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
}
|
||||
else:
|
||||
# No specific files identified - just log the webhook
|
||||
logger.info("Webhook general change detected, no specific files", connection_id=connection.connection_id)
|
||||
logger.info(
|
||||
"Webhook general change detected, no specific files",
|
||||
connection_id=connection.connection_id,
|
||||
)
|
||||
|
||||
result = {
|
||||
"connection_id": connection.connection_id,
|
||||
|
|
@ -241,7 +262,15 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to process webhook for connection", connection_id=connection.connection_id, error=str(e))
|
||||
logger.error(
|
||||
"Failed to process webhook for connection",
|
||||
connection_id=connection.connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "error",
|
||||
|
|
|
|||
|
|
@ -395,15 +395,19 @@ async def knowledge_filter_webhook(
|
|||
# Get the webhook payload
|
||||
payload = await request.json()
|
||||
|
||||
logger.info("Knowledge filter webhook received",
|
||||
filter_id=filter_id,
|
||||
subscription_id=subscription_id,
|
||||
payload_size=len(str(payload)))
|
||||
logger.info(
|
||||
"Knowledge filter webhook received",
|
||||
filter_id=filter_id,
|
||||
subscription_id=subscription_id,
|
||||
payload_size=len(str(payload)),
|
||||
)
|
||||
|
||||
# Extract findings from the payload
|
||||
findings = payload.get("findings", [])
|
||||
if not findings:
|
||||
logger.info("No findings in webhook payload", subscription_id=subscription_id)
|
||||
logger.info(
|
||||
"No findings in webhook payload", subscription_id=subscription_id
|
||||
)
|
||||
return JSONResponse({"status": "no_findings"})
|
||||
|
||||
# Process the findings - these are the documents that matched the knowledge filter
|
||||
|
|
@ -420,14 +424,18 @@ async def knowledge_filter_webhook(
|
|||
)
|
||||
|
||||
# Log the matched documents
|
||||
logger.info("Knowledge filter matched documents",
|
||||
filter_id=filter_id,
|
||||
matched_count=len(matched_documents))
|
||||
logger.info(
|
||||
"Knowledge filter matched documents",
|
||||
filter_id=filter_id,
|
||||
matched_count=len(matched_documents),
|
||||
)
|
||||
for doc in matched_documents:
|
||||
logger.debug("Matched document",
|
||||
document_id=doc['document_id'],
|
||||
index=doc['index'],
|
||||
score=doc.get('score'))
|
||||
logger.debug(
|
||||
"Matched document",
|
||||
document_id=doc["document_id"],
|
||||
index=doc["index"],
|
||||
score=doc.get("score"),
|
||||
)
|
||||
|
||||
# Here you could add additional processing:
|
||||
# - Send notifications to external webhooks
|
||||
|
|
@ -446,10 +454,12 @@ async def knowledge_filter_webhook(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to process knowledge filter webhook",
|
||||
filter_id=filter_id,
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
logger.error(
|
||||
"Failed to process knowledge filter webhook",
|
||||
filter_id=filter_id,
|
||||
subscription_id=subscription_id,
|
||||
error=str(e),
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
|
|
|||
|
|
@ -23,14 +23,16 @@ async def search(request: Request, search_service, session_manager):
|
|||
# Extract JWT token from auth middleware
|
||||
jwt_token = request.state.jwt_token
|
||||
|
||||
logger.debug("Search API request",
|
||||
user=str(user),
|
||||
user_id=user.user_id if user else None,
|
||||
has_jwt_token=jwt_token is not None,
|
||||
query=query,
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold)
|
||||
logger.debug(
|
||||
"Search API request",
|
||||
user=str(user),
|
||||
user_id=user.user_id if user else None,
|
||||
has_jwt_token=jwt_token is not None,
|
||||
query=query,
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
result = await search_service.search(
|
||||
query,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ async def upload(request: Request, document_service, session_manager):
|
|||
jwt_token = request.state.jwt_token
|
||||
|
||||
from config.settings import is_no_auth_mode
|
||||
|
||||
|
||||
# In no-auth mode, pass None for owner fields so documents have no owner
|
||||
# This allows all users to see them when switching to auth mode
|
||||
if is_no_auth_mode():
|
||||
|
|
@ -25,7 +25,7 @@ async def upload(request: Request, document_service, session_manager):
|
|||
owner_user_id = user.user_id
|
||||
owner_name = user.name
|
||||
owner_email = user.email
|
||||
|
||||
|
||||
result = await document_service.process_upload_file(
|
||||
upload_file,
|
||||
owner_user_id=owner_user_id,
|
||||
|
|
@ -61,9 +61,9 @@ async def upload_path(request: Request, task_service, session_manager):
|
|||
|
||||
user = request.state.user
|
||||
jwt_token = request.state.jwt_token
|
||||
|
||||
|
||||
from config.settings import is_no_auth_mode
|
||||
|
||||
|
||||
# In no-auth mode, pass None for owner fields so documents have no owner
|
||||
if is_no_auth_mode():
|
||||
owner_user_id = None
|
||||
|
|
@ -73,7 +73,7 @@ async def upload_path(request: Request, task_service, session_manager):
|
|||
owner_user_id = user.user_id
|
||||
owner_name = user.name
|
||||
owner_email = user.email
|
||||
|
||||
|
||||
task_id = await task_service.create_upload_task(
|
||||
owner_user_id,
|
||||
file_paths,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ from starlette.responses import JSONResponse
|
|||
from typing import Optional
|
||||
from session_manager import User
|
||||
from config.settings import is_no_auth_mode
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_current_user(request: Request, session_manager) -> Optional[User]:
|
||||
|
|
@ -25,22 +28,15 @@ def require_auth(session_manager):
|
|||
async def wrapper(request: Request):
|
||||
# In no-auth mode, bypass authentication entirely
|
||||
if is_no_auth_mode():
|
||||
print(f"[DEBUG] No-auth mode: Creating anonymous user")
|
||||
logger.debug("No-auth mode: Creating anonymous user")
|
||||
# Create an anonymous user object so endpoints don't break
|
||||
from session_manager import User
|
||||
from datetime import datetime
|
||||
|
||||
request.state.user = User(
|
||||
user_id="anonymous",
|
||||
email="anonymous@localhost",
|
||||
name="Anonymous User",
|
||||
picture=None,
|
||||
provider="none",
|
||||
created_at=datetime.now(),
|
||||
last_login=datetime.now(),
|
||||
)
|
||||
from session_manager import AnonymousUser
|
||||
request.state.user = AnonymousUser()
|
||||
request.state.jwt_token = None # No JWT in no-auth mode
|
||||
print(f"[DEBUG] Set user_id=anonymous, jwt_token=None")
|
||||
logger.debug("Set user_id=anonymous, jwt_token=None")
|
||||
return await handler(request)
|
||||
|
||||
user = get_current_user(request, session_manager)
|
||||
|
|
@ -72,15 +68,8 @@ def optional_auth(session_manager):
|
|||
from session_manager import User
|
||||
from datetime import datetime
|
||||
|
||||
request.state.user = User(
|
||||
user_id="anonymous",
|
||||
email="anonymous@localhost",
|
||||
name="Anonymous User",
|
||||
picture=None,
|
||||
provider="none",
|
||||
created_at=datetime.now(),
|
||||
last_login=datetime.now(),
|
||||
)
|
||||
from session_manager import AnonymousUser
|
||||
request.state.user = AnonymousUser()
|
||||
request.state.jwt_token = None # No JWT in no-auth mode
|
||||
else:
|
||||
user = get_current_user(request, session_manager)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,12 @@ GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
|||
def is_no_auth_mode():
|
||||
"""Check if we're running in no-auth mode (OAuth credentials missing)"""
|
||||
result = not (GOOGLE_OAUTH_CLIENT_ID and GOOGLE_OAUTH_CLIENT_SECRET)
|
||||
logger.debug("Checking auth mode", no_auth_mode=result, has_client_id=GOOGLE_OAUTH_CLIENT_ID is not None, has_client_secret=GOOGLE_OAUTH_CLIENT_SECRET is not None)
|
||||
logger.debug(
|
||||
"Checking auth mode",
|
||||
no_auth_mode=result,
|
||||
has_client_id=GOOGLE_OAUTH_CLIENT_ID is not None,
|
||||
has_client_secret=GOOGLE_OAUTH_CLIENT_SECRET is not None,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -99,7 +104,9 @@ async def generate_langflow_api_key():
|
|||
return LANGFLOW_KEY
|
||||
|
||||
if not LANGFLOW_SUPERUSER or not LANGFLOW_SUPERUSER_PASSWORD:
|
||||
logger.warning("LANGFLOW_SUPERUSER and LANGFLOW_SUPERUSER_PASSWORD not set, skipping API key generation")
|
||||
logger.warning(
|
||||
"LANGFLOW_SUPERUSER and LANGFLOW_SUPERUSER_PASSWORD not set, skipping API key generation"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
|
|
@ -141,11 +148,19 @@ async def generate_langflow_api_key():
|
|||
raise KeyError("api_key")
|
||||
|
||||
LANGFLOW_KEY = api_key
|
||||
logger.info("Successfully generated Langflow API key", api_key_preview=api_key[:8])
|
||||
logger.info(
|
||||
"Successfully generated Langflow API key",
|
||||
api_key_preview=api_key[:8],
|
||||
)
|
||||
return api_key
|
||||
except (requests.exceptions.RequestException, KeyError) as e:
|
||||
last_error = e
|
||||
logger.warning("Attempt to generate Langflow API key failed", attempt=attempt, max_attempts=max_attempts, error=str(e))
|
||||
logger.warning(
|
||||
"Attempt to generate Langflow API key failed",
|
||||
attempt=attempt,
|
||||
max_attempts=max_attempts,
|
||||
error=str(e),
|
||||
)
|
||||
if attempt < max_attempts:
|
||||
time.sleep(delay_seconds)
|
||||
else:
|
||||
|
|
@ -195,7 +210,9 @@ class AppClients:
|
|||
logger.warning("Failed to initialize Langflow client", error=str(e))
|
||||
self.langflow_client = None
|
||||
if self.langflow_client is None:
|
||||
logger.warning("No Langflow client initialized yet, will attempt later on first use")
|
||||
logger.warning(
|
||||
"No Langflow client initialized yet, will attempt later on first use"
|
||||
)
|
||||
|
||||
# Initialize patched OpenAI client
|
||||
self.patched_async_client = patch_openai_with_mcp(AsyncOpenAI())
|
||||
|
|
@ -218,7 +235,9 @@ class AppClients:
|
|||
)
|
||||
logger.info("Langflow client initialized on-demand")
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Langflow client on-demand", error=str(e))
|
||||
logger.error(
|
||||
"Failed to initialize Langflow client on-demand", error=str(e)
|
||||
)
|
||||
self.langflow_client = None
|
||||
return self.langflow_client
|
||||
|
||||
|
|
|
|||
|
|
@ -321,13 +321,18 @@ class ConnectionManager:
|
|||
if connection_config.config.get(
|
||||
"webhook_channel_id"
|
||||
) or connection_config.config.get("subscription_id"):
|
||||
logger.info("Webhook subscription already exists", connection_id=connection_id)
|
||||
logger.info(
|
||||
"Webhook subscription already exists", connection_id=connection_id
|
||||
)
|
||||
return
|
||||
|
||||
# Check if webhook URL is configured
|
||||
webhook_url = connection_config.config.get("webhook_url")
|
||||
if not webhook_url:
|
||||
logger.info("No webhook URL configured, skipping subscription setup", connection_id=connection_id)
|
||||
logger.info(
|
||||
"No webhook URL configured, skipping subscription setup",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -345,10 +350,18 @@ class ConnectionManager:
|
|||
# Save updated connection config
|
||||
await self.save_connections()
|
||||
|
||||
logger.info("Successfully set up webhook subscription", connection_id=connection_id, subscription_id=subscription_id)
|
||||
logger.info(
|
||||
"Successfully set up webhook subscription",
|
||||
connection_id=connection_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup webhook subscription", connection_id=connection_id, error=str(e))
|
||||
logger.error(
|
||||
"Failed to setup webhook subscription",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
# Don't fail the entire connection setup if webhook fails
|
||||
|
||||
async def _setup_webhook_for_new_connection(
|
||||
|
|
@ -356,12 +369,18 @@ class ConnectionManager:
|
|||
):
|
||||
"""Setup webhook subscription for a newly authenticated connection"""
|
||||
try:
|
||||
logger.info("Setting up subscription for newly authenticated connection", connection_id=connection_id)
|
||||
logger.info(
|
||||
"Setting up subscription for newly authenticated connection",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
|
||||
# Create and authenticate connector
|
||||
connector = self._create_connector(connection_config)
|
||||
if not await connector.authenticate():
|
||||
logger.error("Failed to authenticate connector for webhook setup", connection_id=connection_id)
|
||||
logger.error(
|
||||
"Failed to authenticate connector for webhook setup",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Setup subscription
|
||||
|
|
@ -376,8 +395,16 @@ class ConnectionManager:
|
|||
# Save updated connection config
|
||||
await self.save_connections()
|
||||
|
||||
logger.info("Successfully set up webhook subscription", connection_id=connection_id, subscription_id=subscription_id)
|
||||
logger.info(
|
||||
"Successfully set up webhook subscription",
|
||||
connection_id=connection_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup webhook subscription for new connection", connection_id=connection_id, error=str(e))
|
||||
logger.error(
|
||||
"Failed to setup webhook subscription for new connection",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
# Don't fail the connection setup if webhook fails
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@ from typing import Dict, Any, List, Optional
|
|||
|
||||
from .base import BaseConnector, ConnectorDocument
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
from .google_drive import GoogleDriveConnector
|
||||
from .sharepoint import SharePointConnector
|
||||
from .onedrive import OneDriveConnector
|
||||
from .connection_manager import ConnectionManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -62,7 +67,7 @@ class ConnectorService:
|
|||
|
||||
doc_service = DocumentService(session_manager=self.session_manager)
|
||||
|
||||
print(f"[DEBUG] Processing connector document with ID: {document.id}")
|
||||
logger.debug("Processing connector document", document_id=document.id)
|
||||
|
||||
# Process using the existing pipeline but with connector document metadata
|
||||
result = await doc_service.process_file_common(
|
||||
|
|
@ -77,7 +82,7 @@ class ConnectorService:
|
|||
connector_type=connector_type,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] Document processing result: {result}")
|
||||
logger.debug("Document processing result", result=result)
|
||||
|
||||
# If successfully indexed or already exists, update the indexed documents with connector metadata
|
||||
if result["status"] in ["indexed", "unchanged"]:
|
||||
|
|
@ -104,7 +109,7 @@ class ConnectorService:
|
|||
jwt_token: str = None,
|
||||
):
|
||||
"""Update indexed chunks with connector-specific metadata"""
|
||||
print(f"[DEBUG] Looking for chunks with document_id: {document.id}")
|
||||
logger.debug("Looking for chunks", document_id=document.id)
|
||||
|
||||
# Find all chunks for this document
|
||||
query = {"query": {"term": {"document_id": document.id}}}
|
||||
|
|
@ -117,26 +122,34 @@ class ConnectorService:
|
|||
try:
|
||||
response = await opensearch_client.search(index=self.index_name, body=query)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[ERROR] OpenSearch search failed for connector metadata update: {e}"
|
||||
logger.error(
|
||||
"OpenSearch search failed for connector metadata update",
|
||||
error=str(e),
|
||||
query=query,
|
||||
)
|
||||
print(f"[ERROR] Search query: {query}")
|
||||
raise
|
||||
|
||||
print(f"[DEBUG] Search query: {query}")
|
||||
print(
|
||||
f"[DEBUG] Found {len(response['hits']['hits'])} chunks matching document_id: {document.id}"
|
||||
logger.debug(
|
||||
"Search query executed",
|
||||
query=query,
|
||||
chunks_found=len(response["hits"]["hits"]),
|
||||
document_id=document.id,
|
||||
)
|
||||
|
||||
# Update each chunk with connector metadata
|
||||
print(
|
||||
f"[DEBUG] Updating {len(response['hits']['hits'])} chunks with connector_type: {connector_type}"
|
||||
logger.debug(
|
||||
"Updating chunks with connector_type",
|
||||
chunk_count=len(response["hits"]["hits"]),
|
||||
connector_type=connector_type,
|
||||
)
|
||||
for hit in response["hits"]["hits"]:
|
||||
chunk_id = hit["_id"]
|
||||
current_connector_type = hit["_source"].get("connector_type", "unknown")
|
||||
print(
|
||||
f"[DEBUG] Chunk {chunk_id}: current connector_type = {current_connector_type}, updating to {connector_type}"
|
||||
logger.debug(
|
||||
"Updating chunk connector metadata",
|
||||
chunk_id=chunk_id,
|
||||
current_connector_type=current_connector_type,
|
||||
new_connector_type=connector_type,
|
||||
)
|
||||
|
||||
update_body = {
|
||||
|
|
@ -164,10 +177,14 @@ class ConnectorService:
|
|||
await opensearch_client.update(
|
||||
index=self.index_name, id=chunk_id, body=update_body
|
||||
)
|
||||
print(f"[DEBUG] Updated chunk {chunk_id} with connector metadata")
|
||||
logger.debug("Updated chunk with connector metadata", chunk_id=chunk_id)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] OpenSearch update failed for chunk {chunk_id}: {e}")
|
||||
print(f"[ERROR] Update body: {update_body}")
|
||||
logger.error(
|
||||
"OpenSearch update failed for chunk",
|
||||
chunk_id=chunk_id,
|
||||
error=str(e),
|
||||
update_body=update_body,
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_file_extension(self, mimetype: str) -> str:
|
||||
|
|
@ -226,11 +243,11 @@ class ConnectorService:
|
|||
|
||||
while True:
|
||||
# List files from connector with limit
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"Calling list_files", page_size=page_size, page_token=page_token
|
||||
)
|
||||
file_list = await connector.list_files(page_token, max_files=page_size)
|
||||
logger.info(
|
||||
file_list = await connector.list_files(page_token, limit=page_size)
|
||||
logger.debug(
|
||||
"Got files from connector", file_count=len(file_list.get("files", []))
|
||||
)
|
||||
files = file_list["files"]
|
||||
|
|
|
|||
117
src/main.py
117
src/main.py
|
|
@ -3,11 +3,13 @@ import sys
|
|||
# Check for TUI flag FIRST, before any heavy imports
|
||||
if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == "--tui":
|
||||
from tui.main import run_tui
|
||||
|
||||
run_tui()
|
||||
sys.exit(0)
|
||||
|
||||
# Configure structured logging early
|
||||
from utils.logging_config import configure_from_env, get_logger
|
||||
|
||||
configure_from_env()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -25,6 +27,8 @@ import torch
|
|||
|
||||
# Configuration and setup
|
||||
from config.settings import clients, INDEX_NAME, INDEX_BODY, SESSION_SECRET
|
||||
from config.settings import is_no_auth_mode
|
||||
from utils.gpu_detection import detect_gpu_devices
|
||||
|
||||
# Services
|
||||
from services.document_service import DocumentService
|
||||
|
|
@ -56,8 +60,11 @@ from api import (
|
|||
# Set multiprocessing start method to 'spawn' for CUDA compatibility
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
logger.info("CUDA available", cuda_available=torch.cuda.is_available())
|
||||
logger.info("CUDA version PyTorch was built with", cuda_version=torch.version.cuda)
|
||||
logger.info(
|
||||
"CUDA device information",
|
||||
cuda_available=torch.cuda.is_available(),
|
||||
cuda_version=torch.version.cuda,
|
||||
)
|
||||
|
||||
|
||||
async def wait_for_opensearch():
|
||||
|
|
@ -71,7 +78,12 @@ async def wait_for_opensearch():
|
|||
logger.info("OpenSearch is ready")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning("OpenSearch not ready yet", attempt=attempt + 1, max_retries=max_retries, error=str(e))
|
||||
logger.warning(
|
||||
"OpenSearch not ready yet",
|
||||
attempt=attempt + 1,
|
||||
max_retries=max_retries,
|
||||
error=str(e),
|
||||
)
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
|
|
@ -93,7 +105,9 @@ async def configure_alerting_security():
|
|||
|
||||
# Use admin client (clients.opensearch uses admin credentials)
|
||||
response = await clients.opensearch.cluster.put_settings(body=alerting_settings)
|
||||
logger.info("Alerting security settings configured successfully", response=response)
|
||||
logger.info(
|
||||
"Alerting security settings configured successfully", response=response
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to configure alerting security settings", error=str(e))
|
||||
# Don't fail startup if alerting config fails
|
||||
|
|
@ -133,9 +147,14 @@ async def init_index():
|
|||
await clients.opensearch.indices.create(
|
||||
index=knowledge_filter_index_name, body=knowledge_filter_index_body
|
||||
)
|
||||
logger.info("Created knowledge filters index", index_name=knowledge_filter_index_name)
|
||||
logger.info(
|
||||
"Created knowledge filters index", index_name=knowledge_filter_index_name
|
||||
)
|
||||
else:
|
||||
logger.info("Knowledge filters index already exists, skipping creation", index_name=knowledge_filter_index_name)
|
||||
logger.info(
|
||||
"Knowledge filters index already exists, skipping creation",
|
||||
index_name=knowledge_filter_index_name,
|
||||
)
|
||||
|
||||
# Configure alerting plugin security settings
|
||||
await configure_alerting_security()
|
||||
|
|
@ -190,9 +209,59 @@ async def init_index_when_ready():
|
|||
logger.info("OpenSearch index initialization completed successfully")
|
||||
except Exception as e:
|
||||
logger.error("OpenSearch index initialization failed", error=str(e))
|
||||
logger.warning("OIDC endpoints will still work, but document operations may fail until OpenSearch is ready")
|
||||
logger.warning(
|
||||
"OIDC endpoints will still work, but document operations may fail until OpenSearch is ready"
|
||||
)
|
||||
|
||||
|
||||
async def ingest_default_documents_when_ready(services):
|
||||
"""Scan the local documents folder and ingest files like a non-auth upload."""
|
||||
try:
|
||||
logger.info("Ingesting default documents when ready")
|
||||
base_dir = os.path.abspath(os.path.join(os.getcwd(), "documents"))
|
||||
if not os.path.isdir(base_dir):
|
||||
logger.info("Default documents directory not found; skipping ingestion", base_dir=base_dir)
|
||||
return
|
||||
|
||||
# Collect files recursively
|
||||
file_paths = [
|
||||
os.path.join(root, fn)
|
||||
for root, _, files in os.walk(base_dir)
|
||||
for fn in files
|
||||
]
|
||||
|
||||
if not file_paths:
|
||||
logger.info("No default documents found; nothing to ingest", base_dir=base_dir)
|
||||
return
|
||||
|
||||
# Build a processor that DOES NOT set 'owner' on documents (owner_user_id=None)
|
||||
from models.processors import DocumentFileProcessor
|
||||
|
||||
processor = DocumentFileProcessor(
|
||||
services["document_service"],
|
||||
owner_user_id=None,
|
||||
jwt_token=None,
|
||||
owner_name=None,
|
||||
owner_email=None,
|
||||
)
|
||||
|
||||
task_id = await services["task_service"].create_custom_task(
|
||||
"anonymous", file_paths, processor
|
||||
)
|
||||
logger.info(
|
||||
"Started default documents ingestion task",
|
||||
task_id=task_id,
|
||||
file_count=len(file_paths),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Default documents ingestion failed", error=str(e))
|
||||
|
||||
async def startup_tasks(services):
|
||||
"""Startup tasks"""
|
||||
logger.info("Starting startup tasks")
|
||||
await init_index()
|
||||
await ingest_default_documents_when_ready(services)
|
||||
|
||||
async def initialize_services():
|
||||
"""Initialize all services and their dependencies"""
|
||||
# Generate JWT keys if they don't exist
|
||||
|
|
@ -237,9 +306,14 @@ async def initialize_services():
|
|||
try:
|
||||
await connector_service.initialize()
|
||||
loaded_count = len(connector_service.connection_manager.connections)
|
||||
logger.info("Loaded persisted connector connections on startup", loaded_count=loaded_count)
|
||||
logger.info(
|
||||
"Loaded persisted connector connections on startup",
|
||||
loaded_count=loaded_count,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load persisted connections on startup", error=str(e))
|
||||
logger.warning(
|
||||
"Failed to load persisted connections on startup", error=str(e)
|
||||
)
|
||||
else:
|
||||
logger.info("[CONNECTORS] Skipping connection loading in no-auth mode")
|
||||
|
||||
|
|
@ -639,12 +713,15 @@ async def create_app():
|
|||
|
||||
app = Starlette(debug=True, routes=routes)
|
||||
app.state.services = services # Store services for cleanup
|
||||
app.state.background_tasks = set()
|
||||
|
||||
# Add startup event handler
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# Start index initialization in background to avoid blocking OIDC endpoints
|
||||
asyncio.create_task(init_index_when_ready())
|
||||
t1 = asyncio.create_task(startup_tasks(services))
|
||||
app.state.background_tasks.add(t1)
|
||||
t1.add_done_callback(app.state.background_tasks.discard)
|
||||
|
||||
# Add shutdown event handler
|
||||
@app.on_event("shutdown")
|
||||
|
|
@ -687,18 +764,30 @@ async def cleanup_subscriptions_proper(services):
|
|||
|
||||
for connection in active_connections:
|
||||
try:
|
||||
logger.info("Cancelling subscription for connection", connection_id=connection.connection_id)
|
||||
logger.info(
|
||||
"Cancelling subscription for connection",
|
||||
connection_id=connection.connection_id,
|
||||
)
|
||||
connector = await connector_service.get_connector(
|
||||
connection.connection_id
|
||||
)
|
||||
if connector:
|
||||
subscription_id = connection.config.get("webhook_channel_id")
|
||||
await connector.cleanup_subscription(subscription_id)
|
||||
logger.info("Cancelled subscription", subscription_id=subscription_id)
|
||||
logger.info(
|
||||
"Cancelled subscription", subscription_id=subscription_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel subscription", connection_id=connection.connection_id, error=str(e))
|
||||
logger.error(
|
||||
"Failed to cancel subscription",
|
||||
connection_id=connection.connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
logger.info("Finished cancelling subscriptions", subscription_count=len(active_connections))
|
||||
logger.info(
|
||||
"Finished cancelling subscriptions",
|
||||
subscription_count=len(active_connections),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup subscriptions", error=str(e))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from .tasks import UploadTask, FileTask
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TaskProcessor(ABC):
|
||||
|
|
@ -211,7 +214,7 @@ class S3FileProcessor(TaskProcessor):
|
|||
"connector_type": "s3", # S3 uploads
|
||||
"indexed_time": datetime.datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# Only set owner fields if owner_user_id is provided (for no-auth mode support)
|
||||
if self.owner_user_id is not None:
|
||||
chunk_doc["owner"] = self.owner_user_id
|
||||
|
|
@ -225,10 +228,12 @@ class S3FileProcessor(TaskProcessor):
|
|||
index=INDEX_NAME, id=chunk_id, body=chunk_doc
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[ERROR] OpenSearch indexing failed for S3 chunk {chunk_id}: {e}"
|
||||
logger.error(
|
||||
"OpenSearch indexing failed for S3 chunk",
|
||||
chunk_id=chunk_id,
|
||||
error=str(e),
|
||||
chunk_doc=chunk_doc,
|
||||
)
|
||||
print(f"[ERROR] Chunk document: {chunk_doc}")
|
||||
raise
|
||||
|
||||
result = {"status": "indexed", "id": slim_doc["id"]}
|
||||
|
|
|
|||
|
|
@ -111,7 +111,10 @@ class ChatService:
|
|||
|
||||
# Pass the complete filter expression as a single header to Langflow (only if we have something to send)
|
||||
if filter_expression:
|
||||
logger.info("Sending OpenRAG query filter to Langflow", filter_expression=filter_expression)
|
||||
logger.info(
|
||||
"Sending OpenRAG query filter to Langflow",
|
||||
filter_expression=filter_expression,
|
||||
)
|
||||
extra_headers["X-LANGFLOW-GLOBAL-VAR-OPENRAG-QUERY-FILTER"] = json.dumps(
|
||||
filter_expression
|
||||
)
|
||||
|
|
@ -201,7 +204,11 @@ class ChatService:
|
|||
return {"error": "User ID is required", "conversations": []}
|
||||
|
||||
conversations_dict = get_user_conversations(user_id)
|
||||
logger.debug("Getting chat history for user", user_id=user_id, conversation_count=len(conversations_dict))
|
||||
logger.debug(
|
||||
"Getting chat history for user",
|
||||
user_id=user_id,
|
||||
conversation_count=len(conversations_dict),
|
||||
)
|
||||
|
||||
# Convert conversations dict to list format with metadata
|
||||
conversations = []
|
||||
|
|
|
|||
|
|
@ -196,7 +196,11 @@ class DocumentService:
|
|||
index=INDEX_NAME, id=chunk_id, body=chunk_doc
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("OpenSearch indexing failed for chunk", chunk_id=chunk_id, error=str(e))
|
||||
logger.error(
|
||||
"OpenSearch indexing failed for chunk",
|
||||
chunk_id=chunk_id,
|
||||
error=str(e),
|
||||
)
|
||||
logger.error("Chunk document details", chunk_doc=chunk_doc)
|
||||
raise
|
||||
return {"status": "indexed", "id": file_hash}
|
||||
|
|
@ -232,7 +236,9 @@ class DocumentService:
|
|||
try:
|
||||
exists = await opensearch_client.exists(index=INDEX_NAME, id=file_hash)
|
||||
except Exception as e:
|
||||
logger.error("OpenSearch exists check failed", file_hash=file_hash, error=str(e))
|
||||
logger.error(
|
||||
"OpenSearch exists check failed", file_hash=file_hash, error=str(e)
|
||||
)
|
||||
raise
|
||||
if exists:
|
||||
return {"status": "unchanged", "id": file_hash}
|
||||
|
|
@ -372,7 +378,11 @@ class DocumentService:
|
|||
index=INDEX_NAME, id=chunk_id, body=chunk_doc
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("OpenSearch indexing failed for batch chunk", chunk_id=chunk_id, error=str(e))
|
||||
logger.error(
|
||||
"OpenSearch indexing failed for batch chunk",
|
||||
chunk_id=chunk_id,
|
||||
error=str(e),
|
||||
)
|
||||
logger.error("Chunk document details", chunk_doc=chunk_doc)
|
||||
raise
|
||||
|
||||
|
|
@ -388,9 +398,13 @@ class DocumentService:
|
|||
from concurrent.futures import BrokenExecutor
|
||||
|
||||
if isinstance(e, BrokenExecutor):
|
||||
logger.error("Process pool broken while processing file", file_path=file_path)
|
||||
logger.error(
|
||||
"Process pool broken while processing file", file_path=file_path
|
||||
)
|
||||
logger.info("Worker process likely crashed")
|
||||
logger.info("You should see detailed crash logs above from the worker process")
|
||||
logger.info(
|
||||
"You should see detailed crash logs above from the worker process"
|
||||
)
|
||||
|
||||
# Mark pool as broken for potential recreation
|
||||
self._process_pool_broken = True
|
||||
|
|
@ -399,11 +413,15 @@ class DocumentService:
|
|||
if self._recreate_process_pool():
|
||||
logger.info("Process pool successfully recreated")
|
||||
else:
|
||||
logger.warning("Failed to recreate process pool - future operations may fail")
|
||||
logger.warning(
|
||||
"Failed to recreate process pool - future operations may fail"
|
||||
)
|
||||
|
||||
file_task.error = f"Worker process crashed: {str(e)}"
|
||||
else:
|
||||
logger.error("Failed to process file", file_path=file_path, error=str(e))
|
||||
logger.error(
|
||||
"Failed to process file", file_path=file_path, error=str(e)
|
||||
)
|
||||
file_task.error = str(e)
|
||||
|
||||
logger.error("Full traceback available")
|
||||
|
|
|
|||
|
|
@ -195,7 +195,9 @@ class MonitorService:
|
|||
return monitors
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error listing monitors for user", user_id=user_id, error=str(e))
|
||||
logger.error(
|
||||
"Error listing monitors for user", user_id=user_id, error=str(e)
|
||||
)
|
||||
return []
|
||||
|
||||
async def list_monitors_for_filter(
|
||||
|
|
@ -236,7 +238,9 @@ class MonitorService:
|
|||
return monitors
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error listing monitors for filter", filter_id=filter_id, error=str(e))
|
||||
logger.error(
|
||||
"Error listing monitors for filter", filter_id=filter_id, error=str(e)
|
||||
)
|
||||
return []
|
||||
|
||||
async def _get_or_create_webhook_destination(
|
||||
|
|
|
|||
|
|
@ -138,7 +138,11 @@ class SearchService:
|
|||
search_body["min_score"] = score_threshold
|
||||
|
||||
# Authentication required - DLS will handle document filtering automatically
|
||||
logger.debug("search_service authentication info", user_id=user_id, has_jwt_token=jwt_token is not None)
|
||||
logger.debug(
|
||||
"search_service authentication info",
|
||||
user_id=user_id,
|
||||
has_jwt_token=jwt_token is not None,
|
||||
)
|
||||
if not user_id:
|
||||
logger.debug("search_service: user_id is None/empty, returning auth error")
|
||||
return {"results": [], "error": "Authentication required"}
|
||||
|
|
@ -151,7 +155,9 @@ class SearchService:
|
|||
try:
|
||||
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
|
||||
except Exception as e:
|
||||
logger.error("OpenSearch query failed", error=str(e), search_body=search_body)
|
||||
logger.error(
|
||||
"OpenSearch query failed", error=str(e), search_body=search_body
|
||||
)
|
||||
# Re-raise the exception so the API returns the error to frontend
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,14 @@ import asyncio
|
|||
import uuid
|
||||
import time
|
||||
import random
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from models.tasks import TaskStatus, UploadTask, FileTask
|
||||
|
||||
from session_manager import AnonymousUser
|
||||
from src.utils.gpu_detection import get_worker_count
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TaskService:
|
||||
|
|
@ -104,7 +107,9 @@ class TaskService:
|
|||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Background upload processor failed for task {task_id}: {e}")
|
||||
logger.error(
|
||||
"Background upload processor failed", task_id=task_id, error=str(e)
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
|
@ -136,7 +141,9 @@ class TaskService:
|
|||
try:
|
||||
await processor.process_item(upload_task, item, file_task)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to process item {item}: {e}")
|
||||
logger.error(
|
||||
"Failed to process item", item=str(item), error=str(e)
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
|
@ -157,13 +164,15 @@ class TaskService:
|
|||
upload_task.updated_at = time.time()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print(f"[INFO] Background processor for task {task_id} was cancelled")
|
||||
logger.info("Background processor cancelled", task_id=task_id)
|
||||
if user_id in self.task_store and task_id in self.task_store[user_id]:
|
||||
# Task status and pending files already handled by cancel_task()
|
||||
pass
|
||||
raise # Re-raise to properly handle cancellation
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Background custom processor failed for task {task_id}: {e}")
|
||||
logger.error(
|
||||
"Background custom processor failed", task_id=task_id, error=str(e)
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
|
@ -171,16 +180,29 @@ class TaskService:
|
|||
self.task_store[user_id][task_id].status = TaskStatus.FAILED
|
||||
self.task_store[user_id][task_id].updated_at = time.time()
|
||||
|
||||
def get_task_status(self, user_id: str, task_id: str) -> dict:
|
||||
"""Get the status of a specific upload task"""
|
||||
if (
|
||||
not task_id
|
||||
or user_id not in self.task_store
|
||||
or task_id not in self.task_store[user_id]
|
||||
):
|
||||
def get_task_status(self, user_id: str, task_id: str) -> Optional[dict]:
|
||||
"""Get the status of a specific upload task
|
||||
|
||||
Includes fallback to shared tasks stored under the "anonymous" user key
|
||||
so default system tasks are visible to all users.
|
||||
"""
|
||||
if not task_id:
|
||||
return None
|
||||
|
||||
upload_task = self.task_store[user_id][task_id]
|
||||
# Prefer the caller's user_id; otherwise check shared/anonymous tasks
|
||||
candidate_user_ids = [user_id, AnonymousUser().user_id]
|
||||
|
||||
upload_task = None
|
||||
for candidate_user_id in candidate_user_ids:
|
||||
if (
|
||||
candidate_user_id in self.task_store
|
||||
and task_id in self.task_store[candidate_user_id]
|
||||
):
|
||||
upload_task = self.task_store[candidate_user_id][task_id]
|
||||
break
|
||||
|
||||
if upload_task is None:
|
||||
return None
|
||||
|
||||
file_statuses = {}
|
||||
for file_path, file_task in upload_task.file_tasks.items():
|
||||
|
|
@ -206,14 +228,21 @@ class TaskService:
|
|||
}
|
||||
|
||||
def get_all_tasks(self, user_id: str) -> list:
|
||||
"""Get all tasks for a user"""
|
||||
if user_id not in self.task_store:
|
||||
return []
|
||||
"""Get all tasks for a user
|
||||
|
||||
tasks = []
|
||||
for task_id, upload_task in self.task_store[user_id].items():
|
||||
tasks.append(
|
||||
{
|
||||
Returns the union of the user's own tasks and shared default tasks stored
|
||||
under the "anonymous" user key. User-owned tasks take precedence
|
||||
if a task_id overlaps.
|
||||
"""
|
||||
tasks_by_id = {}
|
||||
|
||||
def add_tasks_from_store(store_user_id):
|
||||
if store_user_id not in self.task_store:
|
||||
return
|
||||
for task_id, upload_task in self.task_store[store_user_id].items():
|
||||
if task_id in tasks_by_id:
|
||||
continue
|
||||
tasks_by_id[task_id] = {
|
||||
"task_id": upload_task.task_id,
|
||||
"status": upload_task.status.value,
|
||||
"total_files": upload_task.total_files,
|
||||
|
|
@ -223,18 +252,36 @@ class TaskService:
|
|||
"created_at": upload_task.created_at,
|
||||
"updated_at": upload_task.updated_at,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by creation time, most recent first
|
||||
# First, add user-owned tasks; then shared anonymous;
|
||||
add_tasks_from_store(user_id)
|
||||
add_tasks_from_store(AnonymousUser().user_id)
|
||||
|
||||
tasks = list(tasks_by_id.values())
|
||||
tasks.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
return tasks
|
||||
|
||||
def cancel_task(self, user_id: str, task_id: str) -> bool:
|
||||
"""Cancel a task if it exists and is not already completed"""
|
||||
if user_id not in self.task_store or task_id not in self.task_store[user_id]:
|
||||
"""Cancel a task if it exists and is not already completed.
|
||||
|
||||
Supports cancellation of shared default tasks stored under the anonymous user.
|
||||
"""
|
||||
# Check candidate user IDs first, then anonymous to find which user ID the task is mapped to
|
||||
candidate_user_ids = [user_id, AnonymousUser().user_id]
|
||||
|
||||
store_user_id = None
|
||||
for candidate_user_id in candidate_user_ids:
|
||||
if (
|
||||
candidate_user_id in self.task_store
|
||||
and task_id in self.task_store[candidate_user_id]
|
||||
):
|
||||
store_user_id = candidate_user_id
|
||||
break
|
||||
|
||||
if store_user_id is None:
|
||||
return False
|
||||
|
||||
upload_task = self.task_store[user_id][task_id]
|
||||
upload_task = self.task_store[store_user_id][task_id]
|
||||
|
||||
# Can only cancel pending or running tasks
|
||||
if upload_task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
||||
|
|
|
|||
|
|
@ -6,8 +6,13 @@ from typing import Dict, Optional, Any
|
|||
from dataclasses import dataclass, asdict
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
import os
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class User:
|
||||
"""User information from OAuth provider"""
|
||||
|
|
@ -26,6 +31,19 @@ class User:
|
|||
if self.last_login is None:
|
||||
self.last_login = datetime.now()
|
||||
|
||||
class AnonymousUser(User):
|
||||
"""Anonymous user"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
user_id="anonymous",
|
||||
email="anonymous@localhost",
|
||||
name="Anonymous User",
|
||||
picture=None,
|
||||
provider="none",
|
||||
)
|
||||
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages user sessions and JWT tokens"""
|
||||
|
|
@ -80,13 +98,15 @@ class SessionManager:
|
|||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
print(
|
||||
f"Failed to get user info: {response.status_code} {response.text}"
|
||||
logger.error(
|
||||
"Failed to get user info",
|
||||
status_code=response.status_code,
|
||||
response_text=response.text,
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting user info: {e}")
|
||||
logger.error("Error getting user info", error=str(e))
|
||||
return None
|
||||
|
||||
async def create_user_session(
|
||||
|
|
@ -173,19 +193,24 @@ class SessionManager:
|
|||
"""Get or create OpenSearch client for user with their JWT"""
|
||||
from config.settings import is_no_auth_mode
|
||||
|
||||
print(
|
||||
f"[DEBUG] get_user_opensearch_client: user_id={user_id}, jwt_token={'None' if jwt_token is None else 'present'}, no_auth_mode={is_no_auth_mode()}"
|
||||
logger.debug(
|
||||
"get_user_opensearch_client",
|
||||
user_id=user_id,
|
||||
jwt_token_present=(jwt_token is not None),
|
||||
no_auth_mode=is_no_auth_mode(),
|
||||
)
|
||||
|
||||
# In no-auth mode, create anonymous JWT for OpenSearch DLS
|
||||
if is_no_auth_mode() and jwt_token is None:
|
||||
if jwt_token is None and (is_no_auth_mode() or user_id in (None, AnonymousUser().user_id)):
|
||||
if not hasattr(self, "_anonymous_jwt"):
|
||||
# Create anonymous JWT token for OpenSearch OIDC
|
||||
print(f"[DEBUG] Creating anonymous JWT...")
|
||||
logger.debug("Creating anonymous JWT")
|
||||
self._anonymous_jwt = self._create_anonymous_jwt()
|
||||
print(f"[DEBUG] Anonymous JWT created: {self._anonymous_jwt[:50]}...")
|
||||
logger.debug(
|
||||
"Anonymous JWT created", jwt_prefix=self._anonymous_jwt[:50]
|
||||
)
|
||||
jwt_token = self._anonymous_jwt
|
||||
print(f"[DEBUG] Using anonymous JWT for OpenSearch")
|
||||
logger.debug("Using anonymous JWT for OpenSearch")
|
||||
|
||||
# Check if we have a cached client for this user
|
||||
if user_id not in self.user_opensearch_clients:
|
||||
|
|
@ -199,14 +224,5 @@ class SessionManager:
|
|||
|
||||
def _create_anonymous_jwt(self) -> str:
|
||||
"""Create JWT token for anonymous user in no-auth mode"""
|
||||
anonymous_user = User(
|
||||
user_id="anonymous",
|
||||
email="anonymous@localhost",
|
||||
name="Anonymous User",
|
||||
picture=None,
|
||||
provider="none",
|
||||
created_at=datetime.now(),
|
||||
last_login=datetime.now(),
|
||||
)
|
||||
|
||||
anonymous_user = AnonymousUser()
|
||||
return self.create_jwt_token(anonymous_user)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
"""OpenRAG Terminal User Interface package."""
|
||||
"""OpenRAG Terminal User Interface package."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
from textual.app import App, ComposeResult
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from .screens.welcome import WelcomeScreen
|
||||
from .screens.config import ConfigScreen
|
||||
|
|
@ -17,10 +20,10 @@ from .widgets.diagnostics_notification import notify_with_diagnostics
|
|||
|
||||
class OpenRAGTUI(App):
|
||||
"""OpenRAG Terminal User Interface application."""
|
||||
|
||||
|
||||
TITLE = "OpenRAG TUI"
|
||||
SUB_TITLE = "Container Management & Configuration"
|
||||
|
||||
|
||||
CSS = """
|
||||
Screen {
|
||||
background: $background;
|
||||
|
|
@ -172,13 +175,13 @@ class OpenRAGTUI(App):
|
|||
padding: 1;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.platform_detector = PlatformDetector()
|
||||
self.container_manager = ContainerManager()
|
||||
self.env_manager = EnvManager()
|
||||
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Initialize the application."""
|
||||
# Check for runtime availability and show appropriate screen
|
||||
|
|
@ -187,31 +190,33 @@ class OpenRAGTUI(App):
|
|||
self,
|
||||
"No container runtime found. Please install Docker or Podman.",
|
||||
severity="warning",
|
||||
timeout=10
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
|
||||
# Load existing config if available
|
||||
config_exists = self.env_manager.load_existing_env()
|
||||
|
||||
|
||||
# Start with welcome screen
|
||||
self.push_screen(WelcomeScreen())
|
||||
|
||||
|
||||
async def action_quit(self) -> None:
|
||||
"""Quit the application."""
|
||||
self.exit()
|
||||
|
||||
|
||||
def check_runtime_requirements(self) -> tuple[bool, str]:
|
||||
"""Check if runtime requirements are met."""
|
||||
if not self.container_manager.is_available():
|
||||
return False, self.platform_detector.get_installation_instructions()
|
||||
|
||||
|
||||
# Check Podman macOS memory if applicable
|
||||
runtime_info = self.container_manager.get_runtime_info()
|
||||
if runtime_info.runtime_type.value == "podman":
|
||||
is_sufficient, _, message = self.platform_detector.check_podman_macos_memory()
|
||||
is_sufficient, _, message = (
|
||||
self.platform_detector.check_podman_macos_memory()
|
||||
)
|
||||
if not is_sufficient:
|
||||
return False, f"Podman VM memory insufficient:\n{message}"
|
||||
|
||||
|
||||
return True, "Runtime requirements satisfied"
|
||||
|
||||
|
||||
|
|
@ -221,10 +226,10 @@ def run_tui():
|
|||
app = OpenRAGTUI()
|
||||
app.run()
|
||||
except KeyboardInterrupt:
|
||||
print("\nOpenRAG TUI interrupted by user")
|
||||
logger.info("OpenRAG TUI interrupted by user")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
print(f"Error running OpenRAG TUI: {e}")
|
||||
logger.error("Error running OpenRAG TUI", error=str(e))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
"""TUI managers package."""
|
||||
"""TUI managers package."""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@ from dataclasses import dataclass, field
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, AsyncIterator
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from ..utils.platform import PlatformDetector, RuntimeInfo, RuntimeType
|
||||
from utils.gpu_detection import detect_gpu_devices
|
||||
|
|
@ -15,6 +18,7 @@ from utils.gpu_detection import detect_gpu_devices
|
|||
|
||||
class ServiceStatus(Enum):
|
||||
"""Container service status."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
RUNNING = "running"
|
||||
STOPPED = "stopped"
|
||||
|
|
@ -27,6 +31,7 @@ class ServiceStatus(Enum):
|
|||
@dataclass
|
||||
class ServiceInfo:
|
||||
"""Container service information."""
|
||||
|
||||
name: str
|
||||
status: ServiceStatus
|
||||
health: Optional[str] = None
|
||||
|
|
@ -34,7 +39,7 @@ class ServiceInfo:
|
|||
image: Optional[str] = None
|
||||
image_digest: Optional[str] = None
|
||||
created: Optional[str] = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ports is None:
|
||||
self.ports = []
|
||||
|
|
@ -42,7 +47,7 @@ class ServiceInfo:
|
|||
|
||||
class ContainerManager:
|
||||
"""Manages Docker/Podman container lifecycle for OpenRAG."""
|
||||
|
||||
|
||||
def __init__(self, compose_file: Optional[Path] = None):
|
||||
self.platform_detector = PlatformDetector()
|
||||
self.runtime_info = self.platform_detector.detect_runtime()
|
||||
|
|
@ -56,138 +61,142 @@ class ContainerManager:
|
|||
self.use_cpu_compose = not has_gpu
|
||||
except Exception:
|
||||
self.use_cpu_compose = True
|
||||
|
||||
|
||||
# Expected services based on compose files
|
||||
self.expected_services = [
|
||||
"openrag-backend",
|
||||
"openrag-frontend",
|
||||
"openrag-frontend",
|
||||
"opensearch",
|
||||
"dashboards",
|
||||
"langflow"
|
||||
"langflow",
|
||||
]
|
||||
|
||||
|
||||
# Map container names to service names
|
||||
self.container_name_map = {
|
||||
"openrag-backend": "openrag-backend",
|
||||
"openrag-frontend": "openrag-frontend",
|
||||
"os": "opensearch",
|
||||
"os": "opensearch",
|
||||
"osdash": "dashboards",
|
||||
"langflow": "langflow"
|
||||
"langflow": "langflow",
|
||||
}
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if container runtime is available."""
|
||||
return self.runtime_info.runtime_type != RuntimeType.NONE
|
||||
|
||||
|
||||
def get_runtime_info(self) -> RuntimeInfo:
|
||||
"""Get container runtime information."""
|
||||
return self.runtime_info
|
||||
|
||||
|
||||
def get_installation_help(self) -> str:
|
||||
"""Get installation instructions if runtime is not available."""
|
||||
return self.platform_detector.get_installation_instructions()
|
||||
|
||||
async def _run_compose_command(self, args: List[str], cpu_mode: Optional[bool] = None) -> tuple[bool, str, str]:
|
||||
|
||||
async def _run_compose_command(
|
||||
self, args: List[str], cpu_mode: Optional[bool] = None
|
||||
) -> tuple[bool, str, str]:
|
||||
"""Run a compose command and return (success, stdout, stderr)."""
|
||||
if not self.is_available():
|
||||
return False, "", "No container runtime available"
|
||||
|
||||
|
||||
if cpu_mode is None:
|
||||
cpu_mode = self.use_cpu_compose
|
||||
compose_file = self.cpu_compose_file if cpu_mode else self.compose_file
|
||||
cmd = self.runtime_info.compose_command + ["-f", str(compose_file)] + args
|
||||
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=Path.cwd()
|
||||
cwd=Path.cwd(),
|
||||
)
|
||||
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
stdout_text = stdout.decode() if stdout else ""
|
||||
stderr_text = stderr.decode() if stderr else ""
|
||||
|
||||
|
||||
success = process.returncode == 0
|
||||
return success, stdout_text, stderr_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return False, "", f"Command execution failed: {e}"
|
||||
|
||||
async def _run_compose_command_streaming(self, args: List[str], cpu_mode: Optional[bool] = None) -> AsyncIterator[str]:
|
||||
|
||||
async def _run_compose_command_streaming(
|
||||
self, args: List[str], cpu_mode: Optional[bool] = None
|
||||
) -> AsyncIterator[str]:
|
||||
"""Run a compose command and yield output lines in real-time."""
|
||||
if not self.is_available():
|
||||
yield "No container runtime available"
|
||||
return
|
||||
|
||||
|
||||
if cpu_mode is None:
|
||||
cpu_mode = self.use_cpu_compose
|
||||
compose_file = self.cpu_compose_file if cpu_mode else self.compose_file
|
||||
cmd = self.runtime_info.compose_command + ["-f", str(compose_file)] + args
|
||||
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT, # Combine stderr with stdout for unified output
|
||||
cwd=Path.cwd()
|
||||
cwd=Path.cwd(),
|
||||
)
|
||||
|
||||
|
||||
# Simple approach: read line by line and yield each one
|
||||
while True:
|
||||
line = await process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
|
||||
|
||||
line_text = line.decode().rstrip()
|
||||
if line_text:
|
||||
yield line_text
|
||||
|
||||
|
||||
# Wait for process to complete
|
||||
await process.wait()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
yield f"Command execution failed: {e}"
|
||||
|
||||
|
||||
async def _run_runtime_command(self, args: List[str]) -> tuple[bool, str, str]:
|
||||
"""Run a runtime command (docker/podman) and return (success, stdout, stderr)."""
|
||||
if not self.is_available():
|
||||
return False, "", "No container runtime available"
|
||||
|
||||
|
||||
cmd = self.runtime_info.runtime_command + args
|
||||
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
stdout_text = stdout.decode() if stdout else ""
|
||||
stderr_text = stderr.decode() if stderr else ""
|
||||
|
||||
|
||||
success = process.returncode == 0
|
||||
return success, stdout_text, stderr_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return False, "", f"Command execution failed: {e}"
|
||||
|
||||
def _process_service_json(self, service: Dict, services: Dict[str, ServiceInfo]) -> None:
|
||||
|
||||
def _process_service_json(
|
||||
self, service: Dict, services: Dict[str, ServiceInfo]
|
||||
) -> None:
|
||||
"""Process a service JSON object and add it to the services dict."""
|
||||
# Debug print to see the actual service data
|
||||
print(f"DEBUG: Processing service data: {json.dumps(service, indent=2)}")
|
||||
|
||||
logger.debug("Processing service data", service_data=service)
|
||||
|
||||
container_name = service.get("Name", "")
|
||||
|
||||
|
||||
# Map container name to service name
|
||||
service_name = self.container_name_map.get(container_name)
|
||||
if not service_name:
|
||||
return
|
||||
|
||||
|
||||
state = service.get("State", "").lower()
|
||||
|
||||
|
||||
# Map compose states to our status enum
|
||||
if "running" in state:
|
||||
status = ServiceStatus.RUNNING
|
||||
|
|
@ -197,17 +206,19 @@ class ContainerManager:
|
|||
status = ServiceStatus.STARTING
|
||||
else:
|
||||
status = ServiceStatus.UNKNOWN
|
||||
|
||||
|
||||
# Extract health - use Status if Health is empty
|
||||
health = service.get("Health", "") or service.get("Status", "N/A")
|
||||
|
||||
|
||||
# Extract ports
|
||||
ports_str = service.get("Ports", "")
|
||||
ports = [p.strip() for p in ports_str.split(",") if p.strip()] if ports_str else []
|
||||
|
||||
ports = (
|
||||
[p.strip() for p in ports_str.split(",") if p.strip()] if ports_str else []
|
||||
)
|
||||
|
||||
# Extract image
|
||||
image = service.get("Image", "N/A")
|
||||
|
||||
|
||||
services[service_name] = ServiceInfo(
|
||||
name=service_name,
|
||||
status=status,
|
||||
|
|
@ -215,23 +226,25 @@ class ContainerManager:
|
|||
ports=ports,
|
||||
image=image,
|
||||
)
|
||||
|
||||
async def get_service_status(self, force_refresh: bool = False) -> Dict[str, ServiceInfo]:
|
||||
|
||||
async def get_service_status(
|
||||
self, force_refresh: bool = False
|
||||
) -> Dict[str, ServiceInfo]:
|
||||
"""Get current status of all services."""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# Use cache if recent and not forcing refresh
|
||||
if not force_refresh and current_time - self.last_status_update < 5:
|
||||
return self.services_cache
|
||||
|
||||
|
||||
services = {}
|
||||
|
||||
|
||||
# Different approach for Podman vs Docker
|
||||
if self.runtime_info.runtime_type == RuntimeType.PODMAN:
|
||||
# For Podman, use direct podman ps command instead of compose
|
||||
cmd = ["ps", "--all", "--format", "json"]
|
||||
success, stdout, stderr = await self._run_runtime_command(cmd)
|
||||
|
||||
|
||||
if success and stdout.strip():
|
||||
try:
|
||||
containers = json.loads(stdout.strip())
|
||||
|
|
@ -240,12 +253,12 @@ class ContainerManager:
|
|||
names = container.get("Names", [])
|
||||
if not names:
|
||||
continue
|
||||
|
||||
|
||||
container_name = names[0]
|
||||
service_name = self.container_name_map.get(container_name)
|
||||
if not service_name:
|
||||
continue
|
||||
|
||||
|
||||
# Get container state
|
||||
state = container.get("State", "").lower()
|
||||
if "running" in state:
|
||||
|
|
@ -256,7 +269,7 @@ class ContainerManager:
|
|||
status = ServiceStatus.STARTING
|
||||
else:
|
||||
status = ServiceStatus.UNKNOWN
|
||||
|
||||
|
||||
# Get other container info
|
||||
image = container.get("Image", "N/A")
|
||||
ports = []
|
||||
|
|
@ -268,7 +281,7 @@ class ContainerManager:
|
|||
container_port = port.get("container_port")
|
||||
if host_port and container_port:
|
||||
ports.append(f"{host_port}:{container_port}")
|
||||
|
||||
|
||||
services[service_name] = ServiceInfo(
|
||||
name=service_name,
|
||||
status=status,
|
||||
|
|
@ -280,55 +293,63 @@ class ContainerManager:
|
|||
pass
|
||||
else:
|
||||
# For Docker, use compose ps command
|
||||
success, stdout, stderr = await self._run_compose_command(["ps", "--format", "json"])
|
||||
|
||||
success, stdout, stderr = await self._run_compose_command(
|
||||
["ps", "--format", "json"]
|
||||
)
|
||||
|
||||
if success and stdout.strip():
|
||||
try:
|
||||
# Handle both single JSON object (Podman) and multiple JSON objects (Docker)
|
||||
if stdout.strip().startswith('[') and stdout.strip().endswith(']'):
|
||||
if stdout.strip().startswith("[") and stdout.strip().endswith("]"):
|
||||
# JSON array format
|
||||
service_list = json.loads(stdout.strip())
|
||||
for service in service_list:
|
||||
self._process_service_json(service, services)
|
||||
else:
|
||||
# Line-by-line JSON format
|
||||
for line in stdout.strip().split('\n'):
|
||||
if line.strip() and line.startswith('{'):
|
||||
for line in stdout.strip().split("\n"):
|
||||
if line.strip() and line.startswith("{"):
|
||||
service = json.loads(line)
|
||||
self._process_service_json(service, services)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback to parsing text output
|
||||
lines = stdout.strip().split('\n')
|
||||
if len(lines) > 1: # Make sure we have at least a header and one line
|
||||
lines = stdout.strip().split("\n")
|
||||
if (
|
||||
len(lines) > 1
|
||||
): # Make sure we have at least a header and one line
|
||||
for line in lines[1:]: # Skip header
|
||||
if line.strip():
|
||||
parts = line.split()
|
||||
if len(parts) >= 3:
|
||||
name = parts[0]
|
||||
|
||||
|
||||
# Only include our expected services
|
||||
if name not in self.expected_services:
|
||||
continue
|
||||
|
||||
|
||||
state = parts[2].lower()
|
||||
|
||||
|
||||
if "up" in state:
|
||||
status = ServiceStatus.RUNNING
|
||||
elif "exit" in state:
|
||||
status = ServiceStatus.STOPPED
|
||||
else:
|
||||
status = ServiceStatus.UNKNOWN
|
||||
|
||||
services[name] = ServiceInfo(name=name, status=status)
|
||||
|
||||
|
||||
services[name] = ServiceInfo(
|
||||
name=name, status=status
|
||||
)
|
||||
|
||||
# Add expected services that weren't found
|
||||
for expected in self.expected_services:
|
||||
if expected not in services:
|
||||
services[expected] = ServiceInfo(name=expected, status=ServiceStatus.MISSING)
|
||||
|
||||
services[expected] = ServiceInfo(
|
||||
name=expected, status=ServiceStatus.MISSING
|
||||
)
|
||||
|
||||
self.services_cache = services
|
||||
self.last_status_update = current_time
|
||||
|
||||
|
||||
return services
|
||||
|
||||
async def get_images_digests(self, images: List[str]) -> Dict[str, str]:
|
||||
|
|
@ -337,9 +358,9 @@ class ContainerManager:
|
|||
for image in images:
|
||||
if not image or image in digests:
|
||||
continue
|
||||
success, stdout, _ = await self._run_runtime_command([
|
||||
"image", "inspect", image, "--format", "{{.Id}}"
|
||||
])
|
||||
success, stdout, _ = await self._run_runtime_command(
|
||||
["image", "inspect", image, "--format", "{{.Id}}"]
|
||||
)
|
||||
if success and stdout.strip():
|
||||
digests[image] = stdout.strip().splitlines()[0]
|
||||
return digests
|
||||
|
|
@ -353,13 +374,15 @@ class ContainerManager:
|
|||
continue
|
||||
for line in compose.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if line.startswith('image:'):
|
||||
if line.startswith("image:"):
|
||||
# image: repo/name:tag
|
||||
val = line.split(':', 1)[1].strip()
|
||||
val = line.split(":", 1)[1].strip()
|
||||
# Remove quotes if present
|
||||
if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")):
|
||||
if (val.startswith('"') and val.endswith('"')) or (
|
||||
val.startswith("'") and val.endswith("'")
|
||||
):
|
||||
val = val[1:-1]
|
||||
images.add(val)
|
||||
except Exception:
|
||||
|
|
@ -374,53 +397,61 @@ class ContainerManager:
|
|||
expected = self._parse_compose_images()
|
||||
results: list[tuple[str, str]] = []
|
||||
for image in expected:
|
||||
digest = '-'
|
||||
success, stdout, _ = await self._run_runtime_command([
|
||||
'image', 'inspect', image, '--format', '{{.Id}}'
|
||||
])
|
||||
digest = "-"
|
||||
success, stdout, _ = await self._run_runtime_command(
|
||||
["image", "inspect", image, "--format", "{{.Id}}"]
|
||||
)
|
||||
if success and stdout.strip():
|
||||
digest = stdout.strip().splitlines()[0]
|
||||
results.append((image, digest))
|
||||
results.sort(key=lambda x: x[0])
|
||||
return results
|
||||
|
||||
async def start_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]:
|
||||
|
||||
async def start_services(
|
||||
self, cpu_mode: bool = False
|
||||
) -> AsyncIterator[tuple[bool, str]]:
|
||||
"""Start all services and yield progress updates."""
|
||||
yield False, "Starting OpenRAG services..."
|
||||
|
||||
success, stdout, stderr = await self._run_compose_command(["up", "-d"], cpu_mode)
|
||||
|
||||
|
||||
success, stdout, stderr = await self._run_compose_command(
|
||||
["up", "-d"], cpu_mode
|
||||
)
|
||||
|
||||
if success:
|
||||
yield True, "Services started successfully"
|
||||
else:
|
||||
yield False, f"Failed to start services: {stderr}"
|
||||
|
||||
|
||||
async def stop_services(self) -> AsyncIterator[tuple[bool, str]]:
|
||||
"""Stop all services and yield progress updates."""
|
||||
yield False, "Stopping OpenRAG services..."
|
||||
|
||||
|
||||
success, stdout, stderr = await self._run_compose_command(["down"])
|
||||
|
||||
|
||||
if success:
|
||||
yield True, "Services stopped successfully"
|
||||
else:
|
||||
yield False, f"Failed to stop services: {stderr}"
|
||||
|
||||
async def restart_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]:
|
||||
|
||||
async def restart_services(
|
||||
self, cpu_mode: bool = False
|
||||
) -> AsyncIterator[tuple[bool, str]]:
|
||||
"""Restart all services and yield progress updates."""
|
||||
yield False, "Restarting OpenRAG services..."
|
||||
|
||||
|
||||
success, stdout, stderr = await self._run_compose_command(["restart"], cpu_mode)
|
||||
|
||||
|
||||
if success:
|
||||
yield True, "Services restarted successfully"
|
||||
else:
|
||||
yield False, f"Failed to restart services: {stderr}"
|
||||
|
||||
async def upgrade_services(self, cpu_mode: bool = False) -> AsyncIterator[tuple[bool, str]]:
|
||||
|
||||
async def upgrade_services(
|
||||
self, cpu_mode: bool = False
|
||||
) -> AsyncIterator[tuple[bool, str]]:
|
||||
"""Upgrade services (pull latest images and restart) and yield progress updates."""
|
||||
yield False, "Pulling latest images..."
|
||||
|
||||
|
||||
# Pull latest images with streaming output
|
||||
pull_success = True
|
||||
async for line in self._run_compose_command_streaming(["pull"], cpu_mode):
|
||||
|
|
@ -428,75 +459,89 @@ class ContainerManager:
|
|||
# Check for error patterns in the output
|
||||
if "error" in line.lower() or "failed" in line.lower():
|
||||
pull_success = False
|
||||
|
||||
|
||||
if not pull_success:
|
||||
yield False, "Failed to pull some images, but continuing with restart..."
|
||||
|
||||
|
||||
yield False, "Images updated, restarting services..."
|
||||
|
||||
|
||||
# Restart with new images using streaming output
|
||||
restart_success = True
|
||||
async for line in self._run_compose_command_streaming(["up", "-d", "--force-recreate"], cpu_mode):
|
||||
async for line in self._run_compose_command_streaming(
|
||||
["up", "-d", "--force-recreate"], cpu_mode
|
||||
):
|
||||
yield False, line
|
||||
# Check for error patterns in the output
|
||||
if "error" in line.lower() or "failed" in line.lower():
|
||||
restart_success = False
|
||||
|
||||
|
||||
if restart_success:
|
||||
yield True, "Services upgraded and restarted successfully"
|
||||
else:
|
||||
yield False, "Some errors occurred during service restart"
|
||||
|
||||
|
||||
async def reset_services(self) -> AsyncIterator[tuple[bool, str]]:
|
||||
"""Reset all services (stop, remove containers/volumes, clear data) and yield progress updates."""
|
||||
yield False, "Stopping all services..."
|
||||
|
||||
|
||||
# Stop and remove everything
|
||||
success, stdout, stderr = await self._run_compose_command([
|
||||
"down",
|
||||
"--volumes",
|
||||
"--remove-orphans",
|
||||
"--rmi", "local"
|
||||
])
|
||||
|
||||
success, stdout, stderr = await self._run_compose_command(
|
||||
["down", "--volumes", "--remove-orphans", "--rmi", "local"]
|
||||
)
|
||||
|
||||
if not success:
|
||||
yield False, f"Failed to stop services: {stderr}"
|
||||
return
|
||||
|
||||
|
||||
yield False, "Cleaning up container data..."
|
||||
|
||||
|
||||
# Additional cleanup - remove any remaining containers/volumes
|
||||
# This is more thorough than just compose down
|
||||
await self._run_runtime_command(["system", "prune", "-f"])
|
||||
|
||||
yield True, "System reset completed - all containers, volumes, and local images removed"
|
||||
|
||||
async def get_service_logs(self, service_name: str, lines: int = 100) -> tuple[bool, str]:
|
||||
|
||||
yield (
|
||||
True,
|
||||
"System reset completed - all containers, volumes, and local images removed",
|
||||
)
|
||||
|
||||
async def get_service_logs(
|
||||
self, service_name: str, lines: int = 100
|
||||
) -> tuple[bool, str]:
|
||||
"""Get logs for a specific service."""
|
||||
success, stdout, stderr = await self._run_compose_command(["logs", "--tail", str(lines), service_name])
|
||||
|
||||
success, stdout, stderr = await self._run_compose_command(
|
||||
["logs", "--tail", str(lines), service_name]
|
||||
)
|
||||
|
||||
if success:
|
||||
return True, stdout
|
||||
else:
|
||||
return False, f"Failed to get logs: {stderr}"
|
||||
|
||||
|
||||
async def follow_service_logs(self, service_name: str) -> AsyncIterator[str]:
|
||||
"""Follow logs for a specific service."""
|
||||
if not self.is_available():
|
||||
yield "No container runtime available"
|
||||
return
|
||||
|
||||
compose_file = self.cpu_compose_file if self.use_cpu_compose else self.compose_file
|
||||
cmd = self.runtime_info.compose_command + ["-f", str(compose_file), "logs", "-f", service_name]
|
||||
|
||||
|
||||
compose_file = (
|
||||
self.cpu_compose_file if self.use_cpu_compose else self.compose_file
|
||||
)
|
||||
cmd = self.runtime_info.compose_command + [
|
||||
"-f",
|
||||
str(compose_file),
|
||||
"logs",
|
||||
"-f",
|
||||
service_name,
|
||||
]
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
cwd=Path.cwd()
|
||||
cwd=Path.cwd(),
|
||||
)
|
||||
|
||||
|
||||
if process.stdout:
|
||||
while True:
|
||||
line = await process.stdout.readline()
|
||||
|
|
@ -506,20 +551,22 @@ class ContainerManager:
|
|||
break
|
||||
else:
|
||||
yield "Error: Unable to read process output"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
yield f"Error following logs: {e}"
|
||||
|
||||
|
||||
async def get_system_stats(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Get system resource usage statistics."""
|
||||
stats = {}
|
||||
|
||||
|
||||
# Get container stats
|
||||
success, stdout, stderr = await self._run_runtime_command(["stats", "--no-stream", "--format", "json"])
|
||||
|
||||
success, stdout, stderr = await self._run_runtime_command(
|
||||
["stats", "--no-stream", "--format", "json"]
|
||||
)
|
||||
|
||||
if success and stdout.strip():
|
||||
try:
|
||||
for line in stdout.strip().split('\n'):
|
||||
for line in stdout.strip().split("\n"):
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
name = data.get("Name", data.get("Container", ""))
|
||||
|
|
@ -533,14 +580,14 @@ class ContainerManager:
|
|||
}
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def debug_podman_services(self) -> str:
|
||||
"""Run a direct Podman command to check services status for debugging."""
|
||||
if self.runtime_info.runtime_type != RuntimeType.PODMAN:
|
||||
return "Not using Podman"
|
||||
|
||||
|
||||
# Try direct podman command
|
||||
cmd = ["podman", "ps", "--all", "--format", "json"]
|
||||
try:
|
||||
|
|
@ -548,18 +595,18 @@ class ContainerManager:
|
|||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=Path.cwd()
|
||||
cwd=Path.cwd(),
|
||||
)
|
||||
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
stdout_text = stdout.decode() if stdout else ""
|
||||
stderr_text = stderr.decode() if stderr else ""
|
||||
|
||||
|
||||
result = f"Command: {' '.join(cmd)}\n"
|
||||
result += f"Return code: {process.returncode}\n"
|
||||
result += f"Stdout: {stdout_text}\n"
|
||||
result += f"Stderr: {stderr_text}\n"
|
||||
|
||||
|
||||
# Try to parse the output
|
||||
if stdout_text.strip():
|
||||
try:
|
||||
|
|
@ -571,16 +618,18 @@ class ContainerManager:
|
|||
result += f" - {name}: {state}\n"
|
||||
except json.JSONDecodeError as e:
|
||||
result += f"\nFailed to parse JSON: {e}\n"
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {e}"
|
||||
|
||||
|
||||
def check_podman_macos_memory(self) -> tuple[bool, str]:
|
||||
"""Check if Podman VM has sufficient memory on macOS."""
|
||||
if self.runtime_info.runtime_type != RuntimeType.PODMAN:
|
||||
return True, "Not using Podman"
|
||||
|
||||
is_sufficient, memory_mb, message = self.platform_detector.check_podman_macos_memory()
|
||||
|
||||
is_sufficient, memory_mb, message = (
|
||||
self.platform_detector.check_podman_macos_memory()
|
||||
)
|
||||
return is_sufficient, message
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
from typing import Dict, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from ..utils.validation import (
|
||||
validate_openai_api_key,
|
||||
|
|
@ -14,13 +17,14 @@ from ..utils.validation import (
|
|||
validate_non_empty,
|
||||
validate_url,
|
||||
validate_documents_paths,
|
||||
sanitize_env_value
|
||||
sanitize_env_value,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvConfig:
|
||||
"""Environment configuration data."""
|
||||
|
||||
# Core settings
|
||||
openai_api_key: str = ""
|
||||
opensearch_password: str = ""
|
||||
|
|
@ -28,155 +32,186 @@ class EnvConfig:
|
|||
langflow_superuser: str = "admin"
|
||||
langflow_superuser_password: str = ""
|
||||
flow_id: str = "1098eea1-6649-4e1d-aed1-b77249fb8dd0"
|
||||
|
||||
|
||||
# OAuth settings
|
||||
google_oauth_client_id: str = ""
|
||||
google_oauth_client_secret: str = ""
|
||||
microsoft_graph_oauth_client_id: str = ""
|
||||
microsoft_graph_oauth_client_secret: str = ""
|
||||
|
||||
|
||||
# Optional settings
|
||||
webhook_base_url: str = ""
|
||||
aws_access_key_id: str = ""
|
||||
aws_secret_access_key: str = ""
|
||||
langflow_public_url: str = ""
|
||||
|
||||
|
||||
# Langflow auth settings
|
||||
langflow_auto_login: str = "False"
|
||||
langflow_new_user_is_active: str = "False"
|
||||
langflow_enable_superuser_cli: str = "False"
|
||||
|
||||
|
||||
# Document paths (comma-separated)
|
||||
openrag_documents_paths: str = "./documents"
|
||||
|
||||
|
||||
# Validation errors
|
||||
validation_errors: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
class EnvManager:
|
||||
"""Manages environment configuration for OpenRAG."""
|
||||
|
||||
|
||||
def __init__(self, env_file: Optional[Path] = None):
|
||||
self.env_file = env_file or Path(".env")
|
||||
self.config = EnvConfig()
|
||||
|
||||
|
||||
def generate_secure_password(self) -> str:
|
||||
"""Generate a secure password for OpenSearch."""
|
||||
# Generate a 16-character password with letters, digits, and symbols
|
||||
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(16))
|
||||
|
||||
return "".join(secrets.choice(alphabet) for _ in range(16))
|
||||
|
||||
def generate_langflow_secret_key(self) -> str:
|
||||
"""Generate a secure secret key for Langflow."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def load_existing_env(self) -> bool:
|
||||
"""Load existing .env file if it exists."""
|
||||
if not self.env_file.exists():
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
with open(self.env_file, 'r') as f:
|
||||
with open(self.env_file, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
if '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
|
||||
if "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = sanitize_env_value(value)
|
||||
|
||||
|
||||
# Map env vars to config attributes
|
||||
attr_map = {
|
||||
'OPENAI_API_KEY': 'openai_api_key',
|
||||
'OPENSEARCH_PASSWORD': 'opensearch_password',
|
||||
'LANGFLOW_SECRET_KEY': 'langflow_secret_key',
|
||||
'LANGFLOW_SUPERUSER': 'langflow_superuser',
|
||||
'LANGFLOW_SUPERUSER_PASSWORD': 'langflow_superuser_password',
|
||||
'FLOW_ID': 'flow_id',
|
||||
'GOOGLE_OAUTH_CLIENT_ID': 'google_oauth_client_id',
|
||||
'GOOGLE_OAUTH_CLIENT_SECRET': 'google_oauth_client_secret',
|
||||
'MICROSOFT_GRAPH_OAUTH_CLIENT_ID': 'microsoft_graph_oauth_client_id',
|
||||
'MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET': 'microsoft_graph_oauth_client_secret',
|
||||
'WEBHOOK_BASE_URL': 'webhook_base_url',
|
||||
'AWS_ACCESS_KEY_ID': 'aws_access_key_id',
|
||||
'AWS_SECRET_ACCESS_KEY': 'aws_secret_access_key',
|
||||
'LANGFLOW_PUBLIC_URL': 'langflow_public_url',
|
||||
'OPENRAG_DOCUMENTS_PATHS': 'openrag_documents_paths',
|
||||
'LANGFLOW_AUTO_LOGIN': 'langflow_auto_login',
|
||||
'LANGFLOW_NEW_USER_IS_ACTIVE': 'langflow_new_user_is_active',
|
||||
'LANGFLOW_ENABLE_SUPERUSER_CLI': 'langflow_enable_superuser_cli',
|
||||
"OPENAI_API_KEY": "openai_api_key",
|
||||
"OPENSEARCH_PASSWORD": "opensearch_password",
|
||||
"LANGFLOW_SECRET_KEY": "langflow_secret_key",
|
||||
"LANGFLOW_SUPERUSER": "langflow_superuser",
|
||||
"LANGFLOW_SUPERUSER_PASSWORD": "langflow_superuser_password",
|
||||
"FLOW_ID": "flow_id",
|
||||
"GOOGLE_OAUTH_CLIENT_ID": "google_oauth_client_id",
|
||||
"GOOGLE_OAUTH_CLIENT_SECRET": "google_oauth_client_secret",
|
||||
"MICROSOFT_GRAPH_OAUTH_CLIENT_ID": "microsoft_graph_oauth_client_id",
|
||||
"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET": "microsoft_graph_oauth_client_secret",
|
||||
"WEBHOOK_BASE_URL": "webhook_base_url",
|
||||
"AWS_ACCESS_KEY_ID": "aws_access_key_id",
|
||||
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
|
||||
"LANGFLOW_PUBLIC_URL": "langflow_public_url",
|
||||
"OPENRAG_DOCUMENTS_PATHS": "openrag_documents_paths",
|
||||
"LANGFLOW_AUTO_LOGIN": "langflow_auto_login",
|
||||
"LANGFLOW_NEW_USER_IS_ACTIVE": "langflow_new_user_is_active",
|
||||
"LANGFLOW_ENABLE_SUPERUSER_CLI": "langflow_enable_superuser_cli",
|
||||
}
|
||||
|
||||
|
||||
if key in attr_map:
|
||||
setattr(self.config, attr_map[key], value)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading .env file: {e}")
|
||||
logger.error("Error loading .env file", error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
def setup_secure_defaults(self) -> None:
|
||||
"""Set up secure default values for passwords and keys."""
|
||||
if not self.config.opensearch_password:
|
||||
self.config.opensearch_password = self.generate_secure_password()
|
||||
|
||||
|
||||
if not self.config.langflow_secret_key:
|
||||
self.config.langflow_secret_key = self.generate_langflow_secret_key()
|
||||
|
||||
|
||||
if not self.config.langflow_superuser_password:
|
||||
self.config.langflow_superuser_password = self.generate_secure_password()
|
||||
|
||||
|
||||
def validate_config(self, mode: str = "full") -> bool:
|
||||
"""
|
||||
Validate the current configuration.
|
||||
|
||||
|
||||
Args:
|
||||
mode: "no_auth" for minimal validation, "full" for complete validation
|
||||
"""
|
||||
self.config.validation_errors.clear()
|
||||
|
||||
|
||||
# Always validate OpenAI API key
|
||||
if not validate_openai_api_key(self.config.openai_api_key):
|
||||
self.config.validation_errors['openai_api_key'] = "Invalid OpenAI API key format (should start with sk-)"
|
||||
|
||||
self.config.validation_errors["openai_api_key"] = (
|
||||
"Invalid OpenAI API key format (should start with sk-)"
|
||||
)
|
||||
|
||||
# Validate documents paths only if provided (optional)
|
||||
if self.config.openrag_documents_paths:
|
||||
is_valid, error_msg, _ = validate_documents_paths(self.config.openrag_documents_paths)
|
||||
is_valid, error_msg, _ = validate_documents_paths(
|
||||
self.config.openrag_documents_paths
|
||||
)
|
||||
if not is_valid:
|
||||
self.config.validation_errors['openrag_documents_paths'] = error_msg
|
||||
|
||||
self.config.validation_errors["openrag_documents_paths"] = error_msg
|
||||
|
||||
# Validate required fields
|
||||
if not validate_non_empty(self.config.opensearch_password):
|
||||
self.config.validation_errors['opensearch_password'] = "OpenSearch password is required"
|
||||
|
||||
self.config.validation_errors["opensearch_password"] = (
|
||||
"OpenSearch password is required"
|
||||
)
|
||||
|
||||
# Langflow secret key is auto-generated; no user input required
|
||||
|
||||
if not validate_non_empty(self.config.langflow_superuser_password):
|
||||
self.config.validation_errors['langflow_superuser_password'] = "Langflow superuser password is required"
|
||||
|
||||
self.config.validation_errors["langflow_superuser_password"] = (
|
||||
"Langflow superuser password is required"
|
||||
)
|
||||
|
||||
if mode == "full":
|
||||
# Validate OAuth settings if provided
|
||||
if self.config.google_oauth_client_id and not validate_google_oauth_client_id(self.config.google_oauth_client_id):
|
||||
self.config.validation_errors['google_oauth_client_id'] = "Invalid Google OAuth client ID format"
|
||||
|
||||
if self.config.google_oauth_client_id and not validate_non_empty(self.config.google_oauth_client_secret):
|
||||
self.config.validation_errors['google_oauth_client_secret'] = "Google OAuth client secret required when client ID is provided"
|
||||
|
||||
if self.config.microsoft_graph_oauth_client_id and not validate_non_empty(self.config.microsoft_graph_oauth_client_secret):
|
||||
self.config.validation_errors['microsoft_graph_oauth_client_secret'] = "Microsoft Graph client secret required when client ID is provided"
|
||||
|
||||
if (
|
||||
self.config.google_oauth_client_id
|
||||
and not validate_google_oauth_client_id(
|
||||
self.config.google_oauth_client_id
|
||||
)
|
||||
):
|
||||
self.config.validation_errors["google_oauth_client_id"] = (
|
||||
"Invalid Google OAuth client ID format"
|
||||
)
|
||||
|
||||
if self.config.google_oauth_client_id and not validate_non_empty(
|
||||
self.config.google_oauth_client_secret
|
||||
):
|
||||
self.config.validation_errors["google_oauth_client_secret"] = (
|
||||
"Google OAuth client secret required when client ID is provided"
|
||||
)
|
||||
|
||||
if self.config.microsoft_graph_oauth_client_id and not validate_non_empty(
|
||||
self.config.microsoft_graph_oauth_client_secret
|
||||
):
|
||||
self.config.validation_errors["microsoft_graph_oauth_client_secret"] = (
|
||||
"Microsoft Graph client secret required when client ID is provided"
|
||||
)
|
||||
|
||||
# Validate optional URLs if provided
|
||||
if self.config.webhook_base_url and not validate_url(self.config.webhook_base_url):
|
||||
self.config.validation_errors['webhook_base_url'] = "Invalid webhook URL format"
|
||||
|
||||
if self.config.langflow_public_url and not validate_url(self.config.langflow_public_url):
|
||||
self.config.validation_errors['langflow_public_url'] = "Invalid Langflow public URL format"
|
||||
|
||||
if self.config.webhook_base_url and not validate_url(
|
||||
self.config.webhook_base_url
|
||||
):
|
||||
self.config.validation_errors["webhook_base_url"] = (
|
||||
"Invalid webhook URL format"
|
||||
)
|
||||
|
||||
if self.config.langflow_public_url and not validate_url(
|
||||
self.config.langflow_public_url
|
||||
):
|
||||
self.config.validation_errors["langflow_public_url"] = (
|
||||
"Invalid Langflow public URL format"
|
||||
)
|
||||
|
||||
return len(self.config.validation_errors) == 0
|
||||
|
||||
|
||||
def save_env_file(self) -> bool:
|
||||
"""Save current configuration to .env file."""
|
||||
try:
|
||||
|
|
@ -184,45 +219,67 @@ class EnvManager:
|
|||
self.setup_secure_defaults()
|
||||
# Create timestamped backup if file exists
|
||||
if self.env_file.exists():
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
backup_file = self.env_file.with_suffix(f'.env.backup.{timestamp}')
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_file = self.env_file.with_suffix(f".env.backup.{timestamp}")
|
||||
self.env_file.rename(backup_file)
|
||||
|
||||
with open(self.env_file, 'w') as f:
|
||||
|
||||
with open(self.env_file, "w") as f:
|
||||
f.write("# OpenRAG Environment Configuration\n")
|
||||
f.write("# Generated by OpenRAG TUI\n\n")
|
||||
|
||||
|
||||
# Core settings
|
||||
f.write("# Core settings\n")
|
||||
f.write(f"LANGFLOW_SECRET_KEY={self.config.langflow_secret_key}\n")
|
||||
f.write(f"LANGFLOW_SUPERUSER={self.config.langflow_superuser}\n")
|
||||
f.write(f"LANGFLOW_SUPERUSER_PASSWORD={self.config.langflow_superuser_password}\n")
|
||||
f.write(
|
||||
f"LANGFLOW_SUPERUSER_PASSWORD={self.config.langflow_superuser_password}\n"
|
||||
)
|
||||
f.write(f"FLOW_ID={self.config.flow_id}\n")
|
||||
f.write(f"OPENSEARCH_PASSWORD={self.config.opensearch_password}\n")
|
||||
f.write(f"OPENAI_API_KEY={self.config.openai_api_key}\n")
|
||||
f.write(f"OPENRAG_DOCUMENTS_PATHS={self.config.openrag_documents_paths}\n")
|
||||
f.write(
|
||||
f"OPENRAG_DOCUMENTS_PATHS={self.config.openrag_documents_paths}\n"
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
# Langflow auth settings
|
||||
f.write("# Langflow auth settings\n")
|
||||
f.write(f"LANGFLOW_AUTO_LOGIN={self.config.langflow_auto_login}\n")
|
||||
f.write(f"LANGFLOW_NEW_USER_IS_ACTIVE={self.config.langflow_new_user_is_active}\n")
|
||||
f.write(f"LANGFLOW_ENABLE_SUPERUSER_CLI={self.config.langflow_enable_superuser_cli}\n")
|
||||
f.write(
|
||||
f"LANGFLOW_NEW_USER_IS_ACTIVE={self.config.langflow_new_user_is_active}\n"
|
||||
)
|
||||
f.write(
|
||||
f"LANGFLOW_ENABLE_SUPERUSER_CLI={self.config.langflow_enable_superuser_cli}\n"
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
# OAuth settings
|
||||
if self.config.google_oauth_client_id or self.config.google_oauth_client_secret:
|
||||
if (
|
||||
self.config.google_oauth_client_id
|
||||
or self.config.google_oauth_client_secret
|
||||
):
|
||||
f.write("# Google OAuth settings\n")
|
||||
f.write(f"GOOGLE_OAUTH_CLIENT_ID={self.config.google_oauth_client_id}\n")
|
||||
f.write(f"GOOGLE_OAUTH_CLIENT_SECRET={self.config.google_oauth_client_secret}\n")
|
||||
f.write(
|
||||
f"GOOGLE_OAUTH_CLIENT_ID={self.config.google_oauth_client_id}\n"
|
||||
)
|
||||
f.write(
|
||||
f"GOOGLE_OAUTH_CLIENT_SECRET={self.config.google_oauth_client_secret}\n"
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
if self.config.microsoft_graph_oauth_client_id or self.config.microsoft_graph_oauth_client_secret:
|
||||
|
||||
if (
|
||||
self.config.microsoft_graph_oauth_client_id
|
||||
or self.config.microsoft_graph_oauth_client_secret
|
||||
):
|
||||
f.write("# Microsoft Graph OAuth settings\n")
|
||||
f.write(f"MICROSOFT_GRAPH_OAUTH_CLIENT_ID={self.config.microsoft_graph_oauth_client_id}\n")
|
||||
f.write(f"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET={self.config.microsoft_graph_oauth_client_secret}\n")
|
||||
f.write(
|
||||
f"MICROSOFT_GRAPH_OAUTH_CLIENT_ID={self.config.microsoft_graph_oauth_client_id}\n"
|
||||
)
|
||||
f.write(
|
||||
f"MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET={self.config.microsoft_graph_oauth_client_secret}\n"
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
# Optional settings
|
||||
optional_vars = [
|
||||
("WEBHOOK_BASE_URL", self.config.webhook_base_url),
|
||||
|
|
@ -230,7 +287,7 @@ class EnvManager:
|
|||
("AWS_SECRET_ACCESS_KEY", self.config.aws_secret_access_key),
|
||||
("LANGFLOW_PUBLIC_URL", self.config.langflow_public_url),
|
||||
]
|
||||
|
||||
|
||||
optional_written = False
|
||||
for var_name, var_value in optional_vars:
|
||||
if var_value:
|
||||
|
|
@ -238,52 +295,89 @@ class EnvManager:
|
|||
f.write("# Optional settings\n")
|
||||
optional_written = True
|
||||
f.write(f"{var_name}={var_value}\n")
|
||||
|
||||
|
||||
if optional_written:
|
||||
f.write("\n")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving .env file: {e}")
|
||||
logger.error("Error saving .env file", error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
def get_no_auth_setup_fields(self) -> List[tuple[str, str, str, bool]]:
|
||||
"""Get fields required for no-auth setup mode. Returns (field_name, display_name, placeholder, can_generate)."""
|
||||
return [
|
||||
("openai_api_key", "OpenAI API Key", "sk-...", False),
|
||||
("opensearch_password", "OpenSearch Password", "Will be auto-generated if empty", True),
|
||||
("langflow_superuser_password", "Langflow Superuser Password", "Will be auto-generated if empty", True),
|
||||
("openrag_documents_paths", "Documents Paths", "./documents,/path/to/more/docs", False),
|
||||
(
|
||||
"opensearch_password",
|
||||
"OpenSearch Password",
|
||||
"Will be auto-generated if empty",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"langflow_superuser_password",
|
||||
"Langflow Superuser Password",
|
||||
"Will be auto-generated if empty",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"openrag_documents_paths",
|
||||
"Documents Paths",
|
||||
"./documents,/path/to/more/docs",
|
||||
False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_full_setup_fields(self) -> List[tuple[str, str, str, bool]]:
|
||||
"""Get all fields for full setup mode."""
|
||||
base_fields = self.get_no_auth_setup_fields()
|
||||
|
||||
|
||||
oauth_fields = [
|
||||
("google_oauth_client_id", "Google OAuth Client ID", "xxx.apps.googleusercontent.com", False),
|
||||
(
|
||||
"google_oauth_client_id",
|
||||
"Google OAuth Client ID",
|
||||
"xxx.apps.googleusercontent.com",
|
||||
False,
|
||||
),
|
||||
("google_oauth_client_secret", "Google OAuth Client Secret", "", False),
|
||||
("microsoft_graph_oauth_client_id", "Microsoft Graph Client ID", "", False),
|
||||
("microsoft_graph_oauth_client_secret", "Microsoft Graph Client Secret", "", False),
|
||||
(
|
||||
"microsoft_graph_oauth_client_secret",
|
||||
"Microsoft Graph Client Secret",
|
||||
"",
|
||||
False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
optional_fields = [
|
||||
("webhook_base_url", "Webhook Base URL (optional)", "https://your-domain.com", False),
|
||||
(
|
||||
"webhook_base_url",
|
||||
"Webhook Base URL (optional)",
|
||||
"https://your-domain.com",
|
||||
False,
|
||||
),
|
||||
("aws_access_key_id", "AWS Access Key ID (optional)", "", False),
|
||||
("aws_secret_access_key", "AWS Secret Access Key (optional)", "", False),
|
||||
("langflow_public_url", "Langflow Public URL (optional)", "http://localhost:7860", False),
|
||||
(
|
||||
"langflow_public_url",
|
||||
"Langflow Public URL (optional)",
|
||||
"http://localhost:7860",
|
||||
False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
return base_fields + oauth_fields + optional_fields
|
||||
|
||||
|
||||
def generate_compose_volume_mounts(self) -> List[str]:
|
||||
"""Generate Docker Compose volume mount strings from documents paths."""
|
||||
is_valid, _, validated_paths = validate_documents_paths(self.config.openrag_documents_paths)
|
||||
|
||||
is_valid, _, validated_paths = validate_documents_paths(
|
||||
self.config.openrag_documents_paths
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
return ["./documents:/app/documents:Z"] # fallback
|
||||
|
||||
|
||||
volume_mounts = []
|
||||
for i, path in enumerate(validated_paths):
|
||||
if i == 0:
|
||||
|
|
@ -291,6 +385,6 @@ class EnvManager:
|
|||
volume_mounts.append(f"{path}:/app/documents:Z")
|
||||
else:
|
||||
# Additional paths map to numbered directories
|
||||
volume_mounts.append(f"{path}:/app/documents{i+1}:Z")
|
||||
|
||||
volume_mounts.append(f"{path}:/app/documents{i + 1}:Z")
|
||||
|
||||
return volume_mounts
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
"""TUI screens package."""
|
||||
"""TUI screens package."""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,16 @@
|
|||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Vertical, Horizontal, ScrollableContainer
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Header, Footer, Static, Button, Input, Label, TabbedContent, TabPane
|
||||
from textual.widgets import (
|
||||
Header,
|
||||
Footer,
|
||||
Static,
|
||||
Button,
|
||||
Input,
|
||||
Label,
|
||||
TabbedContent,
|
||||
TabPane,
|
||||
)
|
||||
from textual.validation import ValidationResult, Validator
|
||||
from rich.text import Text
|
||||
from pathlib import Path
|
||||
|
|
@ -15,11 +24,11 @@ from pathlib import Path
|
|||
|
||||
class OpenAIKeyValidator(Validator):
|
||||
"""Validator for OpenAI API keys."""
|
||||
|
||||
|
||||
def validate(self, value: str) -> ValidationResult:
|
||||
if not value:
|
||||
return self.success()
|
||||
|
||||
|
||||
if validate_openai_api_key(value):
|
||||
return self.success()
|
||||
else:
|
||||
|
|
@ -28,12 +37,12 @@ class OpenAIKeyValidator(Validator):
|
|||
|
||||
class DocumentsPathValidator(Validator):
|
||||
"""Validator for documents paths."""
|
||||
|
||||
|
||||
def validate(self, value: str) -> ValidationResult:
|
||||
# Optional: allow empty value
|
||||
if not value:
|
||||
return self.success()
|
||||
|
||||
|
||||
is_valid, error_msg, _ = validate_documents_paths(value)
|
||||
if is_valid:
|
||||
return self.success()
|
||||
|
|
@ -43,22 +52,22 @@ class DocumentsPathValidator(Validator):
|
|||
|
||||
class ConfigScreen(Screen):
|
||||
"""Configuration screen for environment setup."""
|
||||
|
||||
|
||||
BINDINGS = [
|
||||
("escape", "back", "Back"),
|
||||
("ctrl+s", "save", "Save"),
|
||||
("ctrl+g", "generate", "Generate Passwords"),
|
||||
]
|
||||
|
||||
|
||||
def __init__(self, mode: str = "full"):
|
||||
super().__init__()
|
||||
self.mode = mode # "no_auth" or "full"
|
||||
self.env_manager = EnvManager()
|
||||
self.inputs = {}
|
||||
|
||||
|
||||
# Load existing config if available
|
||||
self.env_manager.load_existing_env()
|
||||
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Create the configuration screen layout."""
|
||||
# Removed top header bar and header text
|
||||
|
|
@ -70,33 +79,37 @@ class ConfigScreen(Screen):
|
|||
Button("Generate Passwords", variant="default", id="generate-btn"),
|
||||
Button("Save Configuration", variant="success", id="save-btn"),
|
||||
Button("Back", variant="default", id="back-btn"),
|
||||
classes="button-row"
|
||||
classes="button-row",
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
|
||||
def _create_header_text(self) -> Text:
|
||||
"""Create the configuration header text."""
|
||||
header_text = Text()
|
||||
|
||||
|
||||
if self.mode == "no_auth":
|
||||
header_text.append("Quick Setup - No Authentication\n", style="bold green")
|
||||
header_text.append("Configure OpenRAG for local document processing only.\n\n", style="dim")
|
||||
header_text.append(
|
||||
"Configure OpenRAG for local document processing only.\n\n", style="dim"
|
||||
)
|
||||
else:
|
||||
header_text.append("Full Setup - OAuth Integration\n", style="bold cyan")
|
||||
header_text.append("Configure OpenRAG with cloud service integrations.\n\n", style="dim")
|
||||
|
||||
header_text.append(
|
||||
"Configure OpenRAG with cloud service integrations.\n\n", style="dim"
|
||||
)
|
||||
|
||||
header_text.append("Required fields are marked with *\n", style="yellow")
|
||||
header_text.append("Use Ctrl+G to generate admin passwords\n", style="dim")
|
||||
|
||||
|
||||
return header_text
|
||||
|
||||
|
||||
def _create_all_fields(self) -> ComposeResult:
|
||||
"""Create all configuration fields in a single scrollable layout."""
|
||||
|
||||
|
||||
# Admin Credentials Section
|
||||
yield Static("Admin Credentials", classes="tab-header")
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# OpenSearch Admin Password
|
||||
yield Label("OpenSearch Admin Password *")
|
||||
current_value = getattr(self.env_manager.config, "opensearch_password", "")
|
||||
|
|
@ -104,64 +117,73 @@ class ConfigScreen(Screen):
|
|||
placeholder="Auto-generated secure password",
|
||||
value=current_value,
|
||||
password=True,
|
||||
id="input-opensearch_password"
|
||||
id="input-opensearch_password",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["opensearch_password"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Langflow Admin Username
|
||||
yield Label("Langflow Admin Username *")
|
||||
current_value = getattr(self.env_manager.config, "langflow_superuser", "")
|
||||
input_widget = Input(
|
||||
placeholder="admin",
|
||||
value=current_value,
|
||||
id="input-langflow_superuser"
|
||||
placeholder="admin", value=current_value, id="input-langflow_superuser"
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["langflow_superuser"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Langflow Admin Password
|
||||
yield Label("Langflow Admin Password *")
|
||||
current_value = getattr(self.env_manager.config, "langflow_superuser_password", "")
|
||||
current_value = getattr(
|
||||
self.env_manager.config, "langflow_superuser_password", ""
|
||||
)
|
||||
input_widget = Input(
|
||||
placeholder="Auto-generated secure password",
|
||||
value=current_value,
|
||||
password=True,
|
||||
id="input-langflow_superuser_password"
|
||||
id="input-langflow_superuser_password",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["langflow_superuser_password"] = input_widget
|
||||
yield Static(" ")
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# API Keys Section
|
||||
yield Static("API Keys", classes="tab-header")
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# OpenAI API Key
|
||||
yield Label("OpenAI API Key *")
|
||||
# Where to create OpenAI keys (helper above the box)
|
||||
yield Static(Text("Get a key: https://platform.openai.com/api-keys", style="dim"), classes="helper-text")
|
||||
yield Static(
|
||||
Text("Get a key: https://platform.openai.com/api-keys", style="dim"),
|
||||
classes="helper-text",
|
||||
)
|
||||
current_value = getattr(self.env_manager.config, "openai_api_key", "")
|
||||
input_widget = Input(
|
||||
placeholder="sk-...",
|
||||
value=current_value,
|
||||
password=True,
|
||||
validators=[OpenAIKeyValidator()],
|
||||
id="input-openai_api_key"
|
||||
id="input-openai_api_key",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["openai_api_key"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Add OAuth fields only in full mode
|
||||
if self.mode == "full":
|
||||
# Google OAuth Client ID
|
||||
yield Label("Google OAuth Client ID")
|
||||
# Where to create Google OAuth credentials (helper above the box)
|
||||
yield Static(Text("Create credentials: https://console.cloud.google.com/apis/credentials", style="dim"), classes="helper-text")
|
||||
yield Static(
|
||||
Text(
|
||||
"Create credentials: https://console.cloud.google.com/apis/credentials",
|
||||
style="dim",
|
||||
),
|
||||
classes="helper-text",
|
||||
)
|
||||
# Callback URL guidance for Google OAuth
|
||||
yield Static(
|
||||
Text(
|
||||
|
|
@ -169,37 +191,47 @@ class ConfigScreen(Screen):
|
|||
" - Local: http://localhost:3000/auth/callback\n"
|
||||
" - Prod: https://your-domain.com/auth/callback\n"
|
||||
"If you use separate apps for login and connectors, add this URL to BOTH.",
|
||||
style="dim"
|
||||
style="dim",
|
||||
),
|
||||
classes="helper-text"
|
||||
classes="helper-text",
|
||||
)
|
||||
current_value = getattr(
|
||||
self.env_manager.config, "google_oauth_client_id", ""
|
||||
)
|
||||
current_value = getattr(self.env_manager.config, "google_oauth_client_id", "")
|
||||
input_widget = Input(
|
||||
placeholder="xxx.apps.googleusercontent.com",
|
||||
value=current_value,
|
||||
id="input-google_oauth_client_id"
|
||||
id="input-google_oauth_client_id",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["google_oauth_client_id"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Google OAuth Client Secret
|
||||
yield Label("Google OAuth Client Secret")
|
||||
current_value = getattr(self.env_manager.config, "google_oauth_client_secret", "")
|
||||
current_value = getattr(
|
||||
self.env_manager.config, "google_oauth_client_secret", ""
|
||||
)
|
||||
input_widget = Input(
|
||||
placeholder="",
|
||||
value=current_value,
|
||||
password=True,
|
||||
id="input-google_oauth_client_secret"
|
||||
id="input-google_oauth_client_secret",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["google_oauth_client_secret"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Microsoft Graph Client ID
|
||||
yield Label("Microsoft Graph Client ID")
|
||||
# Where to create Microsoft app registrations (helper above the box)
|
||||
yield Static(Text("Create app: https://portal.azure.com/#view/Microsoft_AAD_RegisteredApps/ApplicationsListBlade", style="dim"), classes="helper-text")
|
||||
yield Static(
|
||||
Text(
|
||||
"Create app: https://portal.azure.com/#view/Microsoft_AAD_RegisteredApps/ApplicationsListBlade",
|
||||
style="dim",
|
||||
),
|
||||
classes="helper-text",
|
||||
)
|
||||
# Callback URL guidance for Microsoft OAuth
|
||||
yield Static(
|
||||
Text(
|
||||
|
|
@ -207,66 +239,76 @@ class ConfigScreen(Screen):
|
|||
" - Local: http://localhost:3000/auth/callback\n"
|
||||
" - Prod: https://your-domain.com/auth/callback\n"
|
||||
"If you use separate apps for login and connectors, add this URI to BOTH.",
|
||||
style="dim"
|
||||
style="dim",
|
||||
),
|
||||
classes="helper-text"
|
||||
classes="helper-text",
|
||||
)
|
||||
current_value = getattr(
|
||||
self.env_manager.config, "microsoft_graph_oauth_client_id", ""
|
||||
)
|
||||
current_value = getattr(self.env_manager.config, "microsoft_graph_oauth_client_id", "")
|
||||
input_widget = Input(
|
||||
placeholder="",
|
||||
value=current_value,
|
||||
id="input-microsoft_graph_oauth_client_id"
|
||||
id="input-microsoft_graph_oauth_client_id",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["microsoft_graph_oauth_client_id"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Microsoft Graph Client Secret
|
||||
yield Label("Microsoft Graph Client Secret")
|
||||
current_value = getattr(self.env_manager.config, "microsoft_graph_oauth_client_secret", "")
|
||||
current_value = getattr(
|
||||
self.env_manager.config, "microsoft_graph_oauth_client_secret", ""
|
||||
)
|
||||
input_widget = Input(
|
||||
placeholder="",
|
||||
value=current_value,
|
||||
password=True,
|
||||
id="input-microsoft_graph_oauth_client_secret"
|
||||
id="input-microsoft_graph_oauth_client_secret",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["microsoft_graph_oauth_client_secret"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# AWS Access Key ID
|
||||
yield Label("AWS Access Key ID")
|
||||
# Where to create AWS keys (helper above the box)
|
||||
yield Static(Text("Create keys: https://console.aws.amazon.com/iam/home#/security_credentials", style="dim"), classes="helper-text")
|
||||
yield Static(
|
||||
Text(
|
||||
"Create keys: https://console.aws.amazon.com/iam/home#/security_credentials",
|
||||
style="dim",
|
||||
),
|
||||
classes="helper-text",
|
||||
)
|
||||
current_value = getattr(self.env_manager.config, "aws_access_key_id", "")
|
||||
input_widget = Input(
|
||||
placeholder="",
|
||||
value=current_value,
|
||||
id="input-aws_access_key_id"
|
||||
placeholder="", value=current_value, id="input-aws_access_key_id"
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["aws_access_key_id"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# AWS Secret Access Key
|
||||
yield Label("AWS Secret Access Key")
|
||||
current_value = getattr(self.env_manager.config, "aws_secret_access_key", "")
|
||||
current_value = getattr(
|
||||
self.env_manager.config, "aws_secret_access_key", ""
|
||||
)
|
||||
input_widget = Input(
|
||||
placeholder="",
|
||||
value=current_value,
|
||||
password=True,
|
||||
id="input-aws_secret_access_key"
|
||||
id="input-aws_secret_access_key",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["aws_secret_access_key"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Other Settings Section
|
||||
yield Static("Others", classes="tab-header")
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Documents Paths (optional) + picker action button on next line
|
||||
yield Label("Documents Paths")
|
||||
current_value = getattr(self.env_manager.config, "openrag_documents_paths", "")
|
||||
|
|
@ -274,57 +316,63 @@ class ConfigScreen(Screen):
|
|||
placeholder="./documents,/path/to/more/docs",
|
||||
value=current_value,
|
||||
validators=[DocumentsPathValidator()],
|
||||
id="input-openrag_documents_paths"
|
||||
id="input-openrag_documents_paths",
|
||||
)
|
||||
yield input_widget
|
||||
# Actions row with pick button
|
||||
yield Horizontal(Button("Pick…", id="pick-docs-btn"), id="docs-path-actions", classes="controls-row")
|
||||
yield Horizontal(
|
||||
Button("Pick…", id="pick-docs-btn"),
|
||||
id="docs-path-actions",
|
||||
classes="controls-row",
|
||||
)
|
||||
self.inputs["openrag_documents_paths"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Langflow Auth Settings
|
||||
yield Static("Langflow Auth Settings", classes="tab-header")
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Langflow Auto Login
|
||||
yield Label("Langflow Auto Login")
|
||||
current_value = getattr(self.env_manager.config, "langflow_auto_login", "False")
|
||||
input_widget = Input(
|
||||
placeholder="False",
|
||||
value=current_value,
|
||||
id="input-langflow_auto_login"
|
||||
placeholder="False", value=current_value, id="input-langflow_auto_login"
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["langflow_auto_login"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Langflow New User Is Active
|
||||
yield Label("Langflow New User Is Active")
|
||||
current_value = getattr(self.env_manager.config, "langflow_new_user_is_active", "False")
|
||||
current_value = getattr(
|
||||
self.env_manager.config, "langflow_new_user_is_active", "False"
|
||||
)
|
||||
input_widget = Input(
|
||||
placeholder="False",
|
||||
value=current_value,
|
||||
id="input-langflow_new_user_is_active"
|
||||
id="input-langflow_new_user_is_active",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["langflow_new_user_is_active"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Langflow Enable Superuser CLI
|
||||
yield Label("Langflow Enable Superuser CLI")
|
||||
current_value = getattr(self.env_manager.config, "langflow_enable_superuser_cli", "False")
|
||||
current_value = getattr(
|
||||
self.env_manager.config, "langflow_enable_superuser_cli", "False"
|
||||
)
|
||||
input_widget = Input(
|
||||
placeholder="False",
|
||||
value=current_value,
|
||||
id="input-langflow_enable_superuser_cli"
|
||||
id="input-langflow_enable_superuser_cli",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["langflow_enable_superuser_cli"] = input_widget
|
||||
yield Static(" ")
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Langflow Secret Key removed from UI; generated automatically on save
|
||||
|
||||
|
||||
# Add optional fields only in full mode
|
||||
if self.mode == "full":
|
||||
# Webhook Base URL
|
||||
|
|
@ -333,36 +381,43 @@ class ConfigScreen(Screen):
|
|||
input_widget = Input(
|
||||
placeholder="https://your-domain.com",
|
||||
value=current_value,
|
||||
id="input-webhook_base_url"
|
||||
id="input-webhook_base_url",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["webhook_base_url"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
# Langflow Public URL
|
||||
yield Label("Langflow Public URL")
|
||||
current_value = getattr(self.env_manager.config, "langflow_public_url", "")
|
||||
input_widget = Input(
|
||||
placeholder="http://localhost:7860",
|
||||
value=current_value,
|
||||
id="input-langflow_public_url"
|
||||
id="input-langflow_public_url",
|
||||
)
|
||||
yield input_widget
|
||||
self.inputs["langflow_public_url"] = input_widget
|
||||
yield Static(" ")
|
||||
|
||||
def _create_field(self, field_name: str, display_name: str, placeholder: str, can_generate: bool, required: bool = False) -> ComposeResult:
|
||||
|
||||
def _create_field(
|
||||
self,
|
||||
field_name: str,
|
||||
display_name: str,
|
||||
placeholder: str,
|
||||
can_generate: bool,
|
||||
required: bool = False,
|
||||
) -> ComposeResult:
|
||||
"""Create a single form field."""
|
||||
# Create label
|
||||
label_text = f"{display_name}"
|
||||
if required:
|
||||
label_text += " *"
|
||||
|
||||
|
||||
yield Label(label_text)
|
||||
|
||||
|
||||
# Get current value
|
||||
current_value = getattr(self.env_manager.config, field_name, "")
|
||||
|
||||
|
||||
# Create input with appropriate validator
|
||||
if field_name == "openai_api_key":
|
||||
input_widget = Input(
|
||||
|
|
@ -370,35 +425,33 @@ class ConfigScreen(Screen):
|
|||
value=current_value,
|
||||
password=True,
|
||||
validators=[OpenAIKeyValidator()],
|
||||
id=f"input-{field_name}"
|
||||
id=f"input-{field_name}",
|
||||
)
|
||||
elif field_name == "openrag_documents_paths":
|
||||
input_widget = Input(
|
||||
placeholder=placeholder,
|
||||
value=current_value,
|
||||
validators=[DocumentsPathValidator()],
|
||||
id=f"input-{field_name}"
|
||||
id=f"input-{field_name}",
|
||||
)
|
||||
elif "password" in field_name or "secret" in field_name:
|
||||
input_widget = Input(
|
||||
placeholder=placeholder,
|
||||
value=current_value,
|
||||
password=True,
|
||||
id=f"input-{field_name}"
|
||||
id=f"input-{field_name}",
|
||||
)
|
||||
else:
|
||||
input_widget = Input(
|
||||
placeholder=placeholder,
|
||||
value=current_value,
|
||||
id=f"input-{field_name}"
|
||||
placeholder=placeholder, value=current_value, id=f"input-{field_name}"
|
||||
)
|
||||
|
||||
|
||||
yield input_widget
|
||||
self.inputs[field_name] = input_widget
|
||||
|
||||
|
||||
# Add spacing
|
||||
yield Static(" ")
|
||||
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Initialize the screen when mounted."""
|
||||
# Focus the first input field
|
||||
|
|
@ -409,7 +462,7 @@ class ConfigScreen(Screen):
|
|||
inputs[0].focus()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button presses."""
|
||||
if event.button.id == "generate-btn":
|
||||
|
|
@ -420,43 +473,47 @@ class ConfigScreen(Screen):
|
|||
self.action_back()
|
||||
elif event.button.id == "pick-docs-btn":
|
||||
self.action_pick_documents_path()
|
||||
|
||||
|
||||
def action_generate(self) -> None:
|
||||
"""Generate secure passwords for admin accounts."""
|
||||
self.env_manager.setup_secure_defaults()
|
||||
|
||||
|
||||
# Update input fields with generated values
|
||||
for field_name, input_widget in self.inputs.items():
|
||||
if field_name in ["opensearch_password", "langflow_superuser_password"]:
|
||||
new_value = getattr(self.env_manager.config, field_name)
|
||||
input_widget.value = new_value
|
||||
|
||||
|
||||
self.notify("Generated secure passwords", severity="information")
|
||||
|
||||
|
||||
def action_save(self) -> None:
|
||||
"""Save the configuration."""
|
||||
# Update config from input fields
|
||||
for field_name, input_widget in self.inputs.items():
|
||||
setattr(self.env_manager.config, field_name, input_widget.value)
|
||||
|
||||
|
||||
# Validate the configuration
|
||||
if not self.env_manager.validate_config(self.mode):
|
||||
error_messages = []
|
||||
for field, error in self.env_manager.config.validation_errors.items():
|
||||
error_messages.append(f"{field}: {error}")
|
||||
|
||||
self.notify(f"Validation failed:\n" + "\n".join(error_messages[:3]), severity="error")
|
||||
|
||||
self.notify(
|
||||
f"Validation failed:\n" + "\n".join(error_messages[:3]),
|
||||
severity="error",
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Save to file
|
||||
if self.env_manager.save_env_file():
|
||||
self.notify("Configuration saved successfully!", severity="information")
|
||||
# Switch to monitor screen
|
||||
from .monitor import MonitorScreen
|
||||
|
||||
self.app.push_screen(MonitorScreen())
|
||||
else:
|
||||
self.notify("Failed to save configuration", severity="error")
|
||||
|
||||
|
||||
def action_back(self) -> None:
|
||||
"""Go back to welcome screen."""
|
||||
self.app.pop_screen()
|
||||
|
|
@ -465,6 +522,7 @@ class ConfigScreen(Screen):
|
|||
"""Open textual-fspicker to select a path and append it to the input."""
|
||||
try:
|
||||
import importlib
|
||||
|
||||
fsp = importlib.import_module("textual_fspicker")
|
||||
except Exception:
|
||||
self.notify("textual-fspicker not available", severity="warning")
|
||||
|
|
@ -479,9 +537,13 @@ class ConfigScreen(Screen):
|
|||
start = Path(first).expanduser()
|
||||
|
||||
# Prefer SelectDirectory for directories; fallback to FileOpen
|
||||
PickerClass = getattr(fsp, "SelectDirectory", None) or getattr(fsp, "FileOpen", None)
|
||||
PickerClass = getattr(fsp, "SelectDirectory", None) or getattr(
|
||||
fsp, "FileOpen", None
|
||||
)
|
||||
if PickerClass is None:
|
||||
self.notify("No compatible picker found in textual-fspicker", severity="warning")
|
||||
self.notify(
|
||||
"No compatible picker found in textual-fspicker", severity="warning"
|
||||
)
|
||||
return
|
||||
try:
|
||||
picker = PickerClass(location=start)
|
||||
|
|
@ -523,7 +585,7 @@ class ConfigScreen(Screen):
|
|||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def on_input_changed(self, event: Input.Changed) -> None:
|
||||
"""Handle input changes for real-time validation feedback."""
|
||||
# This will trigger validation display in real-time
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from ..managers.container_manager import ContainerManager
|
|||
|
||||
class DiagnosticsScreen(Screen):
|
||||
"""Diagnostics screen for debugging OpenRAG."""
|
||||
|
||||
|
||||
CSS = """
|
||||
#diagnostics-log {
|
||||
border: solid $accent;
|
||||
|
|
@ -40,20 +40,20 @@ class DiagnosticsScreen(Screen):
|
|||
text-align: center;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
BINDINGS = [
|
||||
("escape", "back", "Back"),
|
||||
("r", "refresh", "Refresh"),
|
||||
("ctrl+c", "copy", "Copy to Clipboard"),
|
||||
("ctrl+s", "save", "Save to File"),
|
||||
]
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.container_manager = ContainerManager()
|
||||
self._logger = logging.getLogger("openrag.diagnostics")
|
||||
self._status_timer = None
|
||||
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Create the diagnostics screen layout."""
|
||||
yield Header()
|
||||
|
|
@ -66,24 +66,24 @@ class DiagnosticsScreen(Screen):
|
|||
yield Button("Copy to Clipboard", variant="default", id="copy-btn")
|
||||
yield Button("Save to File", variant="default", id="save-btn")
|
||||
yield Button("Back", variant="default", id="back-btn")
|
||||
|
||||
|
||||
# Status indicator for copy/save operations
|
||||
yield Static("", id="copy-status", classes="copy-indicator")
|
||||
|
||||
|
||||
with ScrollableContainer(id="diagnostics-scroll"):
|
||||
yield Log(id="diagnostics-log", highlight=True)
|
||||
yield Footer()
|
||||
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""Initialize the screen."""
|
||||
self.run_diagnostics()
|
||||
|
||||
|
||||
# Focus the first button (refresh-btn)
|
||||
try:
|
||||
self.query_one("#refresh-btn").focus()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button presses."""
|
||||
if event.button.id == "refresh-btn":
|
||||
|
|
@ -98,25 +98,26 @@ class DiagnosticsScreen(Screen):
|
|||
self.save_to_file()
|
||||
elif event.button.id == "back-btn":
|
||||
self.action_back()
|
||||
|
||||
|
||||
def action_refresh(self) -> None:
|
||||
"""Refresh diagnostics."""
|
||||
self.run_diagnostics()
|
||||
|
||||
|
||||
def action_copy(self) -> None:
|
||||
"""Copy log content to clipboard (keyboard shortcut)."""
|
||||
self.copy_to_clipboard()
|
||||
|
||||
|
||||
def copy_to_clipboard(self) -> None:
|
||||
"""Copy log content to clipboard."""
|
||||
try:
|
||||
log = self.query_one("#diagnostics-log", Log)
|
||||
content = "\n".join(str(line) for line in log.lines)
|
||||
status = self.query_one("#copy-status", Static)
|
||||
|
||||
|
||||
# Try to use pyperclip if available
|
||||
try:
|
||||
import pyperclip
|
||||
|
||||
pyperclip.copy(content)
|
||||
self.notify("Copied to clipboard", severity="information")
|
||||
status.update("✓ Content copied to clipboard")
|
||||
|
|
@ -124,23 +125,19 @@ class DiagnosticsScreen(Screen):
|
|||
return
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# Fallback to platform-specific clipboard commands
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
|
||||
system = platform.system()
|
||||
if system == "Darwin": # macOS
|
||||
process = subprocess.Popen(
|
||||
["pbcopy"], stdin=subprocess.PIPE, text=True
|
||||
)
|
||||
process = subprocess.Popen(["pbcopy"], stdin=subprocess.PIPE, text=True)
|
||||
process.communicate(input=content)
|
||||
self.notify("Copied to clipboard", severity="information")
|
||||
status.update("✓ Content copied to clipboard")
|
||||
elif system == "Windows":
|
||||
process = subprocess.Popen(
|
||||
["clip"], stdin=subprocess.PIPE, text=True
|
||||
)
|
||||
process = subprocess.Popen(["clip"], stdin=subprocess.PIPE, text=True)
|
||||
process.communicate(input=content)
|
||||
self.notify("Copied to clipboard", severity="information")
|
||||
status.update("✓ Content copied to clipboard")
|
||||
|
|
@ -150,7 +147,7 @@ class DiagnosticsScreen(Screen):
|
|||
process = subprocess.Popen(
|
||||
["xclip", "-selection", "clipboard"],
|
||||
stdin=subprocess.PIPE,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
process.communicate(input=content)
|
||||
self.notify("Copied to clipboard", severity="information")
|
||||
|
|
@ -160,65 +157,78 @@ class DiagnosticsScreen(Screen):
|
|||
process = subprocess.Popen(
|
||||
["xsel", "--clipboard", "--input"],
|
||||
stdin=subprocess.PIPE,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
process.communicate(input=content)
|
||||
self.notify("Copied to clipboard", severity="information")
|
||||
status.update("✓ Content copied to clipboard")
|
||||
except FileNotFoundError:
|
||||
self.notify("Clipboard utilities not found. Install xclip or xsel.", severity="error")
|
||||
status.update("❌ Clipboard utilities not found. Install xclip or xsel.")
|
||||
self.notify(
|
||||
"Clipboard utilities not found. Install xclip or xsel.",
|
||||
severity="error",
|
||||
)
|
||||
status.update(
|
||||
"❌ Clipboard utilities not found. Install xclip or xsel."
|
||||
)
|
||||
else:
|
||||
self.notify("Clipboard not supported on this platform", severity="error")
|
||||
self.notify(
|
||||
"Clipboard not supported on this platform", severity="error"
|
||||
)
|
||||
status.update("❌ Clipboard not supported on this platform")
|
||||
|
||||
|
||||
self._hide_status_after_delay(status)
|
||||
except Exception as e:
|
||||
self.notify(f"Failed to copy to clipboard: {e}", severity="error")
|
||||
status = self.query_one("#copy-status", Static)
|
||||
status.update(f"❌ Failed to copy: {e}")
|
||||
self._hide_status_after_delay(status)
|
||||
|
||||
def _hide_status_after_delay(self, status_widget: Static, delay: float = 3.0) -> None:
|
||||
|
||||
def _hide_status_after_delay(
|
||||
self, status_widget: Static, delay: float = 3.0
|
||||
) -> None:
|
||||
"""Hide the status message after a delay."""
|
||||
# Cancel any existing timer
|
||||
if self._status_timer:
|
||||
self._status_timer.cancel()
|
||||
|
||||
|
||||
# Create and run the timer task
|
||||
self._status_timer = asyncio.create_task(self._clear_status_after_delay(status_widget, delay))
|
||||
|
||||
async def _clear_status_after_delay(self, status_widget: Static, delay: float) -> None:
|
||||
self._status_timer = asyncio.create_task(
|
||||
self._clear_status_after_delay(status_widget, delay)
|
||||
)
|
||||
|
||||
async def _clear_status_after_delay(
|
||||
self, status_widget: Static, delay: float
|
||||
) -> None:
|
||||
"""Clear the status message after a delay."""
|
||||
await asyncio.sleep(delay)
|
||||
status_widget.update("")
|
||||
|
||||
|
||||
def action_save(self) -> None:
|
||||
"""Save log content to file (keyboard shortcut)."""
|
||||
self.save_to_file()
|
||||
|
||||
|
||||
def save_to_file(self) -> None:
|
||||
"""Save log content to a file."""
|
||||
try:
|
||||
log = self.query_one("#diagnostics-log", Log)
|
||||
content = "\n".join(str(line) for line in log.lines)
|
||||
status = self.query_one("#copy-status", Static)
|
||||
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
logs_dir = Path("logs")
|
||||
logs_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# Create a timestamped filename
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = logs_dir / f"openrag_diagnostics_{timestamp}.txt"
|
||||
|
||||
|
||||
# Save to file
|
||||
with open(filename, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
self.notify(f"Saved to {filename}", severity="information")
|
||||
status.update(f"✓ Saved to {filename}")
|
||||
|
||||
|
||||
# Log the save operation
|
||||
self._logger.info(f"Diagnostics saved to {filename}")
|
||||
self._hide_status_after_delay(status)
|
||||
|
|
@ -226,55 +236,57 @@ class DiagnosticsScreen(Screen):
|
|||
error_msg = f"Failed to save file: {e}"
|
||||
self.notify(error_msg, severity="error")
|
||||
self._logger.error(error_msg)
|
||||
|
||||
|
||||
status = self.query_one("#copy-status", Static)
|
||||
status.update(f"❌ {error_msg}")
|
||||
self._hide_status_after_delay(status)
|
||||
|
||||
|
||||
def action_back(self) -> None:
|
||||
"""Go back to previous screen."""
|
||||
self.app.pop_screen()
|
||||
|
||||
|
||||
def _get_system_info(self) -> Text:
|
||||
"""Get system information text."""
|
||||
info_text = Text()
|
||||
|
||||
|
||||
runtime_info = self.container_manager.get_runtime_info()
|
||||
|
||||
|
||||
info_text.append("Container Runtime Information\n", style="bold")
|
||||
info_text.append("=" * 30 + "\n")
|
||||
info_text.append(f"Type: {runtime_info.runtime_type.value}\n")
|
||||
info_text.append(f"Compose Command: {' '.join(runtime_info.compose_command)}\n")
|
||||
info_text.append(f"Runtime Command: {' '.join(runtime_info.runtime_command)}\n")
|
||||
|
||||
|
||||
if runtime_info.version:
|
||||
info_text.append(f"Version: {runtime_info.version}\n")
|
||||
|
||||
|
||||
return info_text
|
||||
|
||||
|
||||
def run_diagnostics(self) -> None:
|
||||
"""Run all diagnostics."""
|
||||
log = self.query_one("#diagnostics-log", Log)
|
||||
log.clear()
|
||||
|
||||
|
||||
# System information
|
||||
system_info = self._get_system_info()
|
||||
log.write(str(system_info))
|
||||
log.write("")
|
||||
|
||||
|
||||
# Run async diagnostics
|
||||
asyncio.create_task(self._run_async_diagnostics())
|
||||
|
||||
|
||||
async def _run_async_diagnostics(self) -> None:
|
||||
"""Run asynchronous diagnostics."""
|
||||
log = self.query_one("#diagnostics-log", Log)
|
||||
|
||||
|
||||
# Check services
|
||||
log.write("[bold green]Service Status[/bold green]")
|
||||
services = await self.container_manager.get_service_status(force_refresh=True)
|
||||
for name, info in services.items():
|
||||
status_color = "green" if info.status == "running" else "red"
|
||||
log.write(f"[bold]{name}[/bold]: [{status_color}]{info.status.value}[/{status_color}]")
|
||||
log.write(
|
||||
f"[bold]{name}[/bold]: [{status_color}]{info.status.value}[/{status_color}]"
|
||||
)
|
||||
if info.health:
|
||||
log.write(f" Health: {info.health}")
|
||||
if info.ports:
|
||||
|
|
@ -282,40 +294,38 @@ class DiagnosticsScreen(Screen):
|
|||
if info.image:
|
||||
log.write(f" Image: {info.image}")
|
||||
log.write("")
|
||||
|
||||
|
||||
# Check for Podman-specific issues
|
||||
if self.container_manager.runtime_info.runtime_type.name == "PODMAN":
|
||||
await self.check_podman()
|
||||
|
||||
|
||||
async def check_podman(self) -> None:
|
||||
"""Run Podman-specific diagnostics."""
|
||||
log = self.query_one("#diagnostics-log", Log)
|
||||
log.write("[bold green]Podman Diagnostics[/bold green]")
|
||||
|
||||
|
||||
# Check if using Podman
|
||||
if self.container_manager.runtime_info.runtime_type.name != "PODMAN":
|
||||
log.write("[yellow]Not using Podman[/yellow]")
|
||||
return
|
||||
|
||||
|
||||
# Check Podman version
|
||||
cmd = ["podman", "--version"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode == 0:
|
||||
log.write(f"Podman version: {stdout.decode().strip()}")
|
||||
else:
|
||||
log.write(f"[red]Failed to get Podman version: {stderr.decode().strip()}[/red]")
|
||||
|
||||
log.write(
|
||||
f"[red]Failed to get Podman version: {stderr.decode().strip()}[/red]"
|
||||
)
|
||||
|
||||
# Check Podman containers
|
||||
cmd = ["podman", "ps", "--all"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode == 0:
|
||||
|
|
@ -323,15 +333,17 @@ class DiagnosticsScreen(Screen):
|
|||
for line in stdout.decode().strip().split("\n"):
|
||||
log.write(f" {line}")
|
||||
else:
|
||||
log.write(f"[red]Failed to list Podman containers: {stderr.decode().strip()}[/red]")
|
||||
|
||||
log.write(
|
||||
f"[red]Failed to list Podman containers: {stderr.decode().strip()}[/red]"
|
||||
)
|
||||
|
||||
# Check Podman compose
|
||||
cmd = ["podman", "compose", "ps"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=self.container_manager.compose_file.parent
|
||||
cwd=self.container_manager.compose_file.parent,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode == 0:
|
||||
|
|
@ -339,39 +351,39 @@ class DiagnosticsScreen(Screen):
|
|||
for line in stdout.decode().strip().split("\n"):
|
||||
log.write(f" {line}")
|
||||
else:
|
||||
log.write(f"[red]Failed to list Podman compose services: {stderr.decode().strip()}[/red]")
|
||||
|
||||
log.write(
|
||||
f"[red]Failed to list Podman compose services: {stderr.decode().strip()}[/red]"
|
||||
)
|
||||
|
||||
log.write("")
|
||||
|
||||
|
||||
async def check_docker(self) -> None:
|
||||
"""Run Docker-specific diagnostics."""
|
||||
log = self.query_one("#diagnostics-log", Log)
|
||||
log.write("[bold green]Docker Diagnostics[/bold green]")
|
||||
|
||||
|
||||
# Check if using Docker
|
||||
if "DOCKER" not in self.container_manager.runtime_info.runtime_type.name:
|
||||
log.write("[yellow]Not using Docker[/yellow]")
|
||||
return
|
||||
|
||||
|
||||
# Check Docker version
|
||||
cmd = ["docker", "--version"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode == 0:
|
||||
log.write(f"Docker version: {stdout.decode().strip()}")
|
||||
else:
|
||||
log.write(f"[red]Failed to get Docker version: {stderr.decode().strip()}[/red]")
|
||||
|
||||
log.write(
|
||||
f"[red]Failed to get Docker version: {stderr.decode().strip()}[/red]"
|
||||
)
|
||||
|
||||
# Check Docker containers
|
||||
cmd = ["docker", "ps", "--all"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode == 0:
|
||||
|
|
@ -379,15 +391,17 @@ class DiagnosticsScreen(Screen):
|
|||
for line in stdout.decode().strip().split("\n"):
|
||||
log.write(f" {line}")
|
||||
else:
|
||||
log.write(f"[red]Failed to list Docker containers: {stderr.decode().strip()}[/red]")
|
||||
|
||||
log.write(
|
||||
f"[red]Failed to list Docker containers: {stderr.decode().strip()}[/red]"
|
||||
)
|
||||
|
||||
# Check Docker compose
|
||||
cmd = ["docker", "compose", "ps"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=self.container_manager.compose_file.parent
|
||||
cwd=self.container_manager.compose_file.parent,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode == 0:
|
||||
|
|
@ -395,8 +409,11 @@ class DiagnosticsScreen(Screen):
|
|||
for line in stdout.decode().strip().split("\n"):
|
||||
log.write(f" {line}")
|
||||
else:
|
||||
log.write(f"[red]Failed to list Docker compose services: {stderr.decode().strip()}[/red]")
|
||||
|
||||
log.write(
|
||||
f"[red]Failed to list Docker compose services: {stderr.decode().strip()}[/red]"
|
||||
)
|
||||
|
||||
log.write("")
|
||||
|
||||
|
||||
# Made with Bob
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from ..managers.container_manager import ContainerManager
|
|||
|
||||
class LogsScreen(Screen):
|
||||
"""Logs viewing and monitoring screen."""
|
||||
|
||||
|
||||
BINDINGS = [
|
||||
("escape", "back", "Back"),
|
||||
("f", "follow", "Follow Logs"),
|
||||
|
|
@ -27,44 +27,50 @@ class LogsScreen(Screen):
|
|||
("ctrl+u", "scroll_page_up", "Page Up"),
|
||||
("ctrl+f", "scroll_page_down", "Page Down"),
|
||||
]
|
||||
|
||||
|
||||
def __init__(self, initial_service: str = "openrag-backend"):
|
||||
super().__init__()
|
||||
self.container_manager = ContainerManager()
|
||||
|
||||
|
||||
# Validate the initial service against available options
|
||||
valid_services = ["openrag-backend", "openrag-frontend", "opensearch", "langflow", "dashboards"]
|
||||
valid_services = [
|
||||
"openrag-backend",
|
||||
"openrag-frontend",
|
||||
"opensearch",
|
||||
"langflow",
|
||||
"dashboards",
|
||||
]
|
||||
if initial_service not in valid_services:
|
||||
initial_service = "openrag-backend" # fallback
|
||||
|
||||
|
||||
self.current_service = initial_service
|
||||
self.logs_area = None
|
||||
self.following = False
|
||||
self.follow_task = None
|
||||
self.auto_scroll = True
|
||||
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Create the logs screen layout."""
|
||||
yield Container(
|
||||
Vertical(
|
||||
Static(f"Service Logs: {self.current_service}", id="logs-title"),
|
||||
self._create_logs_area(),
|
||||
id="logs-content"
|
||||
id="logs-content",
|
||||
),
|
||||
id="main-container"
|
||||
id="main-container",
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
|
||||
def _create_logs_area(self) -> TextArea:
|
||||
"""Create the logs text area."""
|
||||
self.logs_area = TextArea(
|
||||
text="Loading logs...",
|
||||
read_only=True,
|
||||
show_line_numbers=False,
|
||||
id="logs-area"
|
||||
id="logs-area",
|
||||
)
|
||||
return self.logs_area
|
||||
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
"""Initialize the screen when mounted."""
|
||||
# Set the correct service in the select widget after a brief delay
|
||||
|
|
@ -72,34 +78,40 @@ class LogsScreen(Screen):
|
|||
select = self.query_one("#service-select")
|
||||
# Set a default first, then set the desired value
|
||||
select.value = "openrag-backend"
|
||||
if self.current_service in ["openrag-backend", "openrag-frontend", "opensearch", "langflow", "dashboards"]:
|
||||
if self.current_service in [
|
||||
"openrag-backend",
|
||||
"openrag-frontend",
|
||||
"opensearch",
|
||||
"langflow",
|
||||
"dashboards",
|
||||
]:
|
||||
select.value = self.current_service
|
||||
except Exception as e:
|
||||
# If setting the service fails, just use the default
|
||||
pass
|
||||
|
||||
|
||||
await self._load_logs()
|
||||
|
||||
|
||||
# Focus the logs area since there are no buttons
|
||||
try:
|
||||
self.logs_area.focus()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def on_unmount(self) -> None:
|
||||
"""Clean up when unmounting."""
|
||||
self._stop_following()
|
||||
|
||||
|
||||
|
||||
|
||||
async def _load_logs(self, lines: int = 200) -> None:
|
||||
"""Load recent logs for the current service."""
|
||||
if not self.container_manager.is_available():
|
||||
self.logs_area.text = "No container runtime available"
|
||||
return
|
||||
|
||||
success, logs = await self.container_manager.get_service_logs(self.current_service, lines)
|
||||
|
||||
|
||||
success, logs = await self.container_manager.get_service_logs(
|
||||
self.current_service, lines
|
||||
)
|
||||
|
||||
if success:
|
||||
self.logs_area.text = logs
|
||||
# Scroll to bottom if auto scroll is enabled
|
||||
|
|
@ -107,67 +119,71 @@ class LogsScreen(Screen):
|
|||
self.logs_area.scroll_end()
|
||||
else:
|
||||
self.logs_area.text = f"Failed to load logs: {logs}"
|
||||
|
||||
|
||||
def _stop_following(self) -> None:
|
||||
"""Stop following logs."""
|
||||
self.following = False
|
||||
if self.follow_task and not self.follow_task.is_finished:
|
||||
self.follow_task.cancel()
|
||||
|
||||
|
||||
# No button to update since we removed it
|
||||
|
||||
|
||||
async def _follow_logs(self) -> None:
|
||||
"""Follow logs in real-time."""
|
||||
if not self.container_manager.is_available():
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
async for log_line in self.container_manager.follow_service_logs(self.current_service):
|
||||
async for log_line in self.container_manager.follow_service_logs(
|
||||
self.current_service
|
||||
):
|
||||
if not self.following:
|
||||
break
|
||||
|
||||
|
||||
# Append new line to logs area
|
||||
current_text = self.logs_area.text
|
||||
new_text = current_text + "\n" + log_line
|
||||
|
||||
|
||||
# Keep only last 1000 lines to prevent memory issues
|
||||
lines = new_text.split('\n')
|
||||
lines = new_text.split("\n")
|
||||
if len(lines) > 1000:
|
||||
lines = lines[-1000:]
|
||||
new_text = '\n'.join(lines)
|
||||
|
||||
new_text = "\n".join(lines)
|
||||
|
||||
self.logs_area.text = new_text
|
||||
# Scroll to bottom if auto scroll is enabled
|
||||
if self.auto_scroll:
|
||||
self.logs_area.scroll_end()
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
if self.following: # Only show error if we're still supposed to be following
|
||||
if (
|
||||
self.following
|
||||
): # Only show error if we're still supposed to be following
|
||||
self.notify(f"Error following logs: {e}", severity="error")
|
||||
finally:
|
||||
self.following = False
|
||||
|
||||
|
||||
def action_refresh(self) -> None:
|
||||
"""Refresh logs."""
|
||||
self._stop_following()
|
||||
self.run_worker(self._load_logs())
|
||||
|
||||
|
||||
def action_follow(self) -> None:
|
||||
"""Toggle log following."""
|
||||
if self.following:
|
||||
self._stop_following()
|
||||
else:
|
||||
self.following = True
|
||||
|
||||
|
||||
# Start following
|
||||
self.follow_task = self.run_worker(self._follow_logs(), exclusive=False)
|
||||
|
||||
|
||||
def action_clear(self) -> None:
|
||||
"""Clear the logs area."""
|
||||
self.logs_area.text = ""
|
||||
|
||||
|
||||
def action_toggle_auto_scroll(self) -> None:
|
||||
"""Toggle auto scroll on/off."""
|
||||
self.auto_scroll = not self.auto_scroll
|
||||
|
|
@ -201,13 +217,13 @@ class LogsScreen(Screen):
|
|||
def on_key(self, event) -> None:
|
||||
"""Handle key presses that might be intercepted by TextArea."""
|
||||
key = event.key
|
||||
|
||||
|
||||
# Handle keys that TextArea might intercept
|
||||
if key == "ctrl+u":
|
||||
self.action_scroll_page_up()
|
||||
event.prevent_default()
|
||||
elif key == "ctrl+f":
|
||||
self.action_scroll_page_down()
|
||||
self.action_scroll_page_down()
|
||||
event.prevent_default()
|
||||
elif key.upper() == "G":
|
||||
self.action_scroll_bottom()
|
||||
|
|
@ -216,4 +232,4 @@ class LogsScreen(Screen):
|
|||
def action_back(self) -> None:
|
||||
"""Go back to previous screen."""
|
||||
self._stop_following()
|
||||
self.app.pop_screen()
|
||||
self.app.pop_screen()
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from ..widgets.diagnostics_notification import notify_with_diagnostics
|
|||
|
||||
class MonitorScreen(Screen):
|
||||
"""Service monitoring and control screen."""
|
||||
|
||||
|
||||
BINDINGS = [
|
||||
("escape", "back", "Back"),
|
||||
("r", "refresh", "Refresh"),
|
||||
|
|
@ -35,7 +35,7 @@ class MonitorScreen(Screen):
|
|||
("j", "cursor_down", "Move Down"),
|
||||
("k", "cursor_up", "Move Up"),
|
||||
]
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.container_manager = ContainerManager()
|
||||
|
|
@ -47,14 +47,14 @@ class MonitorScreen(Screen):
|
|||
self._follow_task = None
|
||||
self._follow_service = None
|
||||
self._logs_buffer = []
|
||||
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Create the monitoring screen layout."""
|
||||
# Just show the services content directly (no header, no tabs)
|
||||
yield from self._create_services_tab()
|
||||
|
||||
|
||||
yield Footer()
|
||||
|
||||
|
||||
def _create_services_tab(self) -> ComposeResult:
|
||||
"""Create the services monitoring tab."""
|
||||
# Current mode indicator + toggle
|
||||
|
|
@ -75,69 +75,73 @@ class MonitorScreen(Screen):
|
|||
yield Horizontal(id="services-controls", classes="button-row")
|
||||
# Create services table with image + digest info
|
||||
self.services_table = DataTable(id="services-table")
|
||||
self.services_table.add_columns("Service", "Status", "Health", "Ports", "Image", "Digest")
|
||||
self.services_table.add_columns(
|
||||
"Service", "Status", "Health", "Ports", "Image", "Digest"
|
||||
)
|
||||
yield self.services_table
|
||||
|
||||
|
||||
|
||||
def _get_runtime_status(self) -> Text:
|
||||
"""Get container runtime status text."""
|
||||
status_text = Text()
|
||||
|
||||
|
||||
if not self.container_manager.is_available():
|
||||
status_text.append("WARNING: No container runtime available\n", style="bold red")
|
||||
status_text.append("Please install Docker or Podman to continue.\n", style="dim")
|
||||
status_text.append(
|
||||
"WARNING: No container runtime available\n", style="bold red"
|
||||
)
|
||||
status_text.append(
|
||||
"Please install Docker or Podman to continue.\n", style="dim"
|
||||
)
|
||||
return status_text
|
||||
|
||||
|
||||
runtime_info = self.container_manager.get_runtime_info()
|
||||
|
||||
|
||||
if runtime_info.runtime_type == RuntimeType.DOCKER:
|
||||
status_text.append("Docker Runtime\n", style="bold blue")
|
||||
elif runtime_info.runtime_type == RuntimeType.PODMAN:
|
||||
status_text.append("Podman Runtime\n", style="bold purple")
|
||||
else:
|
||||
status_text.append("Container Runtime\n", style="bold green")
|
||||
|
||||
|
||||
if runtime_info.version:
|
||||
status_text.append(f"Version: {runtime_info.version}\n", style="dim")
|
||||
|
||||
|
||||
# Check Podman macOS memory if applicable
|
||||
if runtime_info.runtime_type == RuntimeType.PODMAN:
|
||||
is_sufficient, message = self.container_manager.check_podman_macos_memory()
|
||||
if not is_sufficient:
|
||||
status_text.append(f"WARNING: {message}\n", style="bold yellow")
|
||||
|
||||
|
||||
return status_text
|
||||
|
||||
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
"""Initialize the screen when mounted."""
|
||||
await self._refresh_services()
|
||||
# Set up auto-refresh every 5 seconds
|
||||
self.refresh_timer = self.set_interval(5.0, self._auto_refresh)
|
||||
|
||||
|
||||
# Focus the services table
|
||||
try:
|
||||
self.services_table.focus()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def on_unmount(self) -> None:
|
||||
"""Clean up when unmounting."""
|
||||
if self.refresh_timer:
|
||||
self.refresh_timer.stop()
|
||||
# Stop following logs if running
|
||||
self._stop_follow()
|
||||
|
||||
|
||||
async def on_screen_resume(self) -> None:
|
||||
"""Called when the screen is resumed (e.g., after a modal is closed)."""
|
||||
# Refresh services when returning from a modal
|
||||
await self._refresh_services()
|
||||
|
||||
|
||||
async def _refresh_services(self) -> None:
|
||||
"""Refresh the services table."""
|
||||
if not self.container_manager.is_available():
|
||||
return
|
||||
|
||||
|
||||
services = await self.container_manager.get_service_status(force_refresh=True)
|
||||
# Collect images actually reported by running/stopped containers so names match runtime
|
||||
images_set = set()
|
||||
|
|
@ -147,7 +151,9 @@ class MonitorScreen(Screen):
|
|||
images_set.add(img)
|
||||
# Ensure compose-declared images are also shown (e.g., langflow when stopped)
|
||||
try:
|
||||
for img in self.container_manager._parse_compose_images(): # best-effort, no YAML dep
|
||||
for img in (
|
||||
self.container_manager._parse_compose_images()
|
||||
): # best-effort, no YAML dep
|
||||
if img:
|
||||
images_set.add(img)
|
||||
except Exception:
|
||||
|
|
@ -155,23 +161,23 @@ class MonitorScreen(Screen):
|
|||
images = list(images_set)
|
||||
# Lookup digests/IDs for these image names
|
||||
digest_map = await self.container_manager.get_images_digests(images)
|
||||
|
||||
|
||||
# Clear existing rows
|
||||
self.services_table.clear()
|
||||
if self.images_table:
|
||||
self.images_table.clear()
|
||||
|
||||
|
||||
# Add service rows
|
||||
for service_name, service_info in services.items():
|
||||
status_style = self._get_status_style(service_info.status)
|
||||
|
||||
|
||||
self.services_table.add_row(
|
||||
service_info.name,
|
||||
Text(service_info.status.value, style=status_style),
|
||||
service_info.health or "N/A",
|
||||
", ".join(service_info.ports) if service_info.ports else "N/A",
|
||||
service_info.image or "N/A",
|
||||
digest_map.get(service_info.image or "", "-")
|
||||
digest_map.get(service_info.image or "", "-"),
|
||||
)
|
||||
# Populate images table (unique images as reported by runtime)
|
||||
if self.images_table:
|
||||
|
|
@ -181,7 +187,7 @@ class MonitorScreen(Screen):
|
|||
self._update_controls(list(services.values()))
|
||||
# Update mode indicator
|
||||
self._update_mode_row()
|
||||
|
||||
|
||||
def _get_status_style(self, status: ServiceStatus) -> str:
|
||||
"""Get the Rich style for a service status."""
|
||||
status_styles = {
|
||||
|
|
@ -191,20 +197,20 @@ class MonitorScreen(Screen):
|
|||
ServiceStatus.STOPPING: "bold yellow",
|
||||
ServiceStatus.ERROR: "bold red",
|
||||
ServiceStatus.MISSING: "dim",
|
||||
ServiceStatus.UNKNOWN: "dim"
|
||||
ServiceStatus.UNKNOWN: "dim",
|
||||
}
|
||||
return status_styles.get(status, "white")
|
||||
|
||||
|
||||
async def _auto_refresh(self) -> None:
|
||||
"""Auto-refresh services if not in operation."""
|
||||
if not self.operation_in_progress:
|
||||
await self._refresh_services()
|
||||
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button presses."""
|
||||
button_id = event.button.id or ""
|
||||
button_label = event.button.label or ""
|
||||
|
||||
|
||||
# Use button ID prefixes to determine action, ignoring any random suffix
|
||||
if button_id.startswith("start-btn"):
|
||||
self.run_worker(self._start_services())
|
||||
|
|
@ -228,18 +234,18 @@ class MonitorScreen(Screen):
|
|||
"logs-backend": "openrag-backend",
|
||||
"logs-frontend": "openrag-frontend",
|
||||
"logs-opensearch": "opensearch",
|
||||
"logs-langflow": "langflow"
|
||||
"logs-langflow": "langflow",
|
||||
}
|
||||
|
||||
|
||||
# Extract the base button ID (without any suffix)
|
||||
button_base_id = button_id.split("-")[0] + "-" + button_id.split("-")[1]
|
||||
|
||||
|
||||
service_name = service_mapping.get(button_base_id)
|
||||
if service_name:
|
||||
# Load recent logs then start following
|
||||
self.run_worker(self._show_logs(service_name))
|
||||
self._start_follow(service_name)
|
||||
|
||||
|
||||
async def _start_services(self, cpu_mode: bool = False) -> None:
|
||||
"""Start services with progress updates."""
|
||||
self.operation_in_progress = True
|
||||
|
|
@ -249,12 +255,12 @@ class MonitorScreen(Screen):
|
|||
modal = CommandOutputModal(
|
||||
"Starting Services",
|
||||
command_generator,
|
||||
on_complete=None # We'll refresh in on_screen_resume instead
|
||||
on_complete=None, # We'll refresh in on_screen_resume instead
|
||||
)
|
||||
self.app.push_screen(modal)
|
||||
finally:
|
||||
self.operation_in_progress = False
|
||||
|
||||
|
||||
async def _stop_services(self) -> None:
|
||||
"""Stop services with progress updates."""
|
||||
self.operation_in_progress = True
|
||||
|
|
@ -264,12 +270,12 @@ class MonitorScreen(Screen):
|
|||
modal = CommandOutputModal(
|
||||
"Stopping Services",
|
||||
command_generator,
|
||||
on_complete=None # We'll refresh in on_screen_resume instead
|
||||
on_complete=None, # We'll refresh in on_screen_resume instead
|
||||
)
|
||||
self.app.push_screen(modal)
|
||||
finally:
|
||||
self.operation_in_progress = False
|
||||
|
||||
|
||||
async def _restart_services(self) -> None:
|
||||
"""Restart services with progress updates."""
|
||||
self.operation_in_progress = True
|
||||
|
|
@ -279,12 +285,12 @@ class MonitorScreen(Screen):
|
|||
modal = CommandOutputModal(
|
||||
"Restarting Services",
|
||||
command_generator,
|
||||
on_complete=None # We'll refresh in on_screen_resume instead
|
||||
on_complete=None, # We'll refresh in on_screen_resume instead
|
||||
)
|
||||
self.app.push_screen(modal)
|
||||
finally:
|
||||
self.operation_in_progress = False
|
||||
|
||||
|
||||
async def _upgrade_services(self) -> None:
|
||||
"""Upgrade services with progress updates."""
|
||||
self.operation_in_progress = True
|
||||
|
|
@ -294,12 +300,12 @@ class MonitorScreen(Screen):
|
|||
modal = CommandOutputModal(
|
||||
"Upgrading Services",
|
||||
command_generator,
|
||||
on_complete=None # We'll refresh in on_screen_resume instead
|
||||
on_complete=None, # We'll refresh in on_screen_resume instead
|
||||
)
|
||||
self.app.push_screen(modal)
|
||||
finally:
|
||||
self.operation_in_progress = False
|
||||
|
||||
|
||||
async def _reset_services(self) -> None:
|
||||
"""Reset services with progress updates."""
|
||||
self.operation_in_progress = True
|
||||
|
|
@ -309,17 +315,17 @@ class MonitorScreen(Screen):
|
|||
modal = CommandOutputModal(
|
||||
"Resetting Services",
|
||||
command_generator,
|
||||
on_complete=None # We'll refresh in on_screen_resume instead
|
||||
on_complete=None, # We'll refresh in on_screen_resume instead
|
||||
)
|
||||
self.app.push_screen(modal)
|
||||
finally:
|
||||
self.operation_in_progress = False
|
||||
|
||||
|
||||
def _strip_ansi_codes(self, text: str) -> str:
|
||||
"""Strip ANSI escape sequences from text."""
|
||||
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
|
||||
return ansi_escape.sub('', text)
|
||||
|
||||
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||||
return ansi_escape.sub("", text)
|
||||
|
||||
async def _show_logs(self, service_name: str) -> None:
|
||||
"""Show logs for a service."""
|
||||
success, logs = await self.container_manager.get_service_logs(service_name)
|
||||
|
|
@ -346,7 +352,7 @@ class MonitorScreen(Screen):
|
|||
notify_with_diagnostics(
|
||||
self.app,
|
||||
f"Failed to get logs for {service_name}: {logs}",
|
||||
severity="error"
|
||||
severity="error",
|
||||
)
|
||||
|
||||
def _stop_follow(self) -> None:
|
||||
|
|
@ -391,11 +397,9 @@ class MonitorScreen(Screen):
|
|||
pass
|
||||
except Exception as e:
|
||||
notify_with_diagnostics(
|
||||
self.app,
|
||||
f"Error following logs: {e}",
|
||||
severity="error"
|
||||
self.app, f"Error following logs: {e}", severity="error"
|
||||
)
|
||||
|
||||
|
||||
def action_refresh(self) -> None:
|
||||
"""Refresh services manually."""
|
||||
self.run_worker(self._refresh_services())
|
||||
|
|
@ -431,14 +435,15 @@ class MonitorScreen(Screen):
|
|||
try:
|
||||
current = getattr(self.container_manager, "use_cpu_compose", True)
|
||||
self.container_manager.use_cpu_compose = not current
|
||||
self.notify("Switched to GPU compose" if not current else "Switched to CPU compose", severity="information")
|
||||
self.notify(
|
||||
"Switched to GPU compose" if not current else "Switched to CPU compose",
|
||||
severity="information",
|
||||
)
|
||||
self._update_mode_row()
|
||||
self.action_refresh()
|
||||
except Exception as e:
|
||||
notify_with_diagnostics(
|
||||
self.app,
|
||||
f"Failed to toggle mode: {e}",
|
||||
severity="error"
|
||||
self.app, f"Failed to toggle mode: {e}", severity="error"
|
||||
)
|
||||
|
||||
def _update_controls(self, services: list[ServiceInfo]) -> None:
|
||||
|
|
@ -446,83 +451,93 @@ class MonitorScreen(Screen):
|
|||
try:
|
||||
# Get the controls container
|
||||
controls = self.query_one("#services-controls", Horizontal)
|
||||
|
||||
|
||||
# Check if any services are running
|
||||
any_running = any(s.status == ServiceStatus.RUNNING for s in services)
|
||||
|
||||
|
||||
# Clear existing buttons by removing all children
|
||||
controls.remove_children()
|
||||
|
||||
|
||||
# Use a single ID for each button type, but make them unique with a suffix
|
||||
# This ensures we don't create duplicate IDs across refreshes
|
||||
import random
|
||||
|
||||
suffix = f"-{random.randint(10000, 99999)}"
|
||||
|
||||
|
||||
# Add appropriate buttons based on service state
|
||||
if any_running:
|
||||
# When services are running, show stop and restart
|
||||
controls.mount(Button("Stop Services", variant="error", id=f"stop-btn{suffix}"))
|
||||
controls.mount(Button("Restart", variant="primary", id=f"restart-btn{suffix}"))
|
||||
controls.mount(
|
||||
Button("Stop Services", variant="error", id=f"stop-btn{suffix}")
|
||||
)
|
||||
controls.mount(
|
||||
Button("Restart", variant="primary", id=f"restart-btn{suffix}")
|
||||
)
|
||||
else:
|
||||
# When services are not running, show start
|
||||
controls.mount(Button("Start Services", variant="success", id=f"start-btn{suffix}"))
|
||||
|
||||
controls.mount(
|
||||
Button("Start Services", variant="success", id=f"start-btn{suffix}")
|
||||
)
|
||||
|
||||
# Always show upgrade and reset buttons
|
||||
controls.mount(Button("Upgrade", variant="warning", id=f"upgrade-btn{suffix}"))
|
||||
controls.mount(
|
||||
Button("Upgrade", variant="warning", id=f"upgrade-btn{suffix}")
|
||||
)
|
||||
controls.mount(Button("Reset", variant="error", id=f"reset-btn{suffix}"))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
notify_with_diagnostics(
|
||||
self.app,
|
||||
f"Error updating controls: {e}",
|
||||
severity="error"
|
||||
self.app, f"Error updating controls: {e}", severity="error"
|
||||
)
|
||||
|
||||
|
||||
def action_back(self) -> None:
|
||||
"""Go back to previous screen."""
|
||||
self.app.pop_screen()
|
||||
|
||||
|
||||
def action_start(self) -> None:
|
||||
"""Start services."""
|
||||
self.run_worker(self._start_services())
|
||||
|
||||
|
||||
def action_stop(self) -> None:
|
||||
"""Stop services."""
|
||||
self.run_worker(self._stop_services())
|
||||
|
||||
|
||||
def action_upgrade(self) -> None:
|
||||
"""Upgrade services."""
|
||||
self.run_worker(self._upgrade_services())
|
||||
|
||||
|
||||
def action_reset(self) -> None:
|
||||
"""Reset services."""
|
||||
self.run_worker(self._reset_services())
|
||||
|
||||
|
||||
def action_logs(self) -> None:
|
||||
"""View logs for the selected service."""
|
||||
try:
|
||||
# Get the currently focused row in the services table
|
||||
table = self.query_one("#services-table", DataTable)
|
||||
|
||||
|
||||
if table.cursor_row is not None and table.cursor_row >= 0:
|
||||
# Get the service name from the first column of the selected row
|
||||
row_data = table.get_row_at(table.cursor_row)
|
||||
if row_data:
|
||||
service_name = str(row_data[0]) # First column is service name
|
||||
|
||||
|
||||
# Map display names to actual service names
|
||||
service_mapping = {
|
||||
"openrag-backend": "openrag-backend",
|
||||
"openrag-frontend": "openrag-frontend",
|
||||
"openrag-frontend": "openrag-frontend",
|
||||
"opensearch": "opensearch",
|
||||
"langflow": "langflow",
|
||||
"dashboards": "dashboards"
|
||||
"dashboards": "dashboards",
|
||||
}
|
||||
|
||||
actual_service_name = service_mapping.get(service_name, service_name)
|
||||
|
||||
|
||||
actual_service_name = service_mapping.get(
|
||||
service_name, service_name
|
||||
)
|
||||
|
||||
# Push the logs screen with the selected service
|
||||
from .logs import LogsScreen
|
||||
|
||||
logs_screen = LogsScreen(initial_service=actual_service_name)
|
||||
self.app.push_screen(logs_screen)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from ..managers.env_manager import EnvManager
|
|||
|
||||
class WelcomeScreen(Screen):
|
||||
"""Initial welcome screen with setup options."""
|
||||
|
||||
|
||||
BINDINGS = [
|
||||
("q", "quit", "Quit"),
|
||||
("enter", "default_action", "Continue"),
|
||||
|
|
@ -25,7 +25,7 @@ class WelcomeScreen(Screen):
|
|||
("3", "monitor", "Monitor Services"),
|
||||
("4", "diagnostics", "Diagnostics"),
|
||||
]
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.container_manager = ContainerManager()
|
||||
|
|
@ -34,19 +34,19 @@ class WelcomeScreen(Screen):
|
|||
self.has_oauth_config = False
|
||||
self.default_button_id = "basic-setup-btn"
|
||||
self._state_checked = False
|
||||
|
||||
|
||||
# Load .env file if it exists
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""Create the welcome screen layout."""
|
||||
yield Container(
|
||||
Vertical(
|
||||
Static(self._create_welcome_text(), id="welcome-text"),
|
||||
self._create_dynamic_buttons(),
|
||||
id="welcome-container"
|
||||
id="welcome-container",
|
||||
),
|
||||
id="main-container"
|
||||
id="main-container",
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
|
|
@ -65,55 +65,67 @@ class WelcomeScreen(Screen):
|
|||
welcome_text.append("Terminal User Interface for OpenRAG\n\n", style="dim")
|
||||
|
||||
if self.services_running:
|
||||
welcome_text.append("✓ Services are currently running\n\n", style="bold green")
|
||||
welcome_text.append(
|
||||
"✓ Services are currently running\n\n", style="bold green"
|
||||
)
|
||||
elif self.has_oauth_config:
|
||||
welcome_text.append("OAuth credentials detected — Advanced Setup recommended\n\n", style="bold green")
|
||||
welcome_text.append(
|
||||
"OAuth credentials detected — Advanced Setup recommended\n\n",
|
||||
style="bold green",
|
||||
)
|
||||
else:
|
||||
welcome_text.append("Select a setup below to continue\n\n", style="white")
|
||||
return welcome_text
|
||||
|
||||
|
||||
def _create_dynamic_buttons(self) -> Horizontal:
|
||||
"""Create buttons based on current state."""
|
||||
# Check OAuth config early to determine which buttons to show
|
||||
has_oauth = (
|
||||
bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or
|
||||
bool(os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID"))
|
||||
has_oauth = bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or bool(
|
||||
os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID")
|
||||
)
|
||||
|
||||
|
||||
buttons = []
|
||||
|
||||
|
||||
if self.services_running:
|
||||
# Services running - only show monitor
|
||||
buttons.append(Button("Monitor Services", variant="success", id="monitor-btn"))
|
||||
buttons.append(
|
||||
Button("Monitor Services", variant="success", id="monitor-btn")
|
||||
)
|
||||
else:
|
||||
# Services not running - show setup options
|
||||
if has_oauth:
|
||||
# Only show advanced setup if OAuth is configured
|
||||
buttons.append(Button("Advanced Setup", variant="success", id="advanced-setup-btn"))
|
||||
buttons.append(
|
||||
Button("Advanced Setup", variant="success", id="advanced-setup-btn")
|
||||
)
|
||||
else:
|
||||
# Only show basic setup if no OAuth
|
||||
buttons.append(Button("Basic Setup", variant="success", id="basic-setup-btn"))
|
||||
|
||||
buttons.append(
|
||||
Button("Basic Setup", variant="success", id="basic-setup-btn")
|
||||
)
|
||||
|
||||
# Always show monitor option
|
||||
buttons.append(Button("Monitor Services", variant="default", id="monitor-btn"))
|
||||
|
||||
buttons.append(
|
||||
Button("Monitor Services", variant="default", id="monitor-btn")
|
||||
)
|
||||
|
||||
return Horizontal(*buttons, classes="button-row")
|
||||
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
"""Initialize screen state when mounted."""
|
||||
# Check if services are running
|
||||
if self.container_manager.is_available():
|
||||
services = await self.container_manager.get_service_status()
|
||||
running_services = [s.name for s in services.values() if s.status == ServiceStatus.RUNNING]
|
||||
running_services = [
|
||||
s.name for s in services.values() if s.status == ServiceStatus.RUNNING
|
||||
]
|
||||
self.services_running = len(running_services) > 0
|
||||
|
||||
|
||||
|
||||
# Check for OAuth configuration
|
||||
self.has_oauth_config = (
|
||||
bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or
|
||||
bool(os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID"))
|
||||
self.has_oauth_config = bool(os.getenv("GOOGLE_OAUTH_CLIENT_ID")) or bool(
|
||||
os.getenv("MICROSOFT_GRAPH_OAUTH_CLIENT_ID")
|
||||
)
|
||||
|
||||
|
||||
# Set default button focus
|
||||
if self.services_running:
|
||||
self.default_button_id = "monitor-btn"
|
||||
|
|
@ -121,12 +133,14 @@ class WelcomeScreen(Screen):
|
|||
self.default_button_id = "advanced-setup-btn"
|
||||
else:
|
||||
self.default_button_id = "basic-setup-btn"
|
||||
|
||||
|
||||
# Update the welcome text and recompose with new state
|
||||
try:
|
||||
welcome_widget = self.query_one("#welcome-text")
|
||||
welcome_widget.update(self._create_welcome_text()) # This is fine for Static widgets
|
||||
|
||||
welcome_widget.update(
|
||||
self._create_welcome_text()
|
||||
) # This is fine for Static widgets
|
||||
|
||||
# Focus the appropriate button
|
||||
if self.services_running:
|
||||
try:
|
||||
|
|
@ -143,10 +157,10 @@ class WelcomeScreen(Screen):
|
|||
self.query_one("#basic-setup-btn").focus()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
except:
|
||||
pass # Widgets might not be mounted yet
|
||||
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""Handle button presses."""
|
||||
if event.button.id == "basic-setup-btn":
|
||||
|
|
@ -157,7 +171,7 @@ class WelcomeScreen(Screen):
|
|||
self.action_monitor()
|
||||
elif event.button.id == "diagnostics-btn":
|
||||
self.action_diagnostics()
|
||||
|
||||
|
||||
def action_default_action(self) -> None:
|
||||
"""Handle Enter key - go to default action based on state."""
|
||||
if self.services_running:
|
||||
|
|
@ -166,27 +180,31 @@ class WelcomeScreen(Screen):
|
|||
self.action_full_setup()
|
||||
else:
|
||||
self.action_no_auth_setup()
|
||||
|
||||
|
||||
def action_no_auth_setup(self) -> None:
|
||||
"""Switch to basic configuration screen."""
|
||||
from .config import ConfigScreen
|
||||
|
||||
self.app.push_screen(ConfigScreen(mode="no_auth"))
|
||||
|
||||
|
||||
def action_full_setup(self) -> None:
|
||||
"""Switch to advanced configuration screen."""
|
||||
from .config import ConfigScreen
|
||||
|
||||
self.app.push_screen(ConfigScreen(mode="full"))
|
||||
|
||||
|
||||
def action_monitor(self) -> None:
|
||||
"""Switch to monitoring screen."""
|
||||
from .monitor import MonitorScreen
|
||||
|
||||
self.app.push_screen(MonitorScreen())
|
||||
|
||||
|
||||
def action_diagnostics(self) -> None:
|
||||
"""Switch to diagnostics screen."""
|
||||
from .diagnostics import DiagnosticsScreen
|
||||
|
||||
self.app.push_screen(DiagnosticsScreen())
|
||||
|
||||
|
||||
def action_quit(self) -> None:
|
||||
"""Quit the application."""
|
||||
self.app.exit()
|
||||
self.app.exit()
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
"""TUI utilities package."""
|
||||
"""TUI utilities package."""
|
||||
|
|
|
|||
|
|
@ -34,40 +34,66 @@ class PlatformDetector:
|
|||
"""Detect available container runtime and compose capabilities."""
|
||||
# First check if we have podman installed
|
||||
podman_version = self._get_podman_version()
|
||||
|
||||
|
||||
# If we have podman, check if docker is actually podman in disguise
|
||||
if podman_version:
|
||||
docker_version = self._get_docker_version()
|
||||
if docker_version and podman_version in docker_version:
|
||||
# This is podman masquerading as docker
|
||||
if self._check_command(["docker", "compose", "--help"]):
|
||||
return RuntimeInfo(RuntimeType.PODMAN, ["docker", "compose"], ["docker"], podman_version)
|
||||
return RuntimeInfo(
|
||||
RuntimeType.PODMAN,
|
||||
["docker", "compose"],
|
||||
["docker"],
|
||||
podman_version,
|
||||
)
|
||||
if self._check_command(["docker-compose", "--help"]):
|
||||
return RuntimeInfo(RuntimeType.PODMAN, ["docker-compose"], ["docker"], podman_version)
|
||||
|
||||
return RuntimeInfo(
|
||||
RuntimeType.PODMAN,
|
||||
["docker-compose"],
|
||||
["docker"],
|
||||
podman_version,
|
||||
)
|
||||
|
||||
# Check for native podman compose
|
||||
if self._check_command(["podman", "compose", "--help"]):
|
||||
return RuntimeInfo(RuntimeType.PODMAN, ["podman", "compose"], ["podman"], podman_version)
|
||||
|
||||
return RuntimeInfo(
|
||||
RuntimeType.PODMAN,
|
||||
["podman", "compose"],
|
||||
["podman"],
|
||||
podman_version,
|
||||
)
|
||||
|
||||
# Check for actual docker
|
||||
if self._check_command(["docker", "compose", "--help"]):
|
||||
version = self._get_docker_version()
|
||||
return RuntimeInfo(RuntimeType.DOCKER, ["docker", "compose"], ["docker"], version)
|
||||
return RuntimeInfo(
|
||||
RuntimeType.DOCKER, ["docker", "compose"], ["docker"], version
|
||||
)
|
||||
if self._check_command(["docker-compose", "--help"]):
|
||||
version = self._get_docker_version()
|
||||
return RuntimeInfo(RuntimeType.DOCKER_COMPOSE, ["docker-compose"], ["docker"], version)
|
||||
|
||||
return RuntimeInfo(
|
||||
RuntimeType.DOCKER_COMPOSE, ["docker-compose"], ["docker"], version
|
||||
)
|
||||
|
||||
return RuntimeInfo(RuntimeType.NONE, [], [])
|
||||
|
||||
def detect_gpu_available(self) -> bool:
|
||||
"""Best-effort detection of NVIDIA GPU availability for containers."""
|
||||
try:
|
||||
res = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True, timeout=5)
|
||||
if res.returncode == 0 and any("GPU" in ln for ln in res.stdout.splitlines()):
|
||||
res = subprocess.run(
|
||||
["nvidia-smi", "-L"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if res.returncode == 0 and any(
|
||||
"GPU" in ln for ln in res.stdout.splitlines()
|
||||
):
|
||||
return True
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
for cmd in (["docker", "info", "--format", "{{json .Runtimes}}"], ["podman", "info", "--format", "json"]):
|
||||
for cmd in (
|
||||
["docker", "info", "--format", "{{json .Runtimes}}"],
|
||||
["podman", "info", "--format", "json"],
|
||||
):
|
||||
try:
|
||||
res = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
|
||||
if res.returncode == 0 and "nvidia" in res.stdout.lower():
|
||||
|
|
@ -85,7 +111,9 @@ class PlatformDetector:
|
|||
|
||||
def _get_docker_version(self) -> Optional[str]:
|
||||
try:
|
||||
res = subprocess.run(["docker", "--version"], capture_output=True, text=True, timeout=5)
|
||||
res = subprocess.run(
|
||||
["docker", "--version"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if res.returncode == 0:
|
||||
return res.stdout.strip()
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
|
|
@ -94,7 +122,9 @@ class PlatformDetector:
|
|||
|
||||
def _get_podman_version(self) -> Optional[str]:
|
||||
try:
|
||||
res = subprocess.run(["podman", "--version"], capture_output=True, text=True, timeout=5)
|
||||
res = subprocess.run(
|
||||
["podman", "--version"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if res.returncode == 0:
|
||||
return res.stdout.strip()
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
|
|
@ -110,7 +140,12 @@ class PlatformDetector:
|
|||
if self.platform_system != "Darwin":
|
||||
return True, 0, "Not running on macOS"
|
||||
try:
|
||||
result = subprocess.run(["podman", "machine", "inspect"], capture_output=True, text=True, timeout=10)
|
||||
result = subprocess.run(
|
||||
["podman", "machine", "inspect"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return False, 0, "Could not inspect Podman machine"
|
||||
machines = json.loads(result.stdout)
|
||||
|
|
@ -124,7 +159,11 @@ class PlatformDetector:
|
|||
if not is_sufficient:
|
||||
status += "\nTo increase: podman machine stop && podman machine rm && podman machine init --memory 8192 && podman machine start"
|
||||
return is_sufficient, memory_mb, status
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, json.JSONDecodeError) as e:
|
||||
except (
|
||||
subprocess.TimeoutExpired,
|
||||
FileNotFoundError,
|
||||
json.JSONDecodeError,
|
||||
) as e:
|
||||
return False, 0, f"Error checking Podman VM memory: {e}"
|
||||
|
||||
def get_installation_instructions(self) -> str:
|
||||
|
|
@ -167,4 +206,4 @@ Or Podman Desktop:
|
|||
No container runtime found. Please install Docker or Podman for your platform:
|
||||
- Docker: https://docs.docker.com/get-docker/
|
||||
- Podman: https://podman.io/getting-started/installation
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,28 +8,31 @@ from typing import Optional
|
|||
|
||||
class ValidationError(Exception):
|
||||
"""Validation error exception."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def validate_env_var_name(name: str) -> bool:
|
||||
"""Validate environment variable name format."""
|
||||
return bool(re.match(r'^[A-Z][A-Z0-9_]*$', name))
|
||||
return bool(re.match(r"^[A-Z][A-Z0-9_]*$", name))
|
||||
|
||||
|
||||
def validate_path(path: str, must_exist: bool = False, must_be_dir: bool = False) -> bool:
|
||||
def validate_path(
|
||||
path: str, must_exist: bool = False, must_be_dir: bool = False
|
||||
) -> bool:
|
||||
"""Validate file/directory path."""
|
||||
if not path:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
path_obj = Path(path).expanduser().resolve()
|
||||
|
||||
|
||||
if must_exist and not path_obj.exists():
|
||||
return False
|
||||
|
||||
|
||||
if must_be_dir and path_obj.exists() and not path_obj.is_dir():
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
except (OSError, ValueError):
|
||||
return False
|
||||
|
|
@ -39,15 +42,17 @@ def validate_url(url: str) -> bool:
|
|||
"""Validate URL format."""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
|
||||
url_pattern = re.compile(
|
||||
r'^https?://' # http:// or https://
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain
|
||||
r'localhost|' # localhost
|
||||
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # IP
|
||||
r'(?::\d+)?' # optional port
|
||||
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
|
||||
|
||||
r"^https?://" # http:// or https://
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|" # domain
|
||||
r"localhost|" # localhost
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/?|[/?]\S+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
return bool(url_pattern.match(url))
|
||||
|
||||
|
||||
|
|
@ -55,14 +60,14 @@ def validate_openai_api_key(key: str) -> bool:
|
|||
"""Validate OpenAI API key format."""
|
||||
if not key:
|
||||
return False
|
||||
return key.startswith('sk-') and len(key) > 20
|
||||
return key.startswith("sk-") and len(key) > 20
|
||||
|
||||
|
||||
def validate_google_oauth_client_id(client_id: str) -> bool:
|
||||
"""Validate Google OAuth client ID format."""
|
||||
if not client_id:
|
||||
return False
|
||||
return client_id.endswith('.apps.googleusercontent.com')
|
||||
return client_id.endswith(".apps.googleusercontent.com")
|
||||
|
||||
|
||||
def validate_non_empty(value: str) -> bool:
|
||||
|
|
@ -74,37 +79,38 @@ def sanitize_env_value(value: str) -> str:
|
|||
"""Sanitize environment variable value."""
|
||||
# Remove leading/trailing whitespace
|
||||
value = value.strip()
|
||||
|
||||
|
||||
# Remove quotes if they wrap the entire value
|
||||
if len(value) >= 2:
|
||||
if (value.startswith('"') and value.endswith('"')) or \
|
||||
(value.startswith("'") and value.endswith("'")):
|
||||
if (value.startswith('"') and value.endswith('"')) or (
|
||||
value.startswith("'") and value.endswith("'")
|
||||
):
|
||||
value = value[1:-1]
|
||||
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
|
||||
"""
|
||||
Validate comma-separated documents paths for volume mounting.
|
||||
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message, validated_paths)
|
||||
"""
|
||||
if not paths_str:
|
||||
return False, "Documents paths cannot be empty", []
|
||||
|
||||
paths = [path.strip() for path in paths_str.split(',') if path.strip()]
|
||||
|
||||
|
||||
paths = [path.strip() for path in paths_str.split(",") if path.strip()]
|
||||
|
||||
if not paths:
|
||||
return False, "No valid paths provided", []
|
||||
|
||||
|
||||
validated_paths = []
|
||||
|
||||
|
||||
for path in paths:
|
||||
try:
|
||||
path_obj = Path(path).expanduser().resolve()
|
||||
|
||||
|
||||
# Check if path exists
|
||||
if not path_obj.exists():
|
||||
# Try to create it
|
||||
|
|
@ -112,11 +118,11 @@ def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
|
|||
path_obj.mkdir(parents=True, exist_ok=True)
|
||||
except (OSError, PermissionError) as e:
|
||||
return False, f"Cannot create directory '{path}': {e}", []
|
||||
|
||||
|
||||
# Check if it's a directory
|
||||
if not path_obj.is_dir():
|
||||
return False, f"Path '{path}' must be a directory", []
|
||||
|
||||
|
||||
# Check if we can write to it
|
||||
try:
|
||||
test_file = path_obj / ".openrag_test"
|
||||
|
|
@ -124,10 +130,10 @@ def validate_documents_paths(paths_str: str) -> tuple[bool, str, list[str]]:
|
|||
test_file.unlink()
|
||||
except (OSError, PermissionError):
|
||||
return False, f"Directory '{path}' is not writable", []
|
||||
|
||||
|
||||
validated_paths.append(str(path_obj))
|
||||
|
||||
|
||||
except (OSError, ValueError) as e:
|
||||
return False, f"Invalid path '{path}': {e}", []
|
||||
|
||||
return True, "All paths valid", validated_paths
|
||||
|
||||
return True, "All paths valid", validated_paths
|
||||
|
|
|
|||
|
|
@ -65,13 +65,13 @@ class CommandOutputModal(ModalScreen):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str,
|
||||
self,
|
||||
title: str,
|
||||
command_generator: AsyncIterator[tuple[bool, str]],
|
||||
on_complete: Optional[Callable] = None
|
||||
on_complete: Optional[Callable] = None,
|
||||
):
|
||||
"""Initialize the modal dialog.
|
||||
|
||||
|
||||
Args:
|
||||
title: Title of the modal dialog
|
||||
command_generator: Async generator that yields (is_complete, message) tuples
|
||||
|
|
@ -104,29 +104,32 @@ class CommandOutputModal(ModalScreen):
|
|||
async def _run_command(self) -> None:
|
||||
"""Run the command and update the output in real-time."""
|
||||
output = self.query_one("#command-output", RichLog)
|
||||
|
||||
|
||||
try:
|
||||
async for is_complete, message in self.command_generator:
|
||||
# Simple approach: just append each line as it comes
|
||||
output.write(message + "\n")
|
||||
|
||||
|
||||
# Scroll to bottom
|
||||
container = self.query_one("#output-container", ScrollableContainer)
|
||||
container.scroll_end(animate=False)
|
||||
|
||||
|
||||
# If command is complete, update UI
|
||||
if is_complete:
|
||||
output.write("[bold green]Command completed successfully[/bold green]\n")
|
||||
output.write(
|
||||
"[bold green]Command completed successfully[/bold green]\n"
|
||||
)
|
||||
# Call the completion callback if provided
|
||||
if self.on_complete:
|
||||
await asyncio.sleep(0.5) # Small delay for better UX
|
||||
self.on_complete()
|
||||
except Exception as e:
|
||||
output.write(f"[bold red]Error: {e}[/bold red]\n")
|
||||
|
||||
|
||||
# Enable the close button and focus it
|
||||
close_btn = self.query_one("#close-btn", Button)
|
||||
close_btn.disabled = False
|
||||
close_btn.focus()
|
||||
|
||||
|
||||
# Made with Bob
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ def notify_with_diagnostics(
|
|||
app: App,
|
||||
message: str,
|
||||
severity: Literal["information", "warning", "error"] = "error",
|
||||
timeout: float = 10.0
|
||||
timeout: float = 10.0,
|
||||
) -> None:
|
||||
"""Show a notification with a button to open the diagnostics screen.
|
||||
|
||||
|
||||
Args:
|
||||
app: The Textual app
|
||||
message: The notification message
|
||||
|
|
@ -21,18 +21,20 @@ def notify_with_diagnostics(
|
|||
"""
|
||||
# First show the notification
|
||||
app.notify(message, severity=severity, timeout=timeout)
|
||||
|
||||
|
||||
# Then add a button to open diagnostics screen
|
||||
def open_diagnostics() -> None:
|
||||
from ..screens.diagnostics import DiagnosticsScreen
|
||||
|
||||
app.push_screen(DiagnosticsScreen())
|
||||
|
||||
|
||||
# Add a separate notification with just the button
|
||||
app.notify(
|
||||
"Click to view diagnostics",
|
||||
severity="information",
|
||||
timeout=timeout,
|
||||
title="Diagnostics"
|
||||
title="Diagnostics",
|
||||
)
|
||||
|
||||
# Made with Bob
|
||||
|
||||
# Made with Bob
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ def notify_with_diagnostics(
|
|||
app: App,
|
||||
message: str,
|
||||
severity: Literal["information", "warning", "error"] = "error",
|
||||
timeout: float = 10.0
|
||||
timeout: float = 10.0,
|
||||
) -> None:
|
||||
"""Show a notification with a button to open the diagnostics screen.
|
||||
|
||||
|
||||
Args:
|
||||
app: The Textual app
|
||||
message: The notification message
|
||||
|
|
@ -21,18 +21,20 @@ def notify_with_diagnostics(
|
|||
"""
|
||||
# First show the notification
|
||||
app.notify(message, severity=severity, timeout=timeout)
|
||||
|
||||
|
||||
# Then add a button to open diagnostics screen
|
||||
def open_diagnostics() -> None:
|
||||
from ..screens.diagnostics import DiagnosticsScreen
|
||||
|
||||
app.push_screen(DiagnosticsScreen())
|
||||
|
||||
|
||||
# Add a separate notification with just the button
|
||||
app.notify(
|
||||
"Click to view diagnostics",
|
||||
severity="information",
|
||||
timeout=timeout,
|
||||
title="Diagnostics"
|
||||
title="Diagnostics",
|
||||
)
|
||||
|
||||
|
||||
# Made with Bob
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
from .gpu_detection import detect_gpu_devices
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Global converter cache for worker processes
|
||||
_worker_converter = None
|
||||
|
|
@ -37,11 +42,11 @@ def get_worker_converter():
|
|||
"1" # Still disable progress bars
|
||||
)
|
||||
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Initializing DocumentConverter in worker process"
|
||||
logger.info(
|
||||
"Initializing DocumentConverter in worker process", worker_pid=os.getpid()
|
||||
)
|
||||
_worker_converter = DocumentConverter()
|
||||
print(f"[WORKER {os.getpid()}] DocumentConverter ready in worker process")
|
||||
logger.info("DocumentConverter ready in worker process", worker_pid=os.getpid())
|
||||
|
||||
return _worker_converter
|
||||
|
||||
|
|
@ -118,33 +123,45 @@ def process_document_sync(file_path: str):
|
|||
start_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
|
||||
try:
|
||||
print(f"[WORKER {os.getpid()}] Starting document processing: {file_path}")
|
||||
print(f"[WORKER {os.getpid()}] Initial memory usage: {start_memory:.1f} MB")
|
||||
logger.info(
|
||||
"Starting document processing",
|
||||
worker_pid=os.getpid(),
|
||||
file_path=file_path,
|
||||
initial_memory_mb=f"{start_memory:.1f}",
|
||||
)
|
||||
|
||||
# Check file size
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) / 1024 / 1024 # MB
|
||||
print(f"[WORKER {os.getpid()}] File size: {file_size:.1f} MB")
|
||||
logger.info(
|
||||
"File size determined",
|
||||
worker_pid=os.getpid(),
|
||||
file_size_mb=f"{file_size:.1f}",
|
||||
)
|
||||
except OSError as e:
|
||||
print(f"[WORKER {os.getpid()}] WARNING: Cannot get file size: {e}")
|
||||
logger.warning("Cannot get file size", worker_pid=os.getpid(), error=str(e))
|
||||
file_size = 0
|
||||
|
||||
# Get the cached converter for this worker
|
||||
try:
|
||||
print(f"[WORKER {os.getpid()}] Getting document converter...")
|
||||
logger.info("Getting document converter", worker_pid=os.getpid())
|
||||
converter = get_worker_converter()
|
||||
memory_after_converter = process.memory_info().rss / 1024 / 1024
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Memory after converter init: {memory_after_converter:.1f} MB"
|
||||
logger.info(
|
||||
"Memory after converter init",
|
||||
worker_pid=os.getpid(),
|
||||
memory_mb=f"{memory_after_converter:.1f}",
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[WORKER {os.getpid()}] ERROR: Failed to initialize converter: {e}")
|
||||
logger.error(
|
||||
"Failed to initialize converter", worker_pid=os.getpid(), error=str(e)
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
# Compute file hash
|
||||
try:
|
||||
print(f"[WORKER {os.getpid()}] Computing file hash...")
|
||||
logger.info("Computing file hash", worker_pid=os.getpid())
|
||||
sha256 = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
while True:
|
||||
|
|
@ -153,50 +170,67 @@ def process_document_sync(file_path: str):
|
|||
break
|
||||
sha256.update(chunk)
|
||||
file_hash = sha256.hexdigest()
|
||||
print(f"[WORKER {os.getpid()}] File hash computed: {file_hash[:12]}...")
|
||||
logger.info(
|
||||
"File hash computed",
|
||||
worker_pid=os.getpid(),
|
||||
file_hash_prefix=file_hash[:12],
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[WORKER {os.getpid()}] ERROR: Failed to compute file hash: {e}")
|
||||
logger.error(
|
||||
"Failed to compute file hash", worker_pid=os.getpid(), error=str(e)
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
# Convert with docling
|
||||
try:
|
||||
print(f"[WORKER {os.getpid()}] Starting docling conversion...")
|
||||
logger.info("Starting docling conversion", worker_pid=os.getpid())
|
||||
memory_before_convert = process.memory_info().rss / 1024 / 1024
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Memory before conversion: {memory_before_convert:.1f} MB"
|
||||
logger.info(
|
||||
"Memory before conversion",
|
||||
worker_pid=os.getpid(),
|
||||
memory_mb=f"{memory_before_convert:.1f}",
|
||||
)
|
||||
|
||||
result = converter.convert(file_path)
|
||||
|
||||
memory_after_convert = process.memory_info().rss / 1024 / 1024
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Memory after conversion: {memory_after_convert:.1f} MB"
|
||||
logger.info(
|
||||
"Memory after conversion",
|
||||
worker_pid=os.getpid(),
|
||||
memory_mb=f"{memory_after_convert:.1f}",
|
||||
)
|
||||
print(f"[WORKER {os.getpid()}] Docling conversion completed")
|
||||
logger.info("Docling conversion completed", worker_pid=os.getpid())
|
||||
|
||||
full_doc = result.document.export_to_dict()
|
||||
memory_after_export = process.memory_info().rss / 1024 / 1024
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Memory after export: {memory_after_export:.1f} MB"
|
||||
logger.info(
|
||||
"Memory after export",
|
||||
worker_pid=os.getpid(),
|
||||
memory_mb=f"{memory_after_export:.1f}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] ERROR: Failed during docling conversion: {e}"
|
||||
)
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Current memory usage: {process.memory_info().rss / 1024 / 1024:.1f} MB"
|
||||
current_memory = process.memory_info().rss / 1024 / 1024
|
||||
logger.error(
|
||||
"Failed during docling conversion",
|
||||
worker_pid=os.getpid(),
|
||||
error=str(e),
|
||||
current_memory_mb=f"{current_memory:.1f}",
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
# Extract relevant content (same logic as extract_relevant)
|
||||
try:
|
||||
print(f"[WORKER {os.getpid()}] Extracting relevant content...")
|
||||
logger.info("Extracting relevant content", worker_pid=os.getpid())
|
||||
origin = full_doc.get("origin", {})
|
||||
texts = full_doc.get("texts", [])
|
||||
print(f"[WORKER {os.getpid()}] Found {len(texts)} text fragments")
|
||||
logger.info(
|
||||
"Found text fragments",
|
||||
worker_pid=os.getpid(),
|
||||
fragment_count=len(texts),
|
||||
)
|
||||
|
||||
page_texts = defaultdict(list)
|
||||
for txt in texts:
|
||||
|
|
@ -210,22 +244,27 @@ def process_document_sync(file_path: str):
|
|||
joined = "\n".join(page_texts[page])
|
||||
chunks.append({"page": page, "text": joined})
|
||||
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Created {len(chunks)} chunks from {len(page_texts)} pages"
|
||||
logger.info(
|
||||
"Created chunks from pages",
|
||||
worker_pid=os.getpid(),
|
||||
chunk_count=len(chunks),
|
||||
page_count=len(page_texts),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] ERROR: Failed during content extraction: {e}"
|
||||
logger.error(
|
||||
"Failed during content extraction", worker_pid=os.getpid(), error=str(e)
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
final_memory = process.memory_info().rss / 1024 / 1024
|
||||
memory_delta = final_memory - start_memory
|
||||
print(f"[WORKER {os.getpid()}] Document processing completed successfully")
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Final memory: {final_memory:.1f} MB (Delta +{memory_delta:.1f} MB)"
|
||||
logger.info(
|
||||
"Document processing completed successfully",
|
||||
worker_pid=os.getpid(),
|
||||
final_memory_mb=f"{final_memory:.1f}",
|
||||
memory_delta_mb=f"{memory_delta:.1f}",
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -239,24 +278,29 @@ def process_document_sync(file_path: str):
|
|||
except Exception as e:
|
||||
final_memory = process.memory_info().rss / 1024 / 1024
|
||||
memory_delta = final_memory - start_memory
|
||||
print(f"[WORKER {os.getpid()}] FATAL ERROR in process_document_sync")
|
||||
print(f"[WORKER {os.getpid()}] File: {file_path}")
|
||||
print(f"[WORKER {os.getpid()}] Python version: {sys.version}")
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] Memory at crash: {final_memory:.1f} MB (Delta +{memory_delta:.1f} MB)"
|
||||
logger.error(
|
||||
"FATAL ERROR in process_document_sync",
|
||||
worker_pid=os.getpid(),
|
||||
file_path=file_path,
|
||||
python_version=sys.version,
|
||||
memory_at_crash_mb=f"{final_memory:.1f}",
|
||||
memory_delta_mb=f"{memory_delta:.1f}",
|
||||
error_type=type(e).__name__,
|
||||
error=str(e),
|
||||
)
|
||||
print(f"[WORKER {os.getpid()}] Error: {type(e).__name__}: {e}")
|
||||
print(f"[WORKER {os.getpid()}] Full traceback:")
|
||||
logger.error("Full traceback:", worker_pid=os.getpid())
|
||||
traceback.print_exc()
|
||||
|
||||
# Try to get more system info before crashing
|
||||
try:
|
||||
import platform
|
||||
|
||||
print(
|
||||
f"[WORKER {os.getpid()}] System: {platform.system()} {platform.release()}"
|
||||
logger.error(
|
||||
"System info",
|
||||
worker_pid=os.getpid(),
|
||||
system=f"{platform.system()} {platform.release()}",
|
||||
architecture=platform.machine(),
|
||||
)
|
||||
print(f"[WORKER {os.getpid()}] Architecture: {platform.machine()}")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
import multiprocessing
|
||||
import os
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def detect_gpu_devices():
|
||||
|
|
@ -30,13 +33,15 @@ def get_worker_count():
|
|||
|
||||
if has_gpu_devices:
|
||||
default_workers = min(4, multiprocessing.cpu_count() // 2)
|
||||
print(
|
||||
f"GPU mode enabled with {gpu_count} GPU(s) - using limited concurrency ({default_workers} workers)"
|
||||
logger.info(
|
||||
"GPU mode enabled with limited concurrency",
|
||||
gpu_count=gpu_count,
|
||||
worker_count=default_workers,
|
||||
)
|
||||
else:
|
||||
default_workers = multiprocessing.cpu_count()
|
||||
print(
|
||||
f"CPU-only mode enabled - using full concurrency ({default_workers} workers)"
|
||||
logger.info(
|
||||
"CPU-only mode enabled with full concurrency", worker_count=default_workers
|
||||
)
|
||||
|
||||
return int(os.getenv("MAX_WORKERS", default_workers))
|
||||
|
|
|
|||
|
|
@ -9,13 +9,15 @@ def configure_logging(
|
|||
log_level: str = "INFO",
|
||||
json_logs: bool = False,
|
||||
include_timestamps: bool = True,
|
||||
service_name: str = "openrag"
|
||||
service_name: str = "openrag",
|
||||
) -> None:
|
||||
"""Configure structlog for the application."""
|
||||
|
||||
|
||||
# Convert string log level to actual level
|
||||
level = getattr(structlog.stdlib.logging, log_level.upper(), structlog.stdlib.logging.INFO)
|
||||
|
||||
level = getattr(
|
||||
structlog.stdlib.logging, log_level.upper(), structlog.stdlib.logging.INFO
|
||||
)
|
||||
|
||||
# Base processors
|
||||
shared_processors = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
|
|
@ -23,29 +25,65 @@ def configure_logging(
|
|||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.dev.set_exc_info,
|
||||
]
|
||||
|
||||
|
||||
if include_timestamps:
|
||||
shared_processors.append(structlog.processors.TimeStamper(fmt="iso"))
|
||||
|
||||
# Add service name to all logs
|
||||
|
||||
# Add service name and file location to all logs
|
||||
shared_processors.append(
|
||||
structlog.processors.CallsiteParameterAdder(
|
||||
parameters=[structlog.processors.CallsiteParameter.FUNC_NAME]
|
||||
parameters=[
|
||||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||||
structlog.processors.CallsiteParameter.FILENAME,
|
||||
structlog.processors.CallsiteParameter.LINENO,
|
||||
structlog.processors.CallsiteParameter.PATHNAME,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Console output configuration
|
||||
if json_logs or os.getenv("LOG_FORMAT", "").lower() == "json":
|
||||
# JSON output for production/containers
|
||||
shared_processors.append(structlog.processors.JSONRenderer())
|
||||
console_renderer = structlog.processors.JSONRenderer()
|
||||
else:
|
||||
# Pretty colored output for development
|
||||
console_renderer = structlog.dev.ConsoleRenderer(
|
||||
colors=sys.stderr.isatty(),
|
||||
exception_formatter=structlog.dev.plain_traceback,
|
||||
)
|
||||
|
||||
# Custom clean format: timestamp path/file:loc logentry
|
||||
def custom_formatter(logger, log_method, event_dict):
|
||||
timestamp = event_dict.pop("timestamp", "")
|
||||
pathname = event_dict.pop("pathname", "")
|
||||
filename = event_dict.pop("filename", "")
|
||||
lineno = event_dict.pop("lineno", "")
|
||||
level = event_dict.pop("level", "")
|
||||
|
||||
# Build file location - prefer pathname for full path, fallback to filename
|
||||
if pathname and lineno:
|
||||
location = f"{pathname}:{lineno}"
|
||||
elif filename and lineno:
|
||||
location = f"{filename}:{lineno}"
|
||||
elif pathname:
|
||||
location = pathname
|
||||
elif filename:
|
||||
location = filename
|
||||
else:
|
||||
location = "unknown"
|
||||
|
||||
# Build the main message
|
||||
message_parts = []
|
||||
event = event_dict.pop("event", "")
|
||||
if event:
|
||||
message_parts.append(event)
|
||||
|
||||
# Add any remaining context
|
||||
for key, value in event_dict.items():
|
||||
if key not in ["service", "func_name"]: # Skip internal fields
|
||||
message_parts.append(f"{key}={value}")
|
||||
|
||||
message = " ".join(message_parts)
|
||||
|
||||
return f"{timestamp} {location} {message}"
|
||||
|
||||
console_renderer = custom_formatter
|
||||
|
||||
# Configure structlog
|
||||
structlog.configure(
|
||||
processors=shared_processors + [console_renderer],
|
||||
|
|
@ -54,7 +92,7 @@ def configure_logging(
|
|||
logger_factory=structlog.WriteLoggerFactory(sys.stderr),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
|
||||
# Add global context
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(service=service_name)
|
||||
|
|
@ -73,9 +111,7 @@ def configure_from_env() -> None:
|
|||
log_level = os.getenv("LOG_LEVEL", "INFO")
|
||||
json_logs = os.getenv("LOG_FORMAT", "").lower() == "json"
|
||||
service_name = os.getenv("SERVICE_NAME", "openrag")
|
||||
|
||||
|
||||
configure_logging(
|
||||
log_level=log_level,
|
||||
json_logs=json_logs,
|
||||
service_name=service_name
|
||||
)
|
||||
log_level=log_level, json_logs=json_logs, service_name=service_name
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from utils.gpu_detection import get_worker_count
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Create shared process pool at import time (before CUDA initialization)
|
||||
# This avoids the "Cannot re-initialize CUDA in forked subprocess" error
|
||||
MAX_WORKERS = get_worker_count()
|
||||
process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
print(f"Shared process pool initialized with {MAX_WORKERS} workers")
|
||||
logger.info("Shared process pool initialized", max_workers=MAX_WORKERS)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
from docling.document_converter import DocumentConverter
|
||||
import logging
|
||||
|
||||
print("Warming up docling models...")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info("Warming up docling models")
|
||||
|
||||
try:
|
||||
# Use the sample document to warm up docling
|
||||
test_file = "/app/warmup_ocr.pdf"
|
||||
print(f"Using {test_file} to warm up docling...")
|
||||
logger.info(f"Using test file to warm up docling: {test_file}")
|
||||
DocumentConverter().convert(test_file)
|
||||
print("Docling models warmed up successfully")
|
||||
logger.info("Docling models warmed up successfully")
|
||||
except Exception as e:
|
||||
print(f"Docling warm-up completed with: {e}")
|
||||
logger.info(f"Docling warm-up completed with exception: {str(e)}")
|
||||
# This is expected - we just want to trigger the model downloads
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue