From 7b635df9d0aa7da91bf675845ecaf68c3c8402bd Mon Sep 17 00:00:00 2001 From: Lucas Oliveira <62335616+lucaseduoli@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:59:10 -0300 Subject: [PATCH] fix: added better onboarding error handling, added probing api keys and models (#326) * Added error showing to onboarding card * Added error state on animated provider steps * removed toast on error * Fixed animation on onboarding card * fixed animation time * Implemented provider validation * Added provider validation before ingestion * Changed error border * remove log --------- Co-authored-by: Mike Fortman --- frontend/src/app/chat/page.tsx | 2 - .../components/animated-provider-steps.tsx | 24 +- .../onboarding/components/onboarding-card.tsx | 319 +++++++------ src/api/provider_validation.py | 437 ++++++++++++++++++ src/api/settings.py | 26 +- 5 files changed, 664 insertions(+), 144 deletions(-) create mode 100644 src/api/provider_validation.py diff --git a/frontend/src/app/chat/page.tsx b/frontend/src/app/chat/page.tsx index f34b1efb..5b7a33f1 100644 --- a/frontend/src/app/chat/page.tsx +++ b/frontend/src/app/chat/page.tsx @@ -1154,8 +1154,6 @@ function ChatPage() { } }; - console.log(messages) - return ( <> {/* Debug header - only show in debug mode */} diff --git a/frontend/src/app/onboarding/components/animated-provider-steps.tsx b/frontend/src/app/onboarding/components/animated-provider-steps.tsx index fecf2c72..3092cc0d 100644 --- a/frontend/src/app/onboarding/components/animated-provider-steps.tsx +++ b/frontend/src/app/onboarding/components/animated-provider-steps.tsx @@ -1,7 +1,7 @@ "use client"; import { AnimatePresence, motion } from "framer-motion"; -import { CheckIcon } from "lucide-react"; +import { CheckIcon, XIcon } from "lucide-react"; import { useEffect, useState } from "react"; import { @@ -20,6 +20,7 @@ export function AnimatedProviderSteps({ steps, storageKey = "provider-steps", processingStartTime, + hasError = false, }: { currentStep: number; isCompleted: boolean; @@ -27,6 +28,7 @@ export function AnimatedProviderSteps({ steps: string[]; storageKey?: string; processingStartTime?: number | null; + hasError?: boolean; }) { const [startTime, setStartTime] = useState(null); const [elapsedTime, setElapsedTime] = useState(0); @@ -63,7 +65,7 @@ export function AnimatedProviderSteps({ } }, [isCompleted, startTime, storageKey]); - const isDone = currentStep >= steps.length && !isCompleted; + const isDone = currentStep >= steps.length && !isCompleted && !hasError; return ( @@ -79,8 +81,8 @@ export function AnimatedProviderSteps({
+
- - {isDone ? "Done" : "Thinking"} + + {hasError ? "Error" : isDone ? "Done" : "Thinking"}
- {!isDone && ( + {!isDone && !hasError && ( ({ @@ -86,6 +88,8 @@ const OnboardingCard = ({ null, ); + const [error, setError] = useState(null); + // Query tasks to track completion const { data: tasks } = useGetTasksQuery({ enabled: currentStep !== null, // Only poll when onboarding has started @@ -126,11 +130,15 @@ const OnboardingCard = ({ onSuccess: (data) => { console.log("Onboarding completed successfully", data); setCurrentStep(0); + setError(null); }, onError: (error) => { - toast.error("Failed to complete onboarding", { - description: error.message, - }); + setError(error.message); + setCurrentStep(TOTAL_PROVIDER_STEPS); + // Reset to provider selection after 1 second + setTimeout(() => { + setCurrentStep(null); + }, 1000); }, }); @@ -144,6 +152,9 @@ const OnboardingCard = ({ return; } + // Clear any previous error + setError(null); + // Prepare onboarding data const onboardingData: OnboardingVariables = { model_provider: settings.model_provider, @@ -181,139 +192,179 @@ const OnboardingCard = ({ {currentStep === null ? ( -
- - - - -
- -
- OpenAI -
-
- - -
- -
- IBM watsonx.ai -
-
- - -
- -
- Ollama -
-
-
- - - - - - - - - -
- - - -
- -
-
- {!isComplete && ( - - {isLoadingModels - ? "Loading models..." - : !!settings.llm_model && - !!settings.embedding_model && - !isDoclingHealthy - ? "docling-serve must be running to continue" - : "Please fill in all required fields"} - +
+ + {error && ( + +
+ + + {error} + +
+
)} - +
+
+ + + + +
+ +
+ OpenAI +
+
+ + +
+ +
+ IBM watsonx.ai +
+
+ + +
+ +
+ Ollama +
+
+
+ + + + + + + + + +
+ + + +
+ +
+
+ {!isComplete && ( + + {isLoadingModels + ? "Loading models..." + : !!settings.llm_model && + !!settings.embedding_model && + !isDoclingHealthy + ? "docling-serve must be running to continue" + : "Please fill in all required fields"} + + )} +
+
) : ( @@ -321,6 +372,7 @@ const OnboardingCard = ({ key="provider-steps" initial={{ opacity: 0, y: 24 }} animate={{ opacity: 1, y: 0 }} + exit={{ opacity: 0, y: 24 }} transition={{ duration: 0.4, ease: "easeInOut" }} > )} diff --git a/src/api/provider_validation.py b/src/api/provider_validation.py new file mode 100644 index 00000000..7814d45f --- /dev/null +++ b/src/api/provider_validation.py @@ -0,0 +1,437 @@ +"""Provider validation utilities for testing API keys and models during onboarding.""" + +import httpx +from utils.container_utils import transform_localhost_url +from utils.logging_config import get_logger + +logger = get_logger(__name__) + + +async def validate_provider_setup( + provider: str, + api_key: str = None, + embedding_model: str = None, + llm_model: str = None, + endpoint: str = None, + project_id: str = None, +) -> None: + """ + Validate provider setup by testing completion with tool calling and embedding. + + Args: + provider: Provider name ('openai', 'watsonx', 'ollama') + api_key: API key for the provider (optional for ollama) + embedding_model: Embedding model to test + llm_model: LLM model to test + endpoint: Provider endpoint (required for ollama and watsonx) + project_id: Project ID (required for watsonx) + + Raises: + Exception: If validation fails with message "Setup failed, please try again or select a different provider." + """ + provider_lower = provider.lower() + + try: + logger.info(f"Starting validation for provider: {provider_lower}") + + # Test completion with tool calling + await test_completion_with_tools( + provider=provider_lower, + api_key=api_key, + llm_model=llm_model, + endpoint=endpoint, + project_id=project_id, + ) + + # Test embedding + await test_embedding( + provider=provider_lower, + api_key=api_key, + embedding_model=embedding_model, + endpoint=endpoint, + project_id=project_id, + ) + + logger.info(f"Validation successful for provider: {provider_lower}") + + except Exception as e: + logger.error(f"Validation failed for provider {provider_lower}: {str(e)}") + raise Exception("Setup failed, please try again or select a different provider.") + + +async def test_completion_with_tools( + provider: str, + api_key: str = None, + llm_model: str = None, + endpoint: str = None, + project_id: str = None, +) -> None: + """Test completion with tool calling for the provider.""" + + if provider == "openai": + await _test_openai_completion_with_tools(api_key, llm_model) + elif provider == "watsonx": + await _test_watsonx_completion_with_tools(api_key, llm_model, endpoint, project_id) + elif provider == "ollama": + await _test_ollama_completion_with_tools(llm_model, endpoint) + else: + raise ValueError(f"Unknown provider: {provider}") + + +async def test_embedding( + provider: str, + api_key: str = None, + embedding_model: str = None, + endpoint: str = None, + project_id: str = None, +) -> None: + """Test embedding generation for the provider.""" + + if provider == "openai": + await _test_openai_embedding(api_key, embedding_model) + elif provider == "watsonx": + await _test_watsonx_embedding(api_key, embedding_model, endpoint, project_id) + elif provider == "ollama": + await _test_ollama_embedding(embedding_model, endpoint) + else: + raise ValueError(f"Unknown provider: {provider}") + + +# OpenAI validation functions +async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None: + """Test OpenAI completion with tool calling.""" + try: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + # Simple tool calling test + payload = { + "model": llm_model, + "messages": [ + {"role": "user", "content": "What's the weather like?"} + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + } + } + } + ], + "max_tokens": 50, + } + + async with httpx.AsyncClient() as client: + response = await client.post( + "https://api.openai.com/v1/chat/completions", + headers=headers, + json=payload, + timeout=30.0, + ) + + if response.status_code != 200: + logger.error(f"OpenAI completion test failed: {response.status_code} - {response.text}") + raise Exception(f"OpenAI API error: {response.status_code}") + + logger.info("OpenAI completion with tool calling test passed") + + except httpx.TimeoutException: + logger.error("OpenAI completion test timed out") + raise Exception("Request timed out") + except Exception as e: + logger.error(f"OpenAI completion test failed: {str(e)}") + raise + + +async def _test_openai_embedding(api_key: str, embedding_model: str) -> None: + """Test OpenAI embedding generation.""" + try: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + payload = { + "model": embedding_model, + "input": "test embedding", + } + + async with httpx.AsyncClient() as client: + response = await client.post( + "https://api.openai.com/v1/embeddings", + headers=headers, + json=payload, + timeout=30.0, + ) + + if response.status_code != 200: + logger.error(f"OpenAI embedding test failed: {response.status_code} - {response.text}") + raise Exception(f"OpenAI API error: {response.status_code}") + + data = response.json() + if not data.get("data") or len(data["data"]) == 0: + raise Exception("No embedding data returned") + + logger.info("OpenAI embedding test passed") + + except httpx.TimeoutException: + logger.error("OpenAI embedding test timed out") + raise Exception("Request timed out") + except Exception as e: + logger.error(f"OpenAI embedding test failed: {str(e)}") + raise + + +# IBM Watson validation functions +async def _test_watsonx_completion_with_tools( + api_key: str, llm_model: str, endpoint: str, project_id: str +) -> None: + """Test IBM Watson completion with tool calling.""" + try: + # Get bearer token from IBM IAM + async with httpx.AsyncClient() as client: + token_response = await client.post( + "https://iam.cloud.ibm.com/identity/token", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "urn:ibm:params:oauth:grant-type:apikey", + "apikey": api_key, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + logger.error(f"IBM IAM token request failed: {token_response.status_code}") + raise Exception("Failed to authenticate with IBM Watson") + + bearer_token = token_response.json().get("access_token") + if not bearer_token: + raise Exception("No access token received from IBM") + + headers = { + "Authorization": f"Bearer {bearer_token}", + "Content-Type": "application/json", + } + + # Test completion with tools + url = f"{endpoint}/ml/v1/text/chat" + params = {"version": "2024-09-16"} + payload = { + "model_id": llm_model, + "project_id": project_id, + "messages": [ + {"role": "user", "content": "What's the weather like?"} + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + } + } + } + ], + "max_tokens": 50, + } + + async with httpx.AsyncClient() as client: + response = await client.post( + url, + headers=headers, + params=params, + json=payload, + timeout=30.0, + ) + + if response.status_code != 200: + logger.error(f"IBM Watson completion test failed: {response.status_code} - {response.text}") + raise Exception(f"IBM Watson API error: {response.status_code}") + + logger.info("IBM Watson completion with tool calling test passed") + + except httpx.TimeoutException: + logger.error("IBM Watson completion test timed out") + raise Exception("Request timed out") + except Exception as e: + logger.error(f"IBM Watson completion test failed: {str(e)}") + raise + + +async def _test_watsonx_embedding( + api_key: str, embedding_model: str, endpoint: str, project_id: str +) -> None: + """Test IBM Watson embedding generation.""" + try: + # Get bearer token from IBM IAM + async with httpx.AsyncClient() as client: + token_response = await client.post( + "https://iam.cloud.ibm.com/identity/token", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "urn:ibm:params:oauth:grant-type:apikey", + "apikey": api_key, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + logger.error(f"IBM IAM token request failed: {token_response.status_code}") + raise Exception("Failed to authenticate with IBM Watson") + + bearer_token = token_response.json().get("access_token") + if not bearer_token: + raise Exception("No access token received from IBM") + + headers = { + "Authorization": f"Bearer {bearer_token}", + "Content-Type": "application/json", + } + + # Test embedding + url = f"{endpoint}/ml/v1/text/embeddings" + params = {"version": "2024-09-16"} + payload = { + "model_id": embedding_model, + "project_id": project_id, + "inputs": ["test embedding"], + } + + async with httpx.AsyncClient() as client: + response = await client.post( + url, + headers=headers, + params=params, + json=payload, + timeout=30.0, + ) + + if response.status_code != 200: + logger.error(f"IBM Watson embedding test failed: {response.status_code} - {response.text}") + raise Exception(f"IBM Watson API error: {response.status_code}") + + data = response.json() + if not data.get("results") or len(data["results"]) == 0: + raise Exception("No embedding data returned") + + logger.info("IBM Watson embedding test passed") + + except httpx.TimeoutException: + logger.error("IBM Watson embedding test timed out") + raise Exception("Request timed out") + except Exception as e: + logger.error(f"IBM Watson embedding test failed: {str(e)}") + raise + + +# Ollama validation functions +async def _test_ollama_completion_with_tools(llm_model: str, endpoint: str) -> None: + """Test Ollama completion with tool calling.""" + try: + ollama_url = transform_localhost_url(endpoint) + url = f"{ollama_url}/api/chat" + + payload = { + "model": llm_model, + "messages": [ + {"role": "user", "content": "What's the weather like?"} + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state" + } + }, + "required": ["location"] + } + } + } + ], + "stream": False, + } + + async with httpx.AsyncClient() as client: + response = await client.post( + url, + json=payload, + timeout=30.0, + ) + + if response.status_code != 200: + logger.error(f"Ollama completion test failed: {response.status_code} - {response.text}") + raise Exception(f"Ollama API error: {response.status_code}") + + logger.info("Ollama completion with tool calling test passed") + + except httpx.TimeoutException: + logger.error("Ollama completion test timed out") + raise Exception("Request timed out") + except Exception as e: + logger.error(f"Ollama completion test failed: {str(e)}") + raise + + +async def _test_ollama_embedding(embedding_model: str, endpoint: str) -> None: + """Test Ollama embedding generation.""" + try: + ollama_url = transform_localhost_url(endpoint) + url = f"{ollama_url}/api/embeddings" + + payload = { + "model": embedding_model, + "prompt": "test embedding", + } + + async with httpx.AsyncClient() as client: + response = await client.post( + url, + json=payload, + timeout=30.0, + ) + + if response.status_code != 200: + logger.error(f"Ollama embedding test failed: {response.status_code} - {response.text}") + raise Exception(f"Ollama API error: {response.status_code}") + + data = response.json() + if not data.get("embedding"): + raise Exception("No embedding data returned") + + logger.info("Ollama embedding test passed") + + except httpx.TimeoutException: + logger.error("Ollama embedding test timed out") + raise Exception("Request timed out") + except Exception as e: + logger.error(f"Ollama embedding test failed: {str(e)}") + raise diff --git a/src/api/settings.py b/src/api/settings.py index 5fc30cf5..398bc97b 100644 --- a/src/api/settings.py +++ b/src/api/settings.py @@ -1,5 +1,6 @@ import json import platform +import time from starlette.responses import JSONResponse from utils.container_utils import transform_localhost_url from utils.logging_config import get_logger @@ -533,6 +534,29 @@ async def onboarding(request, flows_service): {"error": "No valid fields provided for update"}, status_code=400 ) + # Validate provider setup before initializing OpenSearch index + try: + from api.provider_validation import validate_provider_setup + + provider = current_config.provider.model_provider.lower() if current_config.provider.model_provider else "openai" + + logger.info(f"Validating provider setup for {provider}") + await validate_provider_setup( + provider=provider, + api_key=current_config.provider.api_key, + embedding_model=current_config.knowledge.embedding_model, + llm_model=current_config.agent.llm_model, + endpoint=current_config.provider.endpoint, + project_id=current_config.provider.project_id, + ) + logger.info(f"Provider setup validation completed successfully for {provider}") + except Exception as e: + logger.error(f"Provider validation failed: {str(e)}") + return JSONResponse( + {"error": str(e)}, + status_code=400, + ) + # Initialize the OpenSearch index now that we have the embedding model configured try: # Import here to avoid circular imports @@ -694,7 +718,7 @@ async def onboarding(request, flows_service): except Exception as e: logger.error("Failed to update onboarding settings", error=str(e)) return JSONResponse( - {"error": f"Failed to update onboarding settings: {str(e)}"}, + {"error": str(e)}, status_code=500, )