Added models fetching
This commit is contained in:
parent
4cbdf6ed19
commit
d68a787017
7 changed files with 668 additions and 37 deletions
140
frontend/src/app/api/queries/useGetModelsQuery.ts
Normal file
140
frontend/src/app/api/queries/useGetModelsQuery.ts
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
import {
|
||||||
|
type UseQueryOptions,
|
||||||
|
useQuery,
|
||||||
|
useQueryClient,
|
||||||
|
} from "@tanstack/react-query";
|
||||||
|
|
||||||
|
export interface ModelOption {
|
||||||
|
value: string;
|
||||||
|
label: string;
|
||||||
|
default?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelsResponse {
|
||||||
|
language_models: ModelOption[];
|
||||||
|
embedding_models: ModelOption[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface OllamaModelsParams {
|
||||||
|
endpoint?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IBMModelsParams {
|
||||||
|
api_key?: string;
|
||||||
|
endpoint?: string;
|
||||||
|
project_id?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetOpenAIModelsQuery = (
|
||||||
|
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
|
||||||
|
) => {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
|
async function getOpenAIModels(): Promise<ModelsResponse> {
|
||||||
|
const response = await fetch("/api/models/openai");
|
||||||
|
if (response.ok) {
|
||||||
|
return await response.json();
|
||||||
|
} else {
|
||||||
|
throw new Error("Failed to fetch OpenAI models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const queryResult = useQuery(
|
||||||
|
{
|
||||||
|
queryKey: ["models", "openai"],
|
||||||
|
queryFn: getOpenAIModels,
|
||||||
|
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||||
|
...options,
|
||||||
|
},
|
||||||
|
queryClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
return queryResult;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useGetOllamaModelsQuery = (
|
||||||
|
params?: OllamaModelsParams,
|
||||||
|
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
|
||||||
|
) => {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
|
async function getOllamaModels(): Promise<ModelsResponse> {
|
||||||
|
const url = new URL("/api/models/ollama", window.location.origin);
|
||||||
|
if (params?.endpoint) {
|
||||||
|
url.searchParams.set("endpoint", params.endpoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(url.toString());
|
||||||
|
if (response.ok) {
|
||||||
|
return await response.json();
|
||||||
|
} else {
|
||||||
|
throw new Error("Failed to fetch Ollama models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const queryResult = useQuery(
|
||||||
|
{
|
||||||
|
queryKey: ["models", "ollama", params],
|
||||||
|
queryFn: getOllamaModels,
|
||||||
|
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||||
|
enabled: !!params?.endpoint, // Only run if endpoint is provided
|
||||||
|
...options,
|
||||||
|
},
|
||||||
|
queryClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
return queryResult;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useGetIBMModelsQuery = (
|
||||||
|
params?: IBMModelsParams,
|
||||||
|
options?: Omit<UseQueryOptions<ModelsResponse>, "queryKey" | "queryFn">,
|
||||||
|
) => {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
|
async function getIBMModels(): Promise<ModelsResponse> {
|
||||||
|
const url = "/api/models/ibm";
|
||||||
|
|
||||||
|
// If we have credentials, use POST to send them securely
|
||||||
|
if (params?.api_key || params?.endpoint || params?.project_id) {
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
api_key: params.api_key,
|
||||||
|
endpoint: params.endpoint,
|
||||||
|
project_id: params.project_id,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
return await response.json();
|
||||||
|
} else {
|
||||||
|
throw new Error("Failed to fetch IBM models");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Use GET for default models
|
||||||
|
const response = await fetch(url);
|
||||||
|
if (response.ok) {
|
||||||
|
return await response.json();
|
||||||
|
} else {
|
||||||
|
throw new Error("Failed to fetch IBM models");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const queryResult = useQuery(
|
||||||
|
{
|
||||||
|
queryKey: ["models", "ibm", params],
|
||||||
|
queryFn: getIBMModels,
|
||||||
|
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||||
|
enabled: !!(params?.api_key && params?.endpoint && params?.project_id), // Only run if all credentials are provided
|
||||||
|
...options,
|
||||||
|
},
|
||||||
|
queryClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
return queryResult;
|
||||||
|
};
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import { useState } from "react";
|
import { useState, useEffect } from "react";
|
||||||
import { LabelInput } from "@/components/label-input";
|
import { LabelInput } from "@/components/label-input";
|
||||||
import IBMLogo from "@/components/logo/ibm-logo";
|
import IBMLogo from "@/components/logo/ibm-logo";
|
||||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||||
|
import { useGetIBMModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||||
import { AdvancedOnboarding } from "./advanced";
|
import { AdvancedOnboarding } from "./advanced";
|
||||||
|
|
||||||
export function IBMOnboarding({
|
export function IBMOnboarding({
|
||||||
|
|
@ -15,21 +16,44 @@ export function IBMOnboarding({
|
||||||
sampleDataset: boolean;
|
sampleDataset: boolean;
|
||||||
setSampleDataset: (dataset: boolean) => void;
|
setSampleDataset: (dataset: boolean) => void;
|
||||||
}) {
|
}) {
|
||||||
const languageModels = [
|
const [endpoint, setEndpoint] = useState("");
|
||||||
{ value: "gpt-oss", label: "gpt-oss" },
|
const [apiKey, setApiKey] = useState("");
|
||||||
{ value: "llama3.1", label: "llama3.1" },
|
const [projectId, setProjectId] = useState("");
|
||||||
{ value: "llama3.2", label: "llama3.2" },
|
const [languageModel, setLanguageModel] = useState("meta-llama/llama-3-1-70b-instruct");
|
||||||
{ value: "llama3.3", label: "llama3.3" },
|
const [embeddingModel, setEmbeddingModel] = useState("ibm/slate-125m-english-rtrvr");
|
||||||
{ value: "llama3.4", label: "llama3.4" },
|
|
||||||
{ value: "llama3.5", label: "llama3.5" },
|
// Fetch models from API when all credentials are provided
|
||||||
];
|
const { data: modelsData } = useGetIBMModelsQuery(
|
||||||
const embeddingModels = [
|
(apiKey && endpoint && projectId) ? { api_key: apiKey, endpoint, project_id: projectId } : undefined,
|
||||||
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
|
{ enabled: !!(apiKey && endpoint && projectId) }
|
||||||
];
|
|
||||||
const [languageModel, setLanguageModel] = useState("gpt-oss");
|
|
||||||
const [embeddingModel, setEmbeddingModel] = useState(
|
|
||||||
"text-embedding-3-small",
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Use fetched models or fallback to defaults
|
||||||
|
const languageModels = modelsData?.language_models || [
|
||||||
|
{ value: "meta-llama/llama-3-1-70b-instruct", label: "Llama 3.1 70B Instruct", default: true },
|
||||||
|
{ value: "meta-llama/llama-3-1-8b-instruct", label: "Llama 3.1 8B Instruct" },
|
||||||
|
{ value: "ibm/granite-13b-chat-v2", label: "Granite 13B Chat v2" },
|
||||||
|
{ value: "ibm/granite-13b-instruct-v2", label: "Granite 13B Instruct v2" },
|
||||||
|
];
|
||||||
|
const embeddingModels = modelsData?.embedding_models || [
|
||||||
|
{ value: "ibm/slate-125m-english-rtrvr", label: "Slate 125M English Retriever", default: true },
|
||||||
|
{ value: "sentence-transformers/all-minilm-l12-v2", label: "All-MiniLM L12 v2" },
|
||||||
|
];
|
||||||
|
|
||||||
|
// Update default selections when models are loaded
|
||||||
|
useEffect(() => {
|
||||||
|
if (modelsData) {
|
||||||
|
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||||
|
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||||
|
|
||||||
|
if (defaultLangModel) {
|
||||||
|
setLanguageModel(defaultLangModel.value);
|
||||||
|
}
|
||||||
|
if (defaultEmbedModel) {
|
||||||
|
setEmbeddingModel(defaultEmbedModel.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [modelsData]);
|
||||||
const handleLanguageModelChange = (model: string) => {
|
const handleLanguageModelChange = (model: string) => {
|
||||||
setLanguageModel(model);
|
setLanguageModel(model);
|
||||||
};
|
};
|
||||||
|
|
@ -48,21 +72,27 @@ export function IBMOnboarding({
|
||||||
helperText="The API endpoint for your watsonx.ai account."
|
helperText="The API endpoint for your watsonx.ai account."
|
||||||
id="api-endpoint"
|
id="api-endpoint"
|
||||||
required
|
required
|
||||||
placeholder="https://..."
|
placeholder="https://us-south.ml.cloud.ibm.com"
|
||||||
|
value={endpoint}
|
||||||
|
onChange={(e) => setEndpoint(e.target.value)}
|
||||||
/>
|
/>
|
||||||
<LabelInput
|
<LabelInput
|
||||||
label="IBM API key"
|
label="IBM API key"
|
||||||
helperText="The API key for your watsonx.ai account."
|
helperText="The API key for your watsonx.ai account."
|
||||||
id="api-key"
|
id="api-key"
|
||||||
required
|
required
|
||||||
placeholder="sk-..."
|
placeholder="your-api-key"
|
||||||
|
value={apiKey}
|
||||||
|
onChange={(e) => setApiKey(e.target.value)}
|
||||||
/>
|
/>
|
||||||
<LabelInput
|
<LabelInput
|
||||||
label="IBM Project ID"
|
label="IBM Project ID"
|
||||||
helperText="The project ID for your watsonx.ai account."
|
helperText="The project ID for your watsonx.ai account."
|
||||||
id="project-id"
|
id="project-id"
|
||||||
required
|
required
|
||||||
placeholder="..."
|
placeholder="your-project-id"
|
||||||
|
value={projectId}
|
||||||
|
onChange={(e) => setProjectId(e.target.value)}
|
||||||
/>
|
/>
|
||||||
<AdvancedOnboarding
|
<AdvancedOnboarding
|
||||||
icon={<IBMLogo className="w-4 h-4" />}
|
icon={<IBMLogo className="w-4 h-4" />}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import { useState } from "react";
|
import { useState, useEffect } from "react";
|
||||||
import { LabelInput } from "@/components/label-input";
|
import { LabelInput } from "@/components/label-input";
|
||||||
import { LabelWrapper } from "@/components/label-wrapper";
|
import { LabelWrapper } from "@/components/label-wrapper";
|
||||||
import OllamaLogo from "@/components/logo/ollama-logo";
|
import OllamaLogo from "@/components/logo/ollama-logo";
|
||||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||||
|
import { useGetOllamaModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||||
import { AdvancedOnboarding } from "./advanced";
|
import { AdvancedOnboarding } from "./advanced";
|
||||||
import { ModelSelector } from "./model-selector";
|
import { ModelSelector } from "./model-selector";
|
||||||
|
|
||||||
|
|
@ -17,24 +18,43 @@ export function OllamaOnboarding({
|
||||||
sampleDataset: boolean;
|
sampleDataset: boolean;
|
||||||
setSampleDataset: (dataset: boolean) => void;
|
setSampleDataset: (dataset: boolean) => void;
|
||||||
}) {
|
}) {
|
||||||
const [open, setOpen] = useState(false);
|
const [endpoint, setEndpoint] = useState("");
|
||||||
const [value, setValue] = useState("");
|
const [languageModel, setLanguageModel] = useState("llama3.2");
|
||||||
const [languageModel, setLanguageModel] = useState("gpt-oss");
|
const [embeddingModel, setEmbeddingModel] = useState("nomic-embed-text");
|
||||||
const [embeddingModel, setEmbeddingModel] = useState(
|
|
||||||
"text-embedding-3-small",
|
// Fetch models from API when endpoint is provided
|
||||||
|
const { data: modelsData } = useGetOllamaModelsQuery(
|
||||||
|
endpoint ? { endpoint } : undefined,
|
||||||
|
{ enabled: !!endpoint }
|
||||||
);
|
);
|
||||||
const languageModels = [
|
|
||||||
{ value: "gpt-oss", label: "gpt-oss", default: true },
|
// Use fetched models or fallback to defaults
|
||||||
|
const languageModels = modelsData?.language_models || [
|
||||||
|
{ value: "llama3.2", label: "llama3.2", default: true },
|
||||||
{ value: "llama3.1", label: "llama3.1" },
|
{ value: "llama3.1", label: "llama3.1" },
|
||||||
{ value: "llama3.2", label: "llama3.2" },
|
{ value: "llama3", label: "llama3" },
|
||||||
{ value: "llama3.3", label: "llama3.3" },
|
{ value: "mistral", label: "mistral" },
|
||||||
{ value: "llama3.4", label: "llama3.4" },
|
{ value: "codellama", label: "codellama" },
|
||||||
{ value: "llama3.5", label: "llama3.5" },
|
|
||||||
];
|
];
|
||||||
const embeddingModels = [
|
const embeddingModels = modelsData?.embedding_models || [
|
||||||
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
|
{ value: "nomic-embed-text", label: "nomic-embed-text", default: true },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Update default selections when models are loaded
|
||||||
|
useEffect(() => {
|
||||||
|
if (modelsData) {
|
||||||
|
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||||
|
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||||
|
|
||||||
|
if (defaultLangModel) {
|
||||||
|
setLanguageModel(defaultLangModel.value);
|
||||||
|
}
|
||||||
|
if (defaultEmbedModel) {
|
||||||
|
setEmbeddingModel(defaultEmbedModel.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [modelsData]);
|
||||||
|
|
||||||
const handleSampleDatasetChange = (dataset: boolean) => {
|
const handleSampleDatasetChange = (dataset: boolean) => {
|
||||||
setSampleDataset(dataset);
|
setSampleDataset(dataset);
|
||||||
};
|
};
|
||||||
|
|
@ -45,7 +65,9 @@ export function OllamaOnboarding({
|
||||||
helperText="The endpoint for your Ollama server."
|
helperText="The endpoint for your Ollama server."
|
||||||
id="api-endpoint"
|
id="api-endpoint"
|
||||||
required
|
required
|
||||||
placeholder="http://..."
|
placeholder="http://localhost:11434"
|
||||||
|
value={endpoint}
|
||||||
|
onChange={(e) => setEndpoint(e.target.value)}
|
||||||
/>
|
/>
|
||||||
<LabelWrapper
|
<LabelWrapper
|
||||||
label="Embedding model"
|
label="Embedding model"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import { useState } from "react";
|
import { useState, useEffect } from "react";
|
||||||
import { LabelInput } from "@/components/label-input";
|
import { LabelInput } from "@/components/label-input";
|
||||||
import OpenAILogo from "@/components/logo/openai-logo";
|
import OpenAILogo from "@/components/logo/openai-logo";
|
||||||
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
import type { Settings } from "../api/queries/useGetSettingsQuery";
|
||||||
|
import { useGetOpenAIModelsQuery } from "../api/queries/useGetModelsQuery";
|
||||||
import { AdvancedOnboarding } from "./advanced";
|
import { AdvancedOnboarding } from "./advanced";
|
||||||
|
|
||||||
export function OpenAIOnboarding({
|
export function OpenAIOnboarding({
|
||||||
|
|
@ -19,10 +20,30 @@ export function OpenAIOnboarding({
|
||||||
const [embeddingModel, setEmbeddingModel] = useState(
|
const [embeddingModel, setEmbeddingModel] = useState(
|
||||||
"text-embedding-3-small",
|
"text-embedding-3-small",
|
||||||
);
|
);
|
||||||
const languageModels = [{ value: "gpt-4o-mini", label: "gpt-4o-mini" }];
|
|
||||||
const embeddingModels = [
|
// Fetch models from API
|
||||||
{ value: "text-embedding-3-small", label: "text-embedding-3-small" },
|
const { data: modelsData } = useGetOpenAIModelsQuery();
|
||||||
|
|
||||||
|
// Use fetched models or fallback to defaults
|
||||||
|
const languageModels = modelsData?.language_models || [{ value: "gpt-4o-mini", label: "gpt-4o-mini", default: true }];
|
||||||
|
const embeddingModels = modelsData?.embedding_models || [
|
||||||
|
{ value: "text-embedding-3-small", label: "text-embedding-3-small", default: true },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Update default selections when models are loaded
|
||||||
|
useEffect(() => {
|
||||||
|
if (modelsData) {
|
||||||
|
const defaultLangModel = modelsData.language_models.find(m => m.default);
|
||||||
|
const defaultEmbedModel = modelsData.embedding_models.find(m => m.default);
|
||||||
|
|
||||||
|
if (defaultLangModel && languageModel === "gpt-4o-mini") {
|
||||||
|
setLanguageModel(defaultLangModel.value);
|
||||||
|
}
|
||||||
|
if (defaultEmbedModel && embeddingModel === "text-embedding-3-small") {
|
||||||
|
setEmbeddingModel(defaultEmbedModel.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [modelsData, languageModel, embeddingModel]);
|
||||||
const handleLanguageModelChange = (model: string) => {
|
const handleLanguageModelChange = (model: string) => {
|
||||||
setLanguageModel(model);
|
setLanguageModel(model);
|
||||||
};
|
};
|
||||||
|
|
|
||||||
63
src/api/models.py
Normal file
63
src/api/models.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_openai_models(request, models_service, session_manager):
|
||||||
|
"""Get available OpenAI models"""
|
||||||
|
try:
|
||||||
|
models = await models_service.get_openai_models()
|
||||||
|
return JSONResponse(models)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get OpenAI models: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"Failed to retrieve OpenAI models: {str(e)}"},
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_ollama_models(request, models_service, session_manager):
|
||||||
|
"""Get available Ollama models"""
|
||||||
|
try:
|
||||||
|
# Get endpoint from query parameters if provided
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
endpoint = query_params.get("endpoint")
|
||||||
|
|
||||||
|
models = await models_service.get_ollama_models(endpoint=endpoint)
|
||||||
|
return JSONResponse(models)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get Ollama models: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"Failed to retrieve Ollama models: {str(e)}"},
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_ibm_models(request, models_service, session_manager):
|
||||||
|
"""Get available IBM Watson models"""
|
||||||
|
try:
|
||||||
|
# Get credentials from query parameters or request body if provided
|
||||||
|
if request.method == "POST":
|
||||||
|
body = await request.json()
|
||||||
|
api_key = body.get("api_key")
|
||||||
|
endpoint = body.get("endpoint")
|
||||||
|
project_id = body.get("project_id")
|
||||||
|
else:
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
api_key = query_params.get("api_key")
|
||||||
|
endpoint = query_params.get("endpoint")
|
||||||
|
project_id = query_params.get("project_id")
|
||||||
|
|
||||||
|
models = await models_service.get_ibm_models(
|
||||||
|
api_key=api_key,
|
||||||
|
endpoint=endpoint,
|
||||||
|
project_id=project_id
|
||||||
|
)
|
||||||
|
return JSONResponse(models)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get IBM models: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"Failed to retrieve IBM models: {str(e)}"},
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
38
src/main.py
38
src/main.py
|
|
@ -33,6 +33,7 @@ from api import (
|
||||||
flows,
|
flows,
|
||||||
knowledge_filter,
|
knowledge_filter,
|
||||||
langflow_files,
|
langflow_files,
|
||||||
|
models,
|
||||||
nudges,
|
nudges,
|
||||||
oidc,
|
oidc,
|
||||||
router,
|
router,
|
||||||
|
|
@ -66,6 +67,7 @@ from services.knowledge_filter_service import KnowledgeFilterService
|
||||||
# Configuration and setup
|
# Configuration and setup
|
||||||
# Services
|
# Services
|
||||||
from services.langflow_file_service import LangflowFileService
|
from services.langflow_file_service import LangflowFileService
|
||||||
|
from services.models_service import ModelsService
|
||||||
from services.monitor_service import MonitorService
|
from services.monitor_service import MonitorService
|
||||||
from services.search_service import SearchService
|
from services.search_service import SearchService
|
||||||
from services.task_service import TaskService
|
from services.task_service import TaskService
|
||||||
|
|
@ -409,6 +411,7 @@ async def initialize_services():
|
||||||
chat_service = ChatService()
|
chat_service = ChatService()
|
||||||
flows_service = FlowsService()
|
flows_service = FlowsService()
|
||||||
knowledge_filter_service = KnowledgeFilterService(session_manager)
|
knowledge_filter_service = KnowledgeFilterService(session_manager)
|
||||||
|
models_service = ModelsService()
|
||||||
monitor_service = MonitorService(session_manager)
|
monitor_service = MonitorService(session_manager)
|
||||||
|
|
||||||
# Set process pool for document service
|
# Set process pool for document service
|
||||||
|
|
@ -470,6 +473,7 @@ async def initialize_services():
|
||||||
"auth_service": auth_service,
|
"auth_service": auth_service,
|
||||||
"connector_service": connector_service,
|
"connector_service": connector_service,
|
||||||
"knowledge_filter_service": knowledge_filter_service,
|
"knowledge_filter_service": knowledge_filter_service,
|
||||||
|
"models_service": models_service,
|
||||||
"monitor_service": monitor_service,
|
"monitor_service": monitor_service,
|
||||||
"session_manager": session_manager,
|
"session_manager": session_manager,
|
||||||
}
|
}
|
||||||
|
|
@ -909,6 +913,40 @@ async def create_app():
|
||||||
),
|
),
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
),
|
),
|
||||||
|
# Models endpoints
|
||||||
|
Route(
|
||||||
|
"/models/openai",
|
||||||
|
require_auth(services["session_manager"])(
|
||||||
|
partial(
|
||||||
|
models.get_openai_models,
|
||||||
|
models_service=services["models_service"],
|
||||||
|
session_manager=services["session_manager"]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
methods=["GET"],
|
||||||
|
),
|
||||||
|
Route(
|
||||||
|
"/models/ollama",
|
||||||
|
require_auth(services["session_manager"])(
|
||||||
|
partial(
|
||||||
|
models.get_ollama_models,
|
||||||
|
models_service=services["models_service"],
|
||||||
|
session_manager=services["session_manager"]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
methods=["GET"],
|
||||||
|
),
|
||||||
|
Route(
|
||||||
|
"/models/ibm",
|
||||||
|
require_auth(services["session_manager"])(
|
||||||
|
partial(
|
||||||
|
models.get_ibm_models,
|
||||||
|
models_service=services["models_service"],
|
||||||
|
session_manager=services["session_manager"]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
methods=["GET", "POST"],
|
||||||
|
),
|
||||||
# Onboarding endpoint
|
# Onboarding endpoint
|
||||||
Route(
|
Route(
|
||||||
"/onboarding",
|
"/onboarding",
|
||||||
|
|
|
||||||
317
src/services/models_service.py
Normal file
317
src/services/models_service.py
Normal file
|
|
@ -0,0 +1,317 @@
|
||||||
|
import httpx
|
||||||
|
import os
|
||||||
|
from typing import Dict, List
|
||||||
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsService:
|
||||||
|
"""Service for fetching available models from different AI providers"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.session_manager = None
|
||||||
|
|
||||||
|
async def get_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||||
|
"""Fetch available models from OpenAI API"""
|
||||||
|
try:
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
logger.warning("OPENAI_API_KEY not set, using default models")
|
||||||
|
return self._get_default_openai_models()
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
"https://api.openai.com/v1/models", headers=headers, timeout=10.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = data.get("data", [])
|
||||||
|
|
||||||
|
# Filter for relevant models
|
||||||
|
language_models = []
|
||||||
|
embedding_models = []
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
model_id = model.get("id", "")
|
||||||
|
|
||||||
|
# Language models (GPT models)
|
||||||
|
if any(prefix in model_id for prefix in ["gpt-4", "gpt-3.5"]):
|
||||||
|
language_models.append(
|
||||||
|
{
|
||||||
|
"value": model_id,
|
||||||
|
"label": model_id,
|
||||||
|
"default": model_id == "gpt-4o-mini",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding models
|
||||||
|
elif "text-embedding" in model_id:
|
||||||
|
embedding_models.append(
|
||||||
|
{
|
||||||
|
"value": model_id,
|
||||||
|
"label": model_id,
|
||||||
|
"default": model_id == "text-embedding-3-small",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by name and ensure defaults are first
|
||||||
|
language_models.sort(
|
||||||
|
key=lambda x: (not x.get("default", False), x["value"])
|
||||||
|
)
|
||||||
|
embedding_models.sort(
|
||||||
|
key=lambda x: (not x.get("default", False), x["value"])
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"language_models": language_models,
|
||||||
|
"embedding_models": embedding_models,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to fetch OpenAI models: {response.status_code}")
|
||||||
|
return self._get_default_openai_models()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching OpenAI models: {str(e)}")
|
||||||
|
return self._get_default_openai_models()
|
||||||
|
|
||||||
|
async def get_ollama_models(
|
||||||
|
self, endpoint: str = None
|
||||||
|
) -> Dict[str, List[Dict[str, str]]]:
|
||||||
|
"""Fetch available models from Ollama API"""
|
||||||
|
try:
|
||||||
|
# Use provided endpoint or default
|
||||||
|
ollama_url = endpoint or os.getenv(
|
||||||
|
"OLLAMA_BASE_URL", "http://localhost:11434"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = data.get("models", [])
|
||||||
|
|
||||||
|
# Extract model names
|
||||||
|
language_models = []
|
||||||
|
embedding_models = []
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
model_name = model.get("name", "").split(":")[
|
||||||
|
0
|
||||||
|
] # Remove tag if present
|
||||||
|
|
||||||
|
if model_name:
|
||||||
|
# Most Ollama models can be used as language models
|
||||||
|
language_models.append(
|
||||||
|
{
|
||||||
|
"value": model_name,
|
||||||
|
"label": model_name,
|
||||||
|
"default": "llama3" in model_name.lower(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some models are specifically for embeddings
|
||||||
|
if any(
|
||||||
|
embed in model_name.lower()
|
||||||
|
for embed in ["embed", "sentence", "all-minilm"]
|
||||||
|
):
|
||||||
|
embedding_models.append(
|
||||||
|
{
|
||||||
|
"value": model_name,
|
||||||
|
"label": model_name,
|
||||||
|
"default": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove duplicates and sort
|
||||||
|
language_models = list(
|
||||||
|
{m["value"]: m for m in language_models}.values()
|
||||||
|
)
|
||||||
|
embedding_models = list(
|
||||||
|
{m["value"]: m for m in embedding_models}.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
language_models.sort(
|
||||||
|
key=lambda x: (not x.get("default", False), x["value"])
|
||||||
|
)
|
||||||
|
embedding_models.sort(key=lambda x: x["value"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"language_models": language_models,
|
||||||
|
"embedding_models": embedding_models
|
||||||
|
if embedding_models
|
||||||
|
else [
|
||||||
|
{
|
||||||
|
"value": "nomic-embed-text",
|
||||||
|
"label": "nomic-embed-text",
|
||||||
|
"default": True,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to fetch Ollama models: {response.status_code}")
|
||||||
|
return self._get_default_ollama_models()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching Ollama models: {str(e)}")
|
||||||
|
return self._get_default_ollama_models()
|
||||||
|
|
||||||
|
async def get_ibm_models(
|
||||||
|
self, endpoint: str = None
|
||||||
|
) -> Dict[str, List[Dict[str, str]]]:
|
||||||
|
"""Fetch available models from IBM Watson API"""
|
||||||
|
try:
|
||||||
|
# Use provided endpoint or default
|
||||||
|
watson_endpoint = endpoint or os.getenv("IBM_WATSON_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
|
||||||
|
|
||||||
|
# Fetch foundation models using the correct endpoint
|
||||||
|
models_url = f"{watson_endpoint}/ml/v1/foundation_model_specs"
|
||||||
|
|
||||||
|
language_models = []
|
||||||
|
embedding_models = []
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Fetch text chat models
|
||||||
|
text_params = {
|
||||||
|
"version": "2024-09-16",
|
||||||
|
"filters": "function_text_chat,!lifecycle_withdrawn"
|
||||||
|
}
|
||||||
|
text_response = await client.get(models_url, params=text_params, timeout=10.0)
|
||||||
|
|
||||||
|
if text_response.status_code == 200:
|
||||||
|
text_data = text_response.json()
|
||||||
|
text_models = text_data.get("resources", [])
|
||||||
|
|
||||||
|
for i, model in enumerate(text_models):
|
||||||
|
model_id = model.get("model_id", "")
|
||||||
|
model_name = model.get("name", model_id)
|
||||||
|
|
||||||
|
language_models.append({
|
||||||
|
"value": model_id,
|
||||||
|
"label": model_name or model_id,
|
||||||
|
"default": i == 0 # First model is default
|
||||||
|
})
|
||||||
|
|
||||||
|
# Fetch embedding models
|
||||||
|
embed_params = {
|
||||||
|
"version": "2024-09-16",
|
||||||
|
"filters": "function_embedding,!lifecycle_withdrawn"
|
||||||
|
}
|
||||||
|
embed_response = await client.get(models_url, params=embed_params, timeout=10.0)
|
||||||
|
|
||||||
|
if embed_response.status_code == 200:
|
||||||
|
embed_data = embed_response.json()
|
||||||
|
embed_models = embed_data.get("resources", [])
|
||||||
|
|
||||||
|
for i, model in enumerate(embed_models):
|
||||||
|
model_id = model.get("model_id", "")
|
||||||
|
model_name = model.get("name", model_id)
|
||||||
|
|
||||||
|
embedding_models.append({
|
||||||
|
"value": model_id,
|
||||||
|
"label": model_name or model_id,
|
||||||
|
"default": i == 0 # First model is default
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"language_models": language_models if language_models else self._get_default_ibm_models()["language_models"],
|
||||||
|
"embedding_models": embedding_models if embedding_models else self._get_default_ibm_models()["embedding_models"]
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching IBM models: {str(e)}")
|
||||||
|
return self._get_default_ibm_models()
|
||||||
|
|
||||||
|
def _get_default_openai_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||||
|
"""Default OpenAI models when API is not available"""
|
||||||
|
return {
|
||||||
|
"language_models": [
|
||||||
|
{"value": "gpt-4o-mini", "label": "gpt-4o-mini", "default": True},
|
||||||
|
{"value": "gpt-4o", "label": "gpt-4o", "default": False},
|
||||||
|
{"value": "gpt-4-turbo", "label": "gpt-4-turbo", "default": False},
|
||||||
|
{"value": "gpt-3.5-turbo", "label": "gpt-3.5-turbo", "default": False},
|
||||||
|
],
|
||||||
|
"embedding_models": [
|
||||||
|
{
|
||||||
|
"value": "text-embedding-3-small",
|
||||||
|
"label": "text-embedding-3-small",
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "text-embedding-3-large",
|
||||||
|
"label": "text-embedding-3-large",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "text-embedding-ada-002",
|
||||||
|
"label": "text-embedding-ada-002",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_default_ollama_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||||
|
"""Default Ollama models when API is not available"""
|
||||||
|
return {
|
||||||
|
"language_models": [
|
||||||
|
{"value": "llama3.2", "label": "llama3.2", "default": True},
|
||||||
|
{"value": "llama3.1", "label": "llama3.1", "default": False},
|
||||||
|
{"value": "llama3", "label": "llama3", "default": False},
|
||||||
|
{"value": "mistral", "label": "mistral", "default": False},
|
||||||
|
{"value": "codellama", "label": "codellama", "default": False},
|
||||||
|
],
|
||||||
|
"embedding_models": [
|
||||||
|
{
|
||||||
|
"value": "nomic-embed-text",
|
||||||
|
"label": "nomic-embed-text",
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
{"value": "all-minilm", "label": "all-minilm", "default": False},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_default_ibm_models(self) -> Dict[str, List[Dict[str, str]]]:
|
||||||
|
"""Default IBM Watson models when API is not available"""
|
||||||
|
return {
|
||||||
|
"language_models": [
|
||||||
|
{
|
||||||
|
"value": "meta-llama/llama-3-1-70b-instruct",
|
||||||
|
"label": "Llama 3.1 70B Instruct",
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "meta-llama/llama-3-1-8b-instruct",
|
||||||
|
"label": "Llama 3.1 8B Instruct",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "ibm/granite-13b-chat-v2",
|
||||||
|
"label": "Granite 13B Chat v2",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "ibm/granite-13b-instruct-v2",
|
||||||
|
"label": "Granite 13B Instruct v2",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"embedding_models": [
|
||||||
|
{
|
||||||
|
"value": "ibm/slate-125m-english-rtrvr",
|
||||||
|
"label": "Slate 125M English Retriever",
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "sentence-transformers/all-minilm-l12-v2",
|
||||||
|
"label": "All-MiniLM L12 v2",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue