fix: added better onboarding error handling, added probing api keys and models (#326)
* Added error showing to onboarding card * Added error state on animated provider steps * removed toast on error * Fixed animation on onboarding card * fixed animation time * Implemented provider validation * Added provider validation before ingestion * Changed error border * remove log --------- Co-authored-by: Mike Fortman <michael.fortman@datastax.com>
This commit is contained in:
parent
3296a60b3b
commit
7b635df9d0
5 changed files with 664 additions and 144 deletions
|
|
@ -1154,8 +1154,6 @@ function ChatPage() {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
console.log(messages)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{/* Debug header - only show in debug mode */}
|
{/* Debug header - only show in debug mode */}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { AnimatePresence, motion } from "framer-motion";
|
import { AnimatePresence, motion } from "framer-motion";
|
||||||
import { CheckIcon } from "lucide-react";
|
import { CheckIcon, XIcon } from "lucide-react";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
|
@ -20,6 +20,7 @@ export function AnimatedProviderSteps({
|
||||||
steps,
|
steps,
|
||||||
storageKey = "provider-steps",
|
storageKey = "provider-steps",
|
||||||
processingStartTime,
|
processingStartTime,
|
||||||
|
hasError = false,
|
||||||
}: {
|
}: {
|
||||||
currentStep: number;
|
currentStep: number;
|
||||||
isCompleted: boolean;
|
isCompleted: boolean;
|
||||||
|
|
@ -27,6 +28,7 @@ export function AnimatedProviderSteps({
|
||||||
steps: string[];
|
steps: string[];
|
||||||
storageKey?: string;
|
storageKey?: string;
|
||||||
processingStartTime?: number | null;
|
processingStartTime?: number | null;
|
||||||
|
hasError?: boolean;
|
||||||
}) {
|
}) {
|
||||||
const [startTime, setStartTime] = useState<number | null>(null);
|
const [startTime, setStartTime] = useState<number | null>(null);
|
||||||
const [elapsedTime, setElapsedTime] = useState<number>(0);
|
const [elapsedTime, setElapsedTime] = useState<number>(0);
|
||||||
|
|
@ -63,7 +65,7 @@ export function AnimatedProviderSteps({
|
||||||
}
|
}
|
||||||
}, [isCompleted, startTime, storageKey]);
|
}, [isCompleted, startTime, storageKey]);
|
||||||
|
|
||||||
const isDone = currentStep >= steps.length && !isCompleted;
|
const isDone = currentStep >= steps.length && !isCompleted && !hasError;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AnimatePresence mode="wait">
|
<AnimatePresence mode="wait">
|
||||||
|
|
@ -79,8 +81,8 @@ export function AnimatedProviderSteps({
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"transition-all duration-150 relative",
|
"transition-all duration-300 relative",
|
||||||
isDone ? "w-3.5 h-3.5" : "w-6 h-6",
|
isDone || hasError ? "w-3.5 h-3.5" : "w-6 h-6",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<CheckIcon
|
<CheckIcon
|
||||||
|
|
@ -89,21 +91,27 @@ export function AnimatedProviderSteps({
|
||||||
isDone ? "opacity-100" : "opacity-0",
|
isDone ? "opacity-100" : "opacity-0",
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
|
<XIcon
|
||||||
|
className={cn(
|
||||||
|
"text-accent-red-foreground shrink-0 w-3.5 h-3.5 absolute inset-0 transition-all duration-150",
|
||||||
|
hasError ? "opacity-100" : "opacity-0",
|
||||||
|
)}
|
||||||
|
/>
|
||||||
<AnimatedProcessingIcon
|
<AnimatedProcessingIcon
|
||||||
className={cn(
|
className={cn(
|
||||||
"text-current shrink-0 absolute inset-0 transition-all duration-150",
|
"text-current shrink-0 absolute inset-0 transition-all duration-150",
|
||||||
isDone ? "opacity-0" : "opacity-100",
|
isDone || hasError ? "opacity-0" : "opacity-100",
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<span className="text-mmd font-medium text-muted-foreground">
|
<span className="!text-mmd font-medium text-muted-foreground">
|
||||||
{isDone ? "Done" : "Thinking"}
|
{hasError ? "Error" : isDone ? "Done" : "Thinking"}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div className="overflow-hidden">
|
<div className="overflow-hidden">
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{!isDone && (
|
{!isDone && !hasError && (
|
||||||
<motion.div
|
<motion.div
|
||||||
initial={{ opacity: 1, y: 0, height: "auto" }}
|
initial={{ opacity: 1, y: 0, height: "auto" }}
|
||||||
exit={{ opacity: 0, y: -24, height: 0 }}
|
exit={{ opacity: 0, y: -24, height: 0 }}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { AnimatePresence, motion } from "framer-motion";
|
import { AnimatePresence, motion } from "framer-motion";
|
||||||
|
import { X } from "lucide-react";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import {
|
import {
|
||||||
|
|
@ -70,6 +71,7 @@ const OnboardingCard = ({
|
||||||
embedding_model: "",
|
embedding_model: "",
|
||||||
llm_model: "",
|
llm_model: "",
|
||||||
});
|
});
|
||||||
|
setError(null);
|
||||||
};
|
};
|
||||||
|
|
||||||
const [settings, setSettings] = useState<OnboardingVariables>({
|
const [settings, setSettings] = useState<OnboardingVariables>({
|
||||||
|
|
@ -86,6 +88,8 @@ const OnboardingCard = ({
|
||||||
null,
|
null,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
|
||||||
// Query tasks to track completion
|
// Query tasks to track completion
|
||||||
const { data: tasks } = useGetTasksQuery({
|
const { data: tasks } = useGetTasksQuery({
|
||||||
enabled: currentStep !== null, // Only poll when onboarding has started
|
enabled: currentStep !== null, // Only poll when onboarding has started
|
||||||
|
|
@ -126,11 +130,15 @@ const OnboardingCard = ({
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
console.log("Onboarding completed successfully", data);
|
console.log("Onboarding completed successfully", data);
|
||||||
setCurrentStep(0);
|
setCurrentStep(0);
|
||||||
|
setError(null);
|
||||||
},
|
},
|
||||||
onError: (error) => {
|
onError: (error) => {
|
||||||
toast.error("Failed to complete onboarding", {
|
setError(error.message);
|
||||||
description: error.message,
|
setCurrentStep(TOTAL_PROVIDER_STEPS);
|
||||||
});
|
// Reset to provider selection after 1 second
|
||||||
|
setTimeout(() => {
|
||||||
|
setCurrentStep(null);
|
||||||
|
}, 1000);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -144,6 +152,9 @@ const OnboardingCard = ({
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear any previous error
|
||||||
|
setError(null);
|
||||||
|
|
||||||
// Prepare onboarding data
|
// Prepare onboarding data
|
||||||
const onboardingData: OnboardingVariables = {
|
const onboardingData: OnboardingVariables = {
|
||||||
model_provider: settings.model_provider,
|
model_provider: settings.model_provider,
|
||||||
|
|
@ -181,139 +192,179 @@ const OnboardingCard = ({
|
||||||
{currentStep === null ? (
|
{currentStep === null ? (
|
||||||
<motion.div
|
<motion.div
|
||||||
key="onboarding-form"
|
key="onboarding-form"
|
||||||
initial={{ opacity: 1, y: 0 }}
|
initial={{ opacity: 0, y: -24 }}
|
||||||
exit={{ opacity: 0, y: -24 }}
|
animate={{ opacity: 1, y: 0 }}
|
||||||
|
exit={{ opacity: 0, y: 24 }}
|
||||||
transition={{ duration: 0.4, ease: "easeInOut" }}
|
transition={{ duration: 0.4, ease: "easeInOut" }}
|
||||||
>
|
>
|
||||||
<div className={`w-full max-w-[600px] flex flex-col gap-6`}>
|
<div className={`w-full max-w-[600px] flex flex-col`}>
|
||||||
<Tabs
|
<AnimatePresence mode="wait">
|
||||||
defaultValue={modelProvider}
|
{error && (
|
||||||
onValueChange={handleSetModelProvider}
|
<motion.div
|
||||||
>
|
key="error"
|
||||||
<TabsList className="mb-4">
|
initial={{ opacity: 1, y: 0, height: "auto" }}
|
||||||
<TabsTrigger value="openai">
|
exit={{ opacity: 0, y: -10, height: 0 }}
|
||||||
<TabTrigger
|
>
|
||||||
selected={modelProvider === "openai"}
|
<div className="pb-6 flex items-center gap-4">
|
||||||
isLoading={isLoadingModels}
|
<X className="w-4 h-4 text-destructive shrink-0" />
|
||||||
>
|
<span className="text-mmd text-muted-foreground">
|
||||||
<div
|
{error}
|
||||||
className={cn(
|
</span>
|
||||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
</div>
|
||||||
modelProvider === "openai" ? "bg-white" : "bg-muted",
|
</motion.div>
|
||||||
)}
|
|
||||||
>
|
|
||||||
<OpenAILogo
|
|
||||||
className={cn(
|
|
||||||
"w-4 h-4 shrink-0",
|
|
||||||
modelProvider === "openai"
|
|
||||||
? "text-black"
|
|
||||||
: "text-muted-foreground",
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
OpenAI
|
|
||||||
</TabTrigger>
|
|
||||||
</TabsTrigger>
|
|
||||||
<TabsTrigger value="watsonx">
|
|
||||||
<TabTrigger
|
|
||||||
selected={modelProvider === "watsonx"}
|
|
||||||
isLoading={isLoadingModels}
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
|
||||||
modelProvider === "watsonx"
|
|
||||||
? "bg-[#1063FE]"
|
|
||||||
: "bg-muted",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<IBMLogo
|
|
||||||
className={cn(
|
|
||||||
"w-4 h-4 shrink-0",
|
|
||||||
modelProvider === "watsonx"
|
|
||||||
? "text-white"
|
|
||||||
: "text-muted-foreground",
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
IBM watsonx.ai
|
|
||||||
</TabTrigger>
|
|
||||||
</TabsTrigger>
|
|
||||||
<TabsTrigger value="ollama">
|
|
||||||
<TabTrigger
|
|
||||||
selected={modelProvider === "ollama"}
|
|
||||||
isLoading={isLoadingModels}
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
|
||||||
modelProvider === "ollama" ? "bg-white" : "bg-muted",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<OllamaLogo
|
|
||||||
className={cn(
|
|
||||||
"w-4 h-4 shrink-0",
|
|
||||||
modelProvider === "ollama"
|
|
||||||
? "text-black"
|
|
||||||
: "text-muted-foreground",
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
Ollama
|
|
||||||
</TabTrigger>
|
|
||||||
</TabsTrigger>
|
|
||||||
</TabsList>
|
|
||||||
<TabsContent value="openai">
|
|
||||||
<OpenAIOnboarding
|
|
||||||
setSettings={setSettings}
|
|
||||||
sampleDataset={sampleDataset}
|
|
||||||
setSampleDataset={setSampleDataset}
|
|
||||||
setIsLoadingModels={setIsLoadingModels}
|
|
||||||
/>
|
|
||||||
</TabsContent>
|
|
||||||
<TabsContent value="watsonx">
|
|
||||||
<IBMOnboarding
|
|
||||||
setSettings={setSettings}
|
|
||||||
sampleDataset={sampleDataset}
|
|
||||||
setSampleDataset={setSampleDataset}
|
|
||||||
setIsLoadingModels={setIsLoadingModels}
|
|
||||||
/>
|
|
||||||
</TabsContent>
|
|
||||||
<TabsContent value="ollama">
|
|
||||||
<OllamaOnboarding
|
|
||||||
setSettings={setSettings}
|
|
||||||
sampleDataset={sampleDataset}
|
|
||||||
setSampleDataset={setSampleDataset}
|
|
||||||
setIsLoadingModels={setIsLoadingModels}
|
|
||||||
/>
|
|
||||||
</TabsContent>
|
|
||||||
</Tabs>
|
|
||||||
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<div>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
onClick={handleComplete}
|
|
||||||
disabled={!isComplete || isLoadingModels}
|
|
||||||
loading={onboardingMutation.isPending}
|
|
||||||
>
|
|
||||||
<span className="select-none">Complete</span>
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</TooltipTrigger>
|
|
||||||
{!isComplete && (
|
|
||||||
<TooltipContent>
|
|
||||||
{isLoadingModels
|
|
||||||
? "Loading models..."
|
|
||||||
: !!settings.llm_model &&
|
|
||||||
!!settings.embedding_model &&
|
|
||||||
!isDoclingHealthy
|
|
||||||
? "docling-serve must be running to continue"
|
|
||||||
: "Please fill in all required fields"}
|
|
||||||
</TooltipContent>
|
|
||||||
)}
|
)}
|
||||||
</Tooltip>
|
</AnimatePresence>
|
||||||
|
<div className={`w-full flex flex-col gap-6`}>
|
||||||
|
<Tabs
|
||||||
|
defaultValue={modelProvider}
|
||||||
|
onValueChange={handleSetModelProvider}
|
||||||
|
>
|
||||||
|
<TabsList className="mb-4">
|
||||||
|
<TabsTrigger
|
||||||
|
value="openai"
|
||||||
|
className={cn(
|
||||||
|
error &&
|
||||||
|
modelProvider === "openai" &&
|
||||||
|
"data-[state=active]:border-destructive",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<TabTrigger
|
||||||
|
selected={modelProvider === "openai"}
|
||||||
|
isLoading={isLoadingModels}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||||
|
modelProvider === "openai" ? "bg-white" : "bg-muted",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<OpenAILogo
|
||||||
|
className={cn(
|
||||||
|
"w-4 h-4 shrink-0",
|
||||||
|
modelProvider === "openai"
|
||||||
|
? "text-black"
|
||||||
|
: "text-muted-foreground",
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
OpenAI
|
||||||
|
</TabTrigger>
|
||||||
|
</TabsTrigger>
|
||||||
|
<TabsTrigger
|
||||||
|
value="watsonx"
|
||||||
|
className={cn(
|
||||||
|
error &&
|
||||||
|
modelProvider === "watsonx" &&
|
||||||
|
"data-[state=active]:border-destructive",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<TabTrigger
|
||||||
|
selected={modelProvider === "watsonx"}
|
||||||
|
isLoading={isLoadingModels}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||||
|
modelProvider === "watsonx"
|
||||||
|
? "bg-[#1063FE]"
|
||||||
|
: "bg-muted",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<IBMLogo
|
||||||
|
className={cn(
|
||||||
|
"w-4 h-4 shrink-0",
|
||||||
|
modelProvider === "watsonx"
|
||||||
|
? "text-white"
|
||||||
|
: "text-muted-foreground",
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
IBM watsonx.ai
|
||||||
|
</TabTrigger>
|
||||||
|
</TabsTrigger>
|
||||||
|
<TabsTrigger
|
||||||
|
value="ollama"
|
||||||
|
className={cn(
|
||||||
|
error &&
|
||||||
|
modelProvider === "ollama" &&
|
||||||
|
"data-[state=active]:border-destructive",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<TabTrigger
|
||||||
|
selected={modelProvider === "ollama"}
|
||||||
|
isLoading={isLoadingModels}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||||
|
modelProvider === "ollama" ? "bg-white" : "bg-muted",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<OllamaLogo
|
||||||
|
className={cn(
|
||||||
|
"w-4 h-4 shrink-0",
|
||||||
|
modelProvider === "ollama"
|
||||||
|
? "text-black"
|
||||||
|
: "text-muted-foreground",
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
Ollama
|
||||||
|
</TabTrigger>
|
||||||
|
</TabsTrigger>
|
||||||
|
</TabsList>
|
||||||
|
<TabsContent value="openai">
|
||||||
|
<OpenAIOnboarding
|
||||||
|
setSettings={setSettings}
|
||||||
|
sampleDataset={sampleDataset}
|
||||||
|
setSampleDataset={setSampleDataset}
|
||||||
|
setIsLoadingModels={setIsLoadingModels}
|
||||||
|
/>
|
||||||
|
</TabsContent>
|
||||||
|
<TabsContent value="watsonx">
|
||||||
|
<IBMOnboarding
|
||||||
|
setSettings={setSettings}
|
||||||
|
sampleDataset={sampleDataset}
|
||||||
|
setSampleDataset={setSampleDataset}
|
||||||
|
setIsLoadingModels={setIsLoadingModels}
|
||||||
|
/>
|
||||||
|
</TabsContent>
|
||||||
|
<TabsContent value="ollama">
|
||||||
|
<OllamaOnboarding
|
||||||
|
setSettings={setSettings}
|
||||||
|
sampleDataset={sampleDataset}
|
||||||
|
setSampleDataset={setSampleDataset}
|
||||||
|
setIsLoadingModels={setIsLoadingModels}
|
||||||
|
/>
|
||||||
|
</TabsContent>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<div>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
onClick={handleComplete}
|
||||||
|
disabled={!isComplete || isLoadingModels}
|
||||||
|
loading={onboardingMutation.isPending}
|
||||||
|
>
|
||||||
|
<span className="select-none">Complete</span>
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</TooltipTrigger>
|
||||||
|
{!isComplete && (
|
||||||
|
<TooltipContent>
|
||||||
|
{isLoadingModels
|
||||||
|
? "Loading models..."
|
||||||
|
: !!settings.llm_model &&
|
||||||
|
!!settings.embedding_model &&
|
||||||
|
!isDoclingHealthy
|
||||||
|
? "docling-serve must be running to continue"
|
||||||
|
: "Please fill in all required fields"}
|
||||||
|
</TooltipContent>
|
||||||
|
)}
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</motion.div>
|
</motion.div>
|
||||||
) : (
|
) : (
|
||||||
|
|
@ -321,6 +372,7 @@ const OnboardingCard = ({
|
||||||
key="provider-steps"
|
key="provider-steps"
|
||||||
initial={{ opacity: 0, y: 24 }}
|
initial={{ opacity: 0, y: 24 }}
|
||||||
animate={{ opacity: 1, y: 0 }}
|
animate={{ opacity: 1, y: 0 }}
|
||||||
|
exit={{ opacity: 0, y: 24 }}
|
||||||
transition={{ duration: 0.4, ease: "easeInOut" }}
|
transition={{ duration: 0.4, ease: "easeInOut" }}
|
||||||
>
|
>
|
||||||
<AnimatedProviderSteps
|
<AnimatedProviderSteps
|
||||||
|
|
@ -329,6 +381,7 @@ const OnboardingCard = ({
|
||||||
setCurrentStep={setCurrentStep}
|
setCurrentStep={setCurrentStep}
|
||||||
steps={STEP_LIST}
|
steps={STEP_LIST}
|
||||||
processingStartTime={processingStartTime}
|
processingStartTime={processingStartTime}
|
||||||
|
hasError={!!error}
|
||||||
/>
|
/>
|
||||||
</motion.div>
|
</motion.div>
|
||||||
)}
|
)}
|
||||||
|
|
|
||||||
437
src/api/provider_validation.py
Normal file
437
src/api/provider_validation.py
Normal file
|
|
@ -0,0 +1,437 @@
|
||||||
|
"""Provider validation utilities for testing API keys and models during onboarding."""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from utils.container_utils import transform_localhost_url
|
||||||
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_provider_setup(
|
||||||
|
provider: str,
|
||||||
|
api_key: str = None,
|
||||||
|
embedding_model: str = None,
|
||||||
|
llm_model: str = None,
|
||||||
|
endpoint: str = None,
|
||||||
|
project_id: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider setup by testing completion with tool calling and embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name ('openai', 'watsonx', 'ollama')
|
||||||
|
api_key: API key for the provider (optional for ollama)
|
||||||
|
embedding_model: Embedding model to test
|
||||||
|
llm_model: LLM model to test
|
||||||
|
endpoint: Provider endpoint (required for ollama and watsonx)
|
||||||
|
project_id: Project ID (required for watsonx)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If validation fails with message "Setup failed, please try again or select a different provider."
|
||||||
|
"""
|
||||||
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting validation for provider: {provider_lower}")
|
||||||
|
|
||||||
|
# Test completion with tool calling
|
||||||
|
await test_completion_with_tools(
|
||||||
|
provider=provider_lower,
|
||||||
|
api_key=api_key,
|
||||||
|
llm_model=llm_model,
|
||||||
|
endpoint=endpoint,
|
||||||
|
project_id=project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test embedding
|
||||||
|
await test_embedding(
|
||||||
|
provider=provider_lower,
|
||||||
|
api_key=api_key,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
endpoint=endpoint,
|
||||||
|
project_id=project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Validation successful for provider: {provider_lower}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Validation failed for provider {provider_lower}: {str(e)}")
|
||||||
|
raise Exception("Setup failed, please try again or select a different provider.")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_completion_with_tools(
|
||||||
|
provider: str,
|
||||||
|
api_key: str = None,
|
||||||
|
llm_model: str = None,
|
||||||
|
endpoint: str = None,
|
||||||
|
project_id: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""Test completion with tool calling for the provider."""
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
await _test_openai_completion_with_tools(api_key, llm_model)
|
||||||
|
elif provider == "watsonx":
|
||||||
|
await _test_watsonx_completion_with_tools(api_key, llm_model, endpoint, project_id)
|
||||||
|
elif provider == "ollama":
|
||||||
|
await _test_ollama_completion_with_tools(llm_model, endpoint)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_embedding(
|
||||||
|
provider: str,
|
||||||
|
api_key: str = None,
|
||||||
|
embedding_model: str = None,
|
||||||
|
endpoint: str = None,
|
||||||
|
project_id: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""Test embedding generation for the provider."""
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
await _test_openai_embedding(api_key, embedding_model)
|
||||||
|
elif provider == "watsonx":
|
||||||
|
await _test_watsonx_embedding(api_key, embedding_model, endpoint, project_id)
|
||||||
|
elif provider == "ollama":
|
||||||
|
await _test_ollama_embedding(embedding_model, endpoint)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
|
# OpenAI validation functions
|
||||||
|
async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None:
|
||||||
|
"""Test OpenAI completion with tool calling."""
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Simple tool calling test
|
||||||
|
payload = {
|
||||||
|
"model": llm_model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like?"}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
"https://api.openai.com/v1/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"OpenAI completion test failed: {response.status_code} - {response.text}")
|
||||||
|
raise Exception(f"OpenAI API error: {response.status_code}")
|
||||||
|
|
||||||
|
logger.info("OpenAI completion with tool calling test passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("OpenAI completion test timed out")
|
||||||
|
raise Exception("Request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI completion test failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_openai_embedding(api_key: str, embedding_model: str) -> None:
|
||||||
|
"""Test OpenAI embedding generation."""
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": embedding_model,
|
||||||
|
"input": "test embedding",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
"https://api.openai.com/v1/embeddings",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"OpenAI embedding test failed: {response.status_code} - {response.text}")
|
||||||
|
raise Exception(f"OpenAI API error: {response.status_code}")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if not data.get("data") or len(data["data"]) == 0:
|
||||||
|
raise Exception("No embedding data returned")
|
||||||
|
|
||||||
|
logger.info("OpenAI embedding test passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("OpenAI embedding test timed out")
|
||||||
|
raise Exception("Request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI embedding test failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# IBM Watson validation functions
|
||||||
|
async def _test_watsonx_completion_with_tools(
|
||||||
|
api_key: str, llm_model: str, endpoint: str, project_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Test IBM Watson completion with tool calling."""
|
||||||
|
try:
|
||||||
|
# Get bearer token from IBM IAM
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_response = await client.post(
|
||||||
|
"https://iam.cloud.ibm.com/identity/token",
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data={
|
||||||
|
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||||
|
"apikey": api_key,
|
||||||
|
},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if token_response.status_code != 200:
|
||||||
|
logger.error(f"IBM IAM token request failed: {token_response.status_code}")
|
||||||
|
raise Exception("Failed to authenticate with IBM Watson")
|
||||||
|
|
||||||
|
bearer_token = token_response.json().get("access_token")
|
||||||
|
if not bearer_token:
|
||||||
|
raise Exception("No access token received from IBM")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {bearer_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test completion with tools
|
||||||
|
url = f"{endpoint}/ml/v1/text/chat"
|
||||||
|
params = {"version": "2024-09-16"}
|
||||||
|
payload = {
|
||||||
|
"model_id": llm_model,
|
||||||
|
"project_id": project_id,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like?"}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
params=params,
|
||||||
|
json=payload,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"IBM Watson completion test failed: {response.status_code} - {response.text}")
|
||||||
|
raise Exception(f"IBM Watson API error: {response.status_code}")
|
||||||
|
|
||||||
|
logger.info("IBM Watson completion with tool calling test passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("IBM Watson completion test timed out")
|
||||||
|
raise Exception("Request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"IBM Watson completion test failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_watsonx_embedding(
|
||||||
|
api_key: str, embedding_model: str, endpoint: str, project_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Test IBM Watson embedding generation."""
|
||||||
|
try:
|
||||||
|
# Get bearer token from IBM IAM
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_response = await client.post(
|
||||||
|
"https://iam.cloud.ibm.com/identity/token",
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data={
|
||||||
|
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||||
|
"apikey": api_key,
|
||||||
|
},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if token_response.status_code != 200:
|
||||||
|
logger.error(f"IBM IAM token request failed: {token_response.status_code}")
|
||||||
|
raise Exception("Failed to authenticate with IBM Watson")
|
||||||
|
|
||||||
|
bearer_token = token_response.json().get("access_token")
|
||||||
|
if not bearer_token:
|
||||||
|
raise Exception("No access token received from IBM")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {bearer_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test embedding
|
||||||
|
url = f"{endpoint}/ml/v1/text/embeddings"
|
||||||
|
params = {"version": "2024-09-16"}
|
||||||
|
payload = {
|
||||||
|
"model_id": embedding_model,
|
||||||
|
"project_id": project_id,
|
||||||
|
"inputs": ["test embedding"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
params=params,
|
||||||
|
json=payload,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"IBM Watson embedding test failed: {response.status_code} - {response.text}")
|
||||||
|
raise Exception(f"IBM Watson API error: {response.status_code}")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if not data.get("results") or len(data["results"]) == 0:
|
||||||
|
raise Exception("No embedding data returned")
|
||||||
|
|
||||||
|
logger.info("IBM Watson embedding test passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("IBM Watson embedding test timed out")
|
||||||
|
raise Exception("Request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"IBM Watson embedding test failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Ollama validation functions
|
||||||
|
async def _test_ollama_completion_with_tools(llm_model: str, endpoint: str) -> None:
|
||||||
|
"""Test Ollama completion with tool calling."""
|
||||||
|
try:
|
||||||
|
ollama_url = transform_localhost_url(endpoint)
|
||||||
|
url = f"{ollama_url}/api/chat"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": llm_model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like?"}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Ollama completion test failed: {response.status_code} - {response.text}")
|
||||||
|
raise Exception(f"Ollama API error: {response.status_code}")
|
||||||
|
|
||||||
|
logger.info("Ollama completion with tool calling test passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("Ollama completion test timed out")
|
||||||
|
raise Exception("Request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ollama completion test failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_ollama_embedding(embedding_model: str, endpoint: str) -> None:
|
||||||
|
"""Test Ollama embedding generation."""
|
||||||
|
try:
|
||||||
|
ollama_url = transform_localhost_url(endpoint)
|
||||||
|
url = f"{ollama_url}/api/embeddings"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": embedding_model,
|
||||||
|
"prompt": "test embedding",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Ollama embedding test failed: {response.status_code} - {response.text}")
|
||||||
|
raise Exception(f"Ollama API error: {response.status_code}")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if not data.get("embedding"):
|
||||||
|
raise Exception("No embedding data returned")
|
||||||
|
|
||||||
|
logger.info("Ollama embedding test passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("Ollama embedding test timed out")
|
||||||
|
raise Exception("Request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Ollama embedding test failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import platform
|
import platform
|
||||||
|
import time
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from utils.container_utils import transform_localhost_url
|
from utils.container_utils import transform_localhost_url
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
|
|
@ -533,6 +534,29 @@ async def onboarding(request, flows_service):
|
||||||
{"error": "No valid fields provided for update"}, status_code=400
|
{"error": "No valid fields provided for update"}, status_code=400
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate provider setup before initializing OpenSearch index
|
||||||
|
try:
|
||||||
|
from api.provider_validation import validate_provider_setup
|
||||||
|
|
||||||
|
provider = current_config.provider.model_provider.lower() if current_config.provider.model_provider else "openai"
|
||||||
|
|
||||||
|
logger.info(f"Validating provider setup for {provider}")
|
||||||
|
await validate_provider_setup(
|
||||||
|
provider=provider,
|
||||||
|
api_key=current_config.provider.api_key,
|
||||||
|
embedding_model=current_config.knowledge.embedding_model,
|
||||||
|
llm_model=current_config.agent.llm_model,
|
||||||
|
endpoint=current_config.provider.endpoint,
|
||||||
|
project_id=current_config.provider.project_id,
|
||||||
|
)
|
||||||
|
logger.info(f"Provider setup validation completed successfully for {provider}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Provider validation failed: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": str(e)},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize the OpenSearch index now that we have the embedding model configured
|
# Initialize the OpenSearch index now that we have the embedding model configured
|
||||||
try:
|
try:
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
|
|
@ -694,7 +718,7 @@ async def onboarding(request, flows_service):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to update onboarding settings", error=str(e))
|
logger.error("Failed to update onboarding settings", error=str(e))
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"error": f"Failed to update onboarding settings: {str(e)}"},
|
{"error": str(e)},
|
||||||
status_code=500,
|
status_code=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue