changed models service to check embedding and llm models

This commit is contained in:
Lucas Oliveira 2025-09-22 11:19:10 -03:00
parent a747f3712d
commit ce9859a031

View file

@ -8,6 +8,18 @@ logger = get_logger(__name__)
class ModelsService:
"""Service for fetching available models from different AI providers"""
OLLAMA_EMBEDDING_MODELS = [
"nomic-embed-text",
"mxbai-embed-large",
"snowflake-arctic-embed",
"all-minilm",
"bge-m3",
"bge-large",
"paraphrase-multilingual",
"granite-embedding",
"jina-embeddings-v2-base-en",
]
def __init__(self):
self.session_manager = None
@ -80,49 +92,93 @@ class ModelsService:
async def get_ollama_models(
self, endpoint: str = None
) -> Dict[str, List[Dict[str, str]]]:
"""Fetch available models from Ollama API"""
"""Fetch available models from Ollama API with tool calling capabilities for language models"""
try:
# Use provided endpoint or default
ollama_url = endpoint
# API endpoints
tags_url = f"{ollama_url}/api/tags"
show_url = f"{ollama_url}/api/show"
# Constants for JSON parsing
JSON_MODELS_KEY = "models"
JSON_NAME_KEY = "name"
JSON_CAPABILITIES_KEY = "capabilities"
DESIRED_CAPABILITY = "completion"
TOOL_CALLING_CAPABILITY = "tools"
async with httpx.AsyncClient() as client:
response = await client.get(f"{ollama_url}/api/tags", timeout=10.0)
# Fetch available models
tags_response = await client.get(tags_url, timeout=10.0)
tags_response.raise_for_status()
models_data = tags_response.json()
if response.status_code == 200:
data = response.json()
models = data.get("models", [])
logger.debug(f"Available models: {models_data}")
# Extract model names
# Filter models based on capabilities
language_models = []
embedding_models = []
for model in models:
model_name = model.get("name", "").split(":")[
0
] # Remove tag if present
models = models_data.get(JSON_MODELS_KEY, [])
if model_name:
# Most Ollama models can be used as language models
language_models.append(
{
"value": model_name,
"label": model_name,
"default": "llama3" in model_name.lower(),
}
for model in models:
model_name = model.get(JSON_NAME_KEY, "")
# Remove tag if present (e.g., "llama3:latest" -> "llama3")
clean_model_name = model_name.split(":")[0] if model_name else ""
if not clean_model_name:
continue
logger.debug(f"Checking model: {model_name}")
# Check model capabilities
payload = {"model": model_name}
try:
show_response = await client.post(
show_url, json=payload, timeout=10.0
)
show_response.raise_for_status()
json_data = show_response.json()
capabilities = json_data.get(JSON_CAPABILITIES_KEY, [])
logger.debug(
f"Model: {model_name}, Capabilities: {capabilities}"
)
# Some models are specifically for embeddings
if any(
embed in model_name.lower()
for embed in ["embed", "sentence", "all-minilm"]
):
# Check if model has required capabilities
has_completion = DESIRED_CAPABILITY in capabilities
has_tools = TOOL_CALLING_CAPABILITY in capabilities
# Check if it's an embedding model
is_embedding = any(
embed_model in clean_model_name.lower()
for embed_model in self.OLLAMA_EMBEDDING_MODELS
)
if is_embedding:
# Embedding models only need completion capability
embedding_models.append(
{
"value": model_name,
"label": model_name,
"value": clean_model_name,
"label": clean_model_name,
"default": False,
}
)
elif not is_embedding and has_completion and has_tools:
# Language models need both completion and tool calling
language_models.append(
{
"value": clean_model_name,
"label": clean_model_name,
"default": "llama3" in clean_model_name.lower(),
}
)
except Exception as e:
logger.warning(
f"Failed to check capabilities for model {model_name}: {str(e)}"
)
continue
# Remove duplicates and sort
language_models = list(
@ -137,15 +193,14 @@ class ModelsService:
)
embedding_models.sort(key=lambda x: x["value"])
logger.info(
f"Found {len(language_models)} language models with tool calling and {len(embedding_models)} embedding models"
)
return {
"language_models": language_models,
"embedding_models": embedding_models if embedding_models else [],
"embedding_models": embedding_models,
}
else:
logger.error(f"Failed to fetch Ollama models: {response.status_code}")
raise Exception(
f"Ollama API returned status code {response.status_code}"
)
except Exception as e:
logger.error(f"Error fetching Ollama models: {str(e)}")