Merge branch 'feat/onboarding' of github.com:langflow-ai/openrag into feat/onboarding

This commit is contained in:
Mike Fortman 2025-09-20 11:05:10 -05:00
commit abb1ae0819
21 changed files with 957 additions and 628 deletions

View file

@ -1,4 +1,4 @@
import { useCallback, useRef } from "react";
import { useCallback, useRef, useState, useEffect } from "react";
export function useDebounce<T extends (...args: never[]) => void>(
callback: T,
@ -21,3 +21,19 @@ export function useDebounce<T extends (...args: never[]) => void>(
return debouncedCallback;
}
export function useDebouncedValue<T>(value: T, delay: number): T {
const [debouncedValue, setDebouncedValue] = useState<T>(value);
useEffect(() => {
const handler = setTimeout(() => {
setDebouncedValue(value);
}, delay);
return () => {
clearTimeout(handler);
};
}, [value, delay]);
return debouncedValue;
}

View file

@ -0,0 +1,61 @@
import {
type UseMutationOptions,
useMutation,
useQueryClient,
} from "@tanstack/react-query";
export interface OnboardingVariables {
model_provider: string;
api_key?: string;
endpoint?: string;
project_id?: string;
embedding_model: string;
llm_model: string;
sample_data?: boolean;
}
interface OnboardingResponse {
message: string;
edited: boolean;
}
export const useOnboardingMutation = (
options?: Omit<
UseMutationOptions<
OnboardingResponse,
Error,
OnboardingVariables
>,
"mutationFn"
>,
) => {
const queryClient = useQueryClient();
async function submitOnboarding(
variables: OnboardingVariables,
): Promise<OnboardingResponse> {
const response = await fetch("/api/onboarding", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(variables),
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.error || "Failed to complete onboarding");
}
return response.json();
}
return useMutation({
mutationFn: submitOnboarding,
onSettled: () => {
// Invalidate settings query to refetch updated data
queryClient.invalidateQueries({ queryKey: ["settings"] });
},
...options,
});
};

View file

@ -15,23 +15,33 @@ export interface ModelsResponse {
embedding_models: ModelOption[];
}
export interface OpenAIModelsParams {
apiKey?: string;
}
export interface OllamaModelsParams {
endpoint?: string;
}
export interface IBMModelsParams {
api_key?: string;
endpoint?: string;
project_id?: string;
apiKey?: string;
projectId?: string;
}
export const useGetOpenAIModelsQuery = (
params?: OpenAIModelsParams,
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
) => {
const queryClient = useQueryClient();
async function getOpenAIModels(): Promise<ModelsResponse> {
const response = await fetch("/api/models/openai");
const url = new URL("/api/models/openai", window.location.origin);
if (params?.apiKey) {
url.searchParams.set("api_key", params.apiKey);
}
const response = await fetch(url.toString());
if (response.ok) {
return await response.json();
} else {
@ -41,9 +51,12 @@ export const useGetOpenAIModelsQuery = (
const queryResult = useQuery(
{
queryKey: ["models", "openai"],
queryKey: ["models", "openai", params],
queryFn: getOpenAIModels,
staleTime: 5 * 60 * 1000, // 5 minutes
retry: 2,
enabled: !!params?.apiKey, // Only run if API key is provided
staleTime: 0, // Always fetch fresh data
gcTime: 0, // Don't cache results
...options,
},
queryClient,
@ -76,8 +89,10 @@ export const useGetOllamaModelsQuery = (
{
queryKey: ["models", "ollama", params],
queryFn: getOllamaModels,
staleTime: 5 * 60 * 1000, // 5 minutes
retry: 2,
enabled: !!params?.endpoint, // Only run if endpoint is provided
staleTime: 0, // Always fetch fresh data
gcTime: 0, // Don't cache results
...options,
},
queryClient,
@ -93,35 +108,22 @@ export const useGetIBMModelsQuery = (
const queryClient = useQueryClient();
async function getIBMModels(): Promise<ModelsResponse> {
const url = "/api/models/ibm";
const url = new URL("/api/models/ibm", window.location.origin);
if (params?.endpoint) {
url.searchParams.set("endpoint", params.endpoint);
}
if (params?.apiKey) {
url.searchParams.set("api_key", params.apiKey);
}
if (params?.projectId) {
url.searchParams.set("project_id", params.projectId);
}
// If we have credentials, use POST to send them securely
if (params?.api_key || params?.endpoint || params?.project_id) {
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_key: params.api_key,
endpoint: params.endpoint,
project_id: params.project_id,
}),
});
if (response.ok) {
return await response.json();
} else {
throw new Error("Failed to fetch IBM models");
}
const response = await fetch(url.toString());
if (response.ok) {
return await response.json();
} else {
// Use GET for default models
const response = await fetch(url);
if (response.ok) {
return await response.json();
} else {
throw new Error("Failed to fetch IBM models");
}
throw new Error("Failed to fetch IBM models");
}
}
@ -129,8 +131,10 @@ export const useGetIBMModelsQuery = (
{
queryKey: ["models", "ibm", params],
queryFn: getIBMModels,
staleTime: 5 * 60 * 1000, // 5 minutes
enabled: !!(params?.api_key && params?.endpoint && params?.project_id), // Only run if all credentials are provided
retry: 2,
enabled: !!params?.endpoint && !!params?.apiKey && !!params?.projectId, // Only run if all required params are provided
staleTime: 0, // Always fetch fresh data
gcTime: 0, // Don't cache results
...options,
},
queryClient,

View file

@ -63,3 +63,6 @@ export const useGetSettingsQuery = (
return queryResult;
};

View file

@ -14,8 +14,8 @@ function LoginPageContent() {
const router = useRouter();
const searchParams = useSearchParams();
const { data: settings } = useGetSettingsQuery({
enabled: isAuthenticated,
const { data: settings, isLoading: isSettingsLoading } = useGetSettingsQuery({
enabled: isAuthenticated || isNoAuthMode,
});
const redirect =
@ -25,12 +25,19 @@ function LoginPageContent() {
// Redirect if already authenticated or in no-auth mode
useEffect(() => {
if (!isLoading && (isAuthenticated || isNoAuthMode)) {
if (!isLoading && !isSettingsLoading && (isAuthenticated || isNoAuthMode)) {
router.push(redirect);
}
}, [isLoading, isAuthenticated, isNoAuthMode, router, redirect]);
}, [
isLoading,
isSettingsLoading,
isAuthenticated,
isNoAuthMode,
router,
redirect,
]);
if (isLoading) {
if (isLoading || isSettingsLoading) {
return (
<div className="min-h-screen flex items-center justify-center bg-background">
<div className="flex flex-col items-center gap-4">

View file

@ -32,8 +32,13 @@ export function AdvancedOnboarding({
setSampleDataset: (dataset: boolean) => void;
}) {
const hasEmbeddingModels =
embeddingModels && embeddingModel && setEmbeddingModel;
const hasLanguageModels = languageModels && languageModel && setLanguageModel;
embeddingModels !== undefined &&
embeddingModel !== undefined &&
setEmbeddingModel !== undefined;
const hasLanguageModels =
languageModels !== undefined &&
languageModel !== undefined &&
setLanguageModel !== undefined;
return (
<Accordion type="single" collapsible>
<AccordionItem value="item-1">

View file

@ -0,0 +1,127 @@
import { useState } from "react";
import { LabelInput } from "@/components/label-input";
import IBMLogo from "@/components/logo/ibm-logo";
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
import { useGetIBMModelsQuery } from "../../api/queries/useGetModelsQuery";
import { useModelSelection } from "../hooks/useModelSelection";
import { useUpdateSettings } from "../hooks/useUpdateSettings";
import { useDebouncedValue } from "@/lib/debounce";
import { AdvancedOnboarding } from "./advanced";
export function IBMOnboarding({
setSettings,
sampleDataset,
setSampleDataset,
}: {
setSettings: (settings: OnboardingVariables) => void;
sampleDataset: boolean;
setSampleDataset: (dataset: boolean) => void;
}) {
const [endpoint, setEndpoint] = useState("");
const [apiKey, setApiKey] = useState("");
const [projectId, setProjectId] = useState("");
const debouncedEndpoint = useDebouncedValue(endpoint, 500);
const debouncedApiKey = useDebouncedValue(apiKey, 500);
const debouncedProjectId = useDebouncedValue(projectId, 500);
// Fetch models from API when all credentials are provided
const {
data: modelsData,
isLoading: isLoadingModels,
error: modelsError,
} = useGetIBMModelsQuery(
debouncedEndpoint && debouncedApiKey && debouncedProjectId
? {
endpoint: debouncedEndpoint,
apiKey: debouncedApiKey,
projectId: debouncedProjectId,
}
: undefined,
);
// Use custom hook for model selection logic
const {
languageModel,
embeddingModel,
setLanguageModel,
setEmbeddingModel,
languageModels,
embeddingModels,
} = useModelSelection(modelsData);
const handleSampleDatasetChange = (dataset: boolean) => {
setSampleDataset(dataset);
};
// Update settings when values change
useUpdateSettings(
"ibm",
{
endpoint,
apiKey,
projectId,
languageModel,
embeddingModel,
},
setSettings,
);
return (
<>
<div className="space-y-4">
<LabelInput
label="watsonx.ai API Endpoint"
helperText="The API endpoint for your watsonx.ai account."
id="api-endpoint"
required
placeholder="https://us-south.ml.cloud.ibm.com"
value={endpoint}
onChange={(e) => setEndpoint(e.target.value)}
/>
<LabelInput
label="IBM API key"
helperText="The API key for your watsonx.ai account."
id="api-key"
required
placeholder="your-api-key"
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
/>
<LabelInput
label="IBM Project ID"
helperText="The project ID for your watsonx.ai account."
id="project-id"
required
placeholder="your-project-id"
value={projectId}
onChange={(e) => setProjectId(e.target.value)}
/>
{isLoadingModels && (
<p className="text-sm text-muted-foreground">
Validating configuration...
</p>
)}
{modelsError && (
<p className="text-sm text-red-500">
Invalid configuration or connection failed
</p>
)}
{modelsData &&
(modelsData.language_models?.length > 0 ||
modelsData.embedding_models?.length > 0) && (
<p className="text-sm text-green-600">Configuration is valid</p>
)}
</div>
<AdvancedOnboarding
icon={<IBMLogo className="w-4 h-4" />}
languageModels={languageModels}
embeddingModels={embeddingModels}
languageModel={languageModel}
embeddingModel={embeddingModel}
sampleDataset={sampleDataset}
setLanguageModel={setLanguageModel}
setEmbeddingModel={setEmbeddingModel}
setSampleDataset={handleSampleDatasetChange}
/>
</>
);
}

View file

@ -1,5 +1,5 @@
import { CheckIcon, ChevronsUpDownIcon } from "lucide-react";
import { useState } from "react";
import { useEffect, useState } from "react";
import { Button } from "@/components/ui/button";
import {
Command,
@ -32,6 +32,11 @@ export function ModelSelector({
onValueChange: (value: string) => void;
}) {
const [open, setOpen] = useState(false);
useEffect(() => {
if (value && !options.find((option) => option.value === value)) {
onValueChange("");
}
}, [options, value, onValueChange]);
return (
<Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger asChild>
@ -39,6 +44,7 @@ export function ModelSelector({
<Button
variant="outline"
role="combobox"
disabled={options.length === 0}
aria-expanded={open}
className="w-full gap-2 justify-between font-normal text-sm"
>
@ -53,6 +59,8 @@ export function ModelSelector({
</span>
)}
</div>
) : options.length === 0 ? (
"No models available"
) : (
"Select model..."
)}

View file

@ -0,0 +1,135 @@
import { useState } from "react";
import { LabelInput } from "@/components/label-input";
import { LabelWrapper } from "@/components/label-wrapper";
import OllamaLogo from "@/components/logo/ollama-logo";
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
import { useGetOllamaModelsQuery } from "../../api/queries/useGetModelsQuery";
import { useModelSelection } from "../hooks/useModelSelection";
import { useUpdateSettings } from "../hooks/useUpdateSettings";
import { useDebouncedValue } from "@/lib/debounce";
import { AdvancedOnboarding } from "./advanced";
import { ModelSelector } from "./model-selector";
export function OllamaOnboarding({
setSettings,
sampleDataset,
setSampleDataset,
}: {
setSettings: (settings: OnboardingVariables) => void;
sampleDataset: boolean;
setSampleDataset: (dataset: boolean) => void;
}) {
const [endpoint, setEndpoint] = useState("");
const debouncedEndpoint = useDebouncedValue(endpoint, 500);
// Fetch models from API when endpoint is provided (debounced)
const {
data: modelsData,
isLoading: isLoadingModels,
error: modelsError,
} = useGetOllamaModelsQuery(
debouncedEndpoint ? { endpoint: debouncedEndpoint } : undefined,
);
// Use custom hook for model selection logic
const {
languageModel,
embeddingModel,
setLanguageModel,
setEmbeddingModel,
languageModels,
embeddingModels,
} = useModelSelection(modelsData);
const handleSampleDatasetChange = (dataset: boolean) => {
setSampleDataset(dataset);
};
// Update settings when values change
useUpdateSettings(
"ollama",
{
endpoint,
languageModel,
embeddingModel,
},
setSettings,
);
// Check validation state based on models query
const isConnecting = debouncedEndpoint && isLoadingModels;
const hasConnectionError = debouncedEndpoint && modelsError;
const hasNoModels =
modelsData &&
!modelsData.language_models?.length &&
!modelsData.embedding_models?.length;
const isValidConnection =
modelsData &&
(modelsData.language_models?.length > 0 ||
modelsData.embedding_models?.length > 0);
return (
<>
<div className="space-y-1">
<LabelInput
label="Ollama Endpoint"
helperText="The endpoint for your Ollama server."
id="api-endpoint"
required
placeholder="http://localhost:11434"
value={endpoint}
onChange={(e) => setEndpoint(e.target.value)}
/>
{isConnecting && (
<p className="text-sm text-muted-foreground">
Connecting to Ollama server...
</p>
)}
{hasConnectionError && (
<p className="text-sm text-red-500">
Cannot connect to Ollama server. Please check the endpoint.
</p>
)}
{hasNoModels && (
<p className="text-sm text-yellow-600">
No models found. Please install some models on your Ollama server.
</p>
)}
{isValidConnection && (
<p className="text-sm text-green-600">Connected successfully</p>
)}
</div>
<LabelWrapper
label="Embedding model"
helperText="The embedding model for your Ollama server."
id="embedding-model"
required={true}
>
<ModelSelector
options={embeddingModels}
icon={<OllamaLogo className="w-4 h-4" />}
value={embeddingModel}
onValueChange={setEmbeddingModel}
/>
</LabelWrapper>
<LabelWrapper
label="Language model"
helperText="The embedding model for your Ollama server."
id="embedding-model"
required={true}
>
<ModelSelector
options={languageModels}
icon={<OllamaLogo className="w-4 h-4" />}
value={languageModel}
onValueChange={setLanguageModel}
/>
</LabelWrapper>
<AdvancedOnboarding
sampleDataset={sampleDataset}
setSampleDataset={handleSampleDatasetChange}
/>
</>
);
}

View file

@ -0,0 +1,93 @@
import { useState } from "react";
import { LabelInput } from "@/components/label-input";
import OpenAILogo from "@/components/logo/openai-logo";
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
import { useGetOpenAIModelsQuery } from "../../api/queries/useGetModelsQuery";
import { useModelSelection } from "../hooks/useModelSelection";
import { useUpdateSettings } from "../hooks/useUpdateSettings";
import { useDebouncedValue } from "@/lib/debounce";
import { AdvancedOnboarding } from "./advanced";
export function OpenAIOnboarding({
setSettings,
sampleDataset,
setSampleDataset,
}: {
setSettings: (settings: OnboardingVariables) => void;
sampleDataset: boolean;
setSampleDataset: (dataset: boolean) => void;
}) {
const [apiKey, setApiKey] = useState("");
const debouncedApiKey = useDebouncedValue(apiKey, 500);
// Fetch models from API when API key is provided
const {
data: modelsData,
isLoading: isLoadingModels,
error: modelsError,
} = useGetOpenAIModelsQuery(
debouncedApiKey ? { apiKey: debouncedApiKey } : undefined,
);
// Use custom hook for model selection logic
const {
languageModel,
embeddingModel,
setLanguageModel,
setEmbeddingModel,
languageModels,
embeddingModels,
} = useModelSelection(modelsData);
const handleSampleDatasetChange = (dataset: boolean) => {
setSampleDataset(dataset);
};
// Update settings when values change
useUpdateSettings(
"openai",
{
apiKey,
languageModel,
embeddingModel,
},
setSettings,
);
return (
<>
<div className="space-y-1">
<LabelInput
label="OpenAI API key"
helperText="The API key for your OpenAI account."
id="api-key"
required
placeholder="sk-..."
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
/>
{isLoadingModels && (
<p className="text-sm text-muted-foreground">Validating API key...</p>
)}
{modelsError && (
<p className="text-sm text-red-500">
Invalid API key or configuration
</p>
)}
{modelsData &&
(modelsData.language_models?.length > 0 ||
modelsData.embedding_models?.length > 0) && (
<p className="text-sm text-green-600">Configuration is valid</p>
)}
</div>
<AdvancedOnboarding
icon={<OpenAILogo className="w-4 h-4" />}
languageModels={languageModels}
embeddingModels={embeddingModels}
languageModel={languageModel}
embeddingModel={embeddingModel}
sampleDataset={sampleDataset}
setLanguageModel={setLanguageModel}
setSampleDataset={handleSampleDatasetChange}
setEmbeddingModel={setEmbeddingModel}
/>
</>
);
}

View file

@ -0,0 +1,46 @@
import { useState, useEffect } from "react";
import type { ModelsResponse } from "../../api/queries/useGetModelsQuery";
export function useModelSelection(modelsData: ModelsResponse | undefined) {
const [languageModel, setLanguageModel] = useState("");
const [embeddingModel, setEmbeddingModel] = useState("");
// Update default selections when models are loaded
useEffect(() => {
if (modelsData) {
const defaultLangModel = modelsData.language_models.find(
(m) => m.default,
);
const defaultEmbedModel = modelsData.embedding_models.find(
(m) => m.default,
);
// Set language model: prefer default, fallback to first available
if (!languageModel) {
if (defaultLangModel) {
setLanguageModel(defaultLangModel.value);
} else if (modelsData.language_models.length > 0) {
setLanguageModel(modelsData.language_models[0].value);
}
}
// Set embedding model: prefer default, fallback to first available
if (!embeddingModel) {
if (defaultEmbedModel) {
setEmbeddingModel(defaultEmbedModel.value);
} else if (modelsData.embedding_models.length > 0) {
setEmbeddingModel(modelsData.embedding_models[0].value);
}
}
}
}, [modelsData, languageModel, embeddingModel]);
return {
languageModel,
embeddingModel,
setLanguageModel,
setEmbeddingModel,
languageModels: modelsData?.language_models || [],
embeddingModels: modelsData?.embedding_models || [],
};
}

View file

@ -0,0 +1,58 @@
import { useEffect } from "react";
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
interface ConfigValues {
apiKey?: string;
endpoint?: string;
projectId?: string;
languageModel?: string;
embeddingModel?: string;
}
export function useUpdateSettings(
provider: string,
config: ConfigValues,
setSettings: (settings: OnboardingVariables) => void,
) {
useEffect(() => {
const updatedSettings: OnboardingVariables = {
model_provider: provider,
embedding_model: "",
llm_model: "",
};
// Set language model if provided
if (config.languageModel) {
updatedSettings.llm_model = config.languageModel;
}
// Set embedding model if provided
if (config.embeddingModel) {
updatedSettings.embedding_model = config.embeddingModel;
}
// Set API key if provided
if (config.apiKey) {
updatedSettings.api_key = config.apiKey;
}
// Set endpoint and project ID if provided
if (config.endpoint) {
updatedSettings.endpoint = config.endpoint;
}
if (config.projectId) {
updatedSettings.project_id = config.projectId;
}
setSettings(updatedSettings);
}, [
provider,
config.apiKey,
config.endpoint,
config.projectId,
config.languageModel,
config.embeddingModel,
setSettings,
]);
}

View file

@ -1,110 +0,0 @@
import { useState, useEffect } from "react";
import { LabelInput } from "@/components/label-input";
import IBMLogo from "@/components/logo/ibm-logo";
import type { Settings } from "../api/queries/useGetSettingsQuery";
import { useGetIBMModelsQuery } from "../api/queries/useGetModelsQuery";
import { AdvancedOnboarding } from "./advanced";
export function IBMOnboarding({
settings,
setSettings,
sampleDataset,
setSampleDataset,
}: {
settings: Settings;
setSettings: (settings: Settings) => void;
sampleDataset: boolean;
setSampleDataset: (dataset: boolean) => void;
}) {
const [endpoint, setEndpoint] = useState("");
const [apiKey, setApiKey] = useState("");
const [projectId, setProjectId] = useState("");
const [languageModel, setLanguageModel] = useState("meta-llama/llama-3-1-70b-instruct");
const [embeddingModel, setEmbeddingModel] = useState("ibm/slate-125m-english-rtrvr");
// Fetch models from API when all credentials are provided
const { data: modelsData } = useGetIBMModelsQuery(
(apiKey && endpoint && projectId) ? { api_key: apiKey, endpoint, project_id: projectId } : undefined,
{ enabled: !!(apiKey && endpoint && projectId) }
);
// Use fetched models or fallback to defaults
const languageModels = modelsData?.language_models || [
{ value: "meta-llama/llama-3-1-70b-instruct", label: "Llama 3.1 70B Instruct", default: true },
{ value: "meta-llama/llama-3-1-8b-instruct", label: "Llama 3.1 8B Instruct" },
{ value: "ibm/granite-13b-chat-v2", label: "Granite 13B Chat v2" },
{ value: "ibm/granite-13b-instruct-v2", label: "Granite 13B Instruct v2" },
];
const embeddingModels = modelsData?.embedding_models || [
{ value: "ibm/slate-125m-english-rtrvr", label: "Slate 125M English Retriever", default: true },
{ value: "sentence-transformers/all-minilm-l12-v2", label: "All-MiniLM L12 v2" },
];
// Update default selections when models are loaded
useEffect(() => {
if (modelsData) {
const defaultLangModel = modelsData.language_models.find(m => m.default);
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
if (defaultLangModel) {
setLanguageModel(defaultLangModel.value);
}
if (defaultEmbedModel) {
setEmbeddingModel(defaultEmbedModel.value);
}
}
}, [modelsData]);
const handleLanguageModelChange = (model: string) => {
setLanguageModel(model);
};
const handleEmbeddingModelChange = (model: string) => {
setEmbeddingModel(model);
};
const handleSampleDatasetChange = (dataset: boolean) => {
setSampleDataset(dataset);
};
return (
<>
<LabelInput
label="watsonx.ai API Endpoint"
helperText="The API endpoint for your watsonx.ai account."
id="api-endpoint"
required
placeholder="https://us-south.ml.cloud.ibm.com"
value={endpoint}
onChange={(e) => setEndpoint(e.target.value)}
/>
<LabelInput
label="IBM API key"
helperText="The API key for your watsonx.ai account."
id="api-key"
required
placeholder="your-api-key"
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
/>
<LabelInput
label="IBM Project ID"
helperText="The project ID for your watsonx.ai account."
id="project-id"
required
placeholder="your-project-id"
value={projectId}
onChange={(e) => setProjectId(e.target.value)}
/>
<AdvancedOnboarding
icon={<IBMLogo className="w-4 h-4" />}
languageModels={languageModels}
embeddingModels={embeddingModels}
languageModel={languageModel}
embeddingModel={embeddingModel}
sampleDataset={sampleDataset}
setLanguageModel={handleLanguageModelChange}
setEmbeddingModel={handleEmbeddingModelChange}
setSampleDataset={handleSampleDatasetChange}
/>
</>
);
}

View file

@ -1,105 +0,0 @@
import { useState, useEffect } from "react";
import { LabelInput } from "@/components/label-input";
import { LabelWrapper } from "@/components/label-wrapper";
import OllamaLogo from "@/components/logo/ollama-logo";
import type { Settings } from "../api/queries/useGetSettingsQuery";
import { useGetOllamaModelsQuery } from "../api/queries/useGetModelsQuery";
import { AdvancedOnboarding } from "./advanced";
import { ModelSelector } from "./model-selector";
export function OllamaOnboarding({
settings,
setSettings,
sampleDataset,
setSampleDataset,
}: {
settings: Settings;
setSettings: (settings: Settings) => void;
sampleDataset: boolean;
setSampleDataset: (dataset: boolean) => void;
}) {
const [endpoint, setEndpoint] = useState("");
const [languageModel, setLanguageModel] = useState("llama3.2");
const [embeddingModel, setEmbeddingModel] = useState("nomic-embed-text");
// Fetch models from API when endpoint is provided
const { data: modelsData } = useGetOllamaModelsQuery(
endpoint ? { endpoint } : undefined,
{ enabled: !!endpoint }
);
// Use fetched models or fallback to defaults
const languageModels = modelsData?.language_models || [
{ value: "llama3.2", label: "llama3.2", default: true },
{ value: "llama3.1", label: "llama3.1" },
{ value: "llama3", label: "llama3" },
{ value: "mistral", label: "mistral" },
{ value: "codellama", label: "codellama" },
];
const embeddingModels = modelsData?.embedding_models || [
{ value: "nomic-embed-text", label: "nomic-embed-text", default: true },
];
// Update default selections when models are loaded
useEffect(() => {
if (modelsData) {
const defaultLangModel = modelsData.language_models.find(m => m.default);
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
if (defaultLangModel) {
setLanguageModel(defaultLangModel.value);
}
if (defaultEmbedModel) {
setEmbeddingModel(defaultEmbedModel.value);
}
}
}, [modelsData]);
const handleSampleDatasetChange = (dataset: boolean) => {
setSampleDataset(dataset);
};
return (
<>
<LabelInput
label="Ollama Endpoint"
helperText="The endpoint for your Ollama server."
id="api-endpoint"
required
placeholder="http://localhost:11434"
value={endpoint}
onChange={(e) => setEndpoint(e.target.value)}
/>
<LabelWrapper
label="Embedding model"
helperText="The embedding model for your Ollama server."
id="embedding-model"
required={true}
>
<ModelSelector
options={embeddingModels}
icon={<OllamaLogo className="w-4 h-4" />}
value={embeddingModel}
onValueChange={setEmbeddingModel}
/>
</LabelWrapper>
<LabelWrapper
label="Language model"
helperText="The embedding model for your Ollama server."
id="embedding-model"
required={true}
>
<ModelSelector
options={languageModels}
icon={<OllamaLogo className="w-4 h-4" />}
value={languageModel}
onValueChange={setLanguageModel}
/>
</LabelWrapper>
<AdvancedOnboarding
sampleDataset={sampleDataset}
setSampleDataset={handleSampleDatasetChange}
/>
</>
);
}

View file

@ -1,80 +0,0 @@
import { useState, useEffect } from "react";
import { LabelInput } from "@/components/label-input";
import OpenAILogo from "@/components/logo/openai-logo";
import type { Settings } from "../api/queries/useGetSettingsQuery";
import { useGetOpenAIModelsQuery } from "../api/queries/useGetModelsQuery";
import { AdvancedOnboarding } from "./advanced";
export function OpenAIOnboarding({
settings,
setSettings,
sampleDataset,
setSampleDataset,
}: {
settings: Settings;
setSettings: (settings: Settings) => void;
sampleDataset: boolean;
setSampleDataset: (dataset: boolean) => void;
}) {
const [languageModel, setLanguageModel] = useState("gpt-4o-mini");
const [embeddingModel, setEmbeddingModel] = useState(
"text-embedding-3-small",
);
// Fetch models from API
const { data: modelsData } = useGetOpenAIModelsQuery();
// Use fetched models or fallback to defaults
const languageModels = modelsData?.language_models || [{ value: "gpt-4o-mini", label: "gpt-4o-mini", default: true }];
const embeddingModels = modelsData?.embedding_models || [
{ value: "text-embedding-3-small", label: "text-embedding-3-small", default: true },
];
// Update default selections when models are loaded
useEffect(() => {
if (modelsData) {
const defaultLangModel = modelsData.language_models.find(m => m.default);
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
if (defaultLangModel && languageModel === "gpt-4o-mini") {
setLanguageModel(defaultLangModel.value);
}
if (defaultEmbedModel && embeddingModel === "text-embedding-3-small") {
setEmbeddingModel(defaultEmbedModel.value);
}
}
}, [modelsData, languageModel, embeddingModel]);
const handleLanguageModelChange = (model: string) => {
setLanguageModel(model);
};
const handleEmbeddingModelChange = (model: string) => {
setEmbeddingModel(model);
};
const handleSampleDatasetChange = (dataset: boolean) => {
setSampleDataset(dataset);
};
return (
<>
<LabelInput
label="OpenAI API key"
helperText="The API key for your OpenAI account."
id="api-key"
required
placeholder="sk-..."
/>
<AdvancedOnboarding
icon={<OpenAILogo className="w-4 h-4" />}
languageModels={languageModels}
embeddingModels={embeddingModels}
languageModel={languageModel}
embeddingModel={embeddingModel}
sampleDataset={sampleDataset}
setLanguageModel={handleLanguageModelChange}
setSampleDataset={handleSampleDatasetChange}
setEmbeddingModel={handleEmbeddingModelChange}
/>
</>
);
}

View file

@ -1,12 +1,11 @@
"use client";
import { Suspense, useState } from "react";
import { Suspense, useEffect, useState } from "react";
import { toast } from "sonner";
import { useUpdateFlowSettingMutation } from "@/app/api/mutations/useUpdateFlowSettingMutation";
import {
type Settings,
useGetSettingsQuery,
} from "@/app/api/queries/useGetSettingsQuery";
useOnboardingMutation,
type OnboardingVariables,
} from "@/app/api/mutations/useOnboardingMutation";
import IBMLogo from "@/components/logo/ibm-logo";
import OllamaLogo from "@/components/logo/ollama-logo";
import OpenAILogo from "@/components/logo/openai-logo";
@ -19,44 +18,102 @@ import {
CardHeader,
} from "@/components/ui/card";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { useAuth } from "@/contexts/auth-context";
import { IBMOnboarding } from "./ibm-onboarding";
import { OllamaOnboarding } from "./ollama-onboarding";
import { OpenAIOnboarding } from "./openai-onboarding";
import { IBMOnboarding } from "./components/ibm-onboarding";
import { OllamaOnboarding } from "./components/ollama-onboarding";
import { OpenAIOnboarding } from "./components/openai-onboarding";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { useGetSettingsQuery } from "../api/queries/useGetSettingsQuery";
import { useRouter } from "next/navigation";
function OnboardingPage() {
const { isAuthenticated } = useAuth();
const { data: settingsDb, isLoading: isSettingsLoading } =
useGetSettingsQuery();
const redirect = "/";
const router = useRouter();
// Redirect if already authenticated or in no-auth mode
useEffect(() => {
if (!isSettingsLoading && settingsDb && settingsDb.edited) {
router.push(redirect);
}
}, [isSettingsLoading, redirect]);
const [modelProvider, setModelProvider] = useState<string>("openai");
const [sampleDataset, setSampleDataset] = useState<boolean>(false);
// Fetch settings using React Query
const { data: settingsDb = {} } = useGetSettingsQuery({
enabled: isAuthenticated,
const [sampleDataset, setSampleDataset] = useState<boolean>(true);
const handleSetModelProvider = (provider: string) => {
setModelProvider(provider);
setSettings({
model_provider: provider,
embedding_model: "",
llm_model: "",
});
};
const [settings, setSettings] = useState<OnboardingVariables>({
model_provider: modelProvider,
embedding_model: "",
llm_model: "",
});
const [settings, setSettings] = useState<Settings>(settingsDb);
// Mutations
const updateFlowSettingMutation = useUpdateFlowSettingMutation({
onSuccess: () => {
console.log("Setting updated successfully");
const onboardingMutation = useOnboardingMutation({
onSuccess: (data) => {
toast.success("Onboarding completed successfully!");
console.log("Onboarding completed successfully", data);
},
onError: (error) => {
toast.error("Failed to update settings", {
toast.error("Failed to complete onboarding", {
description: error.message,
});
},
});
const handleComplete = () => {
updateFlowSettingMutation.mutate({
llm_model: settings.agent?.llm_model,
embedding_model: settings.knowledge?.embedding_model,
system_prompt: settings.agent?.system_prompt,
});
if (
!settings.model_provider ||
!settings.llm_model ||
!settings.embedding_model
) {
toast.error("Please complete all required fields");
return;
}
// Prepare onboarding data
const onboardingData: OnboardingVariables = {
model_provider: settings.model_provider,
llm_model: settings.llm_model,
embedding_model: settings.embedding_model,
sample_data: sampleDataset,
};
// Add API key if available
if (settings.api_key) {
onboardingData.api_key = settings.api_key;
}
// Add endpoint if available
if (settings.endpoint) {
onboardingData.endpoint = settings.endpoint;
}
// Add project_id if available
if (settings.project_id) {
onboardingData.project_id = settings.project_id;
}
onboardingMutation.mutate(onboardingData);
};
const isComplete = !!settings.llm_model && !!settings.embedding_model;
return (
<div
className="min-h-dvh w-full flex gap-5 flex-col items-center justify-center bg-background p-4"
@ -74,7 +131,10 @@ function OnboardingPage() {
<p className="text-sm text-muted-foreground">[description of task]</p>
</div>
<Card className="w-full max-w-[580px]">
<Tabs defaultValue={modelProvider} onValueChange={setModelProvider}>
<Tabs
defaultValue={modelProvider}
onValueChange={handleSetModelProvider}
>
<CardHeader>
<TabsList>
<TabsTrigger value="openai">
@ -94,7 +154,6 @@ function OnboardingPage() {
<CardContent>
<TabsContent value="openai">
<OpenAIOnboarding
settings={settings}
setSettings={setSettings}
sampleDataset={sampleDataset}
setSampleDataset={setSampleDataset}
@ -102,7 +161,6 @@ function OnboardingPage() {
</TabsContent>
<TabsContent value="watsonx">
<IBMOnboarding
settings={settings}
setSettings={setSettings}
sampleDataset={sampleDataset}
setSampleDataset={setSampleDataset}
@ -110,7 +168,6 @@ function OnboardingPage() {
</TabsContent>
<TabsContent value="ollama">
<OllamaOnboarding
settings={settings}
setSettings={setSettings}
sampleDataset={sampleDataset}
setSampleDataset={setSampleDataset}
@ -119,9 +176,21 @@ function OnboardingPage() {
</CardContent>
</Tabs>
<CardFooter className="flex justify-end">
<Button size="sm" onClick={handleComplete}>
Complete
</Button>
<Tooltip>
<TooltipTrigger asChild>
<Button
size="sm"
onClick={handleComplete}
disabled={!isComplete}
loading={onboardingMutation.isPending}
>
Complete
</Button>
</TooltipTrigger>
<TooltipContent>
{!isComplete ? "Please fill in all required fields" : ""}
</TooltipContent>
</Tooltip>
</CardFooter>
</Card>
</div>

View file

@ -15,15 +15,16 @@ import { useKnowledgeFilter } from "@/contexts/knowledge-filter-context";
// import { GitHubStarButton } from "@/components/github-star-button"
// import { DiscordLink } from "@/components/discord-link"
import { useTask } from "@/contexts/task-context";
import Logo from "@/components/logo/logo";
export function LayoutWrapper({ children }: { children: React.ReactNode }) {
const pathname = usePathname();
const { tasks, isMenuOpen, toggleMenu } = useTask();
const { selectedFilter, setSelectedFilter, isPanelOpen } =
useKnowledgeFilter();
const { isLoading, isAuthenticated } = useAuth();
const { isLoading, isAuthenticated, isNoAuthMode } = useAuth();
const { isLoading: isSettingsLoading, data: settings } = useGetSettingsQuery({
enabled: isAuthenticated,
enabled: isAuthenticated || isNoAuthMode,
});
// List of paths that should not show navigation
@ -62,7 +63,7 @@ export function LayoutWrapper({ children }: { children: React.ReactNode }) {
<div className="header-start-display px-4">
{/* Logo/Title */}
<div className="flex items-center gap-2">
<Image src="/logo.svg" alt="OpenRAG Logo" width={24} height={22} />
<Logo className="fill-primary" width={24} height={22} />
<span className="text-lg font-semibold">OpenRAG</span>
</div>
</div>

View file

@ -14,7 +14,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) {
const { isLoading, isAuthenticated, isNoAuthMode } = useAuth();
const { data: settings = {}, isLoading: isSettingsLoading } =
useGetSettingsQuery({
enabled: isAuthenticated,
enabled: isAuthenticated || isNoAuthMode,
});
const router = useRouter();
const pathname = usePathname();
@ -31,12 +31,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) {
);
useEffect(() => {
// In no-auth mode, allow access without authentication
if (isNoAuthMode) {
return;
}
if (!isLoading && !isAuthenticated) {
if (!isLoading && !isSettingsLoading && !isAuthenticated && !isNoAuthMode) {
// Redirect to login with current path as redirect parameter
const redirectUrl = `/login?redirect=${encodeURIComponent(pathname)}`;
router.push(redirectUrl);
@ -48,6 +43,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) {
}
}, [
isLoading,
isSettingsLoading,
isAuthenticated,
isNoAuthMode,
router,
@ -57,7 +53,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) {
]);
// Show loading state while checking authentication
if (isLoading) {
if (isLoading || isSettingsLoading) {
return (
<div className="flex items-center justify-center h-64">
<div className="flex flex-col items-center gap-4">

View file

@ -7,7 +7,17 @@ logger = get_logger(__name__)
async def get_openai_models(request, models_service, session_manager):
"""Get available OpenAI models"""
try:
models = await models_service.get_openai_models()
# Get API key from query parameters
query_params = dict(request.query_params)
api_key = query_params.get("api_key")
if not api_key:
return JSONResponse(
{"error": "OpenAI API key is required as query parameter"},
status_code=400
)
models = await models_service.get_openai_models(api_key=api_key)
return JSONResponse(models)
except Exception as e:
logger.error(f"Failed to get OpenAI models: {str(e)}")
@ -37,21 +47,15 @@ async def get_ollama_models(request, models_service, session_manager):
async def get_ibm_models(request, models_service, session_manager):
"""Get available IBM Watson models"""
try:
# Get credentials from query parameters or request body if provided
if request.method == "POST":
body = await request.json()
api_key = body.get("api_key")
endpoint = body.get("endpoint")
project_id = body.get("project_id")
else:
query_params = dict(request.query_params)
api_key = query_params.get("api_key")
endpoint = query_params.get("endpoint")
project_id = query_params.get("project_id")
# Get parameters from query parameters if provided
query_params = dict(request.query_params)
endpoint = query_params.get("endpoint")
api_key = query_params.get("api_key")
project_id = query_params.get("project_id")
models = await models_service.get_ibm_models(
api_key=api_key,
endpoint=endpoint,
api_key=api_key,
project_id=project_id
)
return JSONResponse(models)

View file

@ -69,7 +69,6 @@ def get_docling_tweaks(docling_preset: str = None) -> dict:
}
}
async def get_settings(request, session_manager):
"""Get application settings"""
try:
@ -117,8 +116,7 @@ async def get_settings(request, session_manager):
if LANGFLOW_INGEST_FLOW_ID and openrag_config.edited:
try:
response = await clients.langflow_request(
"GET",
f"/api/v1/flows/{LANGFLOW_INGEST_FLOW_ID}"
"GET", f"/api/v1/flows/{LANGFLOW_INGEST_FLOW_ID}"
)
if response.status_code == 200:
flow_data = response.json()
@ -135,31 +133,23 @@ async def get_settings(request, session_manager):
if flow_data.get("data", {}).get("nodes"):
for node in flow_data["data"]["nodes"]:
node_template = (
node.get("data", {})
.get("node", {})
.get("template", {})
node.get("data", {}).get("node", {}).get("template", {})
)
# Split Text component (SplitText-QIKhg)
if node.get("id") == "SplitText-QIKhg":
if node_template.get("chunk_size", {}).get(
"value"
):
ingestion_defaults["chunkSize"] = (
node_template["chunk_size"]["value"]
)
if node_template.get("chunk_overlap", {}).get(
"value"
):
ingestion_defaults["chunkOverlap"] = (
node_template["chunk_overlap"]["value"]
)
if node_template.get("separator", {}).get(
"value"
):
ingestion_defaults["separator"] = (
node_template["separator"]["value"]
)
if node_template.get("chunk_size", {}).get("value"):
ingestion_defaults["chunkSize"] = node_template[
"chunk_size"
]["value"]
if node_template.get("chunk_overlap", {}).get("value"):
ingestion_defaults["chunkOverlap"] = node_template[
"chunk_overlap"
]["value"]
if node_template.get("separator", {}).get("value"):
ingestion_defaults["separator"] = node_template[
"separator"
]["value"]
# OpenAI Embeddings component (OpenAIEmbeddings-joRJ6)
elif node.get("id") == "OpenAIEmbeddings-joRJ6":
@ -191,89 +181,116 @@ async def update_settings(request, session_manager):
try:
# Get current configuration
current_config = get_openrag_config()
# Check if config is marked as edited
if not current_config.edited:
return JSONResponse(
{"error": "Configuration must be marked as edited before updates are allowed"},
status_code=403
{
"error": "Configuration must be marked as edited before updates are allowed"
},
status_code=403,
)
# Parse request body
body = await request.json()
# Validate allowed fields
allowed_fields = {
"llm_model", "system_prompt", "doclingPresets",
"chunk_size", "chunk_overlap"
"llm_model",
"system_prompt",
"ocr",
"picture_descriptions",
"chunk_size",
"chunk_overlap",
"doclingPresets",
}
# Check for invalid fields
invalid_fields = set(body.keys()) - allowed_fields
if invalid_fields:
return JSONResponse(
{"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"},
status_code=400
{
"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"
},
status_code=400,
)
# Update configuration
config_updated = False
# Update agent settings
if "llm_model" in body:
current_config.agent.llm_model = body["llm_model"]
config_updated = True
if "system_prompt" in body:
current_config.agent.system_prompt = body["system_prompt"]
config_updated = True
# Update knowledge settings
if "doclingPresets" in body:
preset_configs = get_docling_preset_configs()
valid_presets = list(preset_configs.keys())
if body["doclingPresets"] not in valid_presets:
return JSONResponse(
{"error": f"doclingPresets must be one of: {', '.join(valid_presets)}"},
status_code=400
{
"error": f"doclingPresets must be one of: {', '.join(valid_presets)}"
},
status_code=400,
)
current_config.knowledge.doclingPresets = body["doclingPresets"]
config_updated = True
if "ocr" in body:
if not isinstance(body["ocr"], bool):
return JSONResponse(
{"error": "ocr must be a boolean value"}, status_code=400
)
current_config.knowledge.ocr = body["ocr"]
config_updated = True
if "picture_descriptions" in body:
if not isinstance(body["picture_descriptions"], bool):
return JSONResponse(
{"error": "picture_descriptions must be a boolean value"},
status_code=400,
)
current_config.knowledge.picture_descriptions = body["picture_descriptions"]
config_updated = True
if "chunk_size" in body:
if not isinstance(body["chunk_size"], int) or body["chunk_size"] <= 0:
return JSONResponse(
{"error": "chunk_size must be a positive integer"},
status_code=400
{"error": "chunk_size must be a positive integer"}, status_code=400
)
current_config.knowledge.chunk_size = body["chunk_size"]
config_updated = True
if "chunk_overlap" in body:
if not isinstance(body["chunk_overlap"], int) or body["chunk_overlap"] < 0:
return JSONResponse(
{"error": "chunk_overlap must be a non-negative integer"},
status_code=400
{"error": "chunk_overlap must be a non-negative integer"},
status_code=400,
)
current_config.knowledge.chunk_overlap = body["chunk_overlap"]
config_updated = True
if not config_updated:
return JSONResponse(
{"error": "No valid fields provided for update"},
status_code=400
{"error": "No valid fields provided for update"}, status_code=400
)
# Save the updated configuration
if config_manager.save_config_file(current_config):
logger.info("Configuration updated successfully", updated_fields=list(body.keys()))
logger.info(
"Configuration updated successfully", updated_fields=list(body.keys())
)
return JSONResponse({"message": "Configuration updated successfully"})
else:
return JSONResponse(
{"error": "Failed to save configuration"},
status_code=500
{"error": "Failed to save configuration"}, status_code=500
)
except Exception as e:
logger.error("Failed to update settings", error=str(e))
return JSONResponse(
@ -286,120 +303,168 @@ async def onboarding(request, flows_service):
try:
# Get current configuration
current_config = get_openrag_config()
# Check if config is NOT marked as edited (only allow onboarding if not yet configured)
if current_config.edited:
return JSONResponse(
{"error": "Configuration has already been edited. Use /settings endpoint for updates."},
status_code=403
{
"error": "Configuration has already been edited. Use /settings endpoint for updates."
},
status_code=403,
)
# Parse request body
body = await request.json()
# Validate allowed fields
allowed_fields = {
"model_provider", "api_key", "embedding_model", "llm_model", "sample_data"
"model_provider",
"api_key",
"embedding_model",
"llm_model",
"sample_data",
"endpoint",
"project_id",
}
# Check for invalid fields
invalid_fields = set(body.keys()) - allowed_fields
if invalid_fields:
return JSONResponse(
{"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"},
status_code=400
{
"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"
},
status_code=400,
)
# Update configuration
config_updated = False
# Update provider settings
if "model_provider" in body:
if not isinstance(body["model_provider"], str) or not body["model_provider"].strip():
if (
not isinstance(body["model_provider"], str)
or not body["model_provider"].strip()
):
return JSONResponse(
{"error": "model_provider must be a non-empty string"},
status_code=400
{"error": "model_provider must be a non-empty string"},
status_code=400,
)
current_config.provider.model_provider = body["model_provider"].strip()
config_updated = True
if "api_key" in body:
if not isinstance(body["api_key"], str):
return JSONResponse(
{"error": "api_key must be a string"},
status_code=400
{"error": "api_key must be a string"}, status_code=400
)
current_config.provider.api_key = body["api_key"]
config_updated = True
# Update knowledge settings
if "embedding_model" in body:
if not isinstance(body["embedding_model"], str) or not body["embedding_model"].strip():
if (
not isinstance(body["embedding_model"], str)
or not body["embedding_model"].strip()
):
return JSONResponse(
{"error": "embedding_model must be a non-empty string"},
status_code=400
{"error": "embedding_model must be a non-empty string"},
status_code=400,
)
current_config.knowledge.embedding_model = body["embedding_model"].strip()
config_updated = True
# Update agent settings
if "llm_model" in body:
if not isinstance(body["llm_model"], str) or not body["llm_model"].strip():
return JSONResponse(
{"error": "llm_model must be a non-empty string"},
status_code=400
{"error": "llm_model must be a non-empty string"}, status_code=400
)
current_config.agent.llm_model = body["llm_model"].strip()
config_updated = True
if "endpoint" in body:
if not isinstance(body["endpoint"], str) or not body["endpoint"].strip():
return JSONResponse(
{"error": "endpoint must be a non-empty string"}, status_code=400
)
current_config.provider.endpoint = body["endpoint"].strip()
config_updated = True
if "project_id" in body:
if (
not isinstance(body["project_id"], str)
or not body["project_id"].strip()
):
return JSONResponse(
{"error": "project_id must be a non-empty string"}, status_code=400
)
current_config.provider.project_id = body["project_id"].strip()
config_updated = True
# Handle sample_data (unused for now but validate)
if "sample_data" in body:
if not isinstance(body["sample_data"], bool):
return JSONResponse(
{"error": "sample_data must be a boolean value"},
status_code=400
{"error": "sample_data must be a boolean value"}, status_code=400
)
# Note: sample_data is accepted but not used as requested
if not config_updated:
return JSONResponse(
{"error": "No valid fields provided for update"},
status_code=400
{"error": "No valid fields provided for update"}, status_code=400
)
# Save the updated configuration (this will mark it as edited)
if config_manager.save_config_file(current_config):
updated_fields = [k for k in body.keys() if k != "sample_data"] # Exclude sample_data from log
logger.info("Onboarding configuration updated successfully", updated_fields=updated_fields)
updated_fields = [
k for k in body.keys() if k != "sample_data"
] # Exclude sample_data from log
logger.info(
"Onboarding configuration updated successfully",
updated_fields=updated_fields,
)
# If model_provider was updated, assign the new provider to flows
if "model_provider" in body:
provider = body["model_provider"].strip().lower()
try:
flow_result = await flows_service.assign_model_provider(provider)
if flow_result.get("success"):
logger.info(f"Successfully assigned {provider} to flows", flow_result=flow_result)
logger.info(
f"Successfully assigned {provider} to flows",
flow_result=flow_result,
)
else:
logger.warning(f"Failed to assign {provider} to flows", flow_result=flow_result)
logger.warning(
f"Failed to assign {provider} to flows",
flow_result=flow_result,
)
# Continue even if flow assignment fails - configuration was still saved
except Exception as e:
logger.error(f"Error assigning model provider to flows", provider=provider, error=str(e))
logger.error(
f"Error assigning model provider to flows",
provider=provider,
error=str(e),
)
# Continue even if flow assignment fails - configuration was still saved
return JSONResponse({
"message": "Onboarding configuration updated successfully",
"edited": True # Confirm that config is now marked as edited
})
return JSONResponse(
{
"message": "Onboarding configuration updated successfully",
"edited": True, # Confirm that config is now marked as edited
}
)
else:
return JSONResponse(
{"error": "Failed to save configuration"},
status_code=500
{"error": "Failed to save configuration"}, status_code=500
)
except Exception as e:
logger.error("Failed to update onboarding settings", error=str(e))
return JSONResponse(
{"error": f"Failed to update onboarding settings: {str(e)}"}, status_code=500
{"error": f"Failed to update onboarding settings: {str(e)}"},
status_code=500,
)

View file

@ -1,5 +1,4 @@
import httpx
import os
from typing import Dict, List
from utils.logging_config import get_logger
@ -12,14 +11,9 @@ class ModelsService:
def __init__(self):
self.session_manager = None
async def get_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]:
"""Fetch available models from OpenAI API"""
try:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set, using default models")
return self._get_default_openai_models()
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
@ -74,12 +68,14 @@ class ModelsService:
"embedding_models": embedding_models,
}
else:
logger.warning(f"Failed to fetch OpenAI models: {response.status_code}")
return self._get_default_openai_models()
logger.error(f"Failed to fetch OpenAI models: {response.status_code}")
raise Exception(
f"OpenAI API returned status code {response.status_code}"
)
except Exception as e:
logger.error(f"Error fetching OpenAI models: {str(e)}")
return self._get_default_openai_models()
raise
async def get_ollama_models(
self, endpoint: str = None
@ -87,9 +83,7 @@ class ModelsService:
"""Fetch available models from Ollama API"""
try:
# Use provided endpoint or default
ollama_url = endpoint or os.getenv(
"OLLAMA_BASE_URL", "http://localhost:11434"
)
ollama_url = endpoint
async with httpx.AsyncClient() as client:
response = await client.get(f"{ollama_url}/api/tags", timeout=10.0)
@ -145,32 +139,34 @@ class ModelsService:
return {
"language_models": language_models,
"embedding_models": embedding_models
if embedding_models
else [
{
"value": "nomic-embed-text",
"label": "nomic-embed-text",
"default": True,
}
],
"embedding_models": embedding_models if embedding_models else [],
}
else:
logger.warning(f"Failed to fetch Ollama models: {response.status_code}")
return self._get_default_ollama_models()
logger.error(f"Failed to fetch Ollama models: {response.status_code}")
raise Exception(
f"Ollama API returned status code {response.status_code}"
)
except Exception as e:
logger.error(f"Error fetching Ollama models: {str(e)}")
return self._get_default_ollama_models()
raise
async def get_ibm_models(
self, endpoint: str = None
self, endpoint: str = None, api_key: str = None, project_id: str = None
) -> Dict[str, List[Dict[str, str]]]:
"""Fetch available models from IBM Watson API"""
try:
# Use provided endpoint or default
watson_endpoint = endpoint or os.getenv("IBM_WATSON_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
watson_endpoint = endpoint
# Prepare headers for authentication
headers = {
"Content-Type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if project_id:
headers["Project-ID"] = project_id
# Fetch foundation models using the correct endpoint
models_url = f"{watson_endpoint}/ml/v1/foundation_model_specs"
@ -181,9 +177,14 @@ class ModelsService:
# Fetch text chat models
text_params = {
"version": "2024-09-16",
"filters": "function_text_chat,!lifecycle_withdrawn"
"filters": "function_text_chat,!lifecycle_withdrawn",
}
text_response = await client.get(models_url, params=text_params, timeout=10.0)
if project_id:
text_params["project_id"] = project_id
text_response = await client.get(
models_url, params=text_params, headers=headers, timeout=10.0
)
if text_response.status_code == 200:
text_data = text_response.json()
@ -193,18 +194,25 @@ class ModelsService:
model_id = model.get("model_id", "")
model_name = model.get("name", model_id)
language_models.append({
"value": model_id,
"label": model_name or model_id,
"default": i == 0 # First model is default
})
language_models.append(
{
"value": model_id,
"label": model_name or model_id,
"default": i == 0, # First model is default
}
)
# Fetch embedding models
embed_params = {
"version": "2024-09-16",
"filters": "function_embedding,!lifecycle_withdrawn"
"filters": "function_embedding,!lifecycle_withdrawn",
}
embed_response = await client.get(models_url, params=embed_params, timeout=10.0)
if project_id:
embed_params["project_id"] = project_id
embed_response = await client.get(
models_url, params=embed_params, headers=headers, timeout=10.0
)
if embed_response.status_code == 200:
embed_data = embed_response.json()
@ -214,104 +222,22 @@ class ModelsService:
model_id = model.get("model_id", "")
model_name = model.get("name", model_id)
embedding_models.append({
"value": model_id,
"label": model_name or model_id,
"default": i == 0 # First model is default
})
embedding_models.append(
{
"value": model_id,
"label": model_name or model_id,
"default": i == 0, # First model is default
}
)
if not language_models and not embedding_models:
raise Exception("No IBM models retrieved from API")
return {
"language_models": language_models if language_models else self._get_default_ibm_models()["language_models"],
"embedding_models": embedding_models if embedding_models else self._get_default_ibm_models()["embedding_models"]
"language_models": language_models,
"embedding_models": embedding_models,
}
except Exception as e:
logger.error(f"Error fetching IBM models: {str(e)}")
return self._get_default_ibm_models()
def _get_default_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
"""Default OpenAI models when API is not available"""
return {
"language_models": [
{"value": "gpt-4o-mini", "label": "gpt-4o-mini", "default": True},
{"value": "gpt-4o", "label": "gpt-4o", "default": False},
{"value": "gpt-4-turbo", "label": "gpt-4-turbo", "default": False},
{"value": "gpt-3.5-turbo", "label": "gpt-3.5-turbo", "default": False},
],
"embedding_models": [
{
"value": "text-embedding-3-small",
"label": "text-embedding-3-small",
"default": True,
},
{
"value": "text-embedding-3-large",
"label": "text-embedding-3-large",
"default": False,
},
{
"value": "text-embedding-ada-002",
"label": "text-embedding-ada-002",
"default": False,
},
],
}
def _get_default_ollama_models(self) -> Dict[str, List[Dict[str, str]]]:
"""Default Ollama models when API is not available"""
return {
"language_models": [
{"value": "llama3.2", "label": "llama3.2", "default": True},
{"value": "llama3.1", "label": "llama3.1", "default": False},
{"value": "llama3", "label": "llama3", "default": False},
{"value": "mistral", "label": "mistral", "default": False},
{"value": "codellama", "label": "codellama", "default": False},
],
"embedding_models": [
{
"value": "nomic-embed-text",
"label": "nomic-embed-text",
"default": True,
},
{"value": "all-minilm", "label": "all-minilm", "default": False},
],
}
def _get_default_ibm_models(self) -> Dict[str, List[Dict[str, str]]]:
"""Default IBM Watson models when API is not available"""
return {
"language_models": [
{
"value": "meta-llama/llama-3-1-70b-instruct",
"label": "Llama 3.1 70B Instruct",
"default": True,
},
{
"value": "meta-llama/llama-3-1-8b-instruct",
"label": "Llama 3.1 8B Instruct",
"default": False,
},
{
"value": "ibm/granite-13b-chat-v2",
"label": "Granite 13B Chat v2",
"default": False,
},
{
"value": "ibm/granite-13b-instruct-v2",
"label": "Granite 13B Instruct v2",
"default": False,
},
],
"embedding_models": [
{
"value": "ibm/slate-125m-english-rtrvr",
"label": "Slate 125M English Retriever",
"default": True,
},
{
"value": "sentence-transformers/all-minilm-l12-v2",
"label": "All-MiniLM L12 v2",
"default": False,
},
],
}
raise