fix: added better onboarding error handling, added probing api keys and models (#326)
* Added error showing to onboarding card * Added error state on animated provider steps * removed toast on error * Fixed animation on onboarding card * fixed animation time * Implemented provider validation * Added provider validation before ingestion * Changed error border * remove log --------- Co-authored-by: Mike Fortman <michael.fortman@datastax.com>
This commit is contained in:
parent
3296a60b3b
commit
7b635df9d0
5 changed files with 664 additions and 144 deletions
|
|
@ -1154,8 +1154,6 @@ function ChatPage() {
|
|||
}
|
||||
};
|
||||
|
||||
console.log(messages)
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Debug header - only show in debug mode */}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { CheckIcon } from "lucide-react";
|
||||
import { CheckIcon, XIcon } from "lucide-react";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
import {
|
||||
|
|
@ -20,6 +20,7 @@ export function AnimatedProviderSteps({
|
|||
steps,
|
||||
storageKey = "provider-steps",
|
||||
processingStartTime,
|
||||
hasError = false,
|
||||
}: {
|
||||
currentStep: number;
|
||||
isCompleted: boolean;
|
||||
|
|
@ -27,6 +28,7 @@ export function AnimatedProviderSteps({
|
|||
steps: string[];
|
||||
storageKey?: string;
|
||||
processingStartTime?: number | null;
|
||||
hasError?: boolean;
|
||||
}) {
|
||||
const [startTime, setStartTime] = useState<number | null>(null);
|
||||
const [elapsedTime, setElapsedTime] = useState<number>(0);
|
||||
|
|
@ -63,7 +65,7 @@ export function AnimatedProviderSteps({
|
|||
}
|
||||
}, [isCompleted, startTime, storageKey]);
|
||||
|
||||
const isDone = currentStep >= steps.length && !isCompleted;
|
||||
const isDone = currentStep >= steps.length && !isCompleted && !hasError;
|
||||
|
||||
return (
|
||||
<AnimatePresence mode="wait">
|
||||
|
|
@ -79,8 +81,8 @@ export function AnimatedProviderSteps({
|
|||
<div className="flex items-center gap-2">
|
||||
<div
|
||||
className={cn(
|
||||
"transition-all duration-150 relative",
|
||||
isDone ? "w-3.5 h-3.5" : "w-6 h-6",
|
||||
"transition-all duration-300 relative",
|
||||
isDone || hasError ? "w-3.5 h-3.5" : "w-6 h-6",
|
||||
)}
|
||||
>
|
||||
<CheckIcon
|
||||
|
|
@ -89,21 +91,27 @@ export function AnimatedProviderSteps({
|
|||
isDone ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
/>
|
||||
<XIcon
|
||||
className={cn(
|
||||
"text-accent-red-foreground shrink-0 w-3.5 h-3.5 absolute inset-0 transition-all duration-150",
|
||||
hasError ? "opacity-100" : "opacity-0",
|
||||
)}
|
||||
/>
|
||||
<AnimatedProcessingIcon
|
||||
className={cn(
|
||||
"text-current shrink-0 absolute inset-0 transition-all duration-150",
|
||||
isDone ? "opacity-0" : "opacity-100",
|
||||
isDone || hasError ? "opacity-0" : "opacity-100",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<span className="text-mmd font-medium text-muted-foreground">
|
||||
{isDone ? "Done" : "Thinking"}
|
||||
<span className="!text-mmd font-medium text-muted-foreground">
|
||||
{hasError ? "Error" : isDone ? "Done" : "Thinking"}
|
||||
</span>
|
||||
</div>
|
||||
<div className="overflow-hidden">
|
||||
<AnimatePresence>
|
||||
{!isDone && (
|
||||
{!isDone && !hasError && (
|
||||
<motion.div
|
||||
initial={{ opacity: 1, y: 0, height: "auto" }}
|
||||
exit={{ opacity: 0, y: -24, height: 0 }}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { X } from "lucide-react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import {
|
||||
|
|
@ -70,6 +71,7 @@ const OnboardingCard = ({
|
|||
embedding_model: "",
|
||||
llm_model: "",
|
||||
});
|
||||
setError(null);
|
||||
};
|
||||
|
||||
const [settings, setSettings] = useState<OnboardingVariables>({
|
||||
|
|
@ -86,6 +88,8 @@ const OnboardingCard = ({
|
|||
null,
|
||||
);
|
||||
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Query tasks to track completion
|
||||
const { data: tasks } = useGetTasksQuery({
|
||||
enabled: currentStep !== null, // Only poll when onboarding has started
|
||||
|
|
@ -126,11 +130,15 @@ const OnboardingCard = ({
|
|||
onSuccess: (data) => {
|
||||
console.log("Onboarding completed successfully", data);
|
||||
setCurrentStep(0);
|
||||
setError(null);
|
||||
},
|
||||
onError: (error) => {
|
||||
toast.error("Failed to complete onboarding", {
|
||||
description: error.message,
|
||||
});
|
||||
setError(error.message);
|
||||
setCurrentStep(TOTAL_PROVIDER_STEPS);
|
||||
// Reset to provider selection after 1 second
|
||||
setTimeout(() => {
|
||||
setCurrentStep(null);
|
||||
}, 1000);
|
||||
},
|
||||
});
|
||||
|
||||
|
|
@ -144,6 +152,9 @@ const OnboardingCard = ({
|
|||
return;
|
||||
}
|
||||
|
||||
// Clear any previous error
|
||||
setError(null);
|
||||
|
||||
// Prepare onboarding data
|
||||
const onboardingData: OnboardingVariables = {
|
||||
model_provider: settings.model_provider,
|
||||
|
|
@ -181,139 +192,179 @@ const OnboardingCard = ({
|
|||
{currentStep === null ? (
|
||||
<motion.div
|
||||
key="onboarding-form"
|
||||
initial={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -24 }}
|
||||
initial={{ opacity: 0, y: -24 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: 24 }}
|
||||
transition={{ duration: 0.4, ease: "easeInOut" }}
|
||||
>
|
||||
<div className={`w-full max-w-[600px] flex flex-col gap-6`}>
|
||||
<Tabs
|
||||
defaultValue={modelProvider}
|
||||
onValueChange={handleSetModelProvider}
|
||||
>
|
||||
<TabsList className="mb-4">
|
||||
<TabsTrigger value="openai">
|
||||
<TabTrigger
|
||||
selected={modelProvider === "openai"}
|
||||
isLoading={isLoadingModels}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||
modelProvider === "openai" ? "bg-white" : "bg-muted",
|
||||
)}
|
||||
>
|
||||
<OpenAILogo
|
||||
className={cn(
|
||||
"w-4 h-4 shrink-0",
|
||||
modelProvider === "openai"
|
||||
? "text-black"
|
||||
: "text-muted-foreground",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
OpenAI
|
||||
</TabTrigger>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="watsonx">
|
||||
<TabTrigger
|
||||
selected={modelProvider === "watsonx"}
|
||||
isLoading={isLoadingModels}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||
modelProvider === "watsonx"
|
||||
? "bg-[#1063FE]"
|
||||
: "bg-muted",
|
||||
)}
|
||||
>
|
||||
<IBMLogo
|
||||
className={cn(
|
||||
"w-4 h-4 shrink-0",
|
||||
modelProvider === "watsonx"
|
||||
? "text-white"
|
||||
: "text-muted-foreground",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
IBM watsonx.ai
|
||||
</TabTrigger>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="ollama">
|
||||
<TabTrigger
|
||||
selected={modelProvider === "ollama"}
|
||||
isLoading={isLoadingModels}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||
modelProvider === "ollama" ? "bg-white" : "bg-muted",
|
||||
)}
|
||||
>
|
||||
<OllamaLogo
|
||||
className={cn(
|
||||
"w-4 h-4 shrink-0",
|
||||
modelProvider === "ollama"
|
||||
? "text-black"
|
||||
: "text-muted-foreground",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
Ollama
|
||||
</TabTrigger>
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
<TabsContent value="openai">
|
||||
<OpenAIOnboarding
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
setIsLoadingModels={setIsLoadingModels}
|
||||
/>
|
||||
</TabsContent>
|
||||
<TabsContent value="watsonx">
|
||||
<IBMOnboarding
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
setIsLoadingModels={setIsLoadingModels}
|
||||
/>
|
||||
</TabsContent>
|
||||
<TabsContent value="ollama">
|
||||
<OllamaOnboarding
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
setIsLoadingModels={setIsLoadingModels}
|
||||
/>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={handleComplete}
|
||||
disabled={!isComplete || isLoadingModels}
|
||||
loading={onboardingMutation.isPending}
|
||||
>
|
||||
<span className="select-none">Complete</span>
|
||||
</Button>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{!isComplete && (
|
||||
<TooltipContent>
|
||||
{isLoadingModels
|
||||
? "Loading models..."
|
||||
: !!settings.llm_model &&
|
||||
!!settings.embedding_model &&
|
||||
!isDoclingHealthy
|
||||
? "docling-serve must be running to continue"
|
||||
: "Please fill in all required fields"}
|
||||
</TooltipContent>
|
||||
<div className={`w-full max-w-[600px] flex flex-col`}>
|
||||
<AnimatePresence mode="wait">
|
||||
{error && (
|
||||
<motion.div
|
||||
key="error"
|
||||
initial={{ opacity: 1, y: 0, height: "auto" }}
|
||||
exit={{ opacity: 0, y: -10, height: 0 }}
|
||||
>
|
||||
<div className="pb-6 flex items-center gap-4">
|
||||
<X className="w-4 h-4 text-destructive shrink-0" />
|
||||
<span className="text-mmd text-muted-foreground">
|
||||
{error}
|
||||
</span>
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</Tooltip>
|
||||
</AnimatePresence>
|
||||
<div className={`w-full flex flex-col gap-6`}>
|
||||
<Tabs
|
||||
defaultValue={modelProvider}
|
||||
onValueChange={handleSetModelProvider}
|
||||
>
|
||||
<TabsList className="mb-4">
|
||||
<TabsTrigger
|
||||
value="openai"
|
||||
className={cn(
|
||||
error &&
|
||||
modelProvider === "openai" &&
|
||||
"data-[state=active]:border-destructive",
|
||||
)}
|
||||
>
|
||||
<TabTrigger
|
||||
selected={modelProvider === "openai"}
|
||||
isLoading={isLoadingModels}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||
modelProvider === "openai" ? "bg-white" : "bg-muted",
|
||||
)}
|
||||
>
|
||||
<OpenAILogo
|
||||
className={cn(
|
||||
"w-4 h-4 shrink-0",
|
||||
modelProvider === "openai"
|
||||
? "text-black"
|
||||
: "text-muted-foreground",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
OpenAI
|
||||
</TabTrigger>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger
|
||||
value="watsonx"
|
||||
className={cn(
|
||||
error &&
|
||||
modelProvider === "watsonx" &&
|
||||
"data-[state=active]:border-destructive",
|
||||
)}
|
||||
>
|
||||
<TabTrigger
|
||||
selected={modelProvider === "watsonx"}
|
||||
isLoading={isLoadingModels}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||
modelProvider === "watsonx"
|
||||
? "bg-[#1063FE]"
|
||||
: "bg-muted",
|
||||
)}
|
||||
>
|
||||
<IBMLogo
|
||||
className={cn(
|
||||
"w-4 h-4 shrink-0",
|
||||
modelProvider === "watsonx"
|
||||
? "text-white"
|
||||
: "text-muted-foreground",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
IBM watsonx.ai
|
||||
</TabTrigger>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger
|
||||
value="ollama"
|
||||
className={cn(
|
||||
error &&
|
||||
modelProvider === "ollama" &&
|
||||
"data-[state=active]:border-destructive",
|
||||
)}
|
||||
>
|
||||
<TabTrigger
|
||||
selected={modelProvider === "ollama"}
|
||||
isLoading={isLoadingModels}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-2 w-8 h-8 rounded-md",
|
||||
modelProvider === "ollama" ? "bg-white" : "bg-muted",
|
||||
)}
|
||||
>
|
||||
<OllamaLogo
|
||||
className={cn(
|
||||
"w-4 h-4 shrink-0",
|
||||
modelProvider === "ollama"
|
||||
? "text-black"
|
||||
: "text-muted-foreground",
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
Ollama
|
||||
</TabTrigger>
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
<TabsContent value="openai">
|
||||
<OpenAIOnboarding
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
setIsLoadingModels={setIsLoadingModels}
|
||||
/>
|
||||
</TabsContent>
|
||||
<TabsContent value="watsonx">
|
||||
<IBMOnboarding
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
setIsLoadingModels={setIsLoadingModels}
|
||||
/>
|
||||
</TabsContent>
|
||||
<TabsContent value="ollama">
|
||||
<OllamaOnboarding
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
setIsLoadingModels={setIsLoadingModels}
|
||||
/>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={handleComplete}
|
||||
disabled={!isComplete || isLoadingModels}
|
||||
loading={onboardingMutation.isPending}
|
||||
>
|
||||
<span className="select-none">Complete</span>
|
||||
</Button>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{!isComplete && (
|
||||
<TooltipContent>
|
||||
{isLoadingModels
|
||||
? "Loading models..."
|
||||
: !!settings.llm_model &&
|
||||
!!settings.embedding_model &&
|
||||
!isDoclingHealthy
|
||||
? "docling-serve must be running to continue"
|
||||
: "Please fill in all required fields"}
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
) : (
|
||||
|
|
@ -321,6 +372,7 @@ const OnboardingCard = ({
|
|||
key="provider-steps"
|
||||
initial={{ opacity: 0, y: 24 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: 24 }}
|
||||
transition={{ duration: 0.4, ease: "easeInOut" }}
|
||||
>
|
||||
<AnimatedProviderSteps
|
||||
|
|
@ -329,6 +381,7 @@ const OnboardingCard = ({
|
|||
setCurrentStep={setCurrentStep}
|
||||
steps={STEP_LIST}
|
||||
processingStartTime={processingStartTime}
|
||||
hasError={!!error}
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
|
|
|
|||
437
src/api/provider_validation.py
Normal file
437
src/api/provider_validation.py
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
"""Provider validation utilities for testing API keys and models during onboarding."""
|
||||
|
||||
import httpx
|
||||
from utils.container_utils import transform_localhost_url
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def validate_provider_setup(
|
||||
provider: str,
|
||||
api_key: str = None,
|
||||
embedding_model: str = None,
|
||||
llm_model: str = None,
|
||||
endpoint: str = None,
|
||||
project_id: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Validate provider setup by testing completion with tool calling and embedding.
|
||||
|
||||
Args:
|
||||
provider: Provider name ('openai', 'watsonx', 'ollama')
|
||||
api_key: API key for the provider (optional for ollama)
|
||||
embedding_model: Embedding model to test
|
||||
llm_model: LLM model to test
|
||||
endpoint: Provider endpoint (required for ollama and watsonx)
|
||||
project_id: Project ID (required for watsonx)
|
||||
|
||||
Raises:
|
||||
Exception: If validation fails with message "Setup failed, please try again or select a different provider."
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation for provider: {provider_lower}")
|
||||
|
||||
# Test completion with tool calling
|
||||
await test_completion_with_tools(
|
||||
provider=provider_lower,
|
||||
api_key=api_key,
|
||||
llm_model=llm_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
# Test embedding
|
||||
await test_embedding(
|
||||
provider=provider_lower,
|
||||
api_key=api_key,
|
||||
embedding_model=embedding_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
logger.info(f"Validation successful for provider: {provider_lower}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Validation failed for provider {provider_lower}: {str(e)}")
|
||||
raise Exception("Setup failed, please try again or select a different provider.")
|
||||
|
||||
|
||||
async def test_completion_with_tools(
|
||||
provider: str,
|
||||
api_key: str = None,
|
||||
llm_model: str = None,
|
||||
endpoint: str = None,
|
||||
project_id: str = None,
|
||||
) -> None:
|
||||
"""Test completion with tool calling for the provider."""
|
||||
|
||||
if provider == "openai":
|
||||
await _test_openai_completion_with_tools(api_key, llm_model)
|
||||
elif provider == "watsonx":
|
||||
await _test_watsonx_completion_with_tools(api_key, llm_model, endpoint, project_id)
|
||||
elif provider == "ollama":
|
||||
await _test_ollama_completion_with_tools(llm_model, endpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
|
||||
async def test_embedding(
|
||||
provider: str,
|
||||
api_key: str = None,
|
||||
embedding_model: str = None,
|
||||
endpoint: str = None,
|
||||
project_id: str = None,
|
||||
) -> None:
|
||||
"""Test embedding generation for the provider."""
|
||||
|
||||
if provider == "openai":
|
||||
await _test_openai_embedding(api_key, embedding_model)
|
||||
elif provider == "watsonx":
|
||||
await _test_watsonx_embedding(api_key, embedding_model, endpoint, project_id)
|
||||
elif provider == "ollama":
|
||||
await _test_ollama_embedding(embedding_model, endpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
|
||||
# OpenAI validation functions
|
||||
async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None:
|
||||
"""Test OpenAI completion with tool calling."""
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Simple tool calling test
|
||||
payload = {
|
||||
"model": llm_model,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like?"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"max_tokens": 50,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OpenAI completion test failed: {response.status_code} - {response.text}")
|
||||
raise Exception(f"OpenAI API error: {response.status_code}")
|
||||
|
||||
logger.info("OpenAI completion with tool calling test passed")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("OpenAI completion test timed out")
|
||||
raise Exception("Request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI completion test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _test_openai_embedding(api_key: str, embedding_model: str) -> None:
|
||||
"""Test OpenAI embedding generation."""
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": embedding_model,
|
||||
"input": "test embedding",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/embeddings",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OpenAI embedding test failed: {response.status_code} - {response.text}")
|
||||
raise Exception(f"OpenAI API error: {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
if not data.get("data") or len(data["data"]) == 0:
|
||||
raise Exception("No embedding data returned")
|
||||
|
||||
logger.info("OpenAI embedding test passed")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("OpenAI embedding test timed out")
|
||||
raise Exception("Request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI embedding test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# IBM Watson validation functions
|
||||
async def _test_watsonx_completion_with_tools(
|
||||
api_key: str, llm_model: str, endpoint: str, project_id: str
|
||||
) -> None:
|
||||
"""Test IBM Watson completion with tool calling."""
|
||||
try:
|
||||
# Get bearer token from IBM IAM
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
"https://iam.cloud.ibm.com/identity/token",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": api_key,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
logger.error(f"IBM IAM token request failed: {token_response.status_code}")
|
||||
raise Exception("Failed to authenticate with IBM Watson")
|
||||
|
||||
bearer_token = token_response.json().get("access_token")
|
||||
if not bearer_token:
|
||||
raise Exception("No access token received from IBM")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {bearer_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Test completion with tools
|
||||
url = f"{endpoint}/ml/v1/text/chat"
|
||||
params = {"version": "2024-09-16"}
|
||||
payload = {
|
||||
"model_id": llm_model,
|
||||
"project_id": project_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like?"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"max_tokens": 50,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"IBM Watson completion test failed: {response.status_code} - {response.text}")
|
||||
raise Exception(f"IBM Watson API error: {response.status_code}")
|
||||
|
||||
logger.info("IBM Watson completion with tool calling test passed")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("IBM Watson completion test timed out")
|
||||
raise Exception("Request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"IBM Watson completion test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _test_watsonx_embedding(
|
||||
api_key: str, embedding_model: str, endpoint: str, project_id: str
|
||||
) -> None:
|
||||
"""Test IBM Watson embedding generation."""
|
||||
try:
|
||||
# Get bearer token from IBM IAM
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
"https://iam.cloud.ibm.com/identity/token",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": api_key,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
logger.error(f"IBM IAM token request failed: {token_response.status_code}")
|
||||
raise Exception("Failed to authenticate with IBM Watson")
|
||||
|
||||
bearer_token = token_response.json().get("access_token")
|
||||
if not bearer_token:
|
||||
raise Exception("No access token received from IBM")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {bearer_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Test embedding
|
||||
url = f"{endpoint}/ml/v1/text/embeddings"
|
||||
params = {"version": "2024-09-16"}
|
||||
payload = {
|
||||
"model_id": embedding_model,
|
||||
"project_id": project_id,
|
||||
"inputs": ["test embedding"],
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"IBM Watson embedding test failed: {response.status_code} - {response.text}")
|
||||
raise Exception(f"IBM Watson API error: {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
if not data.get("results") or len(data["results"]) == 0:
|
||||
raise Exception("No embedding data returned")
|
||||
|
||||
logger.info("IBM Watson embedding test passed")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("IBM Watson embedding test timed out")
|
||||
raise Exception("Request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"IBM Watson embedding test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Ollama validation functions
|
||||
async def _test_ollama_completion_with_tools(llm_model: str, endpoint: str) -> None:
|
||||
"""Test Ollama completion with tool calling."""
|
||||
try:
|
||||
ollama_url = transform_localhost_url(endpoint)
|
||||
url = f"{ollama_url}/api/chat"
|
||||
|
||||
payload = {
|
||||
"model": llm_model,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like?"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Ollama completion test failed: {response.status_code} - {response.text}")
|
||||
raise Exception(f"Ollama API error: {response.status_code}")
|
||||
|
||||
logger.info("Ollama completion with tool calling test passed")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Ollama completion test timed out")
|
||||
raise Exception("Request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama completion test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _test_ollama_embedding(embedding_model: str, endpoint: str) -> None:
|
||||
"""Test Ollama embedding generation."""
|
||||
try:
|
||||
ollama_url = transform_localhost_url(endpoint)
|
||||
url = f"{ollama_url}/api/embeddings"
|
||||
|
||||
payload = {
|
||||
"model": embedding_model,
|
||||
"prompt": "test embedding",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Ollama embedding test failed: {response.status_code} - {response.text}")
|
||||
raise Exception(f"Ollama API error: {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
if not data.get("embedding"):
|
||||
raise Exception("No embedding data returned")
|
||||
|
||||
logger.info("Ollama embedding test passed")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Ollama embedding test timed out")
|
||||
raise Exception("Request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama embedding test failed: {str(e)}")
|
||||
raise
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import platform
|
||||
import time
|
||||
from starlette.responses import JSONResponse
|
||||
from utils.container_utils import transform_localhost_url
|
||||
from utils.logging_config import get_logger
|
||||
|
|
@ -533,6 +534,29 @@ async def onboarding(request, flows_service):
|
|||
{"error": "No valid fields provided for update"}, status_code=400
|
||||
)
|
||||
|
||||
# Validate provider setup before initializing OpenSearch index
|
||||
try:
|
||||
from api.provider_validation import validate_provider_setup
|
||||
|
||||
provider = current_config.provider.model_provider.lower() if current_config.provider.model_provider else "openai"
|
||||
|
||||
logger.info(f"Validating provider setup for {provider}")
|
||||
await validate_provider_setup(
|
||||
provider=provider,
|
||||
api_key=current_config.provider.api_key,
|
||||
embedding_model=current_config.knowledge.embedding_model,
|
||||
llm_model=current_config.agent.llm_model,
|
||||
endpoint=current_config.provider.endpoint,
|
||||
project_id=current_config.provider.project_id,
|
||||
)
|
||||
logger.info(f"Provider setup validation completed successfully for {provider}")
|
||||
except Exception as e:
|
||||
logger.error(f"Provider validation failed: {str(e)}")
|
||||
return JSONResponse(
|
||||
{"error": str(e)},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Initialize the OpenSearch index now that we have the embedding model configured
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
|
|
@ -694,7 +718,7 @@ async def onboarding(request, flows_service):
|
|||
except Exception as e:
|
||||
logger.error("Failed to update onboarding settings", error=str(e))
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to update onboarding settings: {str(e)}"},
|
||||
{"error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue