Merge pull request #598 from langflow-ai/fix/health_check_models
fix: make health check be lightweight on best case scenario
This commit is contained in:
commit
32d928074a
11 changed files with 303 additions and 94 deletions
|
|
@ -3,6 +3,7 @@ import {
|
||||||
useQuery,
|
useQuery,
|
||||||
useQueryClient,
|
useQueryClient,
|
||||||
} from "@tanstack/react-query";
|
} from "@tanstack/react-query";
|
||||||
|
import { useChat } from "@/contexts/chat-context";
|
||||||
import { useGetSettingsQuery } from "./useGetSettingsQuery";
|
import { useGetSettingsQuery } from "./useGetSettingsQuery";
|
||||||
|
|
||||||
export interface ProviderHealthDetails {
|
export interface ProviderHealthDetails {
|
||||||
|
|
@ -24,6 +25,7 @@ export interface ProviderHealthResponse {
|
||||||
|
|
||||||
export interface ProviderHealthParams {
|
export interface ProviderHealthParams {
|
||||||
provider?: "openai" | "ollama" | "watsonx";
|
provider?: "openai" | "ollama" | "watsonx";
|
||||||
|
test_completion?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track consecutive failures for exponential backoff
|
// Track consecutive failures for exponential backoff
|
||||||
|
|
@ -38,6 +40,9 @@ export const useProviderHealthQuery = (
|
||||||
) => {
|
) => {
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
|
// Get chat error state from context (ChatProvider wraps the entire app in layout.tsx)
|
||||||
|
const { hasChatError, setChatError } = useChat();
|
||||||
|
|
||||||
const { data: settings = {} } = useGetSettingsQuery();
|
const { data: settings = {} } = useGetSettingsQuery();
|
||||||
|
|
||||||
async function checkProviderHealth(): Promise<ProviderHealthResponse> {
|
async function checkProviderHealth(): Promise<ProviderHealthResponse> {
|
||||||
|
|
@ -49,6 +54,12 @@ export const useProviderHealthQuery = (
|
||||||
url.searchParams.set("provider", params.provider);
|
url.searchParams.set("provider", params.provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add test_completion query param if specified or if chat error exists
|
||||||
|
const testCompletion = params?.test_completion ?? hasChatError;
|
||||||
|
if (testCompletion) {
|
||||||
|
url.searchParams.set("test_completion", "true");
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch(url.toString());
|
const response = await fetch(url.toString());
|
||||||
|
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
|
|
@ -90,7 +101,7 @@ export const useProviderHealthQuery = (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const queryKey = ["provider", "health"];
|
const queryKey = ["provider", "health", params?.test_completion];
|
||||||
const failureCountKey = queryKey.join("-");
|
const failureCountKey = queryKey.join("-");
|
||||||
|
|
||||||
const queryResult = useQuery(
|
const queryResult = useQuery(
|
||||||
|
|
@ -103,8 +114,13 @@ export const useProviderHealthQuery = (
|
||||||
const status = data?.status;
|
const status = data?.status;
|
||||||
|
|
||||||
// If healthy, reset failure count and check every 30 seconds
|
// If healthy, reset failure count and check every 30 seconds
|
||||||
|
// Also reset chat error flag if we're using test_completion=true and it succeeded
|
||||||
if (status === "healthy") {
|
if (status === "healthy") {
|
||||||
failureCountMap.set(failureCountKey, 0);
|
failureCountMap.set(failureCountKey, 0);
|
||||||
|
// If we were checking with test_completion=true due to chat errors, reset the flag
|
||||||
|
if (hasChatError && setChatError) {
|
||||||
|
setChatError(false);
|
||||||
|
}
|
||||||
return 30000;
|
return 30000;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -119,7 +135,8 @@ export const useProviderHealthQuery = (
|
||||||
|
|
||||||
// Exponential backoff: 5s, 10s, 20s, then 30s
|
// Exponential backoff: 5s, 10s, 20s, then 30s
|
||||||
const backoffDelays = [5000, 10000, 20000, 30000];
|
const backoffDelays = [5000, 10000, 20000, 30000];
|
||||||
const delay = backoffDelays[Math.min(currentFailures, backoffDelays.length - 1)];
|
const delay =
|
||||||
|
backoffDelays[Math.min(currentFailures, backoffDelays.length - 1)];
|
||||||
|
|
||||||
return delay;
|
return delay;
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ function ChatPage() {
|
||||||
]);
|
]);
|
||||||
const [input, setInput] = useState("");
|
const [input, setInput] = useState("");
|
||||||
const { loading, setLoading } = useLoadingStore();
|
const { loading, setLoading } = useLoadingStore();
|
||||||
|
const { setChatError } = useChat();
|
||||||
const [asyncMode, setAsyncMode] = useState(true);
|
const [asyncMode, setAsyncMode] = useState(true);
|
||||||
const [expandedFunctionCalls, setExpandedFunctionCalls] = useState<
|
const [expandedFunctionCalls, setExpandedFunctionCalls] = useState<
|
||||||
Set<string>
|
Set<string>
|
||||||
|
|
@ -123,6 +124,8 @@ function ChatPage() {
|
||||||
console.error("Streaming error:", error);
|
console.error("Streaming error:", error);
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
setWaitingTooLong(false);
|
setWaitingTooLong(false);
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
setChatError(true);
|
||||||
const errorMessage: Message = {
|
const errorMessage: Message = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content:
|
content:
|
||||||
|
|
@ -197,6 +200,11 @@ function ChatPage() {
|
||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
console.log("Upload result:", result);
|
console.log("Upload result:", result);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
// Set chat error flag if upload fails
|
||||||
|
setChatError(true);
|
||||||
|
}
|
||||||
|
|
||||||
if (response.status === 201) {
|
if (response.status === 201) {
|
||||||
// New flow: Got task ID, start tracking with centralized system
|
// New flow: Got task ID, start tracking with centralized system
|
||||||
const taskId = result.task_id || result.id;
|
const taskId = result.task_id || result.id;
|
||||||
|
|
@ -255,6 +263,8 @@ function ChatPage() {
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Upload failed:", error);
|
console.error("Upload failed:", error);
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
setChatError(true);
|
||||||
const errorMessage: Message = {
|
const errorMessage: Message = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: `❌ Failed to process document. Please try again.`,
|
content: `❌ Failed to process document. Please try again.`,
|
||||||
|
|
@ -858,6 +868,8 @@ function ChatPage() {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
console.error("Chat failed:", result.error);
|
console.error("Chat failed:", result.error);
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
setChatError(true);
|
||||||
const errorMessage: Message = {
|
const errorMessage: Message = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: "Sorry, I encountered an error. Please try again.",
|
content: "Sorry, I encountered an error. Please try again.",
|
||||||
|
|
@ -867,6 +879,8 @@ function ChatPage() {
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Chat error:", error);
|
console.error("Chat error:", error);
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
setChatError(true);
|
||||||
const errorMessage: Message = {
|
const errorMessage: Message = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content:
|
content:
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,16 @@ const OnboardingUpload = ({ onComplete }: OnboardingUploadProps) => {
|
||||||
const errorMessage = error instanceof Error ? error.message : "Upload failed";
|
const errorMessage = error instanceof Error ? error.message : "Upload failed";
|
||||||
console.error("Upload failed", errorMessage);
|
console.error("Upload failed", errorMessage);
|
||||||
|
|
||||||
|
// Dispatch event that chat context can listen to
|
||||||
|
// This avoids circular dependency issues
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
window.dispatchEvent(
|
||||||
|
new CustomEvent("ingestionFailed", {
|
||||||
|
detail: { source: "onboarding" },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Show error toast notification
|
// Show error toast notification
|
||||||
toast.error("Document upload failed", {
|
toast.error("Document upload failed", {
|
||||||
description: errorMessage,
|
description: errorMessage,
|
||||||
|
|
|
||||||
|
|
@ -238,6 +238,15 @@ export function KnowledgeDropdown() {
|
||||||
await uploadFileUtil(file, replace);
|
await uploadFileUtil(file, replace);
|
||||||
refetchTasks();
|
refetchTasks();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
// Dispatch event that chat context can listen to
|
||||||
|
// This avoids circular dependency issues
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
window.dispatchEvent(
|
||||||
|
new CustomEvent("ingestionFailed", {
|
||||||
|
detail: { source: "knowledge-dropdown" },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
toast.error("Upload failed", {
|
toast.error("Upload failed", {
|
||||||
description: error instanceof Error ? error.message : "Unknown error",
|
description: error instanceof Error ? error.message : "Unknown error",
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import { useProviderHealthQuery } from "@/app/api/queries/useProviderHealthQuery
|
||||||
import type { ModelProvider } from "@/app/settings/_helpers/model-helpers";
|
import type { ModelProvider } from "@/app/settings/_helpers/model-helpers";
|
||||||
import { Banner, BannerIcon, BannerTitle } from "@/components/ui/banner";
|
import { Banner, BannerIcon, BannerTitle } from "@/components/ui/banner";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import { useChat } from "@/contexts/chat-context";
|
||||||
import { Button } from "./ui/button";
|
import { Button } from "./ui/button";
|
||||||
|
|
||||||
interface ProviderHealthBannerProps {
|
interface ProviderHealthBannerProps {
|
||||||
|
|
@ -14,13 +15,16 @@ interface ProviderHealthBannerProps {
|
||||||
|
|
||||||
// Custom hook to check provider health status
|
// Custom hook to check provider health status
|
||||||
export function useProviderHealth() {
|
export function useProviderHealth() {
|
||||||
|
const { hasChatError } = useChat();
|
||||||
const {
|
const {
|
||||||
data: health,
|
data: health,
|
||||||
isLoading,
|
isLoading,
|
||||||
isFetching,
|
isFetching,
|
||||||
error,
|
error,
|
||||||
isError,
|
isError,
|
||||||
} = useProviderHealthQuery();
|
} = useProviderHealthQuery({
|
||||||
|
test_completion: hasChatError, // Use test_completion=true when chat errors occur
|
||||||
|
});
|
||||||
|
|
||||||
const isHealthy = health?.status === "healthy" && !isError;
|
const isHealthy = health?.status === "healthy" && !isError;
|
||||||
// Only consider unhealthy if backend is up but provider validation failed
|
// Only consider unhealthy if backend is up but provider validation failed
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,8 @@ interface ChatContextType {
|
||||||
conversationFilter: KnowledgeFilter | null;
|
conversationFilter: KnowledgeFilter | null;
|
||||||
// responseId: undefined = use currentConversationId, null = don't save to localStorage
|
// responseId: undefined = use currentConversationId, null = don't save to localStorage
|
||||||
setConversationFilter: (filter: KnowledgeFilter | null, responseId?: string | null) => void;
|
setConversationFilter: (filter: KnowledgeFilter | null, responseId?: string | null) => void;
|
||||||
|
hasChatError: boolean;
|
||||||
|
setChatError: (hasError: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ChatContext = createContext<ChatContextType | undefined>(undefined);
|
const ChatContext = createContext<ChatContextType | undefined>(undefined);
|
||||||
|
|
@ -108,6 +110,19 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
||||||
const [conversationLoaded, setConversationLoaded] = useState(false);
|
const [conversationLoaded, setConversationLoaded] = useState(false);
|
||||||
const [conversationFilter, setConversationFilterState] =
|
const [conversationFilter, setConversationFilterState] =
|
||||||
useState<KnowledgeFilter | null>(null);
|
useState<KnowledgeFilter | null>(null);
|
||||||
|
const [hasChatError, setChatError] = useState(false);
|
||||||
|
|
||||||
|
// Listen for ingestion failures and set chat error flag
|
||||||
|
useEffect(() => {
|
||||||
|
const handleIngestionFailed = () => {
|
||||||
|
setChatError(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener("ingestionFailed", handleIngestionFailed);
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener("ingestionFailed", handleIngestionFailed);
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
// Debounce refresh requests to prevent excessive reloads
|
// Debounce refresh requests to prevent excessive reloads
|
||||||
const refreshTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
const refreshTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
|
@ -358,6 +373,8 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
||||||
setConversationLoaded,
|
setConversationLoaded,
|
||||||
conversationFilter,
|
conversationFilter,
|
||||||
setConversationFilter,
|
setConversationFilter,
|
||||||
|
hasChatError,
|
||||||
|
setChatError,
|
||||||
}),
|
}),
|
||||||
[
|
[
|
||||||
endpoint,
|
endpoint,
|
||||||
|
|
@ -378,6 +395,7 @@ export function ChatProvider({ children }: ChatProviderProps) {
|
||||||
conversationLoaded,
|
conversationLoaded,
|
||||||
conversationFilter,
|
conversationFilter,
|
||||||
setConversationFilter,
|
setConversationFilter,
|
||||||
|
hasChatError,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -323,6 +323,20 @@ export function TaskProvider({ children }: { children: React.ReactNode }) {
|
||||||
currentTask.error || "Unknown error"
|
currentTask.error || "Unknown error"
|
||||||
}`,
|
}`,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Set chat error flag to trigger test_completion=true on health checks
|
||||||
|
// Only for ingestion-related tasks (tasks with files are ingestion tasks)
|
||||||
|
if (currentTask.files && Object.keys(currentTask.files).length > 0) {
|
||||||
|
// Dispatch event that chat context can listen to
|
||||||
|
// This avoids circular dependency issues
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
window.dispatchEvent(
|
||||||
|
new CustomEvent("ingestionFailed", {
|
||||||
|
detail: { taskId: currentTask.task_id },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
"""Provider health check endpoint."""
|
"""Provider health check endpoint."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import httpx
|
import httpx
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
from config.settings import get_openrag_config
|
from config.settings import get_openrag_config
|
||||||
from api.provider_validation import validate_provider_setup, _test_ollama_lightweight_health
|
from api.provider_validation import validate_provider_setup
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -16,6 +17,8 @@ async def check_provider_health(request):
|
||||||
Query parameters:
|
Query parameters:
|
||||||
provider (optional): Provider to check ('openai', 'ollama', 'watsonx', 'anthropic').
|
provider (optional): Provider to check ('openai', 'ollama', 'watsonx', 'anthropic').
|
||||||
If not provided, checks the currently configured provider.
|
If not provided, checks the currently configured provider.
|
||||||
|
test_completion (optional): If 'true', performs full validation with completion/embedding tests (consumes credits).
|
||||||
|
If 'false' or not provided, performs lightweight validation (no/minimal credits consumed).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
200: Provider is healthy and validated
|
200: Provider is healthy and validated
|
||||||
|
|
@ -26,6 +29,7 @@ async def check_provider_health(request):
|
||||||
# Get optional provider from query params
|
# Get optional provider from query params
|
||||||
query_params = dict(request.query_params)
|
query_params = dict(request.query_params)
|
||||||
check_provider = query_params.get("provider")
|
check_provider = query_params.get("provider")
|
||||||
|
test_completion = query_params.get("test_completion", "false").lower() == "true"
|
||||||
|
|
||||||
# Get current config
|
# Get current config
|
||||||
current_config = get_openrag_config()
|
current_config = get_openrag_config()
|
||||||
|
|
@ -100,6 +104,7 @@ async def check_provider_health(request):
|
||||||
llm_model=llm_model,
|
llm_model=llm_model,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
|
test_completion=test_completion,
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
@ -124,23 +129,14 @@ async def check_provider_health(request):
|
||||||
|
|
||||||
# Validate LLM provider
|
# Validate LLM provider
|
||||||
try:
|
try:
|
||||||
# For Ollama, use lightweight health check that doesn't block on active requests
|
await validate_provider_setup(
|
||||||
if provider == "ollama":
|
provider=provider,
|
||||||
try:
|
api_key=api_key,
|
||||||
await _test_ollama_lightweight_health(endpoint)
|
llm_model=llm_model,
|
||||||
except Exception as lightweight_error:
|
endpoint=endpoint,
|
||||||
# If lightweight check fails, Ollama is down or misconfigured
|
project_id=project_id,
|
||||||
llm_error = str(lightweight_error)
|
test_completion=test_completion,
|
||||||
logger.error(f"LLM provider ({provider}) lightweight check failed: {llm_error}")
|
)
|
||||||
raise
|
|
||||||
else:
|
|
||||||
await validate_provider_setup(
|
|
||||||
provider=provider,
|
|
||||||
api_key=api_key,
|
|
||||||
llm_model=llm_model,
|
|
||||||
endpoint=endpoint,
|
|
||||||
project_id=project_id,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
# Timeout means provider is busy, not misconfigured
|
# Timeout means provider is busy, not misconfigured
|
||||||
if provider == "ollama":
|
if provider == "ollama":
|
||||||
|
|
@ -154,24 +150,25 @@ async def check_provider_health(request):
|
||||||
logger.error(f"LLM provider ({provider}) validation failed: {llm_error}")
|
logger.error(f"LLM provider ({provider}) validation failed: {llm_error}")
|
||||||
|
|
||||||
# Validate embedding provider
|
# Validate embedding provider
|
||||||
|
# For WatsonX with test_completion=True, wait 2 seconds between completion and embedding tests
|
||||||
|
if (
|
||||||
|
test_completion
|
||||||
|
and provider == "watsonx"
|
||||||
|
and embedding_provider == "watsonx"
|
||||||
|
and llm_error is None
|
||||||
|
):
|
||||||
|
logger.info("Waiting 2 seconds before WatsonX embedding test (after completion test)")
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# For Ollama, use lightweight health check first
|
await validate_provider_setup(
|
||||||
if embedding_provider == "ollama":
|
provider=embedding_provider,
|
||||||
try:
|
api_key=embedding_api_key,
|
||||||
await _test_ollama_lightweight_health(embedding_endpoint)
|
embedding_model=embedding_model,
|
||||||
except Exception as lightweight_error:
|
endpoint=embedding_endpoint,
|
||||||
# If lightweight check fails, Ollama is down or misconfigured
|
project_id=embedding_project_id,
|
||||||
embedding_error = str(lightweight_error)
|
test_completion=test_completion,
|
||||||
logger.error(f"Embedding provider ({embedding_provider}) lightweight check failed: {embedding_error}")
|
)
|
||||||
raise
|
|
||||||
else:
|
|
||||||
await validate_provider_setup(
|
|
||||||
provider=embedding_provider,
|
|
||||||
api_key=embedding_api_key,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
endpoint=embedding_endpoint,
|
|
||||||
project_id=embedding_project_id,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
# Timeout means provider is busy, not misconfigured
|
# Timeout means provider is busy, not misconfigured
|
||||||
if embedding_provider == "ollama":
|
if embedding_provider == "ollama":
|
||||||
|
|
|
||||||
|
|
@ -14,17 +14,20 @@ async def validate_provider_setup(
|
||||||
llm_model: str = None,
|
llm_model: str = None,
|
||||||
endpoint: str = None,
|
endpoint: str = None,
|
||||||
project_id: str = None,
|
project_id: str = None,
|
||||||
|
test_completion: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider setup by testing completion with tool calling and embedding.
|
Validate provider setup by testing completion with tool calling and embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: Provider name ('openai', 'watsonx', 'ollama')
|
provider: Provider name ('openai', 'watsonx', 'ollama', 'anthropic')
|
||||||
api_key: API key for the provider (optional for ollama)
|
api_key: API key for the provider (optional for ollama)
|
||||||
embedding_model: Embedding model to test
|
embedding_model: Embedding model to test
|
||||||
llm_model: LLM model to test
|
llm_model: LLM model to test
|
||||||
endpoint: Provider endpoint (required for ollama and watsonx)
|
endpoint: Provider endpoint (required for ollama and watsonx)
|
||||||
project_id: Project ID (required for watsonx)
|
project_id: Project ID (required for watsonx)
|
||||||
|
test_completion: If True, performs full validation with completion/embedding tests (consumes credits).
|
||||||
|
If False, performs lightweight validation (no credits consumed). Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If validation fails with message "Setup failed, please try again or select a different provider."
|
Exception: If validation fails with message "Setup failed, please try again or select a different provider."
|
||||||
|
|
@ -32,29 +35,37 @@ async def validate_provider_setup(
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting validation for provider: {provider_lower}")
|
logger.info(f"Starting validation for provider: {provider_lower} (test_completion={test_completion})")
|
||||||
|
|
||||||
if embedding_model:
|
if test_completion:
|
||||||
# Test embedding
|
# Full validation with completion/embedding tests (consumes credits)
|
||||||
await test_embedding(
|
if embedding_model:
|
||||||
|
# Test embedding
|
||||||
|
await test_embedding(
|
||||||
|
provider=provider_lower,
|
||||||
|
api_key=api_key,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
endpoint=endpoint,
|
||||||
|
project_id=project_id,
|
||||||
|
)
|
||||||
|
elif llm_model:
|
||||||
|
# Test completion with tool calling
|
||||||
|
await test_completion_with_tools(
|
||||||
|
provider=provider_lower,
|
||||||
|
api_key=api_key,
|
||||||
|
llm_model=llm_model,
|
||||||
|
endpoint=endpoint,
|
||||||
|
project_id=project_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Lightweight validation (no credits consumed)
|
||||||
|
await test_lightweight_health(
|
||||||
provider=provider_lower,
|
provider=provider_lower,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
embedding_model=embedding_model,
|
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif llm_model:
|
|
||||||
# Test completion with tool calling
|
|
||||||
await test_completion_with_tools(
|
|
||||||
provider=provider_lower,
|
|
||||||
api_key=api_key,
|
|
||||||
llm_model=llm_model,
|
|
||||||
endpoint=endpoint,
|
|
||||||
project_id=project_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"Validation successful for provider: {provider_lower}")
|
logger.info(f"Validation successful for provider: {provider_lower}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -62,6 +73,26 @@ async def validate_provider_setup(
|
||||||
raise Exception("Setup failed, please try again or select a different provider.")
|
raise Exception("Setup failed, please try again or select a different provider.")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_lightweight_health(
|
||||||
|
provider: str,
|
||||||
|
api_key: str = None,
|
||||||
|
endpoint: str = None,
|
||||||
|
project_id: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""Test provider health with lightweight check (no credits consumed)."""
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
await _test_openai_lightweight_health(api_key)
|
||||||
|
elif provider == "watsonx":
|
||||||
|
await _test_watsonx_lightweight_health(api_key, endpoint, project_id)
|
||||||
|
elif provider == "ollama":
|
||||||
|
await _test_ollama_lightweight_health(endpoint)
|
||||||
|
elif provider == "anthropic":
|
||||||
|
await _test_anthropic_lightweight_health(api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
async def test_completion_with_tools(
|
async def test_completion_with_tools(
|
||||||
provider: str,
|
provider: str,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
|
|
@ -103,6 +134,40 @@ async def test_embedding(
|
||||||
|
|
||||||
|
|
||||||
# OpenAI validation functions
|
# OpenAI validation functions
|
||||||
|
async def _test_openai_lightweight_health(api_key: str) -> None:
|
||||||
|
"""Test OpenAI API key validity with lightweight check.
|
||||||
|
|
||||||
|
Only checks if the API key is valid without consuming credits.
|
||||||
|
Uses the /v1/models endpoint which doesn't consume credits.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Use /v1/models endpoint which validates the key without consuming credits
|
||||||
|
response = await client.get(
|
||||||
|
"https://api.openai.com/v1/models",
|
||||||
|
headers=headers,
|
||||||
|
timeout=10.0, # Short timeout for lightweight check
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"OpenAI lightweight health check failed: {response.status_code}")
|
||||||
|
raise Exception(f"OpenAI API key validation failed: {response.status_code}")
|
||||||
|
|
||||||
|
logger.info("OpenAI lightweight health check passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("OpenAI lightweight health check timed out")
|
||||||
|
raise Exception("OpenAI API request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI lightweight health check failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None:
|
async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None:
|
||||||
"""Test OpenAI completion with tool calling."""
|
"""Test OpenAI completion with tool calling."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -213,6 +278,45 @@ async def _test_openai_embedding(api_key: str, embedding_model: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
# IBM Watson validation functions
|
# IBM Watson validation functions
|
||||||
|
async def _test_watsonx_lightweight_health(
|
||||||
|
api_key: str, endpoint: str, project_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Test WatsonX API key validity with lightweight check.
|
||||||
|
|
||||||
|
Only checks if the API key is valid by getting a bearer token.
|
||||||
|
Does not consume credits by avoiding model inference requests.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get bearer token from IBM IAM - this validates the API key without consuming credits
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_response = await client.post(
|
||||||
|
"https://iam.cloud.ibm.com/identity/token",
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data={
|
||||||
|
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||||
|
"apikey": api_key,
|
||||||
|
},
|
||||||
|
timeout=10.0, # Short timeout for lightweight check
|
||||||
|
)
|
||||||
|
|
||||||
|
if token_response.status_code != 200:
|
||||||
|
logger.error(f"IBM IAM token request failed: {token_response.status_code}")
|
||||||
|
raise Exception("Failed to authenticate with IBM Watson - invalid API key")
|
||||||
|
|
||||||
|
bearer_token = token_response.json().get("access_token")
|
||||||
|
if not bearer_token:
|
||||||
|
raise Exception("No access token received from IBM")
|
||||||
|
|
||||||
|
logger.info("WatsonX lightweight health check passed - API key is valid")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("WatsonX lightweight health check timed out")
|
||||||
|
raise Exception("WatsonX API request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WatsonX lightweight health check failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _test_watsonx_completion_with_tools(
|
async def _test_watsonx_completion_with_tools(
|
||||||
api_key: str, llm_model: str, endpoint: str, project_id: str
|
api_key: str, llm_model: str, endpoint: str, project_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -483,6 +587,48 @@ async def _test_ollama_embedding(embedding_model: str, endpoint: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
# Anthropic validation functions
|
# Anthropic validation functions
|
||||||
|
async def _test_anthropic_lightweight_health(api_key: str) -> None:
|
||||||
|
"""Test Anthropic API key validity with lightweight check.
|
||||||
|
|
||||||
|
Only checks if the API key is valid without consuming credits.
|
||||||
|
Uses a minimal messages request with max_tokens=1 to validate the key.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Minimal validation request - uses cheapest model with minimal tokens
|
||||||
|
payload = {
|
||||||
|
"model": "claude-3-5-haiku-latest", # Cheapest model
|
||||||
|
"max_tokens": 1, # Minimum tokens to validate key
|
||||||
|
"messages": [{"role": "user", "content": "test"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
"https://api.anthropic.com/v1/messages",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=10.0, # Short timeout for lightweight check
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Anthropic lightweight health check failed: {response.status_code}")
|
||||||
|
raise Exception(f"Anthropic API key validation failed: {response.status_code}")
|
||||||
|
|
||||||
|
logger.info("Anthropic lightweight health check passed")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("Anthropic lightweight health check timed out")
|
||||||
|
raise Exception("Anthropic API request timed out")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic lightweight health check failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _test_anthropic_completion_with_tools(api_key: str, llm_model: str) -> None:
|
async def _test_anthropic_completion_with_tools(api_key: str, llm_model: str) -> None:
|
||||||
"""Test Anthropic completion with tool calling."""
|
"""Test Anthropic completion with tool calling."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -897,6 +897,7 @@ async def onboarding(request, flows_service, session_manager=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate provider setup before initializing OpenSearch index
|
# Validate provider setup before initializing OpenSearch index
|
||||||
|
# Use lightweight validation (test_completion=False) to avoid consuming credits during onboarding
|
||||||
try:
|
try:
|
||||||
from api.provider_validation import validate_provider_setup
|
from api.provider_validation import validate_provider_setup
|
||||||
|
|
||||||
|
|
@ -905,13 +906,14 @@ async def onboarding(request, flows_service, session_manager=None):
|
||||||
llm_provider = current_config.agent.llm_provider.lower()
|
llm_provider = current_config.agent.llm_provider.lower()
|
||||||
llm_provider_config = current_config.get_llm_provider_config()
|
llm_provider_config = current_config.get_llm_provider_config()
|
||||||
|
|
||||||
logger.info(f"Validating LLM provider setup for {llm_provider}")
|
logger.info(f"Validating LLM provider setup for {llm_provider} (lightweight)")
|
||||||
await validate_provider_setup(
|
await validate_provider_setup(
|
||||||
provider=llm_provider,
|
provider=llm_provider,
|
||||||
api_key=getattr(llm_provider_config, "api_key", None),
|
api_key=getattr(llm_provider_config, "api_key", None),
|
||||||
llm_model=current_config.agent.llm_model,
|
llm_model=current_config.agent.llm_model,
|
||||||
endpoint=getattr(llm_provider_config, "endpoint", None),
|
endpoint=getattr(llm_provider_config, "endpoint", None),
|
||||||
project_id=getattr(llm_provider_config, "project_id", None),
|
project_id=getattr(llm_provider_config, "project_id", None),
|
||||||
|
test_completion=False, # Lightweight validation - no credits consumed
|
||||||
)
|
)
|
||||||
logger.info(f"LLM provider setup validation completed successfully for {llm_provider}")
|
logger.info(f"LLM provider setup validation completed successfully for {llm_provider}")
|
||||||
|
|
||||||
|
|
@ -920,13 +922,14 @@ async def onboarding(request, flows_service, session_manager=None):
|
||||||
embedding_provider = current_config.knowledge.embedding_provider.lower()
|
embedding_provider = current_config.knowledge.embedding_provider.lower()
|
||||||
embedding_provider_config = current_config.get_embedding_provider_config()
|
embedding_provider_config = current_config.get_embedding_provider_config()
|
||||||
|
|
||||||
logger.info(f"Validating embedding provider setup for {embedding_provider}")
|
logger.info(f"Validating embedding provider setup for {embedding_provider} (lightweight)")
|
||||||
await validate_provider_setup(
|
await validate_provider_setup(
|
||||||
provider=embedding_provider,
|
provider=embedding_provider,
|
||||||
api_key=getattr(embedding_provider_config, "api_key", None),
|
api_key=getattr(embedding_provider_config, "api_key", None),
|
||||||
embedding_model=current_config.knowledge.embedding_model,
|
embedding_model=current_config.knowledge.embedding_model,
|
||||||
endpoint=getattr(embedding_provider_config, "endpoint", None),
|
endpoint=getattr(embedding_provider_config, "endpoint", None),
|
||||||
project_id=getattr(embedding_provider_config, "project_id", None),
|
project_id=getattr(embedding_provider_config, "project_id", None),
|
||||||
|
test_completion=False, # Lightweight validation - no credits consumed
|
||||||
)
|
)
|
||||||
logger.info(f"Embedding provider setup validation completed successfully for {embedding_provider}")
|
logger.info(f"Embedding provider setup validation completed successfully for {embedding_provider}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class ModelsService:
|
||||||
self.session_manager = None
|
self.session_manager = None
|
||||||
|
|
||||||
async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]:
|
async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]:
|
||||||
"""Fetch available models from OpenAI API"""
|
"""Fetch available models from OpenAI API with lightweight validation"""
|
||||||
try:
|
try:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {api_key}",
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
|
@ -58,6 +58,8 @@ class ModelsService:
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Lightweight validation: just check if API key is valid
|
||||||
|
# This doesn't consume credits, only validates the key
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
"https://api.openai.com/v1/models", headers=headers, timeout=10.0
|
"https://api.openai.com/v1/models", headers=headers, timeout=10.0
|
||||||
)
|
)
|
||||||
|
|
@ -101,6 +103,7 @@ class ModelsService:
|
||||||
key=lambda x: (not x.get("default", False), x["value"])
|
key=lambda x: (not x.get("default", False), x["value"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("OpenAI API key validated successfully without consuming credits")
|
||||||
return {
|
return {
|
||||||
"language_models": language_models,
|
"language_models": language_models,
|
||||||
"embedding_models": embedding_models,
|
"embedding_models": embedding_models,
|
||||||
|
|
@ -389,38 +392,12 @@ class ModelsService:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate credentials with the first available LLM model
|
# Lightweight validation: API key is already validated by successfully getting bearer token
|
||||||
if language_models:
|
# No need to make a generation request that consumes credits
|
||||||
first_llm_model = language_models[0]["value"]
|
if bearer_token:
|
||||||
|
logger.info("IBM Watson API key validated successfully without consuming credits")
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
validation_url = f"{watson_endpoint}/ml/v1/text/generation"
|
|
||||||
validation_params = {"version": "2024-09-16"}
|
|
||||||
validation_payload = {
|
|
||||||
"input": "test",
|
|
||||||
"model_id": first_llm_model,
|
|
||||||
"project_id": project_id,
|
|
||||||
"parameters": {
|
|
||||||
"max_new_tokens": 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
validation_response = await client.post(
|
|
||||||
validation_url,
|
|
||||||
headers=headers,
|
|
||||||
params=validation_params,
|
|
||||||
json=validation_payload,
|
|
||||||
timeout=10.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if validation_response.status_code != 200:
|
|
||||||
raise Exception(
|
|
||||||
f"Invalid credentials or endpoint: {validation_response.status_code} - {validation_response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"IBM Watson credentials validated successfully using model: {first_llm_model}")
|
|
||||||
else:
|
else:
|
||||||
logger.warning("No language models available to validate credentials")
|
logger.warning("No bearer token available - API key validation may have failed")
|
||||||
|
|
||||||
if not language_models and not embedding_models:
|
if not language_models and not embedding_models:
|
||||||
raise Exception("No IBM models retrieved from API")
|
raise Exception("No IBM models retrieved from API")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue