update model selectors
This commit is contained in:
parent
d68a787017
commit
6a11aec195
3 changed files with 170 additions and 26 deletions
87
frontend/src/app/settings/helpers/model-helpers.tsx
Normal file
87
frontend/src/app/settings/helpers/model-helpers.tsx
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import OpenAILogo from "@/components/logo/openai-logo";
|
||||
import OllamaLogo from "@/components/logo/ollama-logo";
|
||||
import IBMLogo from "@/components/logo/ibm-logo";
|
||||
|
||||
export type ModelProvider = 'openai' | 'ollama' | 'ibm';
|
||||
|
||||
export interface ModelOption {
|
||||
value: string;
|
||||
label: string;
|
||||
}
|
||||
|
||||
// Helper function to get model logo based on provider or model name
|
||||
export function getModelLogo(modelValue: string, provider?: ModelProvider) {
|
||||
// First check by provider
|
||||
if (provider === 'openai') {
|
||||
return <OpenAILogo className="w-4 h-4" />;
|
||||
} else if (provider === 'ollama') {
|
||||
return <OllamaLogo className="w-4 h-4" />;
|
||||
} else if (provider === 'ibm') {
|
||||
return <IBMLogo className="w-4 h-4" />;
|
||||
}
|
||||
|
||||
// Fallback to model name analysis
|
||||
if (modelValue.includes('gpt') || modelValue.includes('text-embedding')) {
|
||||
return <OpenAILogo className="w-4 h-4" />;
|
||||
} else if (modelValue.includes('llama') || modelValue.includes('ollama')) {
|
||||
return <OllamaLogo className="w-4 h-4" />;
|
||||
} else if (modelValue.includes('granite') || modelValue.includes('slate') || modelValue.includes('ibm')) {
|
||||
return <IBMLogo className="w-4 h-4" />;
|
||||
}
|
||||
|
||||
return <OpenAILogo className="w-4 h-4" />; // Default to OpenAI logo
|
||||
}
|
||||
|
||||
// Helper function to get fallback models by provider
|
||||
export function getFallbackModels(provider: ModelProvider) {
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
return {
|
||||
language: [
|
||||
{ value: 'gpt-4', label: 'GPT-4' },
|
||||
{ value: 'gpt-4-turbo', label: 'GPT-4 Turbo' },
|
||||
{ value: 'gpt-3.5-turbo', label: 'GPT-3.5 Turbo' },
|
||||
],
|
||||
embedding: [
|
||||
{ value: 'text-embedding-ada-002', label: 'text-embedding-ada-002' },
|
||||
{ value: 'text-embedding-3-small', label: 'text-embedding-3-small' },
|
||||
{ value: 'text-embedding-3-large', label: 'text-embedding-3-large' },
|
||||
],
|
||||
};
|
||||
case 'ollama':
|
||||
return {
|
||||
language: [
|
||||
{ value: 'llama2', label: 'Llama 2' },
|
||||
{ value: 'llama2:13b', label: 'Llama 2 13B' },
|
||||
{ value: 'codellama', label: 'Code Llama' },
|
||||
],
|
||||
embedding: [
|
||||
{ value: 'mxbai-embed-large', label: 'MxBai Embed Large' },
|
||||
{ value: 'nomic-embed-text', label: 'Nomic Embed Text' },
|
||||
],
|
||||
};
|
||||
case 'ibm':
|
||||
return {
|
||||
language: [
|
||||
{ value: 'meta-llama/llama-3-1-70b-instruct', label: 'Llama 3.1 70B Instruct' },
|
||||
{ value: 'ibm/granite-13b-chat-v2', label: 'Granite 13B Chat v2' },
|
||||
],
|
||||
embedding: [
|
||||
{ value: 'ibm/slate-125m-english-rtrvr', label: 'Slate 125M English Retriever' },
|
||||
],
|
||||
};
|
||||
default:
|
||||
return {
|
||||
language: [
|
||||
{ value: 'gpt-4', label: 'GPT-4' },
|
||||
{ value: 'gpt-4-turbo', label: 'GPT-4 Turbo' },
|
||||
{ value: 'gpt-3.5-turbo', label: 'GPT-3.5 Turbo' },
|
||||
],
|
||||
embedding: [
|
||||
{ value: 'text-embedding-ada-002', label: 'text-embedding-ada-002' },
|
||||
{ value: 'text-embedding-3-small', label: 'text-embedding-3-small' },
|
||||
{ value: 'text-embedding-3-large', label: 'text-embedding-3-large' },
|
||||
],
|
||||
};
|
||||
}
|
||||
}
|
||||
36
frontend/src/app/settings/helpers/model-select-item.tsx
Normal file
36
frontend/src/app/settings/helpers/model-select-item.tsx
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import { SelectItem } from "@/components/ui/select";
|
||||
import { getModelLogo, type ModelProvider, type ModelOption } from "./model-helpers";
|
||||
|
||||
interface ModelSelectItemProps {
|
||||
model: ModelOption;
|
||||
provider?: ModelProvider;
|
||||
}
|
||||
|
||||
export function ModelSelectItem({ model, provider }: ModelSelectItemProps) {
|
||||
return (
|
||||
<SelectItem value={model.value}>
|
||||
<div className="flex items-center gap-2">
|
||||
{getModelLogo(model.value, provider)}
|
||||
<span>{model.label}</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
}
|
||||
|
||||
interface ModelSelectItemsProps {
|
||||
models?: ModelOption[];
|
||||
fallbackModels: ModelOption[];
|
||||
provider: ModelProvider;
|
||||
}
|
||||
|
||||
export function ModelSelectItems({ models, fallbackModels, provider }: ModelSelectItemsProps) {
|
||||
const modelsToRender = models || fallbackModels;
|
||||
|
||||
return (
|
||||
<>
|
||||
{modelsToRender.map((model) => (
|
||||
<ModelSelectItem key={model.value} model={model} provider={provider} />
|
||||
))}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
|
@ -5,7 +5,10 @@ import { useSearchParams } from "next/navigation";
|
|||
import { Suspense, useCallback, useEffect, useState } from "react";
|
||||
import { useUpdateFlowSettingMutation } from "@/app/api/mutations/useUpdateFlowSettingMutation";
|
||||
import { useGetSettingsQuery } from "@/app/api/queries/useGetSettingsQuery";
|
||||
import { useGetOpenAIModelsQuery, useGetOllamaModelsQuery, useGetIBMModelsQuery } from "@/app/api/queries/useGetModelsQuery";
|
||||
import { ConfirmationDialog } from "@/components/confirmation-dialog";
|
||||
import { ModelSelectItems } from "./helpers/model-select-item";
|
||||
import { getFallbackModels, type ModelProvider } from "./helpers/model-helpers";
|
||||
import { ProtectedRoute } from "@/components/protected-route";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
|
@ -22,7 +25,6 @@ import { Label } from "@/components/ui/label";
|
|||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
|
|
@ -104,6 +106,37 @@ function KnowledgeSourcesPage() {
|
|||
enabled: isAuthenticated,
|
||||
});
|
||||
|
||||
// Get the current provider from settings
|
||||
const currentProvider = (settings.provider?.model_provider || 'openai') as ModelProvider;
|
||||
|
||||
// Fetch available models based on provider
|
||||
const { data: openaiModelsData } = useGetOpenAIModelsQuery({
|
||||
enabled: isAuthenticated && currentProvider === 'openai',
|
||||
});
|
||||
|
||||
const { data: ollamaModelsData } = useGetOllamaModelsQuery(
|
||||
undefined, // No params for now, could be extended later
|
||||
{
|
||||
enabled: isAuthenticated && currentProvider === 'ollama',
|
||||
}
|
||||
);
|
||||
|
||||
const { data: ibmModelsData } = useGetIBMModelsQuery(
|
||||
undefined, // No params for now, could be extended later
|
||||
{
|
||||
enabled: isAuthenticated && currentProvider === 'ibm',
|
||||
}
|
||||
);
|
||||
|
||||
// Select the appropriate models data based on provider
|
||||
const modelsData = currentProvider === 'openai'
|
||||
? openaiModelsData
|
||||
: currentProvider === 'ollama'
|
||||
? ollamaModelsData
|
||||
: currentProvider === 'ibm'
|
||||
? ibmModelsData
|
||||
: openaiModelsData; // fallback to openai
|
||||
|
||||
// Mutations
|
||||
const updateFlowSettingMutation = useUpdateFlowSettingMutation({
|
||||
onSuccess: () => {
|
||||
|
|
@ -171,6 +204,7 @@ function KnowledgeSourcesPage() {
|
|||
debouncedUpdate({ chunk_overlap: numValue });
|
||||
};
|
||||
|
||||
|
||||
// Helper function to get connector icon
|
||||
const getConnectorIcon = useCallback((iconName: string) => {
|
||||
const iconMap: { [key: string]: React.ReactElement } = {
|
||||
|
|
@ -559,21 +593,18 @@ function KnowledgeSourcesPage() {
|
|||
Language Model
|
||||
</Label>
|
||||
<Select
|
||||
value={settings.agent?.llm_model || "gpt-4"}
|
||||
value={settings.agent?.llm_model || modelsData?.language_models?.find(m => m.default)?.value || "gpt-4"}
|
||||
onValueChange={handleModelChange}
|
||||
>
|
||||
<SelectTrigger id="model-select">
|
||||
<SelectValue placeholder="Select a model" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="gpt-4">GPT-4</SelectItem>
|
||||
<SelectItem value="gpt-4-turbo">GPT-4 Turbo</SelectItem>
|
||||
<SelectItem value="gpt-3.5-turbo">GPT-3.5 Turbo</SelectItem>
|
||||
<SelectItem value="claude-3-opus">Claude 3 Opus</SelectItem>
|
||||
<SelectItem value="claude-3-sonnet">
|
||||
Claude 3 Sonnet
|
||||
</SelectItem>
|
||||
<SelectItem value="claude-3-haiku">Claude 3 Haiku</SelectItem>
|
||||
<ModelSelectItems
|
||||
models={modelsData?.language_models}
|
||||
fallbackModels={getFallbackModels(currentProvider).language}
|
||||
provider={currentProvider}
|
||||
/>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
|
@ -685,7 +716,7 @@ function KnowledgeSourcesPage() {
|
|||
</Label>
|
||||
<Select
|
||||
value={
|
||||
settings.knowledge?.embedding_model || "text-embedding-ada-002"
|
||||
settings.knowledge?.embedding_model || modelsData?.embedding_models?.find(m => m.default)?.value || "text-embedding-ada-002"
|
||||
}
|
||||
onValueChange={handleEmbeddingModelChange}
|
||||
>
|
||||
|
|
@ -693,21 +724,11 @@ function KnowledgeSourcesPage() {
|
|||
<SelectValue placeholder="Select an embedding model" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="text-embedding-ada-002">
|
||||
text-embedding-ada-002
|
||||
</SelectItem>
|
||||
<SelectItem value="text-embedding-3-small">
|
||||
text-embedding-3-small
|
||||
</SelectItem>
|
||||
<SelectItem value="text-embedding-3-large">
|
||||
text-embedding-3-large
|
||||
</SelectItem>
|
||||
<SelectItem value="all-MiniLM-L6-v2">
|
||||
all-MiniLM-L6-v2
|
||||
</SelectItem>
|
||||
<SelectItem value="all-mpnet-base-v2">
|
||||
all-mpnet-base-v2
|
||||
</SelectItem>
|
||||
<ModelSelectItems
|
||||
models={modelsData?.embedding_models}
|
||||
fallbackModels={getFallbackModels(currentProvider).embedding}
|
||||
provider={currentProvider}
|
||||
/>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue