changed models service to check embedding and llm models
This commit is contained in:
parent
a747f3712d
commit
ce9859a031
1 changed files with 86 additions and 31 deletions
|
|
@ -8,6 +8,18 @@ logger = get_logger(__name__)
|
|||
class ModelsService:
|
||||
"""Service for fetching available models from different AI providers"""
|
||||
|
||||
OLLAMA_EMBEDDING_MODELS = [
|
||||
"nomic-embed-text",
|
||||
"mxbai-embed-large",
|
||||
"snowflake-arctic-embed",
|
||||
"all-minilm",
|
||||
"bge-m3",
|
||||
"bge-large",
|
||||
"paraphrase-multilingual",
|
||||
"granite-embedding",
|
||||
"jina-embeddings-v2-base-en",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.session_manager = None
|
||||
|
||||
|
|
@ -80,49 +92,93 @@ class ModelsService:
|
|||
async def get_ollama_models(
|
||||
self, endpoint: str = None
|
||||
) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Fetch available models from Ollama API"""
|
||||
"""Fetch available models from Ollama API with tool calling capabilities for language models"""
|
||||
try:
|
||||
# Use provided endpoint or default
|
||||
ollama_url = endpoint
|
||||
|
||||
# API endpoints
|
||||
tags_url = f"{ollama_url}/api/tags"
|
||||
show_url = f"{ollama_url}/api/show"
|
||||
|
||||
# Constants for JSON parsing
|
||||
JSON_MODELS_KEY = "models"
|
||||
JSON_NAME_KEY = "name"
|
||||
JSON_CAPABILITIES_KEY = "capabilities"
|
||||
DESIRED_CAPABILITY = "completion"
|
||||
TOOL_CALLING_CAPABILITY = "tools"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{ollama_url}/api/tags", timeout=10.0)
|
||||
# Fetch available models
|
||||
tags_response = await client.get(tags_url, timeout=10.0)
|
||||
tags_response.raise_for_status()
|
||||
models_data = tags_response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = data.get("models", [])
|
||||
logger.debug(f"Available models: {models_data}")
|
||||
|
||||
# Extract model names
|
||||
# Filter models based on capabilities
|
||||
language_models = []
|
||||
embedding_models = []
|
||||
|
||||
for model in models:
|
||||
model_name = model.get("name", "").split(":")[
|
||||
0
|
||||
] # Remove tag if present
|
||||
models = models_data.get(JSON_MODELS_KEY, [])
|
||||
|
||||
if model_name:
|
||||
# Most Ollama models can be used as language models
|
||||
language_models.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": model_name,
|
||||
"default": "llama3" in model_name.lower(),
|
||||
}
|
||||
for model in models:
|
||||
model_name = model.get(JSON_NAME_KEY, "")
|
||||
# Remove tag if present (e.g., "llama3:latest" -> "llama3")
|
||||
clean_model_name = model_name.split(":")[0] if model_name else ""
|
||||
|
||||
if not clean_model_name:
|
||||
continue
|
||||
|
||||
logger.debug(f"Checking model: {model_name}")
|
||||
|
||||
# Check model capabilities
|
||||
payload = {"model": model_name}
|
||||
try:
|
||||
show_response = await client.post(
|
||||
show_url, json=payload, timeout=10.0
|
||||
)
|
||||
show_response.raise_for_status()
|
||||
json_data = show_response.json()
|
||||
|
||||
capabilities = json_data.get(JSON_CAPABILITIES_KEY, [])
|
||||
logger.debug(
|
||||
f"Model: {model_name}, Capabilities: {capabilities}"
|
||||
)
|
||||
|
||||
# Some models are specifically for embeddings
|
||||
if any(
|
||||
embed in model_name.lower()
|
||||
for embed in ["embed", "sentence", "all-minilm"]
|
||||
):
|
||||
# Check if model has required capabilities
|
||||
has_completion = DESIRED_CAPABILITY in capabilities
|
||||
has_tools = TOOL_CALLING_CAPABILITY in capabilities
|
||||
|
||||
# Check if it's an embedding model
|
||||
is_embedding = any(
|
||||
embed_model in clean_model_name.lower()
|
||||
for embed_model in self.OLLAMA_EMBEDDING_MODELS
|
||||
)
|
||||
|
||||
if is_embedding:
|
||||
# Embedding models only need completion capability
|
||||
embedding_models.append(
|
||||
{
|
||||
"value": model_name,
|
||||
"label": model_name,
|
||||
"value": clean_model_name,
|
||||
"label": clean_model_name,
|
||||
"default": False,
|
||||
}
|
||||
)
|
||||
elif not is_embedding and has_completion and has_tools:
|
||||
# Language models need both completion and tool calling
|
||||
language_models.append(
|
||||
{
|
||||
"value": clean_model_name,
|
||||
"label": clean_model_name,
|
||||
"default": "llama3" in clean_model_name.lower(),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to check capabilities for model {model_name}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Remove duplicates and sort
|
||||
language_models = list(
|
||||
|
|
@ -137,15 +193,14 @@ class ModelsService:
|
|||
)
|
||||
embedding_models.sort(key=lambda x: x["value"])
|
||||
|
||||
logger.info(
|
||||
f"Found {len(language_models)} language models with tool calling and {len(embedding_models)} embedding models"
|
||||
)
|
||||
|
||||
return {
|
||||
"language_models": language_models,
|
||||
"embedding_models": embedding_models if embedding_models else [],
|
||||
"embedding_models": embedding_models,
|
||||
}
|
||||
else:
|
||||
logger.error(f"Failed to fetch Ollama models: {response.status_code}")
|
||||
raise Exception(
|
||||
f"Ollama API returned status code {response.status_code}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Ollama models: {str(e)}")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue