Added models fetching

This commit is contained in:
Lucas Oliveira 2025-09-18 18:25:00 -03:00
parent 4cbdf6ed19
commit d68a787017
7 changed files with 668 additions and 37 deletions

View file

@ -0,0 +1,140 @@
import {
type UseQueryOptions,
useQuery,
useQueryClient,
} from "@tanstack/react-query";
export interface ModelOption {
value: string;
label: string;
default?: boolean;
}
export interface ModelsResponse {
language_models: ModelOption[];
embedding_models: ModelOption[];
}
export interface OllamaModelsParams {
endpoint?: string;
}
export interface IBMModelsParams {
api_key?: string;
endpoint?: string;
project_id?: string;
}
export const useGetOpenAIModelsQuery = (
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
) => {
const queryClient = useQueryClient();
async function getOpenAIModels(): Promise<ModelsResponse> {
const response = await fetch("/api/models/openai");
if (response.ok) {
return await response.json();
} else {
throw new Error("Failed to fetch OpenAI models");
}
}
const queryResult = useQuery(
{
queryKey: ["models", "openai"],
queryFn: getOpenAIModels,
staleTime: 5 * 60 * 1000, // 5 minutes
...options,
},
queryClient,
);
return queryResult;
};
export const useGetOllamaModelsQuery = (
params?: OllamaModelsParams,
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
) => {
const queryClient = useQueryClient();
async function getOllamaModels(): Promise<ModelsResponse> {
const url = new URL("/api/models/ollama", window.location.origin);
if (params?.endpoint) {
url.searchParams.set("endpoint", params.endpoint);
}
const response = await fetch(url.toString());
if (response.ok) {
return await response.json();
} else {
throw new Error("Failed to fetch Ollama models");
}
}
const queryResult = useQuery(
{
queryKey: ["models", "ollama", params],
queryFn: getOllamaModels,
staleTime: 5 * 60 * 1000, // 5 minutes
enabled: !!params?.endpoint, // Only run if endpoint is provided
...options,
},
queryClient,
);
return queryResult;
};
export const useGetIBMModelsQuery = (
params?: IBMModelsParams,
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
) => {
const queryClient = useQueryClient();
async function getIBMModels(): Promise<ModelsResponse> {
const url = "/api/models/ibm";
// If we have credentials, use POST to send them securely
if (params?.api_key || params?.endpoint || params?.project_id) {
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_key: params.api_key,
endpoint: params.endpoint,
project_id: params.project_id,
}),
});
if (response.ok) {
return await response.json();
} else {
throw new Error("Failed to fetch IBM models");
}
} else {
// Use GET for default models
const response = await fetch(url);
if (response.ok) {
return await response.json();
} else {
throw new Error("Failed to fetch IBM models");
}
}
}
const queryResult = useQuery(
{
queryKey: ["models", "ibm", params],
queryFn: getIBMModels,
staleTime: 5 * 60 * 1000, // 5 minutes
enabled: !!(params?.api_key && params?.endpoint && params?.project_id), // Only run if all credentials are provided
...options,
},
queryClient,
);
return queryResult;
};

View file

@ -1,7 +1,8 @@
import { useState } from "react";
import { useState, useEffect } from "react";
import { LabelInput } from "@/components/label-input";
import IBMLogo from "@/components/logo/ibm-logo";
import type { Settings } from "../api/queries/useGetSettingsQuery";
import { useGetIBMModelsQuery } from "../api/queries/useGetModelsQuery";
import { AdvancedOnboarding } from "./advanced";
export function IBMOnboarding({
@ -15,21 +16,44 @@ export function IBMOnboarding({
sampleDataset: boolean;
setSampleDataset: (dataset: boolean) => void;
}) {
const languageModels = [
{ value: "gpt-oss", label: "gpt-oss" },
{ value: "llama3.1", label: "llama3.1" },
{ value: "llama3.2", label: "llama3.2" },
{ value: "llama3.3", label: "llama3.3" },
{ value: "llama3.4", label: "llama3.4" },
{ value: "llama3.5", label: "llama3.5" },
];
const embeddingModels = [
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
];
const [languageModel, setLanguageModel] = useState("gpt-oss");
const [embeddingModel, setEmbeddingModel] = useState(
"text-embedding-3-small",
const [endpoint, setEndpoint] = useState("");
const [apiKey, setApiKey] = useState("");
const [projectId, setProjectId] = useState("");
const [languageModel, setLanguageModel] = useState("meta-llama/llama-3-1-70b-instruct");
const [embeddingModel, setEmbeddingModel] = useState("ibm/slate-125m-english-rtrvr");
// Fetch models from API when all credentials are provided
const { data: modelsData } = useGetIBMModelsQuery(
(apiKey && endpoint && projectId) ? { api_key: apiKey, endpoint, project_id: projectId } : undefined,
{ enabled: !!(apiKey && endpoint && projectId) }
);
// Use fetched models or fallback to defaults
const languageModels = modelsData?.language_models || [
{ value: "meta-llama/llama-3-1-70b-instruct", label: "Llama 3.1 70B Instruct", default: true },
{ value: "meta-llama/llama-3-1-8b-instruct", label: "Llama 3.1 8B Instruct" },
{ value: "ibm/granite-13b-chat-v2", label: "Granite 13B Chat v2" },
{ value: "ibm/granite-13b-instruct-v2", label: "Granite 13B Instruct v2" },
];
const embeddingModels = modelsData?.embedding_models || [
{ value: "ibm/slate-125m-english-rtrvr", label: "Slate 125M English Retriever", default: true },
{ value: "sentence-transformers/all-minilm-l12-v2", label: "All-MiniLM L12 v2" },
];
// Update default selections when models are loaded
useEffect(() => {
if (modelsData) {
const defaultLangModel = modelsData.language_models.find(m => m.default);
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
if (defaultLangModel) {
setLanguageModel(defaultLangModel.value);
}
if (defaultEmbedModel) {
setEmbeddingModel(defaultEmbedModel.value);
}
}
}, [modelsData]);
const handleLanguageModelChange = (model: string) => {
setLanguageModel(model);
};
@ -48,21 +72,27 @@ export function IBMOnboarding({
helperText="The API endpoint for your watsonx.ai account."
id="api-endpoint"
required
placeholder="https://..."
placeholder="https://us-south.ml.cloud.ibm.com"
value={endpoint}
onChange={(e) => setEndpoint(e.target.value)}
/>
<LabelInput
label="IBM API key"
helperText="The API key for your watsonx.ai account."
id="api-key"
required
placeholder="sk-..."
placeholder="your-api-key"
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
/>
<LabelInput
label="IBM Project ID"
helperText="The project ID for your watsonx.ai account."
id="project-id"
required
placeholder="..."
placeholder="your-project-id"
value={projectId}
onChange={(e) => setProjectId(e.target.value)}
/>
<AdvancedOnboarding
icon={<IBMLogo className="w-4 h-4" />}

View file

@ -1,8 +1,9 @@
import { useState } from "react";
import { useState, useEffect } from "react";
import { LabelInput } from "@/components/label-input";
import { LabelWrapper } from "@/components/label-wrapper";
import OllamaLogo from "@/components/logo/ollama-logo";
import type { Settings } from "../api/queries/useGetSettingsQuery";
import { useGetOllamaModelsQuery } from "../api/queries/useGetModelsQuery";
import { AdvancedOnboarding } from "./advanced";
import { ModelSelector } from "./model-selector";
@ -17,24 +18,43 @@ export function OllamaOnboarding({
sampleDataset: boolean;
setSampleDataset: (dataset: boolean) => void;
}) {
const [open, setOpen] = useState(false);
const [value, setValue] = useState("");
const [languageModel, setLanguageModel] = useState("gpt-oss");
const [embeddingModel, setEmbeddingModel] = useState(
"text-embedding-3-small",
const [endpoint, setEndpoint] = useState("");
const [languageModel, setLanguageModel] = useState("llama3.2");
const [embeddingModel, setEmbeddingModel] = useState("nomic-embed-text");
// Fetch models from API when endpoint is provided
const { data: modelsData } = useGetOllamaModelsQuery(
endpoint ? { endpoint } : undefined,
{ enabled: !!endpoint }
);
const languageModels = [
{ value: "gpt-oss", label: "gpt-oss", default: true },
// Use fetched models or fallback to defaults
const languageModels = modelsData?.language_models || [
{ value: "llama3.2", label: "llama3.2", default: true },
{ value: "llama3.1", label: "llama3.1" },
{ value: "llama3.2", label: "llama3.2" },
{ value: "llama3.3", label: "llama3.3" },
{ value: "llama3.4", label: "llama3.4" },
{ value: "llama3.5", label: "llama3.5" },
{ value: "llama3", label: "llama3" },
{ value: "mistral", label: "mistral" },
{ value: "codellama", label: "codellama" },
];
const embeddingModels = [
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
const embeddingModels = modelsData?.embedding_models || [
{ value: "nomic-embed-text", label: "nomic-embed-text", default: true },
];
// Update default selections when models are loaded
useEffect(() => {
if (modelsData) {
const defaultLangModel = modelsData.language_models.find(m => m.default);
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
if (defaultLangModel) {
setLanguageModel(defaultLangModel.value);
}
if (defaultEmbedModel) {
setEmbeddingModel(defaultEmbedModel.value);
}
}
}, [modelsData]);
const handleSampleDatasetChange = (dataset: boolean) => {
setSampleDataset(dataset);
};
@ -45,7 +65,9 @@ export function OllamaOnboarding({
helperText="The endpoint for your Ollama server."
id="api-endpoint"
required
placeholder="http://..."
placeholder="http://localhost:11434"
value={endpoint}
onChange={(e) => setEndpoint(e.target.value)}
/>
<LabelWrapper
label="Embedding model"

View file

@ -1,7 +1,8 @@
import { useState } from "react";
import { useState, useEffect } from "react";
import { LabelInput } from "@/components/label-input";
import OpenAILogo from "@/components/logo/openai-logo";
import type { Settings } from "../api/queries/useGetSettingsQuery";
import { useGetOpenAIModelsQuery } from "../api/queries/useGetModelsQuery";
import { AdvancedOnboarding } from "./advanced";
export function OpenAIOnboarding({
@ -19,10 +20,30 @@ export function OpenAIOnboarding({
const [embeddingModel, setEmbeddingModel] = useState(
"text-embedding-3-small",
);
const languageModels = [{ value: "gpt-4o-mini", label: "gpt-4o-mini" }];
const embeddingModels = [
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
// Fetch models from API
const { data: modelsData } = useGetOpenAIModelsQuery();
// Use fetched models or fallback to defaults
const languageModels = modelsData?.language_models || [{ value: "gpt-4o-mini", label: "gpt-4o-mini", default: true }];
const embeddingModels = modelsData?.embedding_models || [
{ value: "text-embedding-3-small", label: "text-embedding-3-small", default: true },
];
// Update default selections when models are loaded
useEffect(() => {
if (modelsData) {
const defaultLangModel = modelsData.language_models.find(m => m.default);
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
if (defaultLangModel && languageModel === "gpt-4o-mini") {
setLanguageModel(defaultLangModel.value);
}
if (defaultEmbedModel && embeddingModel === "text-embedding-3-small") {
setEmbeddingModel(defaultEmbedModel.value);
}
}
}, [modelsData, languageModel, embeddingModel]);
const handleLanguageModelChange = (model: string) => {
setLanguageModel(model);
};

63
src/api/models.py Normal file
View file

@ -0,0 +1,63 @@
from starlette.responses import JSONResponse
from utils.logging_config import get_logger
logger = get_logger(__name__)
async def get_openai_models(request, models_service, session_manager):
"""Get available OpenAI models"""
try:
models = await models_service.get_openai_models()
return JSONResponse(models)
except Exception as e:
logger.error(f"Failed to get OpenAI models: {str(e)}")
return JSONResponse(
{"error": f"Failed to retrieve OpenAI models: {str(e)}"},
status_code=500
)
async def get_ollama_models(request, models_service, session_manager):
"""Get available Ollama models"""
try:
# Get endpoint from query parameters if provided
query_params = dict(request.query_params)
endpoint = query_params.get("endpoint")
models = await models_service.get_ollama_models(endpoint=endpoint)
return JSONResponse(models)
except Exception as e:
logger.error(f"Failed to get Ollama models: {str(e)}")
return JSONResponse(
{"error": f"Failed to retrieve Ollama models: {str(e)}"},
status_code=500
)
async def get_ibm_models(request, models_service, session_manager):
"""Get available IBM Watson models"""
try:
# Get credentials from query parameters or request body if provided
if request.method == "POST":
body = await request.json()
api_key = body.get("api_key")
endpoint = body.get("endpoint")
project_id = body.get("project_id")
else:
query_params = dict(request.query_params)
api_key = query_params.get("api_key")
endpoint = query_params.get("endpoint")
project_id = query_params.get("project_id")
models = await models_service.get_ibm_models(
api_key=api_key,
endpoint=endpoint,
project_id=project_id
)
return JSONResponse(models)
except Exception as e:
logger.error(f"Failed to get IBM models: {str(e)}")
return JSONResponse(
{"error": f"Failed to retrieve IBM models: {str(e)}"},
status_code=500
)

View file

@ -33,6 +33,7 @@ from api import (
flows,
knowledge_filter,
langflow_files,
models,
nudges,
oidc,
router,
@ -66,6 +67,7 @@ from services.knowledge_filter_service import KnowledgeFilterService
# Configuration and setup
# Services
from services.langflow_file_service import LangflowFileService
from services.models_service import ModelsService
from services.monitor_service import MonitorService
from services.search_service import SearchService
from services.task_service import TaskService
@ -409,6 +411,7 @@ async def initialize_services():
chat_service = ChatService()
flows_service = FlowsService()
knowledge_filter_service = KnowledgeFilterService(session_manager)
models_service = ModelsService()
monitor_service = MonitorService(session_manager)
# Set process pool for document service
@ -470,6 +473,7 @@ async def initialize_services():
"auth_service": auth_service,
"connector_service": connector_service,
"knowledge_filter_service": knowledge_filter_service,
"models_service": models_service,
"monitor_service": monitor_service,
"session_manager": session_manager,
}
@ -909,6 +913,40 @@ async def create_app():
),
methods=["POST"],
),
# Models endpoints
Route(
"/models/openai",
require_auth(services["session_manager"])(
partial(
models.get_openai_models,
models_service=services["models_service"],
session_manager=services["session_manager"]
)
),
methods=["GET"],
),
Route(
"/models/ollama",
require_auth(services["session_manager"])(
partial(
models.get_ollama_models,
models_service=services["models_service"],
session_manager=services["session_manager"]
)
),
methods=["GET"],
),
Route(
"/models/ibm",
require_auth(services["session_manager"])(
partial(
models.get_ibm_models,
models_service=services["models_service"],
session_manager=services["session_manager"]
)
),
methods=["GET", "POST"],
),
# Onboarding endpoint
Route(
"/onboarding",

View file

@ -0,0 +1,317 @@
import httpx
import os
from typing import Dict, List
from utils.logging_config import get_logger
logger = get_logger(__name__)
class ModelsService:
"""Service for fetching available models from different AI providers"""
def __init__(self):
self.session_manager = None
async def get_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
"""Fetch available models from OpenAI API"""
try:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set, using default models")
return self._get_default_openai_models()
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
response = await client.get(
"https://api.openai.com/v1/models", headers=headers, timeout=10.0
)
if response.status_code == 200:
data = response.json()
models = data.get("data", [])
# Filter for relevant models
language_models = []
embedding_models = []
for model in models:
model_id = model.get("id", "")
# Language models (GPT models)
if any(prefix in model_id for prefix in ["gpt-4", "gpt-3.5"]):
language_models.append(
{
"value": model_id,
"label": model_id,
"default": model_id == "gpt-4o-mini",
}
)
# Embedding models
elif "text-embedding" in model_id:
embedding_models.append(
{
"value": model_id,
"label": model_id,
"default": model_id == "text-embedding-3-small",
}
)
# Sort by name and ensure defaults are first
language_models.sort(
key=lambda x: (not x.get("default", False), x["value"])
)
embedding_models.sort(
key=lambda x: (not x.get("default", False), x["value"])
)
return {
"language_models": language_models,
"embedding_models": embedding_models,
}
else:
logger.warning(f"Failed to fetch OpenAI models: {response.status_code}")
return self._get_default_openai_models()
except Exception as e:
logger.error(f"Error fetching OpenAI models: {str(e)}")
return self._get_default_openai_models()
async def get_ollama_models(
self, endpoint: str = None
) -> Dict[str, List[Dict[str, str]]]:
"""Fetch available models from Ollama API"""
try:
# Use provided endpoint or default
ollama_url = endpoint or os.getenv(
"OLLAMA_BASE_URL", "http://localhost:11434"
)
async with httpx.AsyncClient() as client:
response = await client.get(f"{ollama_url}/api/tags", timeout=10.0)
if response.status_code == 200:
data = response.json()
models = data.get("models", [])
# Extract model names
language_models = []
embedding_models = []
for model in models:
model_name = model.get("name", "").split(":")[
0
] # Remove tag if present
if model_name:
# Most Ollama models can be used as language models
language_models.append(
{
"value": model_name,
"label": model_name,
"default": "llama3" in model_name.lower(),
}
)
# Some models are specifically for embeddings
if any(
embed in model_name.lower()
for embed in ["embed", "sentence", "all-minilm"]
):
embedding_models.append(
{
"value": model_name,
"label": model_name,
"default": False,
}
)
# Remove duplicates and sort
language_models = list(
{m["value"]: m for m in language_models}.values()
)
embedding_models = list(
{m["value"]: m for m in embedding_models}.values()
)
language_models.sort(
key=lambda x: (not x.get("default", False), x["value"])
)
embedding_models.sort(key=lambda x: x["value"])
return {
"language_models": language_models,
"embedding_models": embedding_models
if embedding_models
else [
{
"value": "nomic-embed-text",
"label": "nomic-embed-text",
"default": True,
}
],
}
else:
logger.warning(f"Failed to fetch Ollama models: {response.status_code}")
return self._get_default_ollama_models()
except Exception as e:
logger.error(f"Error fetching Ollama models: {str(e)}")
return self._get_default_ollama_models()
async def get_ibm_models(
self, endpoint: str = None
) -> Dict[str, List[Dict[str, str]]]:
"""Fetch available models from IBM Watson API"""
try:
# Use provided endpoint or default
watson_endpoint = endpoint or os.getenv("IBM_WATSON_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
# Fetch foundation models using the correct endpoint
models_url = f"{watson_endpoint}/ml/v1/foundation_model_specs"
language_models = []
embedding_models = []
async with httpx.AsyncClient() as client:
# Fetch text chat models
text_params = {
"version": "2024-09-16",
"filters": "function_text_chat,!lifecycle_withdrawn"
}
text_response = await client.get(models_url, params=text_params, timeout=10.0)
if text_response.status_code == 200:
text_data = text_response.json()
text_models = text_data.get("resources", [])
for i, model in enumerate(text_models):
model_id = model.get("model_id", "")
model_name = model.get("name", model_id)
language_models.append({
"value": model_id,
"label": model_name or model_id,
"default": i == 0 # First model is default
})
# Fetch embedding models
embed_params = {
"version": "2024-09-16",
"filters": "function_embedding,!lifecycle_withdrawn"
}
embed_response = await client.get(models_url, params=embed_params, timeout=10.0)
if embed_response.status_code == 200:
embed_data = embed_response.json()
embed_models = embed_data.get("resources", [])
for i, model in enumerate(embed_models):
model_id = model.get("model_id", "")
model_name = model.get("name", model_id)
embedding_models.append({
"value": model_id,
"label": model_name or model_id,
"default": i == 0 # First model is default
})
return {
"language_models": language_models if language_models else self._get_default_ibm_models()["language_models"],
"embedding_models": embedding_models if embedding_models else self._get_default_ibm_models()["embedding_models"]
}
except Exception as e:
logger.error(f"Error fetching IBM models: {str(e)}")
return self._get_default_ibm_models()
def _get_default_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
"""Default OpenAI models when API is not available"""
return {
"language_models": [
{"value": "gpt-4o-mini", "label": "gpt-4o-mini", "default": True},
{"value": "gpt-4o", "label": "gpt-4o", "default": False},
{"value": "gpt-4-turbo", "label": "gpt-4-turbo", "default": False},
{"value": "gpt-3.5-turbo", "label": "gpt-3.5-turbo", "default": False},
],
"embedding_models": [
{
"value": "text-embedding-3-small",
"label": "text-embedding-3-small",
"default": True,
},
{
"value": "text-embedding-3-large",
"label": "text-embedding-3-large",
"default": False,
},
{
"value": "text-embedding-ada-002",
"label": "text-embedding-ada-002",
"default": False,
},
],
}
def _get_default_ollama_models(self) -> Dict[str, List[Dict[str, str]]]:
"""Default Ollama models when API is not available"""
return {
"language_models": [
{"value": "llama3.2", "label": "llama3.2", "default": True},
{"value": "llama3.1", "label": "llama3.1", "default": False},
{"value": "llama3", "label": "llama3", "default": False},
{"value": "mistral", "label": "mistral", "default": False},
{"value": "codellama", "label": "codellama", "default": False},
],
"embedding_models": [
{
"value": "nomic-embed-text",
"label": "nomic-embed-text",
"default": True,
},
{"value": "all-minilm", "label": "all-minilm", "default": False},
],
}
def _get_default_ibm_models(self) -> Dict[str, List[Dict[str, str]]]:
"""Default IBM Watson models when API is not available"""
return {
"language_models": [
{
"value": "meta-llama/llama-3-1-70b-instruct",
"label": "Llama 3.1 70B Instruct",
"default": True,
},
{
"value": "meta-llama/llama-3-1-8b-instruct",
"label": "Llama 3.1 8B Instruct",
"default": False,
},
{
"value": "ibm/granite-13b-chat-v2",
"label": "Granite 13B Chat v2",
"default": False,
},
{
"value": "ibm/granite-13b-instruct-v2",
"label": "Granite 13B Instruct v2",
"default": False,
},
],
"embedding_models": [
{
"value": "ibm/slate-125m-english-rtrvr",
"label": "Slate 125M English Retriever",
"default": True,
},
{
"value": "sentence-transformers/all-minilm-l12-v2",
"label": "All-MiniLM L12 v2",
"default": False,
},
],
}