diff --git a/frontend/src/app/api/queries/useGetModelsQuery.ts b/frontend/src/app/api/queries/useGetModelsQuery.ts new file mode 100644 index 00000000..e94a752b --- /dev/null +++ b/frontend/src/app/api/queries/useGetModelsQuery.ts @@ -0,0 +1,140 @@ +import { + type UseQueryOptions, + useQuery, + useQueryClient, +} from "@tanstack/react-query"; + +export interface ModelOption { + value: string; + label: string; + default?: boolean; +} + +export interface ModelsResponse { + language_models: ModelOption[]; + embedding_models: ModelOption[]; +} + +export interface OllamaModelsParams { + endpoint?: string; +} + +export interface IBMModelsParams { + api_key?: string; + endpoint?: string; + project_id?: string; +} + +export const useGetOpenAIModelsQuery = ( + options?: Omit, "queryKey" | "queryFn">, +) => { + const queryClient = useQueryClient(); + + async function getOpenAIModels(): Promise { + const response = await fetch("/api/models/openai"); + if (response.ok) { + return await response.json(); + } else { + throw new Error("Failed to fetch OpenAI models"); + } + } + + const queryResult = useQuery( + { + queryKey: ["models", "openai"], + queryFn: getOpenAIModels, + staleTime: 5 * 60 * 1000, // 5 minutes + ...options, + }, + queryClient, + ); + + return queryResult; +}; + +export const useGetOllamaModelsQuery = ( + params?: OllamaModelsParams, + options?: Omit, "queryKey" | "queryFn">, +) => { + const queryClient = useQueryClient(); + + async function getOllamaModels(): Promise { + const url = new URL("/api/models/ollama", window.location.origin); + if (params?.endpoint) { + url.searchParams.set("endpoint", params.endpoint); + } + + const response = await fetch(url.toString()); + if (response.ok) { + return await response.json(); + } else { + throw new Error("Failed to fetch Ollama models"); + } + } + + const queryResult = useQuery( + { + queryKey: ["models", "ollama", params], + queryFn: getOllamaModels, + staleTime: 5 * 60 * 1000, // 5 minutes + enabled: !!params?.endpoint, // Only run if endpoint is provided + ...options, + }, + queryClient, + ); + + return queryResult; +}; + +export const useGetIBMModelsQuery = ( + params?: IBMModelsParams, + options?: Omit, "queryKey" | "queryFn">, +) => { + const queryClient = useQueryClient(); + + async function getIBMModels(): Promise { + const url = "/api/models/ibm"; + + // If we have credentials, use POST to send them securely + if (params?.api_key || params?.endpoint || params?.project_id) { + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + api_key: params.api_key, + endpoint: params.endpoint, + project_id: params.project_id, + }), + }); + + if (response.ok) { + return await response.json(); + } else { + throw new Error("Failed to fetch IBM models"); + } + } else { + // Use GET for default models + const response = await fetch(url); + if (response.ok) { + return await response.json(); + } else { + throw new Error("Failed to fetch IBM models"); + } + } + } + + const queryResult = useQuery( + { + queryKey: ["models", "ibm", params], + queryFn: getIBMModels, + staleTime: 5 * 60 * 1000, // 5 minutes + enabled: !!(params?.api_key && params?.endpoint && params?.project_id), // Only run if all credentials are provided + ...options, + }, + queryClient, + ); + + return queryResult; +}; diff --git a/frontend/src/app/onboarding/ibm-onboarding.tsx b/frontend/src/app/onboarding/ibm-onboarding.tsx index f77e2d99..26b6adeb 100644 --- a/frontend/src/app/onboarding/ibm-onboarding.tsx +++ b/frontend/src/app/onboarding/ibm-onboarding.tsx @@ -1,7 +1,8 @@ -import { useState } from "react"; +import { useState, useEffect } from "react"; import { LabelInput } from "@/components/label-input"; import IBMLogo from "@/components/logo/ibm-logo"; import type { Settings } from "../api/queries/useGetSettingsQuery"; +import { useGetIBMModelsQuery } from "../api/queries/useGetModelsQuery"; import { AdvancedOnboarding } from "./advanced"; export function IBMOnboarding({ @@ -15,21 +16,44 @@ export function IBMOnboarding({ sampleDataset: boolean; setSampleDataset: (dataset: boolean) => void; }) { - const languageModels = [ - { value: "gpt-oss", label: "gpt-oss" }, - { value: "llama3.1", label: "llama3.1" }, - { value: "llama3.2", label: "llama3.2" }, - { value: "llama3.3", label: "llama3.3" }, - { value: "llama3.4", label: "llama3.4" }, - { value: "llama3.5", label: "llama3.5" }, - ]; - const embeddingModels = [ - { value: "text-embedding-3-small", label: "text-embedding-3-small" }, - ]; - const [languageModel, setLanguageModel] = useState("gpt-oss"); - const [embeddingModel, setEmbeddingModel] = useState( - "text-embedding-3-small", + const [endpoint, setEndpoint] = useState(""); + const [apiKey, setApiKey] = useState(""); + const [projectId, setProjectId] = useState(""); + const [languageModel, setLanguageModel] = useState("meta-llama/llama-3-1-70b-instruct"); + const [embeddingModel, setEmbeddingModel] = useState("ibm/slate-125m-english-rtrvr"); + + // Fetch models from API when all credentials are provided + const { data: modelsData } = useGetIBMModelsQuery( + (apiKey && endpoint && projectId) ? { api_key: apiKey, endpoint, project_id: projectId } : undefined, + { enabled: !!(apiKey && endpoint && projectId) } ); + + // Use fetched models or fallback to defaults + const languageModels = modelsData?.language_models || [ + { value: "meta-llama/llama-3-1-70b-instruct", label: "Llama 3.1 70B Instruct", default: true }, + { value: "meta-llama/llama-3-1-8b-instruct", label: "Llama 3.1 8B Instruct" }, + { value: "ibm/granite-13b-chat-v2", label: "Granite 13B Chat v2" }, + { value: "ibm/granite-13b-instruct-v2", label: "Granite 13B Instruct v2" }, + ]; + const embeddingModels = modelsData?.embedding_models || [ + { value: "ibm/slate-125m-english-rtrvr", label: "Slate 125M English Retriever", default: true }, + { value: "sentence-transformers/all-minilm-l12-v2", label: "All-MiniLM L12 v2" }, + ]; + + // Update default selections when models are loaded + useEffect(() => { + if (modelsData) { + const defaultLangModel = modelsData.language_models.find(m => m.default); + const defaultEmbedModel = modelsData.embedding_models.find(m => m.default); + + if (defaultLangModel) { + setLanguageModel(defaultLangModel.value); + } + if (defaultEmbedModel) { + setEmbeddingModel(defaultEmbedModel.value); + } + } + }, [modelsData]); const handleLanguageModelChange = (model: string) => { setLanguageModel(model); }; @@ -48,21 +72,27 @@ export function IBMOnboarding({ helperText="The API endpoint for your watsonx.ai account." id="api-endpoint" required - placeholder="https://..." + placeholder="https://us-south.ml.cloud.ibm.com" + value={endpoint} + onChange={(e) => setEndpoint(e.target.value)} /> setApiKey(e.target.value)} /> setProjectId(e.target.value)} /> } diff --git a/frontend/src/app/onboarding/ollama-onboarding.tsx b/frontend/src/app/onboarding/ollama-onboarding.tsx index 886144f4..2513a8f5 100644 --- a/frontend/src/app/onboarding/ollama-onboarding.tsx +++ b/frontend/src/app/onboarding/ollama-onboarding.tsx @@ -1,8 +1,9 @@ -import { useState } from "react"; +import { useState, useEffect } from "react"; import { LabelInput } from "@/components/label-input"; import { LabelWrapper } from "@/components/label-wrapper"; import OllamaLogo from "@/components/logo/ollama-logo"; import type { Settings } from "../api/queries/useGetSettingsQuery"; +import { useGetOllamaModelsQuery } from "../api/queries/useGetModelsQuery"; import { AdvancedOnboarding } from "./advanced"; import { ModelSelector } from "./model-selector"; @@ -17,24 +18,43 @@ export function OllamaOnboarding({ sampleDataset: boolean; setSampleDataset: (dataset: boolean) => void; }) { - const [open, setOpen] = useState(false); - const [value, setValue] = useState(""); - const [languageModel, setLanguageModel] = useState("gpt-oss"); - const [embeddingModel, setEmbeddingModel] = useState( - "text-embedding-3-small", + const [endpoint, setEndpoint] = useState(""); + const [languageModel, setLanguageModel] = useState("llama3.2"); + const [embeddingModel, setEmbeddingModel] = useState("nomic-embed-text"); + + // Fetch models from API when endpoint is provided + const { data: modelsData } = useGetOllamaModelsQuery( + endpoint ? { endpoint } : undefined, + { enabled: !!endpoint } ); - const languageModels = [ - { value: "gpt-oss", label: "gpt-oss", default: true }, + + // Use fetched models or fallback to defaults + const languageModels = modelsData?.language_models || [ + { value: "llama3.2", label: "llama3.2", default: true }, { value: "llama3.1", label: "llama3.1" }, - { value: "llama3.2", label: "llama3.2" }, - { value: "llama3.3", label: "llama3.3" }, - { value: "llama3.4", label: "llama3.4" }, - { value: "llama3.5", label: "llama3.5" }, + { value: "llama3", label: "llama3" }, + { value: "mistral", label: "mistral" }, + { value: "codellama", label: "codellama" }, ]; - const embeddingModels = [ - { value: "text-embedding-3-small", label: "text-embedding-3-small" }, + const embeddingModels = modelsData?.embedding_models || [ + { value: "nomic-embed-text", label: "nomic-embed-text", default: true }, ]; + // Update default selections when models are loaded + useEffect(() => { + if (modelsData) { + const defaultLangModel = modelsData.language_models.find(m => m.default); + const defaultEmbedModel = modelsData.embedding_models.find(m => m.default); + + if (defaultLangModel) { + setLanguageModel(defaultLangModel.value); + } + if (defaultEmbedModel) { + setEmbeddingModel(defaultEmbedModel.value); + } + } + }, [modelsData]); + const handleSampleDatasetChange = (dataset: boolean) => { setSampleDataset(dataset); }; @@ -45,7 +65,9 @@ export function OllamaOnboarding({ helperText="The endpoint for your Ollama server." id="api-endpoint" required - placeholder="http://..." + placeholder="http://localhost:11434" + value={endpoint} + onChange={(e) => setEndpoint(e.target.value)} /> { + if (modelsData) { + const defaultLangModel = modelsData.language_models.find(m => m.default); + const defaultEmbedModel = modelsData.embedding_models.find(m => m.default); + + if (defaultLangModel && languageModel === "gpt-4o-mini") { + setLanguageModel(defaultLangModel.value); + } + if (defaultEmbedModel && embeddingModel === "text-embedding-3-small") { + setEmbeddingModel(defaultEmbedModel.value); + } + } + }, [modelsData, languageModel, embeddingModel]); const handleLanguageModelChange = (model: string) => { setLanguageModel(model); }; diff --git a/src/api/models.py b/src/api/models.py new file mode 100644 index 00000000..0dc78c2b --- /dev/null +++ b/src/api/models.py @@ -0,0 +1,63 @@ +from starlette.responses import JSONResponse +from utils.logging_config import get_logger + +logger = get_logger(__name__) + + +async def get_openai_models(request, models_service, session_manager): + """Get available OpenAI models""" + try: + models = await models_service.get_openai_models() + return JSONResponse(models) + except Exception as e: + logger.error(f"Failed to get OpenAI models: {str(e)}") + return JSONResponse( + {"error": f"Failed to retrieve OpenAI models: {str(e)}"}, + status_code=500 + ) + + +async def get_ollama_models(request, models_service, session_manager): + """Get available Ollama models""" + try: + # Get endpoint from query parameters if provided + query_params = dict(request.query_params) + endpoint = query_params.get("endpoint") + + models = await models_service.get_ollama_models(endpoint=endpoint) + return JSONResponse(models) + except Exception as e: + logger.error(f"Failed to get Ollama models: {str(e)}") + return JSONResponse( + {"error": f"Failed to retrieve Ollama models: {str(e)}"}, + status_code=500 + ) + + +async def get_ibm_models(request, models_service, session_manager): + """Get available IBM Watson models""" + try: + # Get credentials from query parameters or request body if provided + if request.method == "POST": + body = await request.json() + api_key = body.get("api_key") + endpoint = body.get("endpoint") + project_id = body.get("project_id") + else: + query_params = dict(request.query_params) + api_key = query_params.get("api_key") + endpoint = query_params.get("endpoint") + project_id = query_params.get("project_id") + + models = await models_service.get_ibm_models( + api_key=api_key, + endpoint=endpoint, + project_id=project_id + ) + return JSONResponse(models) + except Exception as e: + logger.error(f"Failed to get IBM models: {str(e)}") + return JSONResponse( + {"error": f"Failed to retrieve IBM models: {str(e)}"}, + status_code=500 + ) \ No newline at end of file diff --git a/src/main.py b/src/main.py index 870d8deb..7a242d1c 100644 --- a/src/main.py +++ b/src/main.py @@ -33,6 +33,7 @@ from api import ( flows, knowledge_filter, langflow_files, + models, nudges, oidc, router, @@ -66,6 +67,7 @@ from services.knowledge_filter_service import KnowledgeFilterService # Configuration and setup # Services from services.langflow_file_service import LangflowFileService +from services.models_service import ModelsService from services.monitor_service import MonitorService from services.search_service import SearchService from services.task_service import TaskService @@ -409,6 +411,7 @@ async def initialize_services(): chat_service = ChatService() flows_service = FlowsService() knowledge_filter_service = KnowledgeFilterService(session_manager) + models_service = ModelsService() monitor_service = MonitorService(session_manager) # Set process pool for document service @@ -470,6 +473,7 @@ async def initialize_services(): "auth_service": auth_service, "connector_service": connector_service, "knowledge_filter_service": knowledge_filter_service, + "models_service": models_service, "monitor_service": monitor_service, "session_manager": session_manager, } @@ -909,6 +913,40 @@ async def create_app(): ), methods=["POST"], ), + # Models endpoints + Route( + "/models/openai", + require_auth(services["session_manager"])( + partial( + models.get_openai_models, + models_service=services["models_service"], + session_manager=services["session_manager"] + ) + ), + methods=["GET"], + ), + Route( + "/models/ollama", + require_auth(services["session_manager"])( + partial( + models.get_ollama_models, + models_service=services["models_service"], + session_manager=services["session_manager"] + ) + ), + methods=["GET"], + ), + Route( + "/models/ibm", + require_auth(services["session_manager"])( + partial( + models.get_ibm_models, + models_service=services["models_service"], + session_manager=services["session_manager"] + ) + ), + methods=["GET", "POST"], + ), # Onboarding endpoint Route( "/onboarding", diff --git a/src/services/models_service.py b/src/services/models_service.py new file mode 100644 index 00000000..9855459e --- /dev/null +++ b/src/services/models_service.py @@ -0,0 +1,317 @@ +import httpx +import os +from typing import Dict, List +from utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class ModelsService: + """Service for fetching available models from different AI providers""" + + def __init__(self): + self.session_manager = None + + async def get_openai_models(self) -> Dict[str, List[Dict[str, str]]]: + """Fetch available models from OpenAI API""" + try: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + logger.warning("OPENAI_API_KEY not set, using default models") + return self._get_default_openai_models() + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient() as client: + response = await client.get( + "https://api.openai.com/v1/models", headers=headers, timeout=10.0 + ) + + if response.status_code == 200: + data = response.json() + models = data.get("data", []) + + # Filter for relevant models + language_models = [] + embedding_models = [] + + for model in models: + model_id = model.get("id", "") + + # Language models (GPT models) + if any(prefix in model_id for prefix in ["gpt-4", "gpt-3.5"]): + language_models.append( + { + "value": model_id, + "label": model_id, + "default": model_id == "gpt-4o-mini", + } + ) + + # Embedding models + elif "text-embedding" in model_id: + embedding_models.append( + { + "value": model_id, + "label": model_id, + "default": model_id == "text-embedding-3-small", + } + ) + + # Sort by name and ensure defaults are first + language_models.sort( + key=lambda x: (not x.get("default", False), x["value"]) + ) + embedding_models.sort( + key=lambda x: (not x.get("default", False), x["value"]) + ) + + return { + "language_models": language_models, + "embedding_models": embedding_models, + } + else: + logger.warning(f"Failed to fetch OpenAI models: {response.status_code}") + return self._get_default_openai_models() + + except Exception as e: + logger.error(f"Error fetching OpenAI models: {str(e)}") + return self._get_default_openai_models() + + async def get_ollama_models( + self, endpoint: str = None + ) -> Dict[str, List[Dict[str, str]]]: + """Fetch available models from Ollama API""" + try: + # Use provided endpoint or default + ollama_url = endpoint or os.getenv( + "OLLAMA_BASE_URL", "http://localhost:11434" + ) + + async with httpx.AsyncClient() as client: + response = await client.get(f"{ollama_url}/api/tags", timeout=10.0) + + if response.status_code == 200: + data = response.json() + models = data.get("models", []) + + # Extract model names + language_models = [] + embedding_models = [] + + for model in models: + model_name = model.get("name", "").split(":")[ + 0 + ] # Remove tag if present + + if model_name: + # Most Ollama models can be used as language models + language_models.append( + { + "value": model_name, + "label": model_name, + "default": "llama3" in model_name.lower(), + } + ) + + # Some models are specifically for embeddings + if any( + embed in model_name.lower() + for embed in ["embed", "sentence", "all-minilm"] + ): + embedding_models.append( + { + "value": model_name, + "label": model_name, + "default": False, + } + ) + + # Remove duplicates and sort + language_models = list( + {m["value"]: m for m in language_models}.values() + ) + embedding_models = list( + {m["value"]: m for m in embedding_models}.values() + ) + + language_models.sort( + key=lambda x: (not x.get("default", False), x["value"]) + ) + embedding_models.sort(key=lambda x: x["value"]) + + return { + "language_models": language_models, + "embedding_models": embedding_models + if embedding_models + else [ + { + "value": "nomic-embed-text", + "label": "nomic-embed-text", + "default": True, + } + ], + } + else: + logger.warning(f"Failed to fetch Ollama models: {response.status_code}") + return self._get_default_ollama_models() + + except Exception as e: + logger.error(f"Error fetching Ollama models: {str(e)}") + return self._get_default_ollama_models() + + async def get_ibm_models( + self, endpoint: str = None + ) -> Dict[str, List[Dict[str, str]]]: + """Fetch available models from IBM Watson API""" + try: + # Use provided endpoint or default + watson_endpoint = endpoint or os.getenv("IBM_WATSON_ENDPOINT", "https://us-south.ml.cloud.ibm.com") + + # Fetch foundation models using the correct endpoint + models_url = f"{watson_endpoint}/ml/v1/foundation_model_specs" + + language_models = [] + embedding_models = [] + + async with httpx.AsyncClient() as client: + # Fetch text chat models + text_params = { + "version": "2024-09-16", + "filters": "function_text_chat,!lifecycle_withdrawn" + } + text_response = await client.get(models_url, params=text_params, timeout=10.0) + + if text_response.status_code == 200: + text_data = text_response.json() + text_models = text_data.get("resources", []) + + for i, model in enumerate(text_models): + model_id = model.get("model_id", "") + model_name = model.get("name", model_id) + + language_models.append({ + "value": model_id, + "label": model_name or model_id, + "default": i == 0 # First model is default + }) + + # Fetch embedding models + embed_params = { + "version": "2024-09-16", + "filters": "function_embedding,!lifecycle_withdrawn" + } + embed_response = await client.get(models_url, params=embed_params, timeout=10.0) + + if embed_response.status_code == 200: + embed_data = embed_response.json() + embed_models = embed_data.get("resources", []) + + for i, model in enumerate(embed_models): + model_id = model.get("model_id", "") + model_name = model.get("name", model_id) + + embedding_models.append({ + "value": model_id, + "label": model_name or model_id, + "default": i == 0 # First model is default + }) + + return { + "language_models": language_models if language_models else self._get_default_ibm_models()["language_models"], + "embedding_models": embedding_models if embedding_models else self._get_default_ibm_models()["embedding_models"] + } + + except Exception as e: + logger.error(f"Error fetching IBM models: {str(e)}") + return self._get_default_ibm_models() + + def _get_default_openai_models(self) -> Dict[str, List[Dict[str, str]]]: + """Default OpenAI models when API is not available""" + return { + "language_models": [ + {"value": "gpt-4o-mini", "label": "gpt-4o-mini", "default": True}, + {"value": "gpt-4o", "label": "gpt-4o", "default": False}, + {"value": "gpt-4-turbo", "label": "gpt-4-turbo", "default": False}, + {"value": "gpt-3.5-turbo", "label": "gpt-3.5-turbo", "default": False}, + ], + "embedding_models": [ + { + "value": "text-embedding-3-small", + "label": "text-embedding-3-small", + "default": True, + }, + { + "value": "text-embedding-3-large", + "label": "text-embedding-3-large", + "default": False, + }, + { + "value": "text-embedding-ada-002", + "label": "text-embedding-ada-002", + "default": False, + }, + ], + } + + def _get_default_ollama_models(self) -> Dict[str, List[Dict[str, str]]]: + """Default Ollama models when API is not available""" + return { + "language_models": [ + {"value": "llama3.2", "label": "llama3.2", "default": True}, + {"value": "llama3.1", "label": "llama3.1", "default": False}, + {"value": "llama3", "label": "llama3", "default": False}, + {"value": "mistral", "label": "mistral", "default": False}, + {"value": "codellama", "label": "codellama", "default": False}, + ], + "embedding_models": [ + { + "value": "nomic-embed-text", + "label": "nomic-embed-text", + "default": True, + }, + {"value": "all-minilm", "label": "all-minilm", "default": False}, + ], + } + + def _get_default_ibm_models(self) -> Dict[str, List[Dict[str, str]]]: + """Default IBM Watson models when API is not available""" + return { + "language_models": [ + { + "value": "meta-llama/llama-3-1-70b-instruct", + "label": "Llama 3.1 70B Instruct", + "default": True, + }, + { + "value": "meta-llama/llama-3-1-8b-instruct", + "label": "Llama 3.1 8B Instruct", + "default": False, + }, + { + "value": "ibm/granite-13b-chat-v2", + "label": "Granite 13B Chat v2", + "default": False, + }, + { + "value": "ibm/granite-13b-instruct-v2", + "label": "Granite 13B Instruct v2", + "default": False, + }, + ], + "embedding_models": [ + { + "value": "ibm/slate-125m-english-rtrvr", + "label": "Slate 125M English Retriever", + "default": True, + }, + { + "value": "sentence-transformers/all-minilm-l12-v2", + "label": "All-MiniLM L12 v2", + "default": False, + }, + ], + }