Added models fetching
This commit is contained in:
parent
4cbdf6ed19
commit
d68a787017
7 changed files with 668 additions and 37 deletions
140
frontend/src/app/api/queries/useGetModelsQuery.ts
Normal file
140
frontend/src/app/api/queries/useGetModelsQuery.ts
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
import {
|
||||
type UseQueryOptions,
|
||||
useQuery,
|
||||
useQueryClient,
|
||||
} from "@tanstack/react-query";
|
||||
|
||||
export interface ModelOption {
|
||||
value: string;
|
||||
label: string;
|
||||
default?: boolean;
|
||||
}
|
||||
|
||||
export interface ModelsResponse {
|
||||
language_models: ModelOption[];
|
||||
embedding_models: ModelOption[];
|
||||
}
|
||||
|
||||
export interface OllamaModelsParams {
|
||||
endpoint?: string;
|
||||
}
|
||||
|
||||
export interface IBMModelsParams {
|
||||
api_key?: string;
|
||||
endpoint?: string;
|
||||
project_id?: string;
|
||||
}
|
||||
|
||||
export const useGetOpenAIModelsQuery = (
|
||||
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
async function getOpenAIModels(): Promise<ModelsResponse> {
|
||||
const response = await fetch("/api/models/openai");
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
} else {
|
||||
throw new Error("Failed to fetch OpenAI models");
|
||||
}
|
||||
}
|
||||
|
||||
const queryResult = useQuery(
|
||||
{
|
||||
queryKey: ["models", "openai"],
|
||||
queryFn: getOpenAIModels,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
...options,
|
||||
},
|
||||
queryClient,
|
||||
);
|
||||
|
||||
return queryResult;
|
||||
};
|
||||
|
||||
export const useGetOllamaModelsQuery = (
|
||||
params?: OllamaModelsParams,
|
||||
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
async function getOllamaModels(): Promise<ModelsResponse> {
|
||||
const url = new URL("/api/models/ollama", window.location.origin);
|
||||
if (params?.endpoint) {
|
||||
url.searchParams.set("endpoint", params.endpoint);
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString());
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
} else {
|
||||
throw new Error("Failed to fetch Ollama models");
|
||||
}
|
||||
}
|
||||
|
||||
const queryResult = useQuery(
|
||||
{
|
||||
queryKey: ["models", "ollama", params],
|
||||
queryFn: getOllamaModels,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
enabled: !!params?.endpoint, // Only run if endpoint is provided
|
||||
...options,
|
||||
},
|
||||
queryClient,
|
||||
);
|
||||
|
||||
return queryResult;
|
||||
};
|
||||
|
||||
export const useGetIBMModelsQuery = (
|
||||
params?: IBMModelsParams,
|
||||
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
async function getIBMModels(): Promise<ModelsResponse> {
|
||||
const url = "/api/models/ibm";
|
||||
|
||||
// If we have credentials, use POST to send them securely
|
||||
if (params?.api_key || params?.endpoint || params?.project_id) {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_key: params.api_key,
|
||||
endpoint: params.endpoint,
|
||||
project_id: params.project_id,
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
} else {
|
||||
throw new Error("Failed to fetch IBM models");
|
||||
}
|
||||
} else {
|
||||
// Use GET for default models
|
||||
const response = await fetch(url);
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
} else {
|
||||
throw new Error("Failed to fetch IBM models");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const queryResult = useQuery(
|
||||
{
|
||||
queryKey: ["models", "ibm", params],
|
||||
queryFn: getIBMModels,
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
enabled: !!(params?.api_key && params?.endpoint && params?.project_id), // Only run if all credentials are provided
|
||||
...options,
|
||||
},
|
||||
queryClient,
|
||||
);
|
||||
|
||||
return queryResult;
|
||||
};
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
import { useState } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import IBMLogo from "@/components/logo/ibm-logo";
|
||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||
import { useGetIBMModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
|
||||
export function IBMOnboarding({
|
||||
|
|
@ -15,21 +16,44 @@ export function IBMOnboarding({
|
|||
sampleDataset: boolean;
|
||||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const languageModels = [
|
||||
{ value: "gpt-oss", label: "gpt-oss" },
|
||||
{ value: "llama3.1", label: "llama3.1" },
|
||||
{ value: "llama3.2", label: "llama3.2" },
|
||||
{ value: "llama3.3", label: "llama3.3" },
|
||||
{ value: "llama3.4", label: "llama3.4" },
|
||||
{ value: "llama3.5", label: "llama3.5" },
|
||||
];
|
||||
const embeddingModels = [
|
||||
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
|
||||
];
|
||||
const [languageModel, setLanguageModel] = useState("gpt-oss");
|
||||
const [embeddingModel, setEmbeddingModel] = useState(
|
||||
"text-embedding-3-small",
|
||||
const [endpoint, setEndpoint] = useState("");
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [projectId, setProjectId] = useState("");
|
||||
const [languageModel, setLanguageModel] = useState("meta-llama/llama-3-1-70b-instruct");
|
||||
const [embeddingModel, setEmbeddingModel] = useState("ibm/slate-125m-english-rtrvr");
|
||||
|
||||
// Fetch models from API when all credentials are provided
|
||||
const { data: modelsData } = useGetIBMModelsQuery(
|
||||
(apiKey && endpoint && projectId) ? { api_key: apiKey, endpoint, project_id: projectId } : undefined,
|
||||
{ enabled: !!(apiKey && endpoint && projectId) }
|
||||
);
|
||||
|
||||
// Use fetched models or fallback to defaults
|
||||
const languageModels = modelsData?.language_models || [
|
||||
{ value: "meta-llama/llama-3-1-70b-instruct", label: "Llama 3.1 70B Instruct", default: true },
|
||||
{ value: "meta-llama/llama-3-1-8b-instruct", label: "Llama 3.1 8B Instruct" },
|
||||
{ value: "ibm/granite-13b-chat-v2", label: "Granite 13B Chat v2" },
|
||||
{ value: "ibm/granite-13b-instruct-v2", label: "Granite 13B Instruct v2" },
|
||||
];
|
||||
const embeddingModels = modelsData?.embedding_models || [
|
||||
{ value: "ibm/slate-125m-english-rtrvr", label: "Slate 125M English Retriever", default: true },
|
||||
{ value: "sentence-transformers/all-minilm-l12-v2", label: "All-MiniLM L12 v2" },
|
||||
];
|
||||
|
||||
// Update default selections when models are loaded
|
||||
useEffect(() => {
|
||||
if (modelsData) {
|
||||
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||
|
||||
if (defaultLangModel) {
|
||||
setLanguageModel(defaultLangModel.value);
|
||||
}
|
||||
if (defaultEmbedModel) {
|
||||
setEmbeddingModel(defaultEmbedModel.value);
|
||||
}
|
||||
}
|
||||
}, [modelsData]);
|
||||
const handleLanguageModelChange = (model: string) => {
|
||||
setLanguageModel(model);
|
||||
};
|
||||
|
|
@ -48,21 +72,27 @@ export function IBMOnboarding({
|
|||
helperText="The API endpoint for your watsonx.ai account."
|
||||
id="api-endpoint"
|
||||
required
|
||||
placeholder="https://..."
|
||||
placeholder="https://us-south.ml.cloud.ibm.com"
|
||||
value={endpoint}
|
||||
onChange={(e) => setEndpoint(e.target.value)}
|
||||
/>
|
||||
<LabelInput
|
||||
label="IBM API key"
|
||||
helperText="The API key for your watsonx.ai account."
|
||||
id="api-key"
|
||||
required
|
||||
placeholder="sk-..."
|
||||
placeholder="your-api-key"
|
||||
value={apiKey}
|
||||
onChange={(e) => setApiKey(e.target.value)}
|
||||
/>
|
||||
<LabelInput
|
||||
label="IBM Project ID"
|
||||
helperText="The project ID for your watsonx.ai account."
|
||||
id="project-id"
|
||||
required
|
||||
placeholder="..."
|
||||
placeholder="your-project-id"
|
||||
value={projectId}
|
||||
onChange={(e) => setProjectId(e.target.value)}
|
||||
/>
|
||||
<AdvancedOnboarding
|
||||
icon={<IBMLogo className="w-4 h-4" />}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import { useState } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import { LabelWrapper } from "@/components/label-wrapper";
|
||||
import OllamaLogo from "@/components/logo/ollama-logo";
|
||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||
import { useGetOllamaModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
import { ModelSelector } from "./model-selector";
|
||||
|
||||
|
|
@ -17,24 +18,43 @@ export function OllamaOnboarding({
|
|||
sampleDataset: boolean;
|
||||
setSampleDataset: (dataset: boolean) => void;
|
||||
}) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [value, setValue] = useState("");
|
||||
const [languageModel, setLanguageModel] = useState("gpt-oss");
|
||||
const [embeddingModel, setEmbeddingModel] = useState(
|
||||
"text-embedding-3-small",
|
||||
const [endpoint, setEndpoint] = useState("");
|
||||
const [languageModel, setLanguageModel] = useState("llama3.2");
|
||||
const [embeddingModel, setEmbeddingModel] = useState("nomic-embed-text");
|
||||
|
||||
// Fetch models from API when endpoint is provided
|
||||
const { data: modelsData } = useGetOllamaModelsQuery(
|
||||
endpoint ? { endpoint } : undefined,
|
||||
{ enabled: !!endpoint }
|
||||
);
|
||||
const languageModels = [
|
||||
{ value: "gpt-oss", label: "gpt-oss", default: true },
|
||||
|
||||
// Use fetched models or fallback to defaults
|
||||
const languageModels = modelsData?.language_models || [
|
||||
{ value: "llama3.2", label: "llama3.2", default: true },
|
||||
{ value: "llama3.1", label: "llama3.1" },
|
||||
{ value: "llama3.2", label: "llama3.2" },
|
||||
{ value: "llama3.3", label: "llama3.3" },
|
||||
{ value: "llama3.4", label: "llama3.4" },
|
||||
{ value: "llama3.5", label: "llama3.5" },
|
||||
{ value: "llama3", label: "llama3" },
|
||||
{ value: "mistral", label: "mistral" },
|
||||
{ value: "codellama", label: "codellama" },
|
||||
];
|
||||
const embeddingModels = [
|
||||
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
|
||||
const embeddingModels = modelsData?.embedding_models || [
|
||||
{ value: "nomic-embed-text", label: "nomic-embed-text", default: true },
|
||||
];
|
||||
|
||||
// Update default selections when models are loaded
|
||||
useEffect(() => {
|
||||
if (modelsData) {
|
||||
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||
|
||||
if (defaultLangModel) {
|
||||
setLanguageModel(defaultLangModel.value);
|
||||
}
|
||||
if (defaultEmbedModel) {
|
||||
setEmbeddingModel(defaultEmbedModel.value);
|
||||
}
|
||||
}
|
||||
}, [modelsData]);
|
||||
|
||||
const handleSampleDatasetChange = (dataset: boolean) => {
|
||||
setSampleDataset(dataset);
|
||||
};
|
||||
|
|
@ -45,7 +65,9 @@ export function OllamaOnboarding({
|
|||
helperText="The endpoint for your Ollama server."
|
||||
id="api-endpoint"
|
||||
required
|
||||
placeholder="http://..."
|
||||
placeholder="http://localhost:11434"
|
||||
value={endpoint}
|
||||
onChange={(e) => setEndpoint(e.target.value)}
|
||||
/>
|
||||
<LabelWrapper
|
||||
label="Embedding model"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import { useState } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { LabelInput } from "@/components/label-input";
|
||||
import OpenAILogo from "@/components/logo/openai-logo";
|
||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||
import { useGetOpenAIModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||
import { AdvancedOnboarding } from "./advanced";
|
||||
|
||||
export function OpenAIOnboarding({
|
||||
|
|
@ -19,10 +20,30 @@ export function OpenAIOnboarding({
|
|||
const [embeddingModel, setEmbeddingModel] = useState(
|
||||
"text-embedding-3-small",
|
||||
);
|
||||
const languageModels = [{ value: "gpt-4o-mini", label: "gpt-4o-mini" }];
|
||||
const embeddingModels = [
|
||||
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
|
||||
|
||||
// Fetch models from API
|
||||
const { data: modelsData } = useGetOpenAIModelsQuery();
|
||||
|
||||
// Use fetched models or fallback to defaults
|
||||
const languageModels = modelsData?.language_models || [{ value: "gpt-4o-mini", label: "gpt-4o-mini", default: true }];
|
||||
const embeddingModels = modelsData?.embedding_models || [
|
||||
{ value: "text-embedding-3-small", label: "text-embedding-3-small", default: true },
|
||||
];
|
||||
|
||||
// Update default selections when models are loaded
|
||||
useEffect(() => {
|
||||
if (modelsData) {
|
||||
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||
|
||||
if (defaultLangModel && languageModel === "gpt-4o-mini") {
|
||||
setLanguageModel(defaultLangModel.value);
|
||||
}
|
||||
if (defaultEmbedModel && embeddingModel === "text-embedding-3-small") {
|
||||
setEmbeddingModel(defaultEmbedModel.value);
|
||||
}
|
||||
}
|
||||
}, [modelsData, languageModel, embeddingModel]);
|
||||
const handleLanguageModelChange = (model: string) => {
|
||||
setLanguageModel(model);
|
||||
};
|
||||
|
|
|
|||
63
src/api/models.py
Normal file
63
src/api/models.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from starlette.responses import JSONResponse
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def get_openai_models(request, models_service, session_manager):
|
||||
"""Get available OpenAI models"""
|
||||
try:
|
||||
models = await models_service.get_openai_models()
|
||||
return JSONResponse(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OpenAI models: {str(e)}")
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to retrieve OpenAI models: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
async def get_ollama_models(request, models_service, session_manager):
|
||||
"""Get available Ollama models"""
|
||||
try:
|
||||
# Get endpoint from query parameters if provided
|
||||
query_params = dict(request.query_params)
|
||||
endpoint = query_params.get("endpoint")
|
||||
|
||||
models = await models_service.get_ollama_models(endpoint=endpoint)
|
||||
return JSONResponse(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Ollama models: {str(e)}")
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to retrieve Ollama models: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
async def get_ibm_models(request, models_service, session_manager):
|
||||
"""Get available IBM Watson models"""
|
||||
try:
|
||||
# Get credentials from query parameters or request body if provided
|
||||
if request.method == "POST":
|
||||
body = await request.json()
|
||||
api_key = body.get("api_key")
|
||||
endpoint = body.get("endpoint")
|
||||
project_id = body.get("project_id")
|
||||
else:
|
||||
query_params = dict(request.query_params)
|
||||
api_key = query_params.get("api_key")
|
||||
endpoint = query_params.get("endpoint")
|
||||
project_id = query_params.get("project_id")
|
||||
|
||||
models = await models_service.get_ibm_models(
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id
|
||||
)
|
||||
return JSONResponse(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get IBM models: {str(e)}")
|
||||
return JSONResponse(
|
||||
{"error": f"Failed to retrieve IBM models: {str(e)}"},
|
||||
status_code=500
|
||||
)
|
||||
38
src/main.py
38
src/main.py
|
|
@ -33,6 +33,7 @@ from api import (
|
|||
flows,
|
||||
knowledge_filter,
|
||||
langflow_files,
|
||||
models,
|
||||
nudges,
|
||||
oidc,
|
||||
router,
|
||||
|
|
@ -66,6 +67,7 @@ from services.knowledge_filter_service import KnowledgeFilterService
|
|||
# Configuration and setup
|
||||
# Services
|
||||
from services.langflow_file_service import LangflowFileService
|
||||
from services.models_service import ModelsService
|
||||
from services.monitor_service import MonitorService
|
||||
from services.search_service import SearchService
|
||||
from services.task_service import TaskService
|
||||
|
|
@ -409,6 +411,7 @@ async def initialize_services():
|
|||
chat_service = ChatService()
|
||||
flows_service = FlowsService()
|
||||
knowledge_filter_service = KnowledgeFilterService(session_manager)
|
||||
models_service = ModelsService()
|
||||
monitor_service = MonitorService(session_manager)
|
||||
|
||||
# Set process pool for document service
|
||||
|
|
@ -470,6 +473,7 @@ async def initialize_services():
|
|||
"auth_service": auth_service,
|
||||
"connector_service": connector_service,
|
||||
"knowledge_filter_service": knowledge_filter_service,
|
||||
"models_service": models_service,
|
||||
"monitor_service": monitor_service,
|
||||
"session_manager": session_manager,
|
||||
}
|
||||
|
|
@ -909,6 +913,40 @@ async def create_app():
|
|||
),
|
||||
methods=["POST"],
|
||||
),
|
||||
# Models endpoints
|
||||
Route(
|
||||
"/models/openai",
|
||||
require_auth(services["session_manager"])(
|
||||
partial(
|
||||
models.get_openai_models,
|
||||
models_service=services["models_service"],
|
||||
session_manager=services["session_manager"]
|
||||
)
|
||||
),
|
||||
methods=["GET"],
|
||||
),
|
||||
Route(
|
||||
"/models/ollama",
|
||||
require_auth(services["session_manager"])(
|
||||
partial(
|
||||
models.get_ollama_models,
|
||||
models_service=services["models_service"],
|
||||
session_manager=services["session_manager"]
|
||||
)
|
||||
),
|
||||
methods=["GET"],
|
||||
),
|
||||
Route(
|
||||
"/models/ibm",
|
||||
require_auth(services["session_manager"])(
|
||||
partial(
|
||||
models.get_ibm_models,
|
||||
models_service=services["models_service"],
|
||||
session_manager=services["session_manager"]
|
||||
)
|
||||
),
|
||||
methods=["GET", "POST"],
|
||||
),
|
||||
# Onboarding endpoint
|
||||
Route(
|
||||
"/onboarding",
|
||||
|
|
|
|||
317
src/services/models_service.py
Normal file
317
src/services/models_service.py
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
import httpx
|
||||
import os
|
||||
from typing import Dict, List
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ModelsService:
|
||||
"""Service for fetching available models from different AI providers"""
|
||||
|
||||
def __init__(self):
|
||||
self.session_manager = None
|
||||
|
||||
async def get_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Fetch available models from OpenAI API"""
|
||||
try:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set, using default models")
|
||||
return self._get_default_openai_models()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://api.openai.com/v1/models", headers=headers, timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = data.get("data", [])
|
||||
|
||||
# Filter for relevant models
|
||||
language_models = []
|
||||
embedding_models = []
|
||||
|
||||
for model in models:
|
||||
model_id = model.get("id", "")
|
||||
|
||||
# Language models (GPT models)
|
||||
if any(prefix in model_id for prefix in ["gpt-4", "gpt-3.5"]):
|
||||
language_models.append(
|
||||
{
|
||||
"value": model_id,
|
||||
"label": model_id,
|
||||
"default": model_id == "gpt-4o-mini",
|
||||
}
|
||||
)
|
||||
|
||||
# Embedding models
|
||||
elif "text-embedding" in model_id:
|
||||
embedding_models.append(
|
||||
{
|
||||
"value": model_id,
|
||||
"label": model_id,
|
||||
"default": model_id == "text-embedding-3-small",
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by name and ensure defaults are first
|
||||
language_models.sort(
|
||||
key=lambda x: (not x.get("default", False), x["value"])
|
||||
)
|
||||
embedding_models.sort(
|
||||
key=lambda x: (not x.get("default", False), x["value"])
|
||||
)
|
||||
|
||||
return {
|
||||
"language_models": language_models,
|
||||
"embedding_models": embedding_models,
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Failed to fetch OpenAI models: {response.status_code}")
|
||||
return self._get_default_openai_models()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching OpenAI models: {str(e)}")
|
||||
return self._get_default_openai_models()
|
||||
|
||||
async def get_ollama_models(
|
||||
self, endpoint: str = None
|
||||
) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Fetch available models from Ollama API"""
|
||||
try:
|
||||
# Use provided endpoint or default
|
||||
ollama_url = endpoint or os.getenv(
|
||||
"OLLAMA_BASE_URL", "http://localhost:11434"
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = data.get("models", [])
|
||||
|
||||
# Extract model names
|
||||
language_models = []
|
||||
embedding_models = []
|
||||
|
||||
for model in models:
|
||||
model_name = model.get("name", "").split(":")[
|
||||
0
|
||||
] # Remove tag if present
|
||||
|
||||
if model_name:
|
||||
# Most Ollama models can be used as language models
|
||||
language_models.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": model_name,
|
||||
"default": "llama3" in model_name.lower(),
|
||||
}
|
||||
)
|
||||
|
||||
# Some models are specifically for embeddings
|
||||
if any(
|
||||
embed in model_name.lower()
|
||||
for embed in ["embed", "sentence", "all-minilm"]
|
||||
):
|
||||
embedding_models.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": model_name,
|
||||
"default": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Remove duplicates and sort
|
||||
language_models = list(
|
||||
{m["value"]: m for m in language_models}.values()
|
||||
)
|
||||
embedding_models = list(
|
||||
{m["value"]: m for m in embedding_models}.values()
|
||||
)
|
||||
|
||||
language_models.sort(
|
||||
key=lambda x: (not x.get("default", False), x["value"])
|
||||
)
|
||||
embedding_models.sort(key=lambda x: x["value"])
|
||||
|
||||
return {
|
||||
"language_models": language_models,
|
||||
"embedding_models": embedding_models
|
||||
if embedding_models
|
||||
else [
|
||||
{
|
||||
"value": "nomic-embed-text",
|
||||
"label": "nomic-embed-text",
|
||||
"default": True,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Failed to fetch Ollama models: {response.status_code}")
|
||||
return self._get_default_ollama_models()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Ollama models: {str(e)}")
|
||||
return self._get_default_ollama_models()
|
||||
|
||||
async def get_ibm_models(
|
||||
self, endpoint: str = None
|
||||
) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Fetch available models from IBM Watson API"""
|
||||
try:
|
||||
# Use provided endpoint or default
|
||||
watson_endpoint = endpoint or os.getenv("IBM_WATSON_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
|
||||
|
||||
# Fetch foundation models using the correct endpoint
|
||||
models_url = f"{watson_endpoint}/ml/v1/foundation_model_specs"
|
||||
|
||||
language_models = []
|
||||
embedding_models = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Fetch text chat models
|
||||
text_params = {
|
||||
"version": "2024-09-16",
|
||||
"filters": "function_text_chat,!lifecycle_withdrawn"
|
||||
}
|
||||
text_response = await client.get(models_url, params=text_params, timeout=10.0)
|
||||
|
||||
if text_response.status_code == 200:
|
||||
text_data = text_response.json()
|
||||
text_models = text_data.get("resources", [])
|
||||
|
||||
for i, model in enumerate(text_models):
|
||||
model_id = model.get("model_id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
language_models.append({
|
||||
"value": model_id,
|
||||
"label": model_name or model_id,
|
||||
"default": i == 0 # First model is default
|
||||
})
|
||||
|
||||
# Fetch embedding models
|
||||
embed_params = {
|
||||
"version": "2024-09-16",
|
||||
"filters": "function_embedding,!lifecycle_withdrawn"
|
||||
}
|
||||
embed_response = await client.get(models_url, params=embed_params, timeout=10.0)
|
||||
|
||||
if embed_response.status_code == 200:
|
||||
embed_data = embed_response.json()
|
||||
embed_models = embed_data.get("resources", [])
|
||||
|
||||
for i, model in enumerate(embed_models):
|
||||
model_id = model.get("model_id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
embedding_models.append({
|
||||
"value": model_id,
|
||||
"label": model_name or model_id,
|
||||
"default": i == 0 # First model is default
|
||||
})
|
||||
|
||||
return {
|
||||
"language_models": language_models if language_models else self._get_default_ibm_models()["language_models"],
|
||||
"embedding_models": embedding_models if embedding_models else self._get_default_ibm_models()["embedding_models"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching IBM models: {str(e)}")
|
||||
return self._get_default_ibm_models()
|
||||
|
||||
def _get_default_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Default OpenAI models when API is not available"""
|
||||
return {
|
||||
"language_models": [
|
||||
{"value": "gpt-4o-mini", "label": "gpt-4o-mini", "default": True},
|
||||
{"value": "gpt-4o", "label": "gpt-4o", "default": False},
|
||||
{"value": "gpt-4-turbo", "label": "gpt-4-turbo", "default": False},
|
||||
{"value": "gpt-3.5-turbo", "label": "gpt-3.5-turbo", "default": False},
|
||||
],
|
||||
"embedding_models": [
|
||||
{
|
||||
"value": "text-embedding-3-small",
|
||||
"label": "text-embedding-3-small",
|
||||
"default": True,
|
||||
},
|
||||
{
|
||||
"value": "text-embedding-3-large",
|
||||
"label": "text-embedding-3-large",
|
||||
"default": False,
|
||||
},
|
||||
{
|
||||
"value": "text-embedding-ada-002",
|
||||
"label": "text-embedding-ada-002",
|
||||
"default": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
def _get_default_ollama_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Default Ollama models when API is not available"""
|
||||
return {
|
||||
"language_models": [
|
||||
{"value": "llama3.2", "label": "llama3.2", "default": True},
|
||||
{"value": "llama3.1", "label": "llama3.1", "default": False},
|
||||
{"value": "llama3", "label": "llama3", "default": False},
|
||||
{"value": "mistral", "label": "mistral", "default": False},
|
||||
{"value": "codellama", "label": "codellama", "default": False},
|
||||
],
|
||||
"embedding_models": [
|
||||
{
|
||||
"value": "nomic-embed-text",
|
||||
"label": "nomic-embed-text",
|
||||
"default": True,
|
||||
},
|
||||
{"value": "all-minilm", "label": "all-minilm", "default": False},
|
||||
],
|
||||
}
|
||||
|
||||
def _get_default_ibm_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Default IBM Watson models when API is not available"""
|
||||
return {
|
||||
"language_models": [
|
||||
{
|
||||
"value": "meta-llama/llama-3-1-70b-instruct",
|
||||
"label": "Llama 3.1 70B Instruct",
|
||||
"default": True,
|
||||
},
|
||||
{
|
||||
"value": "meta-llama/llama-3-1-8b-instruct",
|
||||
"label": "Llama 3.1 8B Instruct",
|
||||
"default": False,
|
||||
},
|
||||
{
|
||||
"value": "ibm/granite-13b-chat-v2",
|
||||
"label": "Granite 13B Chat v2",
|
||||
"default": False,
|
||||
},
|
||||
{
|
||||
"value": "ibm/granite-13b-instruct-v2",
|
||||
"label": "Granite 13B Instruct v2",
|
||||
"default": False,
|
||||
},
|
||||
],
|
||||
"embedding_models": [
|
||||
{
|
||||
"value": "ibm/slate-125m-english-rtrvr",
|
||||
"label": "Slate 125M English Retriever",
|
||||
"default": True,
|
||||
},
|
||||
{
|
||||
"value": "sentence-transformers/all-minilm-l12-v2",
|
||||
"label": "All-MiniLM L12 v2",
|
||||
"default": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue