diff --git a/src/services/models_service.py b/src/services/models_service.py index 7620461b..e38c4e8b 100644 --- a/src/services/models_service.py +++ b/src/services/models_service.py @@ -8,6 +8,18 @@ logger = get_logger(__name__) class ModelsService: """Service for fetching available models from different AI providers""" + OLLAMA_EMBEDDING_MODELS = [ + "nomic-embed-text", + "mxbai-embed-large", + "snowflake-arctic-embed", + "all-minilm", + "bge-m3", + "bge-large", + "paraphrase-multilingual", + "granite-embedding", + "jina-embeddings-v2-base-en", + ] + def __init__(self): self.session_manager = None @@ -80,49 +92,93 @@ class ModelsService: async def get_ollama_models( self, endpoint: str = None ) -> Dict[str, List[Dict[str, str]]]: - """Fetch available models from Ollama API""" + """Fetch available models from Ollama API with tool calling capabilities for language models""" try: # Use provided endpoint or default ollama_url = endpoint + # API endpoints + tags_url = f"{ollama_url}/api/tags" + show_url = f"{ollama_url}/api/show" + + # Constants for JSON parsing + JSON_MODELS_KEY = "models" + JSON_NAME_KEY = "name" + JSON_CAPABILITIES_KEY = "capabilities" + DESIRED_CAPABILITY = "completion" + TOOL_CALLING_CAPABILITY = "tools" + async with httpx.AsyncClient() as client: - response = await client.get(f"{ollama_url}/api/tags", timeout=10.0) + # Fetch available models + tags_response = await client.get(tags_url, timeout=10.0) + tags_response.raise_for_status() + models_data = tags_response.json() - if response.status_code == 200: - data = response.json() - models = data.get("models", []) + logger.debug(f"Available models: {models_data}") - # Extract model names + # Filter models based on capabilities language_models = [] embedding_models = [] - for model in models: - model_name = model.get("name", "").split(":")[ - 0 - ] # Remove tag if present + models = models_data.get(JSON_MODELS_KEY, []) - if model_name: - # Most Ollama models can be used as language models - language_models.append( - { - "value": model_name, - "label": model_name, - "default": "llama3" in model_name.lower(), - } + for model in models: + model_name = model.get(JSON_NAME_KEY, "") + # Remove tag if present (e.g., "llama3:latest" -> "llama3") + clean_model_name = model_name.split(":")[0] if model_name else "" + + if not clean_model_name: + continue + + logger.debug(f"Checking model: {model_name}") + + # Check model capabilities + payload = {"model": model_name} + try: + show_response = await client.post( + show_url, json=payload, timeout=10.0 + ) + show_response.raise_for_status() + json_data = show_response.json() + + capabilities = json_data.get(JSON_CAPABILITIES_KEY, []) + logger.debug( + f"Model: {model_name}, Capabilities: {capabilities}" ) - # Some models are specifically for embeddings - if any( - embed in model_name.lower() - for embed in ["embed", "sentence", "all-minilm"] - ): + # Check if model has required capabilities + has_completion = DESIRED_CAPABILITY in capabilities + has_tools = TOOL_CALLING_CAPABILITY in capabilities + + # Check if it's an embedding model + is_embedding = any( + embed_model in clean_model_name.lower() + for embed_model in self.OLLAMA_EMBEDDING_MODELS + ) + + if is_embedding: + # Embedding models only need completion capability embedding_models.append( { - "value": model_name, - "label": model_name, + "value": clean_model_name, + "label": clean_model_name, "default": False, } ) + elif not is_embedding and has_completion and has_tools: + # Language models need both completion and tool calling + language_models.append( + { + "value": clean_model_name, + "label": clean_model_name, + "default": "llama3" in clean_model_name.lower(), + } + ) + except Exception as e: + logger.warning( + f"Failed to check capabilities for model {model_name}: {str(e)}" + ) + continue # Remove duplicates and sort language_models = list( @@ -137,15 +193,14 @@ class ModelsService: ) embedding_models.sort(key=lambda x: x["value"]) + logger.info( + f"Found {len(language_models)} language models with tool calling and {len(embedding_models)} embedding models" + ) + return { "language_models": language_models, - "embedding_models": embedding_models if embedding_models else [], + "embedding_models": embedding_models, } - else: - logger.error(f"Failed to fetch Ollama models: {response.status_code}") - raise Exception( - f"Ollama API returned status code {response.status_code}" - ) except Exception as e: logger.error(f"Error fetching Ollama models: {str(e)}")