diff --git a/frontend/lib/debounce.ts b/frontend/lib/debounce.ts index 9ff4c59a..d2e70441 100644 --- a/frontend/lib/debounce.ts +++ b/frontend/lib/debounce.ts @@ -1,4 +1,4 @@ -import { useCallback, useRef } from "react"; +import { useCallback, useRef, useState, useEffect } from "react"; export function useDebounce void>( callback: T, @@ -21,3 +21,19 @@ export function useDebounce void>( return debouncedCallback; } + +export function useDebouncedValue(value: T, delay: number): T { + const [debouncedValue, setDebouncedValue] = useState(value); + + useEffect(() => { + const handler = setTimeout(() => { + setDebouncedValue(value); + }, delay); + + return () => { + clearTimeout(handler); + }; + }, [value, delay]); + + return debouncedValue; +} diff --git a/frontend/src/app/api/mutations/useOnboardingMutation.ts b/frontend/src/app/api/mutations/useOnboardingMutation.ts new file mode 100644 index 00000000..96acb597 --- /dev/null +++ b/frontend/src/app/api/mutations/useOnboardingMutation.ts @@ -0,0 +1,61 @@ +import { + type UseMutationOptions, + useMutation, + useQueryClient, +} from "@tanstack/react-query"; + +export interface OnboardingVariables { + model_provider: string; + api_key?: string; + endpoint?: string; + project_id?: string; + embedding_model: string; + llm_model: string; + sample_data?: boolean; +} + +interface OnboardingResponse { + message: string; + edited: boolean; +} + +export const useOnboardingMutation = ( + options?: Omit< + UseMutationOptions< + OnboardingResponse, + Error, + OnboardingVariables + >, + "mutationFn" + >, +) => { + const queryClient = useQueryClient(); + + async function submitOnboarding( + variables: OnboardingVariables, + ): Promise { + const response = await fetch("/api/onboarding", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(variables), + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.error || "Failed to complete onboarding"); + } + + return response.json(); + } + + return useMutation({ + mutationFn: submitOnboarding, + onSettled: () => { + // Invalidate settings query to refetch updated data + queryClient.invalidateQueries({ queryKey: ["settings"] }); + }, + ...options, + }); +}; \ No newline at end of file diff --git a/frontend/src/app/api/queries/useGetModelsQuery.ts b/frontend/src/app/api/queries/useGetModelsQuery.ts index e94a752b..17b2336a 100644 --- a/frontend/src/app/api/queries/useGetModelsQuery.ts +++ b/frontend/src/app/api/queries/useGetModelsQuery.ts @@ -15,23 +15,33 @@ export interface ModelsResponse { embedding_models: ModelOption[]; } +export interface OpenAIModelsParams { + apiKey?: string; +} + export interface OllamaModelsParams { endpoint?: string; } export interface IBMModelsParams { - api_key?: string; endpoint?: string; - project_id?: string; + apiKey?: string; + projectId?: string; } export const useGetOpenAIModelsQuery = ( + params?: OpenAIModelsParams, options?: Omit, "queryKey" | "queryFn">, ) => { const queryClient = useQueryClient(); async function getOpenAIModels(): Promise { - const response = await fetch("/api/models/openai"); + const url = new URL("/api/models/openai", window.location.origin); + if (params?.apiKey) { + url.searchParams.set("api_key", params.apiKey); + } + + const response = await fetch(url.toString()); if (response.ok) { return await response.json(); } else { @@ -41,9 +51,12 @@ export const useGetOpenAIModelsQuery = ( const queryResult = useQuery( { - queryKey: ["models", "openai"], + queryKey: ["models", "openai", params], queryFn: getOpenAIModels, - staleTime: 5 * 60 * 1000, // 5 minutes + retry: 2, + enabled: !!params?.apiKey, // Only run if API key is provided + staleTime: 0, // Always fetch fresh data + gcTime: 0, // Don't cache results ...options, }, queryClient, @@ -76,8 +89,10 @@ export const useGetOllamaModelsQuery = ( { queryKey: ["models", "ollama", params], queryFn: getOllamaModels, - staleTime: 5 * 60 * 1000, // 5 minutes + retry: 2, enabled: !!params?.endpoint, // Only run if endpoint is provided + staleTime: 0, // Always fetch fresh data + gcTime: 0, // Don't cache results ...options, }, queryClient, @@ -93,35 +108,22 @@ export const useGetIBMModelsQuery = ( const queryClient = useQueryClient(); async function getIBMModels(): Promise { - const url = "/api/models/ibm"; + const url = new URL("/api/models/ibm", window.location.origin); + if (params?.endpoint) { + url.searchParams.set("endpoint", params.endpoint); + } + if (params?.apiKey) { + url.searchParams.set("api_key", params.apiKey); + } + if (params?.projectId) { + url.searchParams.set("project_id", params.projectId); + } - // If we have credentials, use POST to send them securely - if (params?.api_key || params?.endpoint || params?.project_id) { - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - api_key: params.api_key, - endpoint: params.endpoint, - project_id: params.project_id, - }), - }); - - if (response.ok) { - return await response.json(); - } else { - throw new Error("Failed to fetch IBM models"); - } + const response = await fetch(url.toString()); + if (response.ok) { + return await response.json(); } else { - // Use GET for default models - const response = await fetch(url); - if (response.ok) { - return await response.json(); - } else { - throw new Error("Failed to fetch IBM models"); - } + throw new Error("Failed to fetch IBM models"); } } @@ -129,8 +131,10 @@ export const useGetIBMModelsQuery = ( { queryKey: ["models", "ibm", params], queryFn: getIBMModels, - staleTime: 5 * 60 * 1000, // 5 minutes - enabled: !!(params?.api_key && params?.endpoint && params?.project_id), // Only run if all credentials are provided + retry: 2, + enabled: !!params?.endpoint && !!params?.apiKey && !!params?.projectId, // Only run if all required params are provided + staleTime: 0, // Always fetch fresh data + gcTime: 0, // Don't cache results ...options, }, queryClient, diff --git a/frontend/src/app/api/queries/useGetSettingsQuery.ts b/frontend/src/app/api/queries/useGetSettingsQuery.ts index 734fe1df..cf1b4ec2 100644 --- a/frontend/src/app/api/queries/useGetSettingsQuery.ts +++ b/frontend/src/app/api/queries/useGetSettingsQuery.ts @@ -63,3 +63,6 @@ export const useGetSettingsQuery = ( return queryResult; }; + + + diff --git a/frontend/src/app/login/page.tsx b/frontend/src/app/login/page.tsx index 9527da11..c2347f1b 100644 --- a/frontend/src/app/login/page.tsx +++ b/frontend/src/app/login/page.tsx @@ -14,8 +14,8 @@ function LoginPageContent() { const router = useRouter(); const searchParams = useSearchParams(); - const { data: settings } = useGetSettingsQuery({ - enabled: isAuthenticated, + const { data: settings, isLoading: isSettingsLoading } = useGetSettingsQuery({ + enabled: isAuthenticated || isNoAuthMode, }); const redirect = @@ -25,12 +25,19 @@ function LoginPageContent() { // Redirect if already authenticated or in no-auth mode useEffect(() => { - if (!isLoading && (isAuthenticated || isNoAuthMode)) { + if (!isLoading && !isSettingsLoading && (isAuthenticated || isNoAuthMode)) { router.push(redirect); } - }, [isLoading, isAuthenticated, isNoAuthMode, router, redirect]); + }, [ + isLoading, + isSettingsLoading, + isAuthenticated, + isNoAuthMode, + router, + redirect, + ]); - if (isLoading) { + if (isLoading || isSettingsLoading) { return (
diff --git a/frontend/src/app/onboarding/advanced.tsx b/frontend/src/app/onboarding/components/advanced.tsx similarity index 92% rename from frontend/src/app/onboarding/advanced.tsx rename to frontend/src/app/onboarding/components/advanced.tsx index 32071ba1..bb0089d5 100644 --- a/frontend/src/app/onboarding/advanced.tsx +++ b/frontend/src/app/onboarding/components/advanced.tsx @@ -32,8 +32,13 @@ export function AdvancedOnboarding({ setSampleDataset: (dataset: boolean) => void; }) { const hasEmbeddingModels = - embeddingModels && embeddingModel && setEmbeddingModel; - const hasLanguageModels = languageModels && languageModel && setLanguageModel; + embeddingModels !== undefined && + embeddingModel !== undefined && + setEmbeddingModel !== undefined; + const hasLanguageModels = + languageModels !== undefined && + languageModel !== undefined && + setLanguageModel !== undefined; return ( diff --git a/frontend/src/app/onboarding/components/ibm-onboarding.tsx b/frontend/src/app/onboarding/components/ibm-onboarding.tsx new file mode 100644 index 00000000..ddf964c6 --- /dev/null +++ b/frontend/src/app/onboarding/components/ibm-onboarding.tsx @@ -0,0 +1,127 @@ +import { useState } from "react"; +import { LabelInput } from "@/components/label-input"; +import IBMLogo from "@/components/logo/ibm-logo"; +import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation"; +import { useGetIBMModelsQuery } from "../../api/queries/useGetModelsQuery"; +import { useModelSelection } from "../hooks/useModelSelection"; +import { useUpdateSettings } from "../hooks/useUpdateSettings"; +import { useDebouncedValue } from "@/lib/debounce"; +import { AdvancedOnboarding } from "./advanced"; + +export function IBMOnboarding({ + setSettings, + sampleDataset, + setSampleDataset, +}: { + setSettings: (settings: OnboardingVariables) => void; + sampleDataset: boolean; + setSampleDataset: (dataset: boolean) => void; +}) { + const [endpoint, setEndpoint] = useState(""); + const [apiKey, setApiKey] = useState(""); + const [projectId, setProjectId] = useState(""); + + const debouncedEndpoint = useDebouncedValue(endpoint, 500); + const debouncedApiKey = useDebouncedValue(apiKey, 500); + const debouncedProjectId = useDebouncedValue(projectId, 500); + + // Fetch models from API when all credentials are provided + const { + data: modelsData, + isLoading: isLoadingModels, + error: modelsError, + } = useGetIBMModelsQuery( + debouncedEndpoint && debouncedApiKey && debouncedProjectId + ? { + endpoint: debouncedEndpoint, + apiKey: debouncedApiKey, + projectId: debouncedProjectId, + } + : undefined, + ); + + // Use custom hook for model selection logic + const { + languageModel, + embeddingModel, + setLanguageModel, + setEmbeddingModel, + languageModels, + embeddingModels, + } = useModelSelection(modelsData); + const handleSampleDatasetChange = (dataset: boolean) => { + setSampleDataset(dataset); + }; + + // Update settings when values change + useUpdateSettings( + "ibm", + { + endpoint, + apiKey, + projectId, + languageModel, + embeddingModel, + }, + setSettings, + ); + return ( + <> +
+ setEndpoint(e.target.value)} + /> + setApiKey(e.target.value)} + /> + setProjectId(e.target.value)} + /> + {isLoadingModels && ( +

+ Validating configuration... +

+ )} + {modelsError && ( +

+ Invalid configuration or connection failed +

+ )} + {modelsData && + (modelsData.language_models?.length > 0 || + modelsData.embedding_models?.length > 0) && ( +

Configuration is valid

+ )} +
+ } + languageModels={languageModels} + embeddingModels={embeddingModels} + languageModel={languageModel} + embeddingModel={embeddingModel} + sampleDataset={sampleDataset} + setLanguageModel={setLanguageModel} + setEmbeddingModel={setEmbeddingModel} + setSampleDataset={handleSampleDatasetChange} + /> + + ); +} diff --git a/frontend/src/app/onboarding/model-selector.tsx b/frontend/src/app/onboarding/components/model-selector.tsx similarity index 90% rename from frontend/src/app/onboarding/model-selector.tsx rename to frontend/src/app/onboarding/components/model-selector.tsx index c6cdebab..7a74bed2 100644 --- a/frontend/src/app/onboarding/model-selector.tsx +++ b/frontend/src/app/onboarding/components/model-selector.tsx @@ -1,5 +1,5 @@ import { CheckIcon, ChevronsUpDownIcon } from "lucide-react"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Button } from "@/components/ui/button"; import { Command, @@ -32,6 +32,11 @@ export function ModelSelector({ onValueChange: (value: string) => void; }) { const [open, setOpen] = useState(false); + useEffect(() => { + if (value && !options.find((option) => option.value === value)) { + onValueChange(""); + } + }, [options, value, onValueChange]); return ( @@ -39,6 +44,7 @@ export function ModelSelector({
+ ) : options.length === 0 ? ( + "No models available" ) : ( "Select model..." )} diff --git a/frontend/src/app/onboarding/components/ollama-onboarding.tsx b/frontend/src/app/onboarding/components/ollama-onboarding.tsx new file mode 100644 index 00000000..668fcb81 --- /dev/null +++ b/frontend/src/app/onboarding/components/ollama-onboarding.tsx @@ -0,0 +1,135 @@ +import { useState } from "react"; +import { LabelInput } from "@/components/label-input"; +import { LabelWrapper } from "@/components/label-wrapper"; +import OllamaLogo from "@/components/logo/ollama-logo"; +import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation"; +import { useGetOllamaModelsQuery } from "../../api/queries/useGetModelsQuery"; +import { useModelSelection } from "../hooks/useModelSelection"; +import { useUpdateSettings } from "../hooks/useUpdateSettings"; +import { useDebouncedValue } from "@/lib/debounce"; +import { AdvancedOnboarding } from "./advanced"; +import { ModelSelector } from "./model-selector"; + +export function OllamaOnboarding({ + setSettings, + sampleDataset, + setSampleDataset, +}: { + setSettings: (settings: OnboardingVariables) => void; + sampleDataset: boolean; + setSampleDataset: (dataset: boolean) => void; +}) { + const [endpoint, setEndpoint] = useState(""); + const debouncedEndpoint = useDebouncedValue(endpoint, 500); + + // Fetch models from API when endpoint is provided (debounced) + const { + data: modelsData, + isLoading: isLoadingModels, + error: modelsError, + } = useGetOllamaModelsQuery( + debouncedEndpoint ? { endpoint: debouncedEndpoint } : undefined, + ); + + // Use custom hook for model selection logic + const { + languageModel, + embeddingModel, + setLanguageModel, + setEmbeddingModel, + languageModels, + embeddingModels, + } = useModelSelection(modelsData); + + const handleSampleDatasetChange = (dataset: boolean) => { + setSampleDataset(dataset); + }; + + // Update settings when values change + useUpdateSettings( + "ollama", + { + endpoint, + languageModel, + embeddingModel, + }, + setSettings, + ); + + // Check validation state based on models query + const isConnecting = debouncedEndpoint && isLoadingModels; + const hasConnectionError = debouncedEndpoint && modelsError; + const hasNoModels = + modelsData && + !modelsData.language_models?.length && + !modelsData.embedding_models?.length; + const isValidConnection = + modelsData && + (modelsData.language_models?.length > 0 || + modelsData.embedding_models?.length > 0); + + return ( + <> +
+ setEndpoint(e.target.value)} + /> + {isConnecting && ( +

+ Connecting to Ollama server... +

+ )} + {hasConnectionError && ( +

+ Cannot connect to Ollama server. Please check the endpoint. +

+ )} + {hasNoModels && ( +

+ No models found. Please install some models on your Ollama server. +

+ )} + {isValidConnection && ( +

Connected successfully

+ )} +
+ + } + value={embeddingModel} + onValueChange={setEmbeddingModel} + /> + + + } + value={languageModel} + onValueChange={setLanguageModel} + /> + + + + + ); +} diff --git a/frontend/src/app/onboarding/components/openai-onboarding.tsx b/frontend/src/app/onboarding/components/openai-onboarding.tsx new file mode 100644 index 00000000..a0c2d391 --- /dev/null +++ b/frontend/src/app/onboarding/components/openai-onboarding.tsx @@ -0,0 +1,93 @@ +import { useState } from "react"; +import { LabelInput } from "@/components/label-input"; +import OpenAILogo from "@/components/logo/openai-logo"; +import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation"; +import { useGetOpenAIModelsQuery } from "../../api/queries/useGetModelsQuery"; +import { useModelSelection } from "../hooks/useModelSelection"; +import { useUpdateSettings } from "../hooks/useUpdateSettings"; +import { useDebouncedValue } from "@/lib/debounce"; +import { AdvancedOnboarding } from "./advanced"; + +export function OpenAIOnboarding({ + setSettings, + sampleDataset, + setSampleDataset, +}: { + setSettings: (settings: OnboardingVariables) => void; + sampleDataset: boolean; + setSampleDataset: (dataset: boolean) => void; +}) { + const [apiKey, setApiKey] = useState(""); + const debouncedApiKey = useDebouncedValue(apiKey, 500); + + // Fetch models from API when API key is provided + const { + data: modelsData, + isLoading: isLoadingModels, + error: modelsError, + } = useGetOpenAIModelsQuery( + debouncedApiKey ? { apiKey: debouncedApiKey } : undefined, + ); + // Use custom hook for model selection logic + const { + languageModel, + embeddingModel, + setLanguageModel, + setEmbeddingModel, + languageModels, + embeddingModels, + } = useModelSelection(modelsData); + const handleSampleDatasetChange = (dataset: boolean) => { + setSampleDataset(dataset); + }; + + // Update settings when values change + useUpdateSettings( + "openai", + { + apiKey, + languageModel, + embeddingModel, + }, + setSettings, + ); + return ( + <> +
+ setApiKey(e.target.value)} + /> + {isLoadingModels && ( +

Validating API key...

+ )} + {modelsError && ( +

+ Invalid API key or configuration +

+ )} + {modelsData && + (modelsData.language_models?.length > 0 || + modelsData.embedding_models?.length > 0) && ( +

Configuration is valid

+ )} +
+ } + languageModels={languageModels} + embeddingModels={embeddingModels} + languageModel={languageModel} + embeddingModel={embeddingModel} + sampleDataset={sampleDataset} + setLanguageModel={setLanguageModel} + setSampleDataset={handleSampleDatasetChange} + setEmbeddingModel={setEmbeddingModel} + /> + + ); +} diff --git a/frontend/src/app/onboarding/hooks/useModelSelection.ts b/frontend/src/app/onboarding/hooks/useModelSelection.ts new file mode 100644 index 00000000..ad695753 --- /dev/null +++ b/frontend/src/app/onboarding/hooks/useModelSelection.ts @@ -0,0 +1,46 @@ +import { useState, useEffect } from "react"; +import type { ModelsResponse } from "../../api/queries/useGetModelsQuery"; + +export function useModelSelection(modelsData: ModelsResponse | undefined) { + const [languageModel, setLanguageModel] = useState(""); + const [embeddingModel, setEmbeddingModel] = useState(""); + + // Update default selections when models are loaded + useEffect(() => { + if (modelsData) { + const defaultLangModel = modelsData.language_models.find( + (m) => m.default, + ); + const defaultEmbedModel = modelsData.embedding_models.find( + (m) => m.default, + ); + + // Set language model: prefer default, fallback to first available + if (!languageModel) { + if (defaultLangModel) { + setLanguageModel(defaultLangModel.value); + } else if (modelsData.language_models.length > 0) { + setLanguageModel(modelsData.language_models[0].value); + } + } + + // Set embedding model: prefer default, fallback to first available + if (!embeddingModel) { + if (defaultEmbedModel) { + setEmbeddingModel(defaultEmbedModel.value); + } else if (modelsData.embedding_models.length > 0) { + setEmbeddingModel(modelsData.embedding_models[0].value); + } + } + } + }, [modelsData, languageModel, embeddingModel]); + + return { + languageModel, + embeddingModel, + setLanguageModel, + setEmbeddingModel, + languageModels: modelsData?.language_models || [], + embeddingModels: modelsData?.embedding_models || [], + }; +} diff --git a/frontend/src/app/onboarding/hooks/useUpdateSettings.ts b/frontend/src/app/onboarding/hooks/useUpdateSettings.ts new file mode 100644 index 00000000..46c07b0c --- /dev/null +++ b/frontend/src/app/onboarding/hooks/useUpdateSettings.ts @@ -0,0 +1,58 @@ +import { useEffect } from "react"; +import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation"; + +interface ConfigValues { + apiKey?: string; + endpoint?: string; + projectId?: string; + languageModel?: string; + embeddingModel?: string; +} + +export function useUpdateSettings( + provider: string, + config: ConfigValues, + setSettings: (settings: OnboardingVariables) => void, +) { + useEffect(() => { + const updatedSettings: OnboardingVariables = { + model_provider: provider, + embedding_model: "", + llm_model: "", + }; + + // Set language model if provided + if (config.languageModel) { + updatedSettings.llm_model = config.languageModel; + } + + // Set embedding model if provided + if (config.embeddingModel) { + updatedSettings.embedding_model = config.embeddingModel; + } + + // Set API key if provided + if (config.apiKey) { + updatedSettings.api_key = config.apiKey; + } + + // Set endpoint and project ID if provided + if (config.endpoint) { + updatedSettings.endpoint = config.endpoint; + } + + if (config.projectId) { + updatedSettings.project_id = config.projectId; + } + + setSettings(updatedSettings); + }, [ + provider, + config.apiKey, + config.endpoint, + config.projectId, + config.languageModel, + config.embeddingModel, + setSettings, + ]); +} diff --git a/frontend/src/app/onboarding/ibm-onboarding.tsx b/frontend/src/app/onboarding/ibm-onboarding.tsx deleted file mode 100644 index 26b6adeb..00000000 --- a/frontend/src/app/onboarding/ibm-onboarding.tsx +++ /dev/null @@ -1,110 +0,0 @@ -import { useState, useEffect } from "react"; -import { LabelInput } from "@/components/label-input"; -import IBMLogo from "@/components/logo/ibm-logo"; -import type { Settings } from "../api/queries/useGetSettingsQuery"; -import { useGetIBMModelsQuery } from "../api/queries/useGetModelsQuery"; -import { AdvancedOnboarding } from "./advanced"; - -export function IBMOnboarding({ - settings, - setSettings, - sampleDataset, - setSampleDataset, -}: { - settings: Settings; - setSettings: (settings: Settings) => void; - sampleDataset: boolean; - setSampleDataset: (dataset: boolean) => void; -}) { - const [endpoint, setEndpoint] = useState(""); - const [apiKey, setApiKey] = useState(""); - const [projectId, setProjectId] = useState(""); - const [languageModel, setLanguageModel] = useState("meta-llama/llama-3-1-70b-instruct"); - const [embeddingModel, setEmbeddingModel] = useState("ibm/slate-125m-english-rtrvr"); - - // Fetch models from API when all credentials are provided - const { data: modelsData } = useGetIBMModelsQuery( - (apiKey && endpoint && projectId) ? { api_key: apiKey, endpoint, project_id: projectId } : undefined, - { enabled: !!(apiKey && endpoint && projectId) } - ); - - // Use fetched models or fallback to defaults - const languageModels = modelsData?.language_models || [ - { value: "meta-llama/llama-3-1-70b-instruct", label: "Llama 3.1 70B Instruct", default: true }, - { value: "meta-llama/llama-3-1-8b-instruct", label: "Llama 3.1 8B Instruct" }, - { value: "ibm/granite-13b-chat-v2", label: "Granite 13B Chat v2" }, - { value: "ibm/granite-13b-instruct-v2", label: "Granite 13B Instruct v2" }, - ]; - const embeddingModels = modelsData?.embedding_models || [ - { value: "ibm/slate-125m-english-rtrvr", label: "Slate 125M English Retriever", default: true }, - { value: "sentence-transformers/all-minilm-l12-v2", label: "All-MiniLM L12 v2" }, - ]; - - // Update default selections when models are loaded - useEffect(() => { - if (modelsData) { - const defaultLangModel = modelsData.language_models.find(m => m.default); - const defaultEmbedModel = modelsData.embedding_models.find(m => m.default); - - if (defaultLangModel) { - setLanguageModel(defaultLangModel.value); - } - if (defaultEmbedModel) { - setEmbeddingModel(defaultEmbedModel.value); - } - } - }, [modelsData]); - const handleLanguageModelChange = (model: string) => { - setLanguageModel(model); - }; - - const handleEmbeddingModelChange = (model: string) => { - setEmbeddingModel(model); - }; - - const handleSampleDatasetChange = (dataset: boolean) => { - setSampleDataset(dataset); - }; - return ( - <> - setEndpoint(e.target.value)} - /> - setApiKey(e.target.value)} - /> - setProjectId(e.target.value)} - /> - } - languageModels={languageModels} - embeddingModels={embeddingModels} - languageModel={languageModel} - embeddingModel={embeddingModel} - sampleDataset={sampleDataset} - setLanguageModel={handleLanguageModelChange} - setEmbeddingModel={handleEmbeddingModelChange} - setSampleDataset={handleSampleDatasetChange} - /> - - ); -} diff --git a/frontend/src/app/onboarding/ollama-onboarding.tsx b/frontend/src/app/onboarding/ollama-onboarding.tsx deleted file mode 100644 index 2513a8f5..00000000 --- a/frontend/src/app/onboarding/ollama-onboarding.tsx +++ /dev/null @@ -1,105 +0,0 @@ -import { useState, useEffect } from "react"; -import { LabelInput } from "@/components/label-input"; -import { LabelWrapper } from "@/components/label-wrapper"; -import OllamaLogo from "@/components/logo/ollama-logo"; -import type { Settings } from "../api/queries/useGetSettingsQuery"; -import { useGetOllamaModelsQuery } from "../api/queries/useGetModelsQuery"; -import { AdvancedOnboarding } from "./advanced"; -import { ModelSelector } from "./model-selector"; - -export function OllamaOnboarding({ - settings, - setSettings, - sampleDataset, - setSampleDataset, -}: { - settings: Settings; - setSettings: (settings: Settings) => void; - sampleDataset: boolean; - setSampleDataset: (dataset: boolean) => void; -}) { - const [endpoint, setEndpoint] = useState(""); - const [languageModel, setLanguageModel] = useState("llama3.2"); - const [embeddingModel, setEmbeddingModel] = useState("nomic-embed-text"); - - // Fetch models from API when endpoint is provided - const { data: modelsData } = useGetOllamaModelsQuery( - endpoint ? { endpoint } : undefined, - { enabled: !!endpoint } - ); - - // Use fetched models or fallback to defaults - const languageModels = modelsData?.language_models || [ - { value: "llama3.2", label: "llama3.2", default: true }, - { value: "llama3.1", label: "llama3.1" }, - { value: "llama3", label: "llama3" }, - { value: "mistral", label: "mistral" }, - { value: "codellama", label: "codellama" }, - ]; - const embeddingModels = modelsData?.embedding_models || [ - { value: "nomic-embed-text", label: "nomic-embed-text", default: true }, - ]; - - // Update default selections when models are loaded - useEffect(() => { - if (modelsData) { - const defaultLangModel = modelsData.language_models.find(m => m.default); - const defaultEmbedModel = modelsData.embedding_models.find(m => m.default); - - if (defaultLangModel) { - setLanguageModel(defaultLangModel.value); - } - if (defaultEmbedModel) { - setEmbeddingModel(defaultEmbedModel.value); - } - } - }, [modelsData]); - - const handleSampleDatasetChange = (dataset: boolean) => { - setSampleDataset(dataset); - }; - return ( - <> - setEndpoint(e.target.value)} - /> - - } - value={embeddingModel} - onValueChange={setEmbeddingModel} - /> - - - } - value={languageModel} - onValueChange={setLanguageModel} - /> - - - - - ); -} diff --git a/frontend/src/app/onboarding/openai-onboarding.tsx b/frontend/src/app/onboarding/openai-onboarding.tsx deleted file mode 100644 index 1e3be530..00000000 --- a/frontend/src/app/onboarding/openai-onboarding.tsx +++ /dev/null @@ -1,80 +0,0 @@ -import { useState, useEffect } from "react"; -import { LabelInput } from "@/components/label-input"; -import OpenAILogo from "@/components/logo/openai-logo"; -import type { Settings } from "../api/queries/useGetSettingsQuery"; -import { useGetOpenAIModelsQuery } from "../api/queries/useGetModelsQuery"; -import { AdvancedOnboarding } from "./advanced"; - -export function OpenAIOnboarding({ - settings, - setSettings, - sampleDataset, - setSampleDataset, -}: { - settings: Settings; - setSettings: (settings: Settings) => void; - sampleDataset: boolean; - setSampleDataset: (dataset: boolean) => void; -}) { - const [languageModel, setLanguageModel] = useState("gpt-4o-mini"); - const [embeddingModel, setEmbeddingModel] = useState( - "text-embedding-3-small", - ); - - // Fetch models from API - const { data: modelsData } = useGetOpenAIModelsQuery(); - - // Use fetched models or fallback to defaults - const languageModels = modelsData?.language_models || [{ value: "gpt-4o-mini", label: "gpt-4o-mini", default: true }]; - const embeddingModels = modelsData?.embedding_models || [ - { value: "text-embedding-3-small", label: "text-embedding-3-small", default: true }, - ]; - - // Update default selections when models are loaded - useEffect(() => { - if (modelsData) { - const defaultLangModel = modelsData.language_models.find(m => m.default); - const defaultEmbedModel = modelsData.embedding_models.find(m => m.default); - - if (defaultLangModel && languageModel === "gpt-4o-mini") { - setLanguageModel(defaultLangModel.value); - } - if (defaultEmbedModel && embeddingModel === "text-embedding-3-small") { - setEmbeddingModel(defaultEmbedModel.value); - } - } - }, [modelsData, languageModel, embeddingModel]); - const handleLanguageModelChange = (model: string) => { - setLanguageModel(model); - }; - - const handleEmbeddingModelChange = (model: string) => { - setEmbeddingModel(model); - }; - - const handleSampleDatasetChange = (dataset: boolean) => { - setSampleDataset(dataset); - }; - return ( - <> - - } - languageModels={languageModels} - embeddingModels={embeddingModels} - languageModel={languageModel} - embeddingModel={embeddingModel} - sampleDataset={sampleDataset} - setLanguageModel={handleLanguageModelChange} - setSampleDataset={handleSampleDatasetChange} - setEmbeddingModel={handleEmbeddingModelChange} - /> - - ); -} diff --git a/frontend/src/app/onboarding/page.tsx b/frontend/src/app/onboarding/page.tsx index 5e7ab924..bed6a389 100644 --- a/frontend/src/app/onboarding/page.tsx +++ b/frontend/src/app/onboarding/page.tsx @@ -1,12 +1,11 @@ "use client"; -import { Suspense, useState } from "react"; +import { Suspense, useEffect, useState } from "react"; import { toast } from "sonner"; -import { useUpdateFlowSettingMutation } from "@/app/api/mutations/useUpdateFlowSettingMutation"; import { - type Settings, - useGetSettingsQuery, -} from "@/app/api/queries/useGetSettingsQuery"; + useOnboardingMutation, + type OnboardingVariables, +} from "@/app/api/mutations/useOnboardingMutation"; import IBMLogo from "@/components/logo/ibm-logo"; import OllamaLogo from "@/components/logo/ollama-logo"; import OpenAILogo from "@/components/logo/openai-logo"; @@ -19,44 +18,102 @@ import { CardHeader, } from "@/components/ui/card"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; -import { useAuth } from "@/contexts/auth-context"; -import { IBMOnboarding } from "./ibm-onboarding"; -import { OllamaOnboarding } from "./ollama-onboarding"; -import { OpenAIOnboarding } from "./openai-onboarding"; +import { IBMOnboarding } from "./components/ibm-onboarding"; +import { OllamaOnboarding } from "./components/ollama-onboarding"; +import { OpenAIOnboarding } from "./components/openai-onboarding"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { useGetSettingsQuery } from "../api/queries/useGetSettingsQuery"; +import { useRouter } from "next/navigation"; function OnboardingPage() { - const { isAuthenticated } = useAuth(); + const { data: settingsDb, isLoading: isSettingsLoading } = + useGetSettingsQuery(); + + const redirect = "/"; + + const router = useRouter(); + + // Redirect if already authenticated or in no-auth mode + useEffect(() => { + if (!isSettingsLoading && settingsDb && settingsDb.edited) { + router.push(redirect); + } + }, [isSettingsLoading, redirect]); const [modelProvider, setModelProvider] = useState("openai"); - const [sampleDataset, setSampleDataset] = useState(false); - // Fetch settings using React Query - const { data: settingsDb = {} } = useGetSettingsQuery({ - enabled: isAuthenticated, + const [sampleDataset, setSampleDataset] = useState(true); + + const handleSetModelProvider = (provider: string) => { + setModelProvider(provider); + setSettings({ + model_provider: provider, + embedding_model: "", + llm_model: "", + }); + }; + + const [settings, setSettings] = useState({ + model_provider: modelProvider, + embedding_model: "", + llm_model: "", }); - const [settings, setSettings] = useState(settingsDb); - // Mutations - const updateFlowSettingMutation = useUpdateFlowSettingMutation({ - onSuccess: () => { - console.log("Setting updated successfully"); + const onboardingMutation = useOnboardingMutation({ + onSuccess: (data) => { + toast.success("Onboarding completed successfully!"); + console.log("Onboarding completed successfully", data); }, onError: (error) => { - toast.error("Failed to update settings", { + toast.error("Failed to complete onboarding", { description: error.message, }); }, }); const handleComplete = () => { - updateFlowSettingMutation.mutate({ - llm_model: settings.agent?.llm_model, - embedding_model: settings.knowledge?.embedding_model, - system_prompt: settings.agent?.system_prompt, - }); + if ( + !settings.model_provider || + !settings.llm_model || + !settings.embedding_model + ) { + toast.error("Please complete all required fields"); + return; + } + + // Prepare onboarding data + const onboardingData: OnboardingVariables = { + model_provider: settings.model_provider, + llm_model: settings.llm_model, + embedding_model: settings.embedding_model, + sample_data: sampleDataset, + }; + + // Add API key if available + if (settings.api_key) { + onboardingData.api_key = settings.api_key; + } + + // Add endpoint if available + if (settings.endpoint) { + onboardingData.endpoint = settings.endpoint; + } + + // Add project_id if available + if (settings.project_id) { + onboardingData.project_id = settings.project_id; + } + + onboardingMutation.mutate(onboardingData); }; + const isComplete = !!settings.llm_model && !!settings.embedding_model; + return (
[description of task]

- + @@ -94,7 +154,6 @@ function OnboardingPage() { - + + + + + + {!isComplete ? "Please fill in all required fields" : ""} + +
diff --git a/frontend/src/components/layout-wrapper.tsx b/frontend/src/components/layout-wrapper.tsx index ba8d7797..7a792456 100644 --- a/frontend/src/components/layout-wrapper.tsx +++ b/frontend/src/components/layout-wrapper.tsx @@ -15,15 +15,16 @@ import { useKnowledgeFilter } from "@/contexts/knowledge-filter-context"; // import { GitHubStarButton } from "@/components/github-star-button" // import { DiscordLink } from "@/components/discord-link" import { useTask } from "@/contexts/task-context"; +import Logo from "@/components/logo/logo"; export function LayoutWrapper({ children }: { children: React.ReactNode }) { const pathname = usePathname(); const { tasks, isMenuOpen, toggleMenu } = useTask(); const { selectedFilter, setSelectedFilter, isPanelOpen } = useKnowledgeFilter(); - const { isLoading, isAuthenticated } = useAuth(); + const { isLoading, isAuthenticated, isNoAuthMode } = useAuth(); const { isLoading: isSettingsLoading, data: settings } = useGetSettingsQuery({ - enabled: isAuthenticated, + enabled: isAuthenticated || isNoAuthMode, }); // List of paths that should not show navigation @@ -62,7 +63,7 @@ export function LayoutWrapper({ children }: { children: React.ReactNode }) {
{/* Logo/Title */}
- OpenRAG Logo + OpenRAG
diff --git a/frontend/src/components/protected-route.tsx b/frontend/src/components/protected-route.tsx index bfec26a3..ba195874 100644 --- a/frontend/src/components/protected-route.tsx +++ b/frontend/src/components/protected-route.tsx @@ -14,7 +14,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) { const { isLoading, isAuthenticated, isNoAuthMode } = useAuth(); const { data: settings = {}, isLoading: isSettingsLoading } = useGetSettingsQuery({ - enabled: isAuthenticated, + enabled: isAuthenticated || isNoAuthMode, }); const router = useRouter(); const pathname = usePathname(); @@ -31,12 +31,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) { ); useEffect(() => { - // In no-auth mode, allow access without authentication - if (isNoAuthMode) { - return; - } - - if (!isLoading && !isAuthenticated) { + if (!isLoading && !isSettingsLoading && !isAuthenticated && !isNoAuthMode) { // Redirect to login with current path as redirect parameter const redirectUrl = `/login?redirect=${encodeURIComponent(pathname)}`; router.push(redirectUrl); @@ -48,6 +43,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) { } }, [ isLoading, + isSettingsLoading, isAuthenticated, isNoAuthMode, router, @@ -57,7 +53,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) { ]); // Show loading state while checking authentication - if (isLoading) { + if (isLoading || isSettingsLoading) { return (
diff --git a/src/api/models.py b/src/api/models.py index 0dc78c2b..d6caa161 100644 --- a/src/api/models.py +++ b/src/api/models.py @@ -7,7 +7,17 @@ logger = get_logger(__name__) async def get_openai_models(request, models_service, session_manager): """Get available OpenAI models""" try: - models = await models_service.get_openai_models() + # Get API key from query parameters + query_params = dict(request.query_params) + api_key = query_params.get("api_key") + + if not api_key: + return JSONResponse( + {"error": "OpenAI API key is required as query parameter"}, + status_code=400 + ) + + models = await models_service.get_openai_models(api_key=api_key) return JSONResponse(models) except Exception as e: logger.error(f"Failed to get OpenAI models: {str(e)}") @@ -37,21 +47,15 @@ async def get_ollama_models(request, models_service, session_manager): async def get_ibm_models(request, models_service, session_manager): """Get available IBM Watson models""" try: - # Get credentials from query parameters or request body if provided - if request.method == "POST": - body = await request.json() - api_key = body.get("api_key") - endpoint = body.get("endpoint") - project_id = body.get("project_id") - else: - query_params = dict(request.query_params) - api_key = query_params.get("api_key") - endpoint = query_params.get("endpoint") - project_id = query_params.get("project_id") + # Get parameters from query parameters if provided + query_params = dict(request.query_params) + endpoint = query_params.get("endpoint") + api_key = query_params.get("api_key") + project_id = query_params.get("project_id") models = await models_service.get_ibm_models( - api_key=api_key, endpoint=endpoint, + api_key=api_key, project_id=project_id ) return JSONResponse(models) diff --git a/src/api/settings.py b/src/api/settings.py index 0763aae6..3a577d92 100644 --- a/src/api/settings.py +++ b/src/api/settings.py @@ -69,7 +69,6 @@ def get_docling_tweaks(docling_preset: str = None) -> dict: } } - async def get_settings(request, session_manager): """Get application settings""" try: @@ -117,8 +116,7 @@ async def get_settings(request, session_manager): if LANGFLOW_INGEST_FLOW_ID and openrag_config.edited: try: response = await clients.langflow_request( - "GET", - f"/api/v1/flows/{LANGFLOW_INGEST_FLOW_ID}" + "GET", f"/api/v1/flows/{LANGFLOW_INGEST_FLOW_ID}" ) if response.status_code == 200: flow_data = response.json() @@ -135,31 +133,23 @@ async def get_settings(request, session_manager): if flow_data.get("data", {}).get("nodes"): for node in flow_data["data"]["nodes"]: node_template = ( - node.get("data", {}) - .get("node", {}) - .get("template", {}) + node.get("data", {}).get("node", {}).get("template", {}) ) # Split Text component (SplitText-QIKhg) if node.get("id") == "SplitText-QIKhg": - if node_template.get("chunk_size", {}).get( - "value" - ): - ingestion_defaults["chunkSize"] = ( - node_template["chunk_size"]["value"] - ) - if node_template.get("chunk_overlap", {}).get( - "value" - ): - ingestion_defaults["chunkOverlap"] = ( - node_template["chunk_overlap"]["value"] - ) - if node_template.get("separator", {}).get( - "value" - ): - ingestion_defaults["separator"] = ( - node_template["separator"]["value"] - ) + if node_template.get("chunk_size", {}).get("value"): + ingestion_defaults["chunkSize"] = node_template[ + "chunk_size" + ]["value"] + if node_template.get("chunk_overlap", {}).get("value"): + ingestion_defaults["chunkOverlap"] = node_template[ + "chunk_overlap" + ]["value"] + if node_template.get("separator", {}).get("value"): + ingestion_defaults["separator"] = node_template[ + "separator" + ]["value"] # OpenAI Embeddings component (OpenAIEmbeddings-joRJ6) elif node.get("id") == "OpenAIEmbeddings-joRJ6": @@ -191,89 +181,116 @@ async def update_settings(request, session_manager): try: # Get current configuration current_config = get_openrag_config() - + # Check if config is marked as edited if not current_config.edited: return JSONResponse( - {"error": "Configuration must be marked as edited before updates are allowed"}, - status_code=403 + { + "error": "Configuration must be marked as edited before updates are allowed" + }, + status_code=403, ) - + # Parse request body body = await request.json() - + # Validate allowed fields allowed_fields = { - "llm_model", "system_prompt", "doclingPresets", - "chunk_size", "chunk_overlap" + "llm_model", + "system_prompt", + "ocr", + "picture_descriptions", + "chunk_size", + "chunk_overlap", + "doclingPresets", } - + # Check for invalid fields invalid_fields = set(body.keys()) - allowed_fields if invalid_fields: return JSONResponse( - {"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"}, - status_code=400 + { + "error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}" + }, + status_code=400, ) - + # Update configuration config_updated = False - + # Update agent settings if "llm_model" in body: current_config.agent.llm_model = body["llm_model"] config_updated = True - + if "system_prompt" in body: current_config.agent.system_prompt = body["system_prompt"] config_updated = True - + # Update knowledge settings if "doclingPresets" in body: preset_configs = get_docling_preset_configs() valid_presets = list(preset_configs.keys()) if body["doclingPresets"] not in valid_presets: return JSONResponse( - {"error": f"doclingPresets must be one of: {', '.join(valid_presets)}"}, - status_code=400 + { + "error": f"doclingPresets must be one of: {', '.join(valid_presets)}" + }, + status_code=400, ) current_config.knowledge.doclingPresets = body["doclingPresets"] config_updated = True - + + if "ocr" in body: + if not isinstance(body["ocr"], bool): + return JSONResponse( + {"error": "ocr must be a boolean value"}, status_code=400 + ) + current_config.knowledge.ocr = body["ocr"] + config_updated = True + + if "picture_descriptions" in body: + if not isinstance(body["picture_descriptions"], bool): + return JSONResponse( + {"error": "picture_descriptions must be a boolean value"}, + status_code=400, + ) + current_config.knowledge.picture_descriptions = body["picture_descriptions"] + config_updated = True + if "chunk_size" in body: if not isinstance(body["chunk_size"], int) or body["chunk_size"] <= 0: return JSONResponse( - {"error": "chunk_size must be a positive integer"}, - status_code=400 + {"error": "chunk_size must be a positive integer"}, status_code=400 ) current_config.knowledge.chunk_size = body["chunk_size"] config_updated = True - + if "chunk_overlap" in body: if not isinstance(body["chunk_overlap"], int) or body["chunk_overlap"] < 0: return JSONResponse( - {"error": "chunk_overlap must be a non-negative integer"}, - status_code=400 + {"error": "chunk_overlap must be a non-negative integer"}, + status_code=400, ) current_config.knowledge.chunk_overlap = body["chunk_overlap"] config_updated = True - + if not config_updated: return JSONResponse( - {"error": "No valid fields provided for update"}, - status_code=400 + {"error": "No valid fields provided for update"}, status_code=400 ) - + # Save the updated configuration if config_manager.save_config_file(current_config): - logger.info("Configuration updated successfully", updated_fields=list(body.keys())) + logger.info( + "Configuration updated successfully", updated_fields=list(body.keys()) + ) return JSONResponse({"message": "Configuration updated successfully"}) else: return JSONResponse( - {"error": "Failed to save configuration"}, - status_code=500 + {"error": "Failed to save configuration"}, status_code=500 ) - + except Exception as e: logger.error("Failed to update settings", error=str(e)) return JSONResponse( @@ -286,120 +303,168 @@ async def onboarding(request, flows_service): try: # Get current configuration current_config = get_openrag_config() - + # Check if config is NOT marked as edited (only allow onboarding if not yet configured) if current_config.edited: return JSONResponse( - {"error": "Configuration has already been edited. Use /settings endpoint for updates."}, - status_code=403 + { + "error": "Configuration has already been edited. Use /settings endpoint for updates." + }, + status_code=403, ) - + # Parse request body body = await request.json() - + # Validate allowed fields allowed_fields = { - "model_provider", "api_key", "embedding_model", "llm_model", "sample_data" + "model_provider", + "api_key", + "embedding_model", + "llm_model", + "sample_data", + "endpoint", + "project_id", } - + # Check for invalid fields invalid_fields = set(body.keys()) - allowed_fields if invalid_fields: return JSONResponse( - {"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"}, - status_code=400 + { + "error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}" + }, + status_code=400, ) - + # Update configuration config_updated = False - + # Update provider settings if "model_provider" in body: - if not isinstance(body["model_provider"], str) or not body["model_provider"].strip(): + if ( + not isinstance(body["model_provider"], str) + or not body["model_provider"].strip() + ): return JSONResponse( - {"error": "model_provider must be a non-empty string"}, - status_code=400 + {"error": "model_provider must be a non-empty string"}, + status_code=400, ) current_config.provider.model_provider = body["model_provider"].strip() config_updated = True - + if "api_key" in body: if not isinstance(body["api_key"], str): return JSONResponse( - {"error": "api_key must be a string"}, - status_code=400 + {"error": "api_key must be a string"}, status_code=400 ) current_config.provider.api_key = body["api_key"] config_updated = True - + # Update knowledge settings if "embedding_model" in body: - if not isinstance(body["embedding_model"], str) or not body["embedding_model"].strip(): + if ( + not isinstance(body["embedding_model"], str) + or not body["embedding_model"].strip() + ): return JSONResponse( - {"error": "embedding_model must be a non-empty string"}, - status_code=400 + {"error": "embedding_model must be a non-empty string"}, + status_code=400, ) current_config.knowledge.embedding_model = body["embedding_model"].strip() config_updated = True - + # Update agent settings if "llm_model" in body: if not isinstance(body["llm_model"], str) or not body["llm_model"].strip(): return JSONResponse( - {"error": "llm_model must be a non-empty string"}, - status_code=400 + {"error": "llm_model must be a non-empty string"}, status_code=400 ) current_config.agent.llm_model = body["llm_model"].strip() config_updated = True - + + if "endpoint" in body: + if not isinstance(body["endpoint"], str) or not body["endpoint"].strip(): + return JSONResponse( + {"error": "endpoint must be a non-empty string"}, status_code=400 + ) + current_config.provider.endpoint = body["endpoint"].strip() + config_updated = True + + if "project_id" in body: + if ( + not isinstance(body["project_id"], str) + or not body["project_id"].strip() + ): + return JSONResponse( + {"error": "project_id must be a non-empty string"}, status_code=400 + ) + current_config.provider.project_id = body["project_id"].strip() + config_updated = True + # Handle sample_data (unused for now but validate) if "sample_data" in body: if not isinstance(body["sample_data"], bool): return JSONResponse( - {"error": "sample_data must be a boolean value"}, - status_code=400 + {"error": "sample_data must be a boolean value"}, status_code=400 ) # Note: sample_data is accepted but not used as requested - + if not config_updated: return JSONResponse( - {"error": "No valid fields provided for update"}, - status_code=400 + {"error": "No valid fields provided for update"}, status_code=400 ) - + # Save the updated configuration (this will mark it as edited) if config_manager.save_config_file(current_config): - updated_fields = [k for k in body.keys() if k != "sample_data"] # Exclude sample_data from log - logger.info("Onboarding configuration updated successfully", updated_fields=updated_fields) - + updated_fields = [ + k for k in body.keys() if k != "sample_data" + ] # Exclude sample_data from log + logger.info( + "Onboarding configuration updated successfully", + updated_fields=updated_fields, + ) + # If model_provider was updated, assign the new provider to flows if "model_provider" in body: provider = body["model_provider"].strip().lower() try: flow_result = await flows_service.assign_model_provider(provider) - + if flow_result.get("success"): - logger.info(f"Successfully assigned {provider} to flows", flow_result=flow_result) + logger.info( + f"Successfully assigned {provider} to flows", + flow_result=flow_result, + ) else: - logger.warning(f"Failed to assign {provider} to flows", flow_result=flow_result) + logger.warning( + f"Failed to assign {provider} to flows", + flow_result=flow_result, + ) # Continue even if flow assignment fails - configuration was still saved - + except Exception as e: - logger.error(f"Error assigning model provider to flows", provider=provider, error=str(e)) + logger.error( + f"Error assigning model provider to flows", + provider=provider, + error=str(e), + ) # Continue even if flow assignment fails - configuration was still saved - - return JSONResponse({ - "message": "Onboarding configuration updated successfully", - "edited": True # Confirm that config is now marked as edited - }) + + return JSONResponse( + { + "message": "Onboarding configuration updated successfully", + "edited": True, # Confirm that config is now marked as edited + } + ) else: return JSONResponse( - {"error": "Failed to save configuration"}, - status_code=500 + {"error": "Failed to save configuration"}, status_code=500 ) - + except Exception as e: logger.error("Failed to update onboarding settings", error=str(e)) return JSONResponse( - {"error": f"Failed to update onboarding settings: {str(e)}"}, status_code=500 + {"error": f"Failed to update onboarding settings: {str(e)}"}, + status_code=500, ) diff --git a/src/services/models_service.py b/src/services/models_service.py index 9855459e..7620461b 100644 --- a/src/services/models_service.py +++ b/src/services/models_service.py @@ -1,5 +1,4 @@ import httpx -import os from typing import Dict, List from utils.logging_config import get_logger @@ -12,14 +11,9 @@ class ModelsService: def __init__(self): self.session_manager = None - async def get_openai_models(self) -> Dict[str, List[Dict[str, str]]]: + async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]: """Fetch available models from OpenAI API""" try: - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - logger.warning("OPENAI_API_KEY not set, using default models") - return self._get_default_openai_models() - headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", @@ -74,12 +68,14 @@ class ModelsService: "embedding_models": embedding_models, } else: - logger.warning(f"Failed to fetch OpenAI models: {response.status_code}") - return self._get_default_openai_models() + logger.error(f"Failed to fetch OpenAI models: {response.status_code}") + raise Exception( + f"OpenAI API returned status code {response.status_code}" + ) except Exception as e: logger.error(f"Error fetching OpenAI models: {str(e)}") - return self._get_default_openai_models() + raise async def get_ollama_models( self, endpoint: str = None @@ -87,9 +83,7 @@ class ModelsService: """Fetch available models from Ollama API""" try: # Use provided endpoint or default - ollama_url = endpoint or os.getenv( - "OLLAMA_BASE_URL", "http://localhost:11434" - ) + ollama_url = endpoint async with httpx.AsyncClient() as client: response = await client.get(f"{ollama_url}/api/tags", timeout=10.0) @@ -145,32 +139,34 @@ class ModelsService: return { "language_models": language_models, - "embedding_models": embedding_models - if embedding_models - else [ - { - "value": "nomic-embed-text", - "label": "nomic-embed-text", - "default": True, - } - ], + "embedding_models": embedding_models if embedding_models else [], } else: - logger.warning(f"Failed to fetch Ollama models: {response.status_code}") - return self._get_default_ollama_models() + logger.error(f"Failed to fetch Ollama models: {response.status_code}") + raise Exception( + f"Ollama API returned status code {response.status_code}" + ) except Exception as e: logger.error(f"Error fetching Ollama models: {str(e)}") - return self._get_default_ollama_models() + raise async def get_ibm_models( - self, endpoint: str = None + self, endpoint: str = None, api_key: str = None, project_id: str = None ) -> Dict[str, List[Dict[str, str]]]: """Fetch available models from IBM Watson API""" try: # Use provided endpoint or default - watson_endpoint = endpoint or os.getenv("IBM_WATSON_ENDPOINT", "https://us-south.ml.cloud.ibm.com") + watson_endpoint = endpoint + # Prepare headers for authentication + headers = { + "Content-Type": "application/json", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if project_id: + headers["Project-ID"] = project_id # Fetch foundation models using the correct endpoint models_url = f"{watson_endpoint}/ml/v1/foundation_model_specs" @@ -181,9 +177,14 @@ class ModelsService: # Fetch text chat models text_params = { "version": "2024-09-16", - "filters": "function_text_chat,!lifecycle_withdrawn" + "filters": "function_text_chat,!lifecycle_withdrawn", } - text_response = await client.get(models_url, params=text_params, timeout=10.0) + if project_id: + text_params["project_id"] = project_id + + text_response = await client.get( + models_url, params=text_params, headers=headers, timeout=10.0 + ) if text_response.status_code == 200: text_data = text_response.json() @@ -193,18 +194,25 @@ class ModelsService: model_id = model.get("model_id", "") model_name = model.get("name", model_id) - language_models.append({ - "value": model_id, - "label": model_name or model_id, - "default": i == 0 # First model is default - }) + language_models.append( + { + "value": model_id, + "label": model_name or model_id, + "default": i == 0, # First model is default + } + ) # Fetch embedding models embed_params = { "version": "2024-09-16", - "filters": "function_embedding,!lifecycle_withdrawn" + "filters": "function_embedding,!lifecycle_withdrawn", } - embed_response = await client.get(models_url, params=embed_params, timeout=10.0) + if project_id: + embed_params["project_id"] = project_id + + embed_response = await client.get( + models_url, params=embed_params, headers=headers, timeout=10.0 + ) if embed_response.status_code == 200: embed_data = embed_response.json() @@ -214,104 +222,22 @@ class ModelsService: model_id = model.get("model_id", "") model_name = model.get("name", model_id) - embedding_models.append({ - "value": model_id, - "label": model_name or model_id, - "default": i == 0 # First model is default - }) + embedding_models.append( + { + "value": model_id, + "label": model_name or model_id, + "default": i == 0, # First model is default + } + ) + + if not language_models and not embedding_models: + raise Exception("No IBM models retrieved from API") return { - "language_models": language_models if language_models else self._get_default_ibm_models()["language_models"], - "embedding_models": embedding_models if embedding_models else self._get_default_ibm_models()["embedding_models"] + "language_models": language_models, + "embedding_models": embedding_models, } except Exception as e: logger.error(f"Error fetching IBM models: {str(e)}") - return self._get_default_ibm_models() - - def _get_default_openai_models(self) -> Dict[str, List[Dict[str, str]]]: - """Default OpenAI models when API is not available""" - return { - "language_models": [ - {"value": "gpt-4o-mini", "label": "gpt-4o-mini", "default": True}, - {"value": "gpt-4o", "label": "gpt-4o", "default": False}, - {"value": "gpt-4-turbo", "label": "gpt-4-turbo", "default": False}, - {"value": "gpt-3.5-turbo", "label": "gpt-3.5-turbo", "default": False}, - ], - "embedding_models": [ - { - "value": "text-embedding-3-small", - "label": "text-embedding-3-small", - "default": True, - }, - { - "value": "text-embedding-3-large", - "label": "text-embedding-3-large", - "default": False, - }, - { - "value": "text-embedding-ada-002", - "label": "text-embedding-ada-002", - "default": False, - }, - ], - } - - def _get_default_ollama_models(self) -> Dict[str, List[Dict[str, str]]]: - """Default Ollama models when API is not available""" - return { - "language_models": [ - {"value": "llama3.2", "label": "llama3.2", "default": True}, - {"value": "llama3.1", "label": "llama3.1", "default": False}, - {"value": "llama3", "label": "llama3", "default": False}, - {"value": "mistral", "label": "mistral", "default": False}, - {"value": "codellama", "label": "codellama", "default": False}, - ], - "embedding_models": [ - { - "value": "nomic-embed-text", - "label": "nomic-embed-text", - "default": True, - }, - {"value": "all-minilm", "label": "all-minilm", "default": False}, - ], - } - - def _get_default_ibm_models(self) -> Dict[str, List[Dict[str, str]]]: - """Default IBM Watson models when API is not available""" - return { - "language_models": [ - { - "value": "meta-llama/llama-3-1-70b-instruct", - "label": "Llama 3.1 70B Instruct", - "default": True, - }, - { - "value": "meta-llama/llama-3-1-8b-instruct", - "label": "Llama 3.1 8B Instruct", - "default": False, - }, - { - "value": "ibm/granite-13b-chat-v2", - "label": "Granite 13B Chat v2", - "default": False, - }, - { - "value": "ibm/granite-13b-instruct-v2", - "label": "Granite 13B Instruct v2", - "default": False, - }, - ], - "embedding_models": [ - { - "value": "ibm/slate-125m-english-rtrvr", - "label": "Slate 125M English Retriever", - "default": True, - }, - { - "value": "sentence-transformers/all-minilm-l12-v2", - "label": "All-MiniLM L12 v2", - "default": False, - }, - ], - } + raise