changed models service to check embedding and llm models
This commit is contained in:
parent
a747f3712d
commit
ce9859a031
1 changed files with 86 additions and 31 deletions
|
|
@ -8,6 +8,18 @@ logger = get_logger(__name__)
|
||||||
class ModelsService:
|
class ModelsService:
|
||||||
"""Service for fetching available models from different AI providers"""
|
"""Service for fetching available models from different AI providers"""
|
||||||
|
|
||||||
|
OLLAMA_EMBEDDING_MODELS = [
|
||||||
|
"nomic-embed-text",
|
||||||
|
"mxbai-embed-large",
|
||||||
|
"snowflake-arctic-embed",
|
||||||
|
"all-minilm",
|
||||||
|
"bge-m3",
|
||||||
|
"bge-large",
|
||||||
|
"paraphrase-multilingual",
|
||||||
|
"granite-embedding",
|
||||||
|
"jina-embeddings-v2-base-en",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.session_manager = None
|
self.session_manager = None
|
||||||
|
|
||||||
|
|
@ -80,49 +92,93 @@ class ModelsService:
|
||||||
async def get_ollama_models(
|
async def get_ollama_models(
|
||||||
self, endpoint: str = None
|
self, endpoint: str = None
|
||||||
) -> Dict[str, List[Dict[str, str]]]:
|
) -> Dict[str, List[Dict[str, str]]]:
|
||||||
"""Fetch available models from Ollama API"""
|
"""Fetch available models from Ollama API with tool calling capabilities for language models"""
|
||||||
try:
|
try:
|
||||||
# Use provided endpoint or default
|
# Use provided endpoint or default
|
||||||
ollama_url = endpoint
|
ollama_url = endpoint
|
||||||
|
|
||||||
|
# API endpoints
|
||||||
|
tags_url = f"{ollama_url}/api/tags"
|
||||||
|
show_url = f"{ollama_url}/api/show"
|
||||||
|
|
||||||
|
# Constants for JSON parsing
|
||||||
|
JSON_MODELS_KEY = "models"
|
||||||
|
JSON_NAME_KEY = "name"
|
||||||
|
JSON_CAPABILITIES_KEY = "capabilities"
|
||||||
|
DESIRED_CAPABILITY = "completion"
|
||||||
|
TOOL_CALLING_CAPABILITY = "tools"
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(f"{ollama_url}/api/tags", timeout=10.0)
|
# Fetch available models
|
||||||
|
tags_response = await client.get(tags_url, timeout=10.0)
|
||||||
|
tags_response.raise_for_status()
|
||||||
|
models_data = tags_response.json()
|
||||||
|
|
||||||
if response.status_code == 200:
|
logger.debug(f"Available models: {models_data}")
|
||||||
data = response.json()
|
|
||||||
models = data.get("models", [])
|
|
||||||
|
|
||||||
# Extract model names
|
# Filter models based on capabilities
|
||||||
language_models = []
|
language_models = []
|
||||||
embedding_models = []
|
embedding_models = []
|
||||||
|
|
||||||
for model in models:
|
models = models_data.get(JSON_MODELS_KEY, [])
|
||||||
model_name = model.get("name", "").split(":")[
|
|
||||||
0
|
|
||||||
] # Remove tag if present
|
|
||||||
|
|
||||||
if model_name:
|
for model in models:
|
||||||
# Most Ollama models can be used as language models
|
model_name = model.get(JSON_NAME_KEY, "")
|
||||||
language_models.append(
|
# Remove tag if present (e.g., "llama3:latest" -> "llama3")
|
||||||
{
|
clean_model_name = model_name.split(":")[0] if model_name else ""
|
||||||
"value": model_name,
|
|
||||||
"label": model_name,
|
if not clean_model_name:
|
||||||
"default": "llama3" in model_name.lower(),
|
continue
|
||||||
}
|
|
||||||
|
logger.debug(f"Checking model: {model_name}")
|
||||||
|
|
||||||
|
# Check model capabilities
|
||||||
|
payload = {"model": model_name}
|
||||||
|
try:
|
||||||
|
show_response = await client.post(
|
||||||
|
show_url, json=payload, timeout=10.0
|
||||||
|
)
|
||||||
|
show_response.raise_for_status()
|
||||||
|
json_data = show_response.json()
|
||||||
|
|
||||||
|
capabilities = json_data.get(JSON_CAPABILITIES_KEY, [])
|
||||||
|
logger.debug(
|
||||||
|
f"Model: {model_name}, Capabilities: {capabilities}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Some models are specifically for embeddings
|
# Check if model has required capabilities
|
||||||
if any(
|
has_completion = DESIRED_CAPABILITY in capabilities
|
||||||
embed in model_name.lower()
|
has_tools = TOOL_CALLING_CAPABILITY in capabilities
|
||||||
for embed in ["embed", "sentence", "all-minilm"]
|
|
||||||
):
|
# Check if it's an embedding model
|
||||||
|
is_embedding = any(
|
||||||
|
embed_model in clean_model_name.lower()
|
||||||
|
for embed_model in self.OLLAMA_EMBEDDING_MODELS
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_embedding:
|
||||||
|
# Embedding models only need completion capability
|
||||||
embedding_models.append(
|
embedding_models.append(
|
||||||
{
|
{
|
||||||
"value": model_name,
|
"value": clean_model_name,
|
||||||
"label": model_name,
|
"label": clean_model_name,
|
||||||
"default": False,
|
"default": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif not is_embedding and has_completion and has_tools:
|
||||||
|
# Language models need both completion and tool calling
|
||||||
|
language_models.append(
|
||||||
|
{
|
||||||
|
"value": clean_model_name,
|
||||||
|
"label": clean_model_name,
|
||||||
|
"default": "llama3" in clean_model_name.lower(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to check capabilities for model {model_name}: {str(e)}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# Remove duplicates and sort
|
# Remove duplicates and sort
|
||||||
language_models = list(
|
language_models = list(
|
||||||
|
|
@ -137,15 +193,14 @@ class ModelsService:
|
||||||
)
|
)
|
||||||
embedding_models.sort(key=lambda x: x["value"])
|
embedding_models.sort(key=lambda x: x["value"])
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(language_models)} language models with tool calling and {len(embedding_models)} embedding models"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"language_models": language_models,
|
"language_models": language_models,
|
||||||
"embedding_models": embedding_models if embedding_models else [],
|
"embedding_models": embedding_models,
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
logger.error(f"Failed to fetch Ollama models: {response.status_code}")
|
|
||||||
raise Exception(
|
|
||||||
f"Ollama API returned status code {response.status_code}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Ollama models: {str(e)}")
|
logger.error(f"Error fetching Ollama models: {str(e)}")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue