Merge branch 'feat/onboarding' of github.com:langflow-ai/openrag into feat/onboarding
This commit is contained in:
commit
abb1ae0819
21 changed files with 957 additions and 628 deletions
|
|
@ -1,4 +1,4 @@
|
|||
import { useCallback, useRef } from "react";
|
||||
import { useCallback, useRef, useState, useEffect } from "react";
|
||||
|
||||
export function useDebounce<T extends (...args: never[]) => void>(
|
||||
callback: T,
|
||||
|
|
@ -21,3 +21,19 @@ export function useDebounce<T extends (...args: never[]) => void>(
|
|||
|
||||
return debouncedCallback;
|
||||
}
|
||||
|
||||
export function useDebouncedValue<T>(value: T, delay: number): T {
|
||||
const [debouncedValue, setDebouncedValue] = useState<T>(value);
|
||||
|
||||
useEffect(() => {
|
||||
const handler = setTimeout(() => {
|
||||
setDebouncedValue(value);
|
||||
}, delay);
|
||||
|
||||
return () => {
|
||||
clearTimeout(handler);
|
||||
};
|
||||
}, [value, delay]);
|
||||
|
||||
return debouncedValue;
|
||||
}
|
||||
|
|
|
|||
61
frontend/src/app/api/mutations/useOnboardingMutation.ts
Normal file
61
frontend/src/app/api/mutations/useOnboardingMutation.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import {
|
||||
type UseMutationOptions,
|
||||
useMutation,
|
||||
useQueryClient,
|
||||
} from "@tanstack/react-query";
|
||||
|
||||
export interface OnboardingVariables {
|
||||
model_provider: string;
|
||||
api_key?: string;
|
||||
endpoint?: string;
|
||||
project_id?: string;
|
||||
embedding_model: string;
|
||||
llm_model: string;
|
||||
sample_data?: boolean;
|
||||
}
|
||||
|
||||
interface OnboardingResponse {
|
||||
message: string;
|
||||
edited: boolean;
|
||||
}
|
||||
|
||||
export const useOnboardingMutation = (
|
||||
options?: Omit<
|
||||
UseMutationOptions<
|
||||
OnboardingResponse,
|
||||
Error,
|
||||
OnboardingVariables
|
||||
>,
|
||||
"mutationFn"
|
||||
>,
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
async function submitOnboarding(
|
||||
variables: OnboardingVariables,
|
||||
): Promise<OnboardingResponse> {
|
||||
const response = await fetch("/api/onboarding", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(variables),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json();
|
||||
throw new Error(error.error || "Failed to complete onboarding");
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
return useMutation({
|
||||
mutationFn: submitOnboarding,
|
||||
onSettled: () => {
|
||||
// Invalidate settings query to refetch updated data
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
},
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
|
@ -15,23 +15,33 @@ export interface ModelsResponse {
|
|||
embedding_models: ModelOption[];
|
||||
}
|
||||
|
||||
export interface OpenAIModelsParams {
|
||||
apiKey?: string;
|
||||
}
|
||||
|
||||
export interface OllamaModelsParams {
|
||||
endpoint?: string;
|
||||
}
|
||||
|
||||
export interface IBMModelsParams {
|
||||
api_key?: string;
|
||||
endpoint?: string;
|
||||
project_id?: string;
|
||||
apiKey?: string;
|
||||
projectId?: string;
|
||||
}
|
||||
|
||||
export const useGetOpenAIModelsQuery = (
|
||||
params?: OpenAIModelsParams,
|
||||
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
async function getOpenAIModels(): Promise<ModelsResponse> {
|
||||
const response = await fetch("/api/models/openai");
|
||||
const url = new URL("/api/models/openai", window.location.origin);
|
||||
if (params?.apiKey) {
|
||||
url.searchParams.set("api_key", params.apiKey);
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString());
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
} else {
|
||||
|
|
@ -41,9 +51,12 @@ export const useGetOpenAIModelsQuery = (
|
|||
|
||||
const queryResult = useQuery(
|
||||
{
|
||||
queryKey: ["models", "openai"],
|
||||
queryKey: ["models", "openai", params],
|
||||
queryFn: getOpenAIModels,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
retry: 2,
|
||||
enabled: !!params?.apiKey, // Only run if API key is provided
|
||||
staleTime: 0, // Always fetch fresh data
|
||||
gcTime: 0, // Don't cache results
|
||||
...options,
|
||||
},
|
||||
queryClient,
|
||||
|
|
@ -76,8 +89,10 @@ export const useGetOllamaModelsQuery = (
|
|||
{
|
||||
queryKey: ["models", "ollama", params],
|
||||
queryFn: getOllamaModels,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
retry: 2,
|
||||
enabled: !!params?.endpoint, // Only run if endpoint is provided
|
||||
staleTime: 0, // Always fetch fresh data
|
||||
gcTime: 0, // Don't cache results
|
||||
...options,
|
||||
},
|
||||
queryClient,
|
||||
|
|
@ -93,35 +108,22 @@ export const useGetIBMModelsQuery = (
|
|||
const queryClient = useQueryClient();
|
||||
|
||||
async function getIBMModels(): Promise<ModelsResponse> {
|
||||
const url = "/api/models/ibm";
|
||||
const url = new URL("/api/models/ibm", window.location.origin);
|
||||
if (params?.endpoint) {
|
||||
url.searchParams.set("endpoint", params.endpoint);
|
||||
}
|
||||
if (params?.apiKey) {
|
||||
url.searchParams.set("api_key", params.apiKey);
|
||||
}
|
||||
if (params?.projectId) {
|
||||
url.searchParams.set("project_id", params.projectId);
|
||||
}
|
||||
|
||||
// If we have credentials, use POST to send them securely
|
||||
if (params?.api_key || params?.endpoint || params?.project_id) {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_key: params.api_key,
|
||||
endpoint: params.endpoint,
|
||||
project_id: params.project_id,
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
} else {
|
||||
throw new Error("Failed to fetch IBM models");
|
||||
}
|
||||
const response = await fetch(url.toString());
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
} else {
|
||||
// Use GET for default models
|
||||
const response = await fetch(url);
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
} else {
|
||||
throw new Error("Failed to fetch IBM models");
|
||||
}
|
||||
throw new Error("Failed to fetch IBM models");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -129,8 +131,10 @@ export const useGetIBMModelsQuery = (
|
|||
{
|
||||
queryKey: ["models", "ibm", params],
|
||||
queryFn: getIBMModels,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
enabled: !!(params?.api_key && params?.endpoint && params?.project_id), // Only run if all credentials are provided
|
||||
retry: 2,
|
||||
enabled: !!params?.endpoint && !!params?.apiKey && !!params?.projectId, // Only run if all required params are provided
|
||||
staleTime: 0, // Always fetch fresh data
|
||||
gcTime: 0, // Don't cache results
|
||||
...options,
|
||||
},
|
||||
queryClient,
|
||||
|
|
|
|||
|
|
@ -63,3 +63,6 @@ export const useGetSettingsQuery = (
|
|||
|
||||
return queryResult;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ function LoginPageContent() {
|
|||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const { data: settings } = useGetSettingsQuery({
|
||||
enabled: isAuthenticated,
|
||||
const { data: settings, isLoading: isSettingsLoading } = useGetSettingsQuery({
|
||||
enabled: isAuthenticated || isNoAuthMode,
|
||||
});
|
||||
|
||||
const redirect =
|
||||
|
|
@ -25,12 +25,19 @@ function LoginPageContent() {
|
|||
|
||||
// Redirect if already authenticated or in no-auth mode
|
||||
useEffect(() => {
|
||||
if (!isLoading && (isAuthenticated || isNoAuthMode)) {
|
||||
if (!isLoading && !isSettingsLoading && (isAuthenticated || isNoAuthMode)) {
|
||||
router.push(redirect);
|
||||
}
|
||||
}, [isLoading, isAuthenticated, isNoAuthMode, router, redirect]);
|
||||
}, [
|
||||
isLoading,
|
||||
isSettingsLoading,
|
||||
isAuthenticated,
|
||||
isNoAuthMode,
|
||||
router,
|
||||
redirect,
|
||||
]);
|
||||
|
||||
if (isLoading) {
|
||||
if (isLoading || isSettingsLoading) {
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-background">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
|
|
|
|||
|
|
@ -32,8 +32,13 @@ export function AdvancedOnboarding({
|
|||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const hasEmbeddingModels =
|
||||
embeddingModels && embeddingModel && setEmbeddingModel;
|
||||
const hasLanguageModels = languageModels && languageModel && setLanguageModel;
|
||||
embeddingModels !== undefined &&
|
||||
embeddingModel !== undefined &&
|
||||
setEmbeddingModel !== undefined;
|
||||
const hasLanguageModels =
|
||||
languageModels !== undefined &&
|
||||
languageModel !== undefined &&
|
||||
setLanguageModel !== undefined;
|
||||
return (
|
||||
<Accordion type="single" collapsible>
|
||||
<AccordionItem value="item-1">
|
||||
127
frontend/src/app/onboarding/components/ibm-onboarding.tsx
Normal file
127
frontend/src/app/onboarding/components/ibm-onboarding.tsx
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import { useState } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import IBMLogo from "@/components/logo/ibm-logo";
|
||||
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
|
||||
import { useGetIBMModelsQuery } from "../../api/queries/useGetModelsQuery";
|
||||
import { useModelSelection } from "../hooks/useModelSelection";
|
||||
import { useUpdateSettings } from "../hooks/useUpdateSettings";
|
||||
import { useDebouncedValue } from "@/lib/debounce";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
|
||||
export function IBMOnboarding({
|
||||
setSettings,
|
||||
sampleDataset,
|
||||
setSampleDataset,
|
||||
}: {
|
||||
setSettings: (settings: OnboardingVariables) => void;
|
||||
sampleDataset: boolean;
|
||||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const [endpoint, setEndpoint] = useState("");
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [projectId, setProjectId] = useState("");
|
||||
|
||||
const debouncedEndpoint = useDebouncedValue(endpoint, 500);
|
||||
const debouncedApiKey = useDebouncedValue(apiKey, 500);
|
||||
const debouncedProjectId = useDebouncedValue(projectId, 500);
|
||||
|
||||
// Fetch models from API when all credentials are provided
|
||||
const {
|
||||
data: modelsData,
|
||||
isLoading: isLoadingModels,
|
||||
error: modelsError,
|
||||
} = useGetIBMModelsQuery(
|
||||
debouncedEndpoint && debouncedApiKey && debouncedProjectId
|
||||
? {
|
||||
endpoint: debouncedEndpoint,
|
||||
apiKey: debouncedApiKey,
|
||||
projectId: debouncedProjectId,
|
||||
}
|
||||
: undefined,
|
||||
);
|
||||
|
||||
// Use custom hook for model selection logic
|
||||
const {
|
||||
languageModel,
|
||||
embeddingModel,
|
||||
setLanguageModel,
|
||||
setEmbeddingModel,
|
||||
languageModels,
|
||||
embeddingModels,
|
||||
} = useModelSelection(modelsData);
|
||||
const handleSampleDatasetChange = (dataset: boolean) => {
|
||||
setSampleDataset(dataset);
|
||||
};
|
||||
|
||||
// Update settings when values change
|
||||
useUpdateSettings(
|
||||
"ibm",
|
||||
{
|
||||
endpoint,
|
||||
apiKey,
|
||||
projectId,
|
||||
languageModel,
|
||||
embeddingModel,
|
||||
},
|
||||
setSettings,
|
||||
);
|
||||
return (
|
||||
<>
|
||||
<div className="space-y-4">
|
||||
<LabelInput
|
||||
label="watsonx.ai API Endpoint"
|
||||
helperText="The API endpoint for your watsonx.ai account."
|
||||
id="api-endpoint"
|
||||
required
|
||||
placeholder="https://us-south.ml.cloud.ibm.com"
|
||||
value={endpoint}
|
||||
onChange={(e) => setEndpoint(e.target.value)}
|
||||
/>
|
||||
<LabelInput
|
||||
label="IBM API key"
|
||||
helperText="The API key for your watsonx.ai account."
|
||||
id="api-key"
|
||||
required
|
||||
placeholder="your-api-key"
|
||||
value={apiKey}
|
||||
onChange={(e) => setApiKey(e.target.value)}
|
||||
/>
|
||||
<LabelInput
|
||||
label="IBM Project ID"
|
||||
helperText="The project ID for your watsonx.ai account."
|
||||
id="project-id"
|
||||
required
|
||||
placeholder="your-project-id"
|
||||
value={projectId}
|
||||
onChange={(e) => setProjectId(e.target.value)}
|
||||
/>
|
||||
{isLoadingModels && (
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Validating configuration...
|
||||
</p>
|
||||
)}
|
||||
{modelsError && (
|
||||
<p className="text-sm text-red-500">
|
||||
Invalid configuration or connection failed
|
||||
</p>
|
||||
)}
|
||||
{modelsData &&
|
||||
(modelsData.language_models?.length > 0 ||
|
||||
modelsData.embedding_models?.length > 0) && (
|
||||
<p className="text-sm text-green-600">Configuration is valid</p>
|
||||
)}
|
||||
</div>
|
||||
<AdvancedOnboarding
|
||||
icon={<IBMLogo className="w-4 h-4" />}
|
||||
languageModels={languageModels}
|
||||
embeddingModels={embeddingModels}
|
||||
languageModel={languageModel}
|
||||
embeddingModel={embeddingModel}
|
||||
sampleDataset={sampleDataset}
|
||||
setLanguageModel={setLanguageModel}
|
||||
setEmbeddingModel={setEmbeddingModel}
|
||||
setSampleDataset={handleSampleDatasetChange}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import { CheckIcon, ChevronsUpDownIcon } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Command,
|
||||
|
|
@ -32,6 +32,11 @@ export function ModelSelector({
|
|||
onValueChange: (value: string) => void;
|
||||
}) {
|
||||
const [open, setOpen] = useState(false);
|
||||
useEffect(() => {
|
||||
if (value && !options.find((option) => option.value === value)) {
|
||||
onValueChange("");
|
||||
}
|
||||
}, [options, value, onValueChange]);
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
|
|
@ -39,6 +44,7 @@ export function ModelSelector({
|
|||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
disabled={options.length === 0}
|
||||
aria-expanded={open}
|
||||
className="w-full gap-2 justify-between font-normal text-sm"
|
||||
>
|
||||
|
|
@ -53,6 +59,8 @@ export function ModelSelector({
|
|||
</span>
|
||||
)}
|
||||
</div>
|
||||
) : options.length === 0 ? (
|
||||
"No models available"
|
||||
) : (
|
||||
"Select model..."
|
||||
)}
|
||||
135
frontend/src/app/onboarding/components/ollama-onboarding.tsx
Normal file
135
frontend/src/app/onboarding/components/ollama-onboarding.tsx
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import { useState } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import { LabelWrapper } from "@/components/label-wrapper";
|
||||
import OllamaLogo from "@/components/logo/ollama-logo";
|
||||
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
|
||||
import { useGetOllamaModelsQuery } from "../../api/queries/useGetModelsQuery";
|
||||
import { useModelSelection } from "../hooks/useModelSelection";
|
||||
import { useUpdateSettings } from "../hooks/useUpdateSettings";
|
||||
import { useDebouncedValue } from "@/lib/debounce";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
import { ModelSelector } from "./model-selector";
|
||||
|
||||
export function OllamaOnboarding({
|
||||
setSettings,
|
||||
sampleDataset,
|
||||
setSampleDataset,
|
||||
}: {
|
||||
setSettings: (settings: OnboardingVariables) => void;
|
||||
sampleDataset: boolean;
|
||||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const [endpoint, setEndpoint] = useState("");
|
||||
const debouncedEndpoint = useDebouncedValue(endpoint, 500);
|
||||
|
||||
// Fetch models from API when endpoint is provided (debounced)
|
||||
const {
|
||||
data: modelsData,
|
||||
isLoading: isLoadingModels,
|
||||
error: modelsError,
|
||||
} = useGetOllamaModelsQuery(
|
||||
debouncedEndpoint ? { endpoint: debouncedEndpoint } : undefined,
|
||||
);
|
||||
|
||||
// Use custom hook for model selection logic
|
||||
const {
|
||||
languageModel,
|
||||
embeddingModel,
|
||||
setLanguageModel,
|
||||
setEmbeddingModel,
|
||||
languageModels,
|
||||
embeddingModels,
|
||||
} = useModelSelection(modelsData);
|
||||
|
||||
const handleSampleDatasetChange = (dataset: boolean) => {
|
||||
setSampleDataset(dataset);
|
||||
};
|
||||
|
||||
// Update settings when values change
|
||||
useUpdateSettings(
|
||||
"ollama",
|
||||
{
|
||||
endpoint,
|
||||
languageModel,
|
||||
embeddingModel,
|
||||
},
|
||||
setSettings,
|
||||
);
|
||||
|
||||
// Check validation state based on models query
|
||||
const isConnecting = debouncedEndpoint && isLoadingModels;
|
||||
const hasConnectionError = debouncedEndpoint && modelsError;
|
||||
const hasNoModels =
|
||||
modelsData &&
|
||||
!modelsData.language_models?.length &&
|
||||
!modelsData.embedding_models?.length;
|
||||
const isValidConnection =
|
||||
modelsData &&
|
||||
(modelsData.language_models?.length > 0 ||
|
||||
modelsData.embedding_models?.length > 0);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="space-y-1">
|
||||
<LabelInput
|
||||
label="Ollama Endpoint"
|
||||
helperText="The endpoint for your Ollama server."
|
||||
id="api-endpoint"
|
||||
required
|
||||
placeholder="http://localhost:11434"
|
||||
value={endpoint}
|
||||
onChange={(e) => setEndpoint(e.target.value)}
|
||||
/>
|
||||
{isConnecting && (
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Connecting to Ollama server...
|
||||
</p>
|
||||
)}
|
||||
{hasConnectionError && (
|
||||
<p className="text-sm text-red-500">
|
||||
Cannot connect to Ollama server. Please check the endpoint.
|
||||
</p>
|
||||
)}
|
||||
{hasNoModels && (
|
||||
<p className="text-sm text-yellow-600">
|
||||
No models found. Please install some models on your Ollama server.
|
||||
</p>
|
||||
)}
|
||||
{isValidConnection && (
|
||||
<p className="text-sm text-green-600">Connected successfully</p>
|
||||
)}
|
||||
</div>
|
||||
<LabelWrapper
|
||||
label="Embedding model"
|
||||
helperText="The embedding model for your Ollama server."
|
||||
id="embedding-model"
|
||||
required={true}
|
||||
>
|
||||
<ModelSelector
|
||||
options={embeddingModels}
|
||||
icon={<OllamaLogo className="w-4 h-4" />}
|
||||
value={embeddingModel}
|
||||
onValueChange={setEmbeddingModel}
|
||||
/>
|
||||
</LabelWrapper>
|
||||
<LabelWrapper
|
||||
label="Language model"
|
||||
helperText="The embedding model for your Ollama server."
|
||||
id="embedding-model"
|
||||
required={true}
|
||||
>
|
||||
<ModelSelector
|
||||
options={languageModels}
|
||||
icon={<OllamaLogo className="w-4 h-4" />}
|
||||
value={languageModel}
|
||||
onValueChange={setLanguageModel}
|
||||
/>
|
||||
</LabelWrapper>
|
||||
|
||||
<AdvancedOnboarding
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={handleSampleDatasetChange}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
93
frontend/src/app/onboarding/components/openai-onboarding.tsx
Normal file
93
frontend/src/app/onboarding/components/openai-onboarding.tsx
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import { useState } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import OpenAILogo from "@/components/logo/openai-logo";
|
||||
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
|
||||
import { useGetOpenAIModelsQuery } from "../../api/queries/useGetModelsQuery";
|
||||
import { useModelSelection } from "../hooks/useModelSelection";
|
||||
import { useUpdateSettings } from "../hooks/useUpdateSettings";
|
||||
import { useDebouncedValue } from "@/lib/debounce";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
|
||||
export function OpenAIOnboarding({
|
||||
setSettings,
|
||||
sampleDataset,
|
||||
setSampleDataset,
|
||||
}: {
|
||||
setSettings: (settings: OnboardingVariables) => void;
|
||||
sampleDataset: boolean;
|
||||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const debouncedApiKey = useDebouncedValue(apiKey, 500);
|
||||
|
||||
// Fetch models from API when API key is provided
|
||||
const {
|
||||
data: modelsData,
|
||||
isLoading: isLoadingModels,
|
||||
error: modelsError,
|
||||
} = useGetOpenAIModelsQuery(
|
||||
debouncedApiKey ? { apiKey: debouncedApiKey } : undefined,
|
||||
);
|
||||
// Use custom hook for model selection logic
|
||||
const {
|
||||
languageModel,
|
||||
embeddingModel,
|
||||
setLanguageModel,
|
||||
setEmbeddingModel,
|
||||
languageModels,
|
||||
embeddingModels,
|
||||
} = useModelSelection(modelsData);
|
||||
const handleSampleDatasetChange = (dataset: boolean) => {
|
||||
setSampleDataset(dataset);
|
||||
};
|
||||
|
||||
// Update settings when values change
|
||||
useUpdateSettings(
|
||||
"openai",
|
||||
{
|
||||
apiKey,
|
||||
languageModel,
|
||||
embeddingModel,
|
||||
},
|
||||
setSettings,
|
||||
);
|
||||
return (
|
||||
<>
|
||||
<div className="space-y-1">
|
||||
<LabelInput
|
||||
label="OpenAI API key"
|
||||
helperText="The API key for your OpenAI account."
|
||||
id="api-key"
|
||||
required
|
||||
placeholder="sk-..."
|
||||
value={apiKey}
|
||||
onChange={(e) => setApiKey(e.target.value)}
|
||||
/>
|
||||
{isLoadingModels && (
|
||||
<p className="text-sm text-muted-foreground">Validating API key...</p>
|
||||
)}
|
||||
{modelsError && (
|
||||
<p className="text-sm text-red-500">
|
||||
Invalid API key or configuration
|
||||
</p>
|
||||
)}
|
||||
{modelsData &&
|
||||
(modelsData.language_models?.length > 0 ||
|
||||
modelsData.embedding_models?.length > 0) && (
|
||||
<p className="text-sm text-green-600">Configuration is valid</p>
|
||||
)}
|
||||
</div>
|
||||
<AdvancedOnboarding
|
||||
icon={<OpenAILogo className="w-4 h-4" />}
|
||||
languageModels={languageModels}
|
||||
embeddingModels={embeddingModels}
|
||||
languageModel={languageModel}
|
||||
embeddingModel={embeddingModel}
|
||||
sampleDataset={sampleDataset}
|
||||
setLanguageModel={setLanguageModel}
|
||||
setSampleDataset={handleSampleDatasetChange}
|
||||
setEmbeddingModel={setEmbeddingModel}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
46
frontend/src/app/onboarding/hooks/useModelSelection.ts
Normal file
46
frontend/src/app/onboarding/hooks/useModelSelection.ts
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import { useState, useEffect } from "react";
|
||||
import type { ModelsResponse } from "../../api/queries/useGetModelsQuery";
|
||||
|
||||
export function useModelSelection(modelsData: ModelsResponse | undefined) {
|
||||
const [languageModel, setLanguageModel] = useState("");
|
||||
const [embeddingModel, setEmbeddingModel] = useState("");
|
||||
|
||||
// Update default selections when models are loaded
|
||||
useEffect(() => {
|
||||
if (modelsData) {
|
||||
const defaultLangModel = modelsData.language_models.find(
|
||||
(m) => m.default,
|
||||
);
|
||||
const defaultEmbedModel = modelsData.embedding_models.find(
|
||||
(m) => m.default,
|
||||
);
|
||||
|
||||
// Set language model: prefer default, fallback to first available
|
||||
if (!languageModel) {
|
||||
if (defaultLangModel) {
|
||||
setLanguageModel(defaultLangModel.value);
|
||||
} else if (modelsData.language_models.length > 0) {
|
||||
setLanguageModel(modelsData.language_models[0].value);
|
||||
}
|
||||
}
|
||||
|
||||
// Set embedding model: prefer default, fallback to first available
|
||||
if (!embeddingModel) {
|
||||
if (defaultEmbedModel) {
|
||||
setEmbeddingModel(defaultEmbedModel.value);
|
||||
} else if (modelsData.embedding_models.length > 0) {
|
||||
setEmbeddingModel(modelsData.embedding_models[0].value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [modelsData, languageModel, embeddingModel]);
|
||||
|
||||
return {
|
||||
languageModel,
|
||||
embeddingModel,
|
||||
setLanguageModel,
|
||||
setEmbeddingModel,
|
||||
languageModels: modelsData?.language_models || [],
|
||||
embeddingModels: modelsData?.embedding_models || [],
|
||||
};
|
||||
}
|
||||
58
frontend/src/app/onboarding/hooks/useUpdateSettings.ts
Normal file
58
frontend/src/app/onboarding/hooks/useUpdateSettings.ts
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
import { useEffect } from "react";
|
||||
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
|
||||
|
||||
interface ConfigValues {
|
||||
apiKey?: string;
|
||||
endpoint?: string;
|
||||
projectId?: string;
|
||||
languageModel?: string;
|
||||
embeddingModel?: string;
|
||||
}
|
||||
|
||||
export function useUpdateSettings(
|
||||
provider: string,
|
||||
config: ConfigValues,
|
||||
setSettings: (settings: OnboardingVariables) => void,
|
||||
) {
|
||||
useEffect(() => {
|
||||
const updatedSettings: OnboardingVariables = {
|
||||
model_provider: provider,
|
||||
embedding_model: "",
|
||||
llm_model: "",
|
||||
};
|
||||
|
||||
// Set language model if provided
|
||||
if (config.languageModel) {
|
||||
updatedSettings.llm_model = config.languageModel;
|
||||
}
|
||||
|
||||
// Set embedding model if provided
|
||||
if (config.embeddingModel) {
|
||||
updatedSettings.embedding_model = config.embeddingModel;
|
||||
}
|
||||
|
||||
// Set API key if provided
|
||||
if (config.apiKey) {
|
||||
updatedSettings.api_key = config.apiKey;
|
||||
}
|
||||
|
||||
// Set endpoint and project ID if provided
|
||||
if (config.endpoint) {
|
||||
updatedSettings.endpoint = config.endpoint;
|
||||
}
|
||||
|
||||
if (config.projectId) {
|
||||
updatedSettings.project_id = config.projectId;
|
||||
}
|
||||
|
||||
setSettings(updatedSettings);
|
||||
}, [
|
||||
provider,
|
||||
config.apiKey,
|
||||
config.endpoint,
|
||||
config.projectId,
|
||||
config.languageModel,
|
||||
config.embeddingModel,
|
||||
setSettings,
|
||||
]);
|
||||
}
|
||||
|
|
@ -1,110 +0,0 @@
|
|||
import { useState, useEffect } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import IBMLogo from "@/components/logo/ibm-logo";
|
||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||
import { useGetIBMModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
|
||||
export function IBMOnboarding({
|
||||
settings,
|
||||
setSettings,
|
||||
sampleDataset,
|
||||
setSampleDataset,
|
||||
}: {
|
||||
settings: Settings;
|
||||
setSettings: (settings: Settings) => void;
|
||||
sampleDataset: boolean;
|
||||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const [endpoint, setEndpoint] = useState("");
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [projectId, setProjectId] = useState("");
|
||||
const [languageModel, setLanguageModel] = useState("meta-llama/llama-3-1-70b-instruct");
|
||||
const [embeddingModel, setEmbeddingModel] = useState("ibm/slate-125m-english-rtrvr");
|
||||
|
||||
// Fetch models from API when all credentials are provided
|
||||
const { data: modelsData } = useGetIBMModelsQuery(
|
||||
(apiKey && endpoint && projectId) ? { api_key: apiKey, endpoint, project_id: projectId } : undefined,
|
||||
{ enabled: !!(apiKey && endpoint && projectId) }
|
||||
);
|
||||
|
||||
// Use fetched models or fallback to defaults
|
||||
const languageModels = modelsData?.language_models || [
|
||||
{ value: "meta-llama/llama-3-1-70b-instruct", label: "Llama 3.1 70B Instruct", default: true },
|
||||
{ value: "meta-llama/llama-3-1-8b-instruct", label: "Llama 3.1 8B Instruct" },
|
||||
{ value: "ibm/granite-13b-chat-v2", label: "Granite 13B Chat v2" },
|
||||
{ value: "ibm/granite-13b-instruct-v2", label: "Granite 13B Instruct v2" },
|
||||
];
|
||||
const embeddingModels = modelsData?.embedding_models || [
|
||||
{ value: "ibm/slate-125m-english-rtrvr", label: "Slate 125M English Retriever", default: true },
|
||||
{ value: "sentence-transformers/all-minilm-l12-v2", label: "All-MiniLM L12 v2" },
|
||||
];
|
||||
|
||||
// Update default selections when models are loaded
|
||||
useEffect(() => {
|
||||
if (modelsData) {
|
||||
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||
|
||||
if (defaultLangModel) {
|
||||
setLanguageModel(defaultLangModel.value);
|
||||
}
|
||||
if (defaultEmbedModel) {
|
||||
setEmbeddingModel(defaultEmbedModel.value);
|
||||
}
|
||||
}
|
||||
}, [modelsData]);
|
||||
const handleLanguageModelChange = (model: string) => {
|
||||
setLanguageModel(model);
|
||||
};
|
||||
|
||||
const handleEmbeddingModelChange = (model: string) => {
|
||||
setEmbeddingModel(model);
|
||||
};
|
||||
|
||||
const handleSampleDatasetChange = (dataset: boolean) => {
|
||||
setSampleDataset(dataset);
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<LabelInput
|
||||
label="watsonx.ai API Endpoint"
|
||||
helperText="The API endpoint for your watsonx.ai account."
|
||||
id="api-endpoint"
|
||||
required
|
||||
placeholder="https://us-south.ml.cloud.ibm.com"
|
||||
value={endpoint}
|
||||
onChange={(e) => setEndpoint(e.target.value)}
|
||||
/>
|
||||
<LabelInput
|
||||
label="IBM API key"
|
||||
helperText="The API key for your watsonx.ai account."
|
||||
id="api-key"
|
||||
required
|
||||
placeholder="your-api-key"
|
||||
value={apiKey}
|
||||
onChange={(e) => setApiKey(e.target.value)}
|
||||
/>
|
||||
<LabelInput
|
||||
label="IBM Project ID"
|
||||
helperText="The project ID for your watsonx.ai account."
|
||||
id="project-id"
|
||||
required
|
||||
placeholder="your-project-id"
|
||||
value={projectId}
|
||||
onChange={(e) => setProjectId(e.target.value)}
|
||||
/>
|
||||
<AdvancedOnboarding
|
||||
icon={<IBMLogo className="w-4 h-4" />}
|
||||
languageModels={languageModels}
|
||||
embeddingModels={embeddingModels}
|
||||
languageModel={languageModel}
|
||||
embeddingModel={embeddingModel}
|
||||
sampleDataset={sampleDataset}
|
||||
setLanguageModel={handleLanguageModelChange}
|
||||
setEmbeddingModel={handleEmbeddingModelChange}
|
||||
setSampleDataset={handleSampleDatasetChange}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,105 +0,0 @@
|
|||
import { useState, useEffect } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import { LabelWrapper } from "@/components/label-wrapper";
|
||||
import OllamaLogo from "@/components/logo/ollama-logo";
|
||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||
import { useGetOllamaModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
import { ModelSelector } from "./model-selector";
|
||||
|
||||
export function OllamaOnboarding({
|
||||
settings,
|
||||
setSettings,
|
||||
sampleDataset,
|
||||
setSampleDataset,
|
||||
}: {
|
||||
settings: Settings;
|
||||
setSettings: (settings: Settings) => void;
|
||||
sampleDataset: boolean;
|
||||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const [endpoint, setEndpoint] = useState("");
|
||||
const [languageModel, setLanguageModel] = useState("llama3.2");
|
||||
const [embeddingModel, setEmbeddingModel] = useState("nomic-embed-text");
|
||||
|
||||
// Fetch models from API when endpoint is provided
|
||||
const { data: modelsData } = useGetOllamaModelsQuery(
|
||||
endpoint ? { endpoint } : undefined,
|
||||
{ enabled: !!endpoint }
|
||||
);
|
||||
|
||||
// Use fetched models or fallback to defaults
|
||||
const languageModels = modelsData?.language_models || [
|
||||
{ value: "llama3.2", label: "llama3.2", default: true },
|
||||
{ value: "llama3.1", label: "llama3.1" },
|
||||
{ value: "llama3", label: "llama3" },
|
||||
{ value: "mistral", label: "mistral" },
|
||||
{ value: "codellama", label: "codellama" },
|
||||
];
|
||||
const embeddingModels = modelsData?.embedding_models || [
|
||||
{ value: "nomic-embed-text", label: "nomic-embed-text", default: true },
|
||||
];
|
||||
|
||||
// Update default selections when models are loaded
|
||||
useEffect(() => {
|
||||
if (modelsData) {
|
||||
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||
|
||||
if (defaultLangModel) {
|
||||
setLanguageModel(defaultLangModel.value);
|
||||
}
|
||||
if (defaultEmbedModel) {
|
||||
setEmbeddingModel(defaultEmbedModel.value);
|
||||
}
|
||||
}
|
||||
}, [modelsData]);
|
||||
|
||||
const handleSampleDatasetChange = (dataset: boolean) => {
|
||||
setSampleDataset(dataset);
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<LabelInput
|
||||
label="Ollama Endpoint"
|
||||
helperText="The endpoint for your Ollama server."
|
||||
id="api-endpoint"
|
||||
required
|
||||
placeholder="http://localhost:11434"
|
||||
value={endpoint}
|
||||
onChange={(e) => setEndpoint(e.target.value)}
|
||||
/>
|
||||
<LabelWrapper
|
||||
label="Embedding model"
|
||||
helperText="The embedding model for your Ollama server."
|
||||
id="embedding-model"
|
||||
required={true}
|
||||
>
|
||||
<ModelSelector
|
||||
options={embeddingModels}
|
||||
icon={<OllamaLogo className="w-4 h-4" />}
|
||||
value={embeddingModel}
|
||||
onValueChange={setEmbeddingModel}
|
||||
/>
|
||||
</LabelWrapper>
|
||||
<LabelWrapper
|
||||
label="Language model"
|
||||
helperText="The embedding model for your Ollama server."
|
||||
id="embedding-model"
|
||||
required={true}
|
||||
>
|
||||
<ModelSelector
|
||||
options={languageModels}
|
||||
icon={<OllamaLogo className="w-4 h-4" />}
|
||||
value={languageModel}
|
||||
onValueChange={setLanguageModel}
|
||||
/>
|
||||
</LabelWrapper>
|
||||
|
||||
<AdvancedOnboarding
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={handleSampleDatasetChange}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
import { useState, useEffect } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import OpenAILogo from "@/components/logo/openai-logo";
|
||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||
import { useGetOpenAIModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
|
||||
export function OpenAIOnboarding({
|
||||
settings,
|
||||
setSettings,
|
||||
sampleDataset,
|
||||
setSampleDataset,
|
||||
}: {
|
||||
settings: Settings;
|
||||
setSettings: (settings: Settings) => void;
|
||||
sampleDataset: boolean;
|
||||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const [languageModel, setLanguageModel] = useState("gpt-4o-mini");
|
||||
const [embeddingModel, setEmbeddingModel] = useState(
|
||||
"text-embedding-3-small",
|
||||
);
|
||||
|
||||
// Fetch models from API
|
||||
const { data: modelsData } = useGetOpenAIModelsQuery();
|
||||
|
||||
// Use fetched models or fallback to defaults
|
||||
const languageModels = modelsData?.language_models || [{ value: "gpt-4o-mini", label: "gpt-4o-mini", default: true }];
|
||||
const embeddingModels = modelsData?.embedding_models || [
|
||||
{ value: "text-embedding-3-small", label: "text-embedding-3-small", default: true },
|
||||
];
|
||||
|
||||
// Update default selections when models are loaded
|
||||
useEffect(() => {
|
||||
if (modelsData) {
|
||||
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||
|
||||
if (defaultLangModel && languageModel === "gpt-4o-mini") {
|
||||
setLanguageModel(defaultLangModel.value);
|
||||
}
|
||||
if (defaultEmbedModel && embeddingModel === "text-embedding-3-small") {
|
||||
setEmbeddingModel(defaultEmbedModel.value);
|
||||
}
|
||||
}
|
||||
}, [modelsData, languageModel, embeddingModel]);
|
||||
const handleLanguageModelChange = (model: string) => {
|
||||
setLanguageModel(model);
|
||||
};
|
||||
|
||||
const handleEmbeddingModelChange = (model: string) => {
|
||||
setEmbeddingModel(model);
|
||||
};
|
||||
|
||||
const handleSampleDatasetChange = (dataset: boolean) => {
|
||||
setSampleDataset(dataset);
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<LabelInput
|
||||
label="OpenAI API key"
|
||||
helperText="The API key for your OpenAI account."
|
||||
id="api-key"
|
||||
required
|
||||
placeholder="sk-..."
|
||||
/>
|
||||
<AdvancedOnboarding
|
||||
icon={<OpenAILogo className="w-4 h-4" />}
|
||||
languageModels={languageModels}
|
||||
embeddingModels={embeddingModels}
|
||||
languageModel={languageModel}
|
||||
embeddingModel={embeddingModel}
|
||||
sampleDataset={sampleDataset}
|
||||
setLanguageModel={handleLanguageModelChange}
|
||||
setSampleDataset={handleSampleDatasetChange}
|
||||
setEmbeddingModel={handleEmbeddingModelChange}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
"use client";
|
||||
|
||||
import { Suspense, useState } from "react";
|
||||
import { Suspense, useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { useUpdateFlowSettingMutation } from "@/app/api/mutations/useUpdateFlowSettingMutation";
|
||||
import {
|
||||
type Settings,
|
||||
useGetSettingsQuery,
|
||||
} from "@/app/api/queries/useGetSettingsQuery";
|
||||
useOnboardingMutation,
|
||||
type OnboardingVariables,
|
||||
} from "@/app/api/mutations/useOnboardingMutation";
|
||||
import IBMLogo from "@/components/logo/ibm-logo";
|
||||
import OllamaLogo from "@/components/logo/ollama-logo";
|
||||
import OpenAILogo from "@/components/logo/openai-logo";
|
||||
|
|
@ -19,44 +18,102 @@ import {
|
|||
CardHeader,
|
||||
} from "@/components/ui/card";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { useAuth } from "@/contexts/auth-context";
|
||||
import { IBMOnboarding } from "./ibm-onboarding";
|
||||
import { OllamaOnboarding } from "./ollama-onboarding";
|
||||
import { OpenAIOnboarding } from "./openai-onboarding";
|
||||
import { IBMOnboarding } from "./components/ibm-onboarding";
|
||||
import { OllamaOnboarding } from "./components/ollama-onboarding";
|
||||
import { OpenAIOnboarding } from "./components/openai-onboarding";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { useGetSettingsQuery } from "../api/queries/useGetSettingsQuery";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
function OnboardingPage() {
|
||||
const { isAuthenticated } = useAuth();
|
||||
const { data: settingsDb, isLoading: isSettingsLoading } =
|
||||
useGetSettingsQuery();
|
||||
|
||||
const redirect = "/";
|
||||
|
||||
const router = useRouter();
|
||||
|
||||
// Redirect if already authenticated or in no-auth mode
|
||||
useEffect(() => {
|
||||
if (!isSettingsLoading && settingsDb && settingsDb.edited) {
|
||||
router.push(redirect);
|
||||
}
|
||||
}, [isSettingsLoading, redirect]);
|
||||
|
||||
const [modelProvider, setModelProvider] = useState<string>("openai");
|
||||
|
||||
const [sampleDataset, setSampleDataset] = useState<boolean>(false);
|
||||
// Fetch settings using React Query
|
||||
const { data: settingsDb = {} } = useGetSettingsQuery({
|
||||
enabled: isAuthenticated,
|
||||
const [sampleDataset, setSampleDataset] = useState<boolean>(true);
|
||||
|
||||
const handleSetModelProvider = (provider: string) => {
|
||||
setModelProvider(provider);
|
||||
setSettings({
|
||||
model_provider: provider,
|
||||
embedding_model: "",
|
||||
llm_model: "",
|
||||
});
|
||||
};
|
||||
|
||||
const [settings, setSettings] = useState<OnboardingVariables>({
|
||||
model_provider: modelProvider,
|
||||
embedding_model: "",
|
||||
llm_model: "",
|
||||
});
|
||||
|
||||
const [settings, setSettings] = useState<Settings>(settingsDb);
|
||||
|
||||
// Mutations
|
||||
const updateFlowSettingMutation = useUpdateFlowSettingMutation({
|
||||
onSuccess: () => {
|
||||
console.log("Setting updated successfully");
|
||||
const onboardingMutation = useOnboardingMutation({
|
||||
onSuccess: (data) => {
|
||||
toast.success("Onboarding completed successfully!");
|
||||
console.log("Onboarding completed successfully", data);
|
||||
},
|
||||
onError: (error) => {
|
||||
toast.error("Failed to update settings", {
|
||||
toast.error("Failed to complete onboarding", {
|
||||
description: error.message,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const handleComplete = () => {
|
||||
updateFlowSettingMutation.mutate({
|
||||
llm_model: settings.agent?.llm_model,
|
||||
embedding_model: settings.knowledge?.embedding_model,
|
||||
system_prompt: settings.agent?.system_prompt,
|
||||
});
|
||||
if (
|
||||
!settings.model_provider ||
|
||||
!settings.llm_model ||
|
||||
!settings.embedding_model
|
||||
) {
|
||||
toast.error("Please complete all required fields");
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare onboarding data
|
||||
const onboardingData: OnboardingVariables = {
|
||||
model_provider: settings.model_provider,
|
||||
llm_model: settings.llm_model,
|
||||
embedding_model: settings.embedding_model,
|
||||
sample_data: sampleDataset,
|
||||
};
|
||||
|
||||
// Add API key if available
|
||||
if (settings.api_key) {
|
||||
onboardingData.api_key = settings.api_key;
|
||||
}
|
||||
|
||||
// Add endpoint if available
|
||||
if (settings.endpoint) {
|
||||
onboardingData.endpoint = settings.endpoint;
|
||||
}
|
||||
|
||||
// Add project_id if available
|
||||
if (settings.project_id) {
|
||||
onboardingData.project_id = settings.project_id;
|
||||
}
|
||||
|
||||
onboardingMutation.mutate(onboardingData);
|
||||
};
|
||||
|
||||
const isComplete = !!settings.llm_model && !!settings.embedding_model;
|
||||
|
||||
return (
|
||||
<div
|
||||
className="min-h-dvh w-full flex gap-5 flex-col items-center justify-center bg-background p-4"
|
||||
|
|
@ -74,7 +131,10 @@ function OnboardingPage() {
|
|||
<p className="text-sm text-muted-foreground">[description of task]</p>
|
||||
</div>
|
||||
<Card className="w-full max-w-[580px]">
|
||||
<Tabs defaultValue={modelProvider} onValueChange={setModelProvider}>
|
||||
<Tabs
|
||||
defaultValue={modelProvider}
|
||||
onValueChange={handleSetModelProvider}
|
||||
>
|
||||
<CardHeader>
|
||||
<TabsList>
|
||||
<TabsTrigger value="openai">
|
||||
|
|
@ -94,7 +154,6 @@ function OnboardingPage() {
|
|||
<CardContent>
|
||||
<TabsContent value="openai">
|
||||
<OpenAIOnboarding
|
||||
settings={settings}
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
|
|
@ -102,7 +161,6 @@ function OnboardingPage() {
|
|||
</TabsContent>
|
||||
<TabsContent value="watsonx">
|
||||
<IBMOnboarding
|
||||
settings={settings}
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
|
|
@ -110,7 +168,6 @@ function OnboardingPage() {
|
|||
</TabsContent>
|
||||
<TabsContent value="ollama">
|
||||
<OllamaOnboarding
|
||||
settings={settings}
|
||||
setSettings={setSettings}
|
||||
sampleDataset={sampleDataset}
|
||||
setSampleDataset={setSampleDataset}
|
||||
|
|
@ -119,9 +176,21 @@ function OnboardingPage() {
|
|||
</CardContent>
|
||||
</Tabs>
|
||||
<CardFooter className="flex justify-end">
|
||||
<Button size="sm" onClick={handleComplete}>
|
||||
Complete
|
||||
</Button>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={handleComplete}
|
||||
disabled={!isComplete}
|
||||
loading={onboardingMutation.isPending}
|
||||
>
|
||||
Complete
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{!isComplete ? "Please fill in all required fields" : ""}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -15,15 +15,16 @@ import { useKnowledgeFilter } from "@/contexts/knowledge-filter-context";
|
|||
// import { GitHubStarButton } from "@/components/github-star-button"
|
||||
// import { DiscordLink } from "@/components/discord-link"
|
||||
import { useTask } from "@/contexts/task-context";
|
||||
import Logo from "@/components/logo/logo";
|
||||
|
||||
export function LayoutWrapper({ children }: { children: React.ReactNode }) {
|
||||
const pathname = usePathname();
|
||||
const { tasks, isMenuOpen, toggleMenu } = useTask();
|
||||
const { selectedFilter, setSelectedFilter, isPanelOpen } =
|
||||
useKnowledgeFilter();
|
||||
const { isLoading, isAuthenticated } = useAuth();
|
||||
const { isLoading, isAuthenticated, isNoAuthMode } = useAuth();
|
||||
const { isLoading: isSettingsLoading, data: settings } = useGetSettingsQuery({
|
||||
enabled: isAuthenticated,
|
||||
enabled: isAuthenticated || isNoAuthMode,
|
||||
});
|
||||
|
||||
// List of paths that should not show navigation
|
||||
|
|
@ -62,7 +63,7 @@ export function LayoutWrapper({ children }: { children: React.ReactNode }) {
|
|||
<div className="header-start-display px-4">
|
||||
{/* Logo/Title */}
|
||||
<div className="flex items-center gap-2">
|
||||
<Image src="/logo.svg" alt="OpenRAG Logo" width={24} height={22} />
|
||||
<Logo className="fill-primary" width={24} height={22} />
|
||||
<span className="text-lg font-semibold">OpenRAG</span>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) {
|
|||
const { isLoading, isAuthenticated, isNoAuthMode } = useAuth();
|
||||
const { data: settings = {}, isLoading: isSettingsLoading } =
|
||||
useGetSettingsQuery({
|
||||
enabled: isAuthenticated,
|
||||
enabled: isAuthenticated || isNoAuthMode,
|
||||
});
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
|
|
@ -31,12 +31,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) {
|
|||
);
|
||||
|
||||
useEffect(() => {
|
||||
// In no-auth mode, allow access without authentication
|
||||
if (isNoAuthMode) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isLoading && !isAuthenticated) {
|
||||
if (!isLoading && !isSettingsLoading && !isAuthenticated && !isNoAuthMode) {
|
||||
// Redirect to login with current path as redirect parameter
|
||||
const redirectUrl = `/login?redirect=${encodeURIComponent(pathname)}`;
|
||||
router.push(redirectUrl);
|
||||
|
|
@ -48,6 +43,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) {
|
|||
}
|
||||
}, [
|
||||
isLoading,
|
||||
isSettingsLoading,
|
||||
isAuthenticated,
|
||||
isNoAuthMode,
|
||||
router,
|
||||
|
|
@ -57,7 +53,7 @@ export function ProtectedRoute({ children }: ProtectedRouteProps) {
|
|||
]);
|
||||
|
||||
// Show loading state while checking authentication
|
||||
if (isLoading) {
|
||||
if (isLoading || isSettingsLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-64">
|
||||
<div className="flex flex-col items-center gap-4">
|
||||
|
|
|
|||
|
|
@ -7,7 +7,17 @@ logger = get_logger(__name__)
|
|||
async def get_openai_models(request, models_service, session_manager):
|
||||
"""Get available OpenAI models"""
|
||||
try:
|
||||
models = await models_service.get_openai_models()
|
||||
# Get API key from query parameters
|
||||
query_params = dict(request.query_params)
|
||||
api_key = query_params.get("api_key")
|
||||
|
||||
if not api_key:
|
||||
return JSONResponse(
|
||||
{"error": "OpenAI API key is required as query parameter"},
|
||||
status_code=400
|
||||
)
|
||||
|
||||
models = await models_service.get_openai_models(api_key=api_key)
|
||||
return JSONResponse(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OpenAI models: {str(e)}")
|
||||
|
|
@ -37,21 +47,15 @@ async def get_ollama_models(request, models_service, session_manager):
|
|||
async def get_ibm_models(request, models_service, session_manager):
|
||||
"""Get available IBM Watson models"""
|
||||
try:
|
||||
# Get credentials from query parameters or request body if provided
|
||||
if request.method == "POST":
|
||||
body = await request.json()
|
||||
api_key = body.get("api_key")
|
||||
endpoint = body.get("endpoint")
|
||||
project_id = body.get("project_id")
|
||||
else:
|
||||
query_params = dict(request.query_params)
|
||||
api_key = query_params.get("api_key")
|
||||
endpoint = query_params.get("endpoint")
|
||||
project_id = query_params.get("project_id")
|
||||
# Get parameters from query parameters if provided
|
||||
query_params = dict(request.query_params)
|
||||
endpoint = query_params.get("endpoint")
|
||||
api_key = query_params.get("api_key")
|
||||
project_id = query_params.get("project_id")
|
||||
|
||||
models = await models_service.get_ibm_models(
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
project_id=project_id
|
||||
)
|
||||
return JSONResponse(models)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,6 @@ def get_docling_tweaks(docling_preset: str = None) -> dict:
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
async def get_settings(request, session_manager):
|
||||
"""Get application settings"""
|
||||
try:
|
||||
|
|
@ -117,8 +116,7 @@ async def get_settings(request, session_manager):
|
|||
if LANGFLOW_INGEST_FLOW_ID and openrag_config.edited:
|
||||
try:
|
||||
response = await clients.langflow_request(
|
||||
"GET",
|
||||
f"/api/v1/flows/{LANGFLOW_INGEST_FLOW_ID}"
|
||||
"GET", f"/api/v1/flows/{LANGFLOW_INGEST_FLOW_ID}"
|
||||
)
|
||||
if response.status_code == 200:
|
||||
flow_data = response.json()
|
||||
|
|
@ -135,31 +133,23 @@ async def get_settings(request, session_manager):
|
|||
if flow_data.get("data", {}).get("nodes"):
|
||||
for node in flow_data["data"]["nodes"]:
|
||||
node_template = (
|
||||
node.get("data", {})
|
||||
.get("node", {})
|
||||
.get("template", {})
|
||||
node.get("data", {}).get("node", {}).get("template", {})
|
||||
)
|
||||
|
||||
# Split Text component (SplitText-QIKhg)
|
||||
if node.get("id") == "SplitText-QIKhg":
|
||||
if node_template.get("chunk_size", {}).get(
|
||||
"value"
|
||||
):
|
||||
ingestion_defaults["chunkSize"] = (
|
||||
node_template["chunk_size"]["value"]
|
||||
)
|
||||
if node_template.get("chunk_overlap", {}).get(
|
||||
"value"
|
||||
):
|
||||
ingestion_defaults["chunkOverlap"] = (
|
||||
node_template["chunk_overlap"]["value"]
|
||||
)
|
||||
if node_template.get("separator", {}).get(
|
||||
"value"
|
||||
):
|
||||
ingestion_defaults["separator"] = (
|
||||
node_template["separator"]["value"]
|
||||
)
|
||||
if node_template.get("chunk_size", {}).get("value"):
|
||||
ingestion_defaults["chunkSize"] = node_template[
|
||||
"chunk_size"
|
||||
]["value"]
|
||||
if node_template.get("chunk_overlap", {}).get("value"):
|
||||
ingestion_defaults["chunkOverlap"] = node_template[
|
||||
"chunk_overlap"
|
||||
]["value"]
|
||||
if node_template.get("separator", {}).get("value"):
|
||||
ingestion_defaults["separator"] = node_template[
|
||||
"separator"
|
||||
]["value"]
|
||||
|
||||
# OpenAI Embeddings component (OpenAIEmbeddings-joRJ6)
|
||||
elif node.get("id") == "OpenAIEmbeddings-joRJ6":
|
||||
|
|
@ -191,89 +181,116 @@ async def update_settings(request, session_manager):
|
|||
try:
|
||||
# Get current configuration
|
||||
current_config = get_openrag_config()
|
||||
|
||||
|
||||
# Check if config is marked as edited
|
||||
if not current_config.edited:
|
||||
return JSONResponse(
|
||||
{"error": "Configuration must be marked as edited before updates are allowed"},
|
||||
status_code=403
|
||||
{
|
||||
"error": "Configuration must be marked as edited before updates are allowed"
|
||||
},
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
|
||||
# Parse request body
|
||||
body = await request.json()
|
||||
|
||||
|
||||
# Validate allowed fields
|
||||
allowed_fields = {
|
||||
"llm_model", "system_prompt", "doclingPresets",
|
||||
"chunk_size", "chunk_overlap"
|
||||
"llm_model",
|
||||
"system_prompt",
|
||||
"ocr",
|
||||
"picture_descriptions",
|
||||
"chunk_size",
|
||||
"chunk_overlap",
|
||||
"doclingPresets",
|
||||
}
|
||||
|
||||
|
||||
# Check for invalid fields
|
||||
invalid_fields = set(body.keys()) - allowed_fields
|
||||
if invalid_fields:
|
||||
return JSONResponse(
|
||||
{"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"},
|
||||
status_code=400
|
||||
{
|
||||
"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
|
||||
# Update configuration
|
||||
config_updated = False
|
||||
|
||||
|
||||
# Update agent settings
|
||||
if "llm_model" in body:
|
||||
current_config.agent.llm_model = body["llm_model"]
|
||||
config_updated = True
|
||||
|
||||
|
||||
if "system_prompt" in body:
|
||||
current_config.agent.system_prompt = body["system_prompt"]
|
||||
config_updated = True
|
||||
|
||||
|
||||
# Update knowledge settings
|
||||
if "doclingPresets" in body:
|
||||
preset_configs = get_docling_preset_configs()
|
||||
valid_presets = list(preset_configs.keys())
|
||||
if body["doclingPresets"] not in valid_presets:
|
||||
return JSONResponse(
|
||||
{"error": f"doclingPresets must be one of: {', '.join(valid_presets)}"},
|
||||
status_code=400
|
||||
{
|
||||
"error": f"doclingPresets must be one of: {', '.join(valid_presets)}"
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
current_config.knowledge.doclingPresets = body["doclingPresets"]
|
||||
config_updated = True
|
||||
|
||||
|
||||
if "ocr" in body:
|
||||
if not isinstance(body["ocr"], bool):
|
||||
return JSONResponse(
|
||||
{"error": "ocr must be a boolean value"}, status_code=400
|
||||
)
|
||||
current_config.knowledge.ocr = body["ocr"]
|
||||
config_updated = True
|
||||
|
||||
if "picture_descriptions" in body:
|
||||
if not isinstance(body["picture_descriptions"], bool):
|
||||
return JSONResponse(
|
||||
{"error": "picture_descriptions must be a boolean value"},
|
||||
status_code=400,
|
||||
)
|
||||
current_config.knowledge.picture_descriptions = body["picture_descriptions"]
|
||||
config_updated = True
|
||||
|
||||
if "chunk_size" in body:
|
||||
if not isinstance(body["chunk_size"], int) or body["chunk_size"] <= 0:
|
||||
return JSONResponse(
|
||||
{"error": "chunk_size must be a positive integer"},
|
||||
status_code=400
|
||||
{"error": "chunk_size must be a positive integer"}, status_code=400
|
||||
)
|
||||
current_config.knowledge.chunk_size = body["chunk_size"]
|
||||
config_updated = True
|
||||
|
||||
|
||||
if "chunk_overlap" in body:
|
||||
if not isinstance(body["chunk_overlap"], int) or body["chunk_overlap"] < 0:
|
||||
return JSONResponse(
|
||||
{"error": "chunk_overlap must be a non-negative integer"},
|
||||
status_code=400
|
||||
{"error": "chunk_overlap must be a non-negative integer"},
|
||||
status_code=400,
|
||||
)
|
||||
current_config.knowledge.chunk_overlap = body["chunk_overlap"]
|
||||
config_updated = True
|
||||
|
||||
|
||||
if not config_updated:
|
||||
return JSONResponse(
|
||||
{"error": "No valid fields provided for update"},
|
||||
status_code=400
|
||||
{"error": "No valid fields provided for update"}, status_code=400
|
||||
)
|
||||
|
||||
|
||||
# Save the updated configuration
|
||||
if config_manager.save_config_file(current_config):
|
||||
logger.info("Configuration updated successfully", updated_fields=list(body.keys()))
|
||||
logger.info(
|
||||
"Configuration updated successfully", updated_fields=list(body.keys())
|
||||
)
|
||||
return JSONResponse({"message": "Configuration updated successfully"})
|
||||
else:
|
||||
return JSONResponse(
|
||||
{"error": "Failed to save configuration"},
|
||||
status_code=500
|
||||
{"error": "Failed to save configuration"}, status_code=500
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update settings", error=str(e))
|
||||
return JSONResponse(
|
||||
|
|
@ -286,120 +303,168 @@ async def onboarding(request, flows_service):
|
|||
try:
|
||||
# Get current configuration
|
||||
current_config = get_openrag_config()
|
||||
|
||||
|
||||
# Check if config is NOT marked as edited (only allow onboarding if not yet configured)
|
||||
if current_config.edited:
|
||||
return JSONResponse(
|
||||
{"error": "Configuration has already been edited. Use /settings endpoint for updates."},
|
||||
status_code=403
|
||||
{
|
||||
"error": "Configuration has already been edited. Use /settings endpoint for updates."
|
||||
},
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
|
||||
# Parse request body
|
||||
body = await request.json()
|
||||
|
||||
|
||||
# Validate allowed fields
|
||||
allowed_fields = {
|
||||
"model_provider", "api_key", "embedding_model", "llm_model", "sample_data"
|
||||
"model_provider",
|
||||
"api_key",
|
||||
"embedding_model",
|
||||
"llm_model",
|
||||
"sample_data",
|
||||
"endpoint",
|
||||
"project_id",
|
||||
}
|
||||
|
||||
|
||||
# Check for invalid fields
|
||||
invalid_fields = set(body.keys()) - allowed_fields
|
||||
if invalid_fields:
|
||||
return JSONResponse(
|
||||
{"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"},
|
||||
status_code=400
|
||||
{
|
||||
"error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}"
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
|
||||
# Update configuration
|
||||
config_updated = False
|
||||
|
||||
|
||||
# Update provider settings
|
||||
if "model_provider" in body:
|
||||
if not isinstance(body["model_provider"], str) or not body["model_provider"].strip():
|
||||
if (
|
||||
not isinstance(body["model_provider"], str)
|
||||
or not body["model_provider"].strip()
|
||||
):
|
||||
return JSONResponse(
|
||||
{"error": "model_provider must be a non-empty string"},
|
||||
status_code=400
|
||||
{"error": "model_provider must be a non-empty string"},
|
||||
status_code=400,
|
||||
)
|
||||
current_config.provider.model_provider = body["model_provider"].strip()
|
||||
config_updated = True
|
||||
|
||||
|
||||
if "api_key" in body:
|
||||
if not isinstance(body["api_key"], str):
|
||||
return JSONResponse(
|
||||
{"error": "api_key must be a string"},
|
||||
status_code=400
|
||||
{"error": "api_key must be a string"}, status_code=400
|
||||
)
|
||||
current_config.provider.api_key = body["api_key"]
|
||||
config_updated = True
|
||||
|
||||
|
||||
# Update knowledge settings
|
||||
if "embedding_model" in body:
|
||||
if not isinstance(body["embedding_model"], str) or not body["embedding_model"].strip():
|
||||
if (
|
||||
not isinstance(body["embedding_model"], str)
|
||||
or not body["embedding_model"].strip()
|
||||
):
|
||||
return JSONResponse(
|
||||
{"error": "embedding_model must be a non-empty string"},
|
||||
status_code=400
|
||||
{"error": "embedding_model must be a non-empty string"},
|
||||
status_code=400,
|
||||
)
|
||||
current_config.knowledge.embedding_model = body["embedding_model"].strip()
|
||||
config_updated = True
|
||||
|
||||
|
||||
# Update agent settings
|
||||
if "llm_model" in body:
|
||||
if not isinstance(body["llm_model"], str) or not body["llm_model"].strip():
|
||||
return JSONResponse(
|
||||
{"error": "llm_model must be a non-empty string"},
|
||||
status_code=400
|
||||
{"error": "llm_model must be a non-empty string"}, status_code=400
|
||||
)
|
||||
current_config.agent.llm_model = body["llm_model"].strip()
|
||||
config_updated = True
|
||||
|
||||
|
||||
if "endpoint" in body:
|
||||
if not isinstance(body["endpoint"], str) or not body["endpoint"].strip():
|
||||
return JSONResponse(
|
||||
{"error": "endpoint must be a non-empty string"}, status_code=400
|
||||
)
|
||||
current_config.provider.endpoint = body["endpoint"].strip()
|
||||
config_updated = True
|
||||
|
||||
if "project_id" in body:
|
||||
if (
|
||||
not isinstance(body["project_id"], str)
|
||||
or not body["project_id"].strip()
|
||||
):
|
||||
return JSONResponse(
|
||||
{"error": "project_id must be a non-empty string"}, status_code=400
|
||||
)
|
||||
current_config.provider.project_id = body["project_id"].strip()
|
||||
config_updated = True
|
||||
|
||||
# Handle sample_data (unused for now but validate)
|
||||
if "sample_data" in body:
|
||||
if not isinstance(body["sample_data"], bool):
|
||||
return JSONResponse(
|
||||
{"error": "sample_data must be a boolean value"},
|
||||
status_code=400
|
||||
{"error": "sample_data must be a boolean value"}, status_code=400
|
||||
)
|
||||
# Note: sample_data is accepted but not used as requested
|
||||
|
||||
|
||||
if not config_updated:
|
||||
return JSONResponse(
|
||||
{"error": "No valid fields provided for update"},
|
||||
status_code=400
|
||||
{"error": "No valid fields provided for update"}, status_code=400
|
||||
)
|
||||
|
||||
|
||||
# Save the updated configuration (this will mark it as edited)
|
||||
if config_manager.save_config_file(current_config):
|
||||
updated_fields = [k for k in body.keys() if k != "sample_data"] # Exclude sample_data from log
|
||||
logger.info("Onboarding configuration updated successfully", updated_fields=updated_fields)
|
||||
|
||||
updated_fields = [
|
||||
k for k in body.keys() if k != "sample_data"
|
||||
] # Exclude sample_data from log
|
||||
logger.info(
|
||||
"Onboarding configuration updated successfully",
|
||||
updated_fields=updated_fields,
|
||||
)
|
||||
|
||||
# If model_provider was updated, assign the new provider to flows
|
||||
if "model_provider" in body:
|
||||
provider = body["model_provider"].strip().lower()
|
||||
try:
|
||||
flow_result = await flows_service.assign_model_provider(provider)
|
||||
|
||||
|
||||
if flow_result.get("success"):
|
||||
logger.info(f"Successfully assigned {provider} to flows", flow_result=flow_result)
|
||||
logger.info(
|
||||
f"Successfully assigned {provider} to flows",
|
||||
flow_result=flow_result,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Failed to assign {provider} to flows", flow_result=flow_result)
|
||||
logger.warning(
|
||||
f"Failed to assign {provider} to flows",
|
||||
flow_result=flow_result,
|
||||
)
|
||||
# Continue even if flow assignment fails - configuration was still saved
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error assigning model provider to flows", provider=provider, error=str(e))
|
||||
logger.error(
|
||||
f"Error assigning model provider to flows",
|
||||
provider=provider,
|
||||
error=str(e),
|
||||
)
|
||||
# Continue even if flow assignment fails - configuration was still saved
|
||||
|
||||
return JSONResponse({
|
||||
"message": "Onboarding configuration updated successfully",
|
||||
"edited": True # Confirm that config is now marked as edited
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": "Onboarding configuration updated successfully",
|
||||
"edited": True, # Confirm that config is now marked as edited
|
||||
}
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
{"error": "Failed to save configuration"},
|
||||
status_code=500
|
||||
{"error": "Failed to save configuration"}, status_code=500
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update onboarding settings", error=str(e))
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to update onboarding settings: {str(e)}"}, status_code=500
|
||||
{"error": f"Failed to update onboarding settings: {str(e)}"},
|
||||
status_code=500,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import httpx
|
||||
import os
|
||||
from typing import Dict, List
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
|
|
@ -12,14 +11,9 @@ class ModelsService:
|
|||
def __init__(self):
|
||||
self.session_manager = None
|
||||
|
||||
async def get_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Fetch available models from OpenAI API"""
|
||||
try:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set, using default models")
|
||||
return self._get_default_openai_models()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
|
|
@ -74,12 +68,14 @@ class ModelsService:
|
|||
"embedding_models": embedding_models,
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Failed to fetch OpenAI models: {response.status_code}")
|
||||
return self._get_default_openai_models()
|
||||
logger.error(f"Failed to fetch OpenAI models: {response.status_code}")
|
||||
raise Exception(
|
||||
f"OpenAI API returned status code {response.status_code}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching OpenAI models: {str(e)}")
|
||||
return self._get_default_openai_models()
|
||||
raise
|
||||
|
||||
async def get_ollama_models(
|
||||
self, endpoint: str = None
|
||||
|
|
@ -87,9 +83,7 @@ class ModelsService:
|
|||
"""Fetch available models from Ollama API"""
|
||||
try:
|
||||
# Use provided endpoint or default
|
||||
ollama_url = endpoint or os.getenv(
|
||||
"OLLAMA_BASE_URL", "http://localhost:11434"
|
||||
)
|
||||
ollama_url = endpoint
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||
|
|
@ -145,32 +139,34 @@ class ModelsService:
|
|||
|
||||
return {
|
||||
"language_models": language_models,
|
||||
"embedding_models": embedding_models
|
||||
if embedding_models
|
||||
else [
|
||||
{
|
||||
"value": "nomic-embed-text",
|
||||
"label": "nomic-embed-text",
|
||||
"default": True,
|
||||
}
|
||||
],
|
||||
"embedding_models": embedding_models if embedding_models else [],
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Failed to fetch Ollama models: {response.status_code}")
|
||||
return self._get_default_ollama_models()
|
||||
logger.error(f"Failed to fetch Ollama models: {response.status_code}")
|
||||
raise Exception(
|
||||
f"Ollama API returned status code {response.status_code}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Ollama models: {str(e)}")
|
||||
return self._get_default_ollama_models()
|
||||
raise
|
||||
|
||||
async def get_ibm_models(
|
||||
self, endpoint: str = None
|
||||
self, endpoint: str = None, api_key: str = None, project_id: str = None
|
||||
) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Fetch available models from IBM Watson API"""
|
||||
try:
|
||||
# Use provided endpoint or default
|
||||
watson_endpoint = endpoint or os.getenv("IBM_WATSON_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
|
||||
watson_endpoint = endpoint
|
||||
|
||||
# Prepare headers for authentication
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if project_id:
|
||||
headers["Project-ID"] = project_id
|
||||
# Fetch foundation models using the correct endpoint
|
||||
models_url = f"{watson_endpoint}/ml/v1/foundation_model_specs"
|
||||
|
||||
|
|
@ -181,9 +177,14 @@ class ModelsService:
|
|||
# Fetch text chat models
|
||||
text_params = {
|
||||
"version": "2024-09-16",
|
||||
"filters": "function_text_chat,!lifecycle_withdrawn"
|
||||
"filters": "function_text_chat,!lifecycle_withdrawn",
|
||||
}
|
||||
text_response = await client.get(models_url, params=text_params, timeout=10.0)
|
||||
if project_id:
|
||||
text_params["project_id"] = project_id
|
||||
|
||||
text_response = await client.get(
|
||||
models_url, params=text_params, headers=headers, timeout=10.0
|
||||
)
|
||||
|
||||
if text_response.status_code == 200:
|
||||
text_data = text_response.json()
|
||||
|
|
@ -193,18 +194,25 @@ class ModelsService:
|
|||
model_id = model.get("model_id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
language_models.append({
|
||||
"value": model_id,
|
||||
"label": model_name or model_id,
|
||||
"default": i == 0 # First model is default
|
||||
})
|
||||
language_models.append(
|
||||
{
|
||||
"value": model_id,
|
||||
"label": model_name or model_id,
|
||||
"default": i == 0, # First model is default
|
||||
}
|
||||
)
|
||||
|
||||
# Fetch embedding models
|
||||
embed_params = {
|
||||
"version": "2024-09-16",
|
||||
"filters": "function_embedding,!lifecycle_withdrawn"
|
||||
"filters": "function_embedding,!lifecycle_withdrawn",
|
||||
}
|
||||
embed_response = await client.get(models_url, params=embed_params, timeout=10.0)
|
||||
if project_id:
|
||||
embed_params["project_id"] = project_id
|
||||
|
||||
embed_response = await client.get(
|
||||
models_url, params=embed_params, headers=headers, timeout=10.0
|
||||
)
|
||||
|
||||
if embed_response.status_code == 200:
|
||||
embed_data = embed_response.json()
|
||||
|
|
@ -214,104 +222,22 @@ class ModelsService:
|
|||
model_id = model.get("model_id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
embedding_models.append({
|
||||
"value": model_id,
|
||||
"label": model_name or model_id,
|
||||
"default": i == 0 # First model is default
|
||||
})
|
||||
embedding_models.append(
|
||||
{
|
||||
"value": model_id,
|
||||
"label": model_name or model_id,
|
||||
"default": i == 0, # First model is default
|
||||
}
|
||||
)
|
||||
|
||||
if not language_models and not embedding_models:
|
||||
raise Exception("No IBM models retrieved from API")
|
||||
|
||||
return {
|
||||
"language_models": language_models if language_models else self._get_default_ibm_models()["language_models"],
|
||||
"embedding_models": embedding_models if embedding_models else self._get_default_ibm_models()["embedding_models"]
|
||||
"language_models": language_models,
|
||||
"embedding_models": embedding_models,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching IBM models: {str(e)}")
|
||||
return self._get_default_ibm_models()
|
||||
|
||||
def _get_default_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Default OpenAI models when API is not available"""
|
||||
return {
|
||||
"language_models": [
|
||||
{"value": "gpt-4o-mini", "label": "gpt-4o-mini", "default": True},
|
||||
{"value": "gpt-4o", "label": "gpt-4o", "default": False},
|
||||
{"value": "gpt-4-turbo", "label": "gpt-4-turbo", "default": False},
|
||||
{"value": "gpt-3.5-turbo", "label": "gpt-3.5-turbo", "default": False},
|
||||
],
|
||||
"embedding_models": [
|
||||
{
|
||||
"value": "text-embedding-3-small",
|
||||
"label": "text-embedding-3-small",
|
||||
"default": True,
|
||||
},
|
||||
{
|
||||
"value": "text-embedding-3-large",
|
||||
"label": "text-embedding-3-large",
|
||||
"default": False,
|
||||
},
|
||||
{
|
||||
"value": "text-embedding-ada-002",
|
||||
"label": "text-embedding-ada-002",
|
||||
"default": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
def _get_default_ollama_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Default Ollama models when API is not available"""
|
||||
return {
|
||||
"language_models": [
|
||||
{"value": "llama3.2", "label": "llama3.2", "default": True},
|
||||
{"value": "llama3.1", "label": "llama3.1", "default": False},
|
||||
{"value": "llama3", "label": "llama3", "default": False},
|
||||
{"value": "mistral", "label": "mistral", "default": False},
|
||||
{"value": "codellama", "label": "codellama", "default": False},
|
||||
],
|
||||
"embedding_models": [
|
||||
{
|
||||
"value": "nomic-embed-text",
|
||||
"label": "nomic-embed-text",
|
||||
"default": True,
|
||||
},
|
||||
{"value": "all-minilm", "label": "all-minilm", "default": False},
|
||||
],
|
||||
}
|
||||
|
||||
def _get_default_ibm_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Default IBM Watson models when API is not available"""
|
||||
return {
|
||||
"language_models": [
|
||||
{
|
||||
"value": "meta-llama/llama-3-1-70b-instruct",
|
||||
"label": "Llama 3.1 70B Instruct",
|
||||
"default": True,
|
||||
},
|
||||
{
|
||||
"value": "meta-llama/llama-3-1-8b-instruct",
|
||||
"label": "Llama 3.1 8B Instruct",
|
||||
"default": False,
|
||||
},
|
||||
{
|
||||
"value": "ibm/granite-13b-chat-v2",
|
||||
"label": "Granite 13B Chat v2",
|
||||
"default": False,
|
||||
},
|
||||
{
|
||||
"value": "ibm/granite-13b-instruct-v2",
|
||||
"label": "Granite 13B Instruct v2",
|
||||
"default": False,
|
||||
},
|
||||
],
|
||||
"embedding_models": [
|
||||
{
|
||||
"value": "ibm/slate-125m-english-rtrvr",
|
||||
"label": "Slate 125M English Retriever",
|
||||
"default": True,
|
||||
},
|
||||
{
|
||||
"value": "sentence-transformers/all-minilm-l12-v2",
|
||||
"label": "All-MiniLM L12 v2",
|
||||
"default": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue