update model selectors

This commit is contained in:
Mike Fortman 2025-09-18 17:03:51 -05:00
parent d68a787017
commit 6a11aec195
3 changed files with 170 additions and 26 deletions

View file

@ -0,0 +1,87 @@
import OpenAILogo from "@/components/logo/openai-logo";
import OllamaLogo from "@/components/logo/ollama-logo";
import IBMLogo from "@/components/logo/ibm-logo";
export type ModelProvider = 'openai' | 'ollama' | 'ibm';
export interface ModelOption {
value: string;
label: string;
}
// Helper function to get model logo based on provider or model name
export function getModelLogo(modelValue: string, provider?: ModelProvider) {
// First check by provider
if (provider === 'openai') {
return <OpenAILogo className="w-4 h-4" />;
} else if (provider === 'ollama') {
return <OllamaLogo className="w-4 h-4" />;
} else if (provider === 'ibm') {
return <IBMLogo className="w-4 h-4" />;
}
// Fallback to model name analysis
if (modelValue.includes('gpt') || modelValue.includes('text-embedding')) {
return <OpenAILogo className="w-4 h-4" />;
} else if (modelValue.includes('llama') || modelValue.includes('ollama')) {
return <OllamaLogo className="w-4 h-4" />;
} else if (modelValue.includes('granite') || modelValue.includes('slate') || modelValue.includes('ibm')) {
return <IBMLogo className="w-4 h-4" />;
}
return <OpenAILogo className="w-4 h-4" />; // Default to OpenAI logo
}
// Helper function to get fallback models by provider
export function getFallbackModels(provider: ModelProvider) {
switch (provider) {
case 'openai':
return {
language: [
{ value: 'gpt-4', label: 'GPT-4' },
{ value: 'gpt-4-turbo', label: 'GPT-4 Turbo' },
{ value: 'gpt-3.5-turbo', label: 'GPT-3.5 Turbo' },
],
embedding: [
{ value: 'text-embedding-ada-002', label: 'text-embedding-ada-002' },
{ value: 'text-embedding-3-small', label: 'text-embedding-3-small' },
{ value: 'text-embedding-3-large', label: 'text-embedding-3-large' },
],
};
case 'ollama':
return {
language: [
{ value: 'llama2', label: 'Llama 2' },
{ value: 'llama2:13b', label: 'Llama 2 13B' },
{ value: 'codellama', label: 'Code Llama' },
],
embedding: [
{ value: 'mxbai-embed-large', label: 'MxBai Embed Large' },
{ value: 'nomic-embed-text', label: 'Nomic Embed Text' },
],
};
case 'ibm':
return {
language: [
{ value: 'meta-llama/llama-3-1-70b-instruct', label: 'Llama 3.1 70B Instruct' },
{ value: 'ibm/granite-13b-chat-v2', label: 'Granite 13B Chat v2' },
],
embedding: [
{ value: 'ibm/slate-125m-english-rtrvr', label: 'Slate 125M English Retriever' },
],
};
default:
return {
language: [
{ value: 'gpt-4', label: 'GPT-4' },
{ value: 'gpt-4-turbo', label: 'GPT-4 Turbo' },
{ value: 'gpt-3.5-turbo', label: 'GPT-3.5 Turbo' },
],
embedding: [
{ value: 'text-embedding-ada-002', label: 'text-embedding-ada-002' },
{ value: 'text-embedding-3-small', label: 'text-embedding-3-small' },
{ value: 'text-embedding-3-large', label: 'text-embedding-3-large' },
],
};
}
}

View file

@ -0,0 +1,36 @@
import { SelectItem } from "@/components/ui/select";
import { getModelLogo, type ModelProvider, type ModelOption } from "./model-helpers";
interface ModelSelectItemProps {
model: ModelOption;
provider?: ModelProvider;
}
export function ModelSelectItem({ model, provider }: ModelSelectItemProps) {
return (
<SelectItem value={model.value}>
<div className="flex items-center gap-2">
{getModelLogo(model.value, provider)}
<span>{model.label}</span>
</div>
</SelectItem>
);
}
interface ModelSelectItemsProps {
models?: ModelOption[];
fallbackModels: ModelOption[];
provider: ModelProvider;
}
export function ModelSelectItems({ models, fallbackModels, provider }: ModelSelectItemsProps) {
const modelsToRender = models || fallbackModels;
return (
<>
{modelsToRender.map((model) => (
<ModelSelectItem key={model.value} model={model} provider={provider} />
))}
</>
);
}

View file

@ -5,7 +5,10 @@ import { useSearchParams } from "next/navigation";
import { Suspense, useCallback, useEffect, useState } from "react";
import { useUpdateFlowSettingMutation } from "@/app/api/mutations/useUpdateFlowSettingMutation";
import { useGetSettingsQuery } from "@/app/api/queries/useGetSettingsQuery";
import { useGetOpenAIModelsQuery, useGetOllamaModelsQuery, useGetIBMModelsQuery } from "@/app/api/queries/useGetModelsQuery";
import { ConfirmationDialog } from "@/components/confirmation-dialog";
import { ModelSelectItems } from "./helpers/model-select-item";
import { getFallbackModels, type ModelProvider } from "./helpers/model-helpers";
import { ProtectedRoute } from "@/components/protected-route";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
@ -22,7 +25,6 @@ import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
@ -104,6 +106,37 @@ function KnowledgeSourcesPage() {
enabled: isAuthenticated,
});
// Get the current provider from settings
const currentProvider = (settings.provider?.model_provider || 'openai') as ModelProvider;
// Fetch available models based on provider
const { data: openaiModelsData } = useGetOpenAIModelsQuery({
enabled: isAuthenticated && currentProvider === 'openai',
});
const { data: ollamaModelsData } = useGetOllamaModelsQuery(
undefined, // No params for now, could be extended later
{
enabled: isAuthenticated && currentProvider === 'ollama',
}
);
const { data: ibmModelsData } = useGetIBMModelsQuery(
undefined, // No params for now, could be extended later
{
enabled: isAuthenticated && currentProvider === 'ibm',
}
);
// Select the appropriate models data based on provider
const modelsData = currentProvider === 'openai'
? openaiModelsData
: currentProvider === 'ollama'
? ollamaModelsData
: currentProvider === 'ibm'
? ibmModelsData
: openaiModelsData; // fallback to openai
// Mutations
const updateFlowSettingMutation = useUpdateFlowSettingMutation({
onSuccess: () => {
@ -171,6 +204,7 @@ function KnowledgeSourcesPage() {
debouncedUpdate({ chunk_overlap: numValue });
};
// Helper function to get connector icon
const getConnectorIcon = useCallback((iconName: string) => {
const iconMap: { [key: string]: React.ReactElement } = {
@ -559,21 +593,18 @@ function KnowledgeSourcesPage() {
Language Model
</Label>
<Select
value={settings.agent?.llm_model || "gpt-4"}
value={settings.agent?.llm_model || modelsData?.language_models?.find(m => m.default)?.value || "gpt-4"}
onValueChange={handleModelChange}
>
<SelectTrigger id="model-select">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
<SelectItem value="gpt-4">GPT-4</SelectItem>
<SelectItem value="gpt-4-turbo">GPT-4 Turbo</SelectItem>
<SelectItem value="gpt-3.5-turbo">GPT-3.5 Turbo</SelectItem>
<SelectItem value="claude-3-opus">Claude 3 Opus</SelectItem>
<SelectItem value="claude-3-sonnet">
Claude 3 Sonnet
</SelectItem>
<SelectItem value="claude-3-haiku">Claude 3 Haiku</SelectItem>
<ModelSelectItems
models={modelsData?.language_models}
fallbackModels={getFallbackModels(currentProvider).language}
provider={currentProvider}
/>
</SelectContent>
</Select>
</div>
@ -685,7 +716,7 @@ function KnowledgeSourcesPage() {
</Label>
<Select
value={
settings.knowledge?.embedding_model || "text-embedding-ada-002"
settings.knowledge?.embedding_model || modelsData?.embedding_models?.find(m => m.default)?.value || "text-embedding-ada-002"
}
onValueChange={handleEmbeddingModelChange}
>
@ -693,21 +724,11 @@ function KnowledgeSourcesPage() {
<SelectValue placeholder="Select an embedding model" />
</SelectTrigger>
<SelectContent>
<SelectItem value="text-embedding-ada-002">
text-embedding-ada-002
</SelectItem>
<SelectItem value="text-embedding-3-small">
text-embedding-3-small
</SelectItem>
<SelectItem value="text-embedding-3-large">
text-embedding-3-large
</SelectItem>
<SelectItem value="all-MiniLM-L6-v2">
all-MiniLM-L6-v2
</SelectItem>
<SelectItem value="all-mpnet-base-v2">
all-mpnet-base-v2
</SelectItem>
<ModelSelectItems
models={modelsData?.embedding_models}
fallbackModels={getFallbackModels(currentProvider).embedding}
provider={currentProvider}
/>
</SelectContent>
</Select>
</div>