create lightweight health check based on query param
This commit is contained in:
parent
a778cd76fa
commit
9b08f1fcee
4 changed files with 198 additions and 86 deletions
|
|
@ -4,7 +4,7 @@ import httpx
|
|||
from starlette.responses import JSONResponse
|
||||
from utils.logging_config import get_logger
|
||||
from config.settings import get_openrag_config
|
||||
from api.provider_validation import validate_provider_setup, _test_ollama_lightweight_health
|
||||
from api.provider_validation import validate_provider_setup
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -16,6 +16,8 @@ async def check_provider_health(request):
|
|||
Query parameters:
|
||||
provider (optional): Provider to check ('openai', 'ollama', 'watsonx', 'anthropic').
|
||||
If not provided, checks the currently configured provider.
|
||||
test_completion (optional): If 'true', performs full validation with completion/embedding tests (consumes credits).
|
||||
If 'false' or not provided, performs lightweight validation (no/minimal credits consumed).
|
||||
|
||||
Returns:
|
||||
200: Provider is healthy and validated
|
||||
|
|
@ -26,6 +28,7 @@ async def check_provider_health(request):
|
|||
# Get optional provider from query params
|
||||
query_params = dict(request.query_params)
|
||||
check_provider = query_params.get("provider")
|
||||
test_completion = query_params.get("test_completion", "false").lower() == "true"
|
||||
|
||||
# Get current config
|
||||
current_config = get_openrag_config()
|
||||
|
|
@ -100,6 +103,7 @@ async def check_provider_health(request):
|
|||
llm_model=llm_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
test_completion=test_completion,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
|
|
@ -124,23 +128,14 @@ async def check_provider_health(request):
|
|||
|
||||
# Validate LLM provider
|
||||
try:
|
||||
# For Ollama, use lightweight health check that doesn't block on active requests
|
||||
if provider == "ollama":
|
||||
try:
|
||||
await _test_ollama_lightweight_health(endpoint)
|
||||
except Exception as lightweight_error:
|
||||
# If lightweight check fails, Ollama is down or misconfigured
|
||||
llm_error = str(lightweight_error)
|
||||
logger.error(f"LLM provider ({provider}) lightweight check failed: {llm_error}")
|
||||
raise
|
||||
else:
|
||||
await validate_provider_setup(
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
llm_model=llm_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
)
|
||||
await validate_provider_setup(
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
llm_model=llm_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
test_completion=test_completion,
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
# Timeout means provider is busy, not misconfigured
|
||||
if provider == "ollama":
|
||||
|
|
@ -155,23 +150,14 @@ async def check_provider_health(request):
|
|||
|
||||
# Validate embedding provider
|
||||
try:
|
||||
# For Ollama, use lightweight health check first
|
||||
if embedding_provider == "ollama":
|
||||
try:
|
||||
await _test_ollama_lightweight_health(embedding_endpoint)
|
||||
except Exception as lightweight_error:
|
||||
# If lightweight check fails, Ollama is down or misconfigured
|
||||
embedding_error = str(lightweight_error)
|
||||
logger.error(f"Embedding provider ({embedding_provider}) lightweight check failed: {embedding_error}")
|
||||
raise
|
||||
else:
|
||||
await validate_provider_setup(
|
||||
provider=embedding_provider,
|
||||
api_key=embedding_api_key,
|
||||
embedding_model=embedding_model,
|
||||
endpoint=embedding_endpoint,
|
||||
project_id=embedding_project_id,
|
||||
)
|
||||
await validate_provider_setup(
|
||||
provider=embedding_provider,
|
||||
api_key=embedding_api_key,
|
||||
embedding_model=embedding_model,
|
||||
endpoint=embedding_endpoint,
|
||||
project_id=embedding_project_id,
|
||||
test_completion=test_completion,
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
# Timeout means provider is busy, not misconfigured
|
||||
if embedding_provider == "ollama":
|
||||
|
|
|
|||
|
|
@ -14,17 +14,20 @@ async def validate_provider_setup(
|
|||
llm_model: str = None,
|
||||
endpoint: str = None,
|
||||
project_id: str = None,
|
||||
test_completion: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Validate provider setup by testing completion with tool calling and embedding.
|
||||
|
||||
Args:
|
||||
provider: Provider name ('openai', 'watsonx', 'ollama')
|
||||
provider: Provider name ('openai', 'watsonx', 'ollama', 'anthropic')
|
||||
api_key: API key for the provider (optional for ollama)
|
||||
embedding_model: Embedding model to test
|
||||
llm_model: LLM model to test
|
||||
endpoint: Provider endpoint (required for ollama and watsonx)
|
||||
project_id: Project ID (required for watsonx)
|
||||
test_completion: If True, performs full validation with completion/embedding tests (consumes credits).
|
||||
If False, performs lightweight validation (no credits consumed). Default: False.
|
||||
|
||||
Raises:
|
||||
Exception: If validation fails with message "Setup failed, please try again or select a different provider."
|
||||
|
|
@ -32,29 +35,37 @@ async def validate_provider_setup(
|
|||
provider_lower = provider.lower()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation for provider: {provider_lower}")
|
||||
logger.info(f"Starting validation for provider: {provider_lower} (test_completion={test_completion})")
|
||||
|
||||
if embedding_model:
|
||||
# Test embedding
|
||||
await test_embedding(
|
||||
if test_completion:
|
||||
# Full validation with completion/embedding tests (consumes credits)
|
||||
if embedding_model:
|
||||
# Test embedding
|
||||
await test_embedding(
|
||||
provider=provider_lower,
|
||||
api_key=api_key,
|
||||
embedding_model=embedding_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
)
|
||||
elif llm_model:
|
||||
# Test completion with tool calling
|
||||
await test_completion_with_tools(
|
||||
provider=provider_lower,
|
||||
api_key=api_key,
|
||||
llm_model=llm_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
)
|
||||
else:
|
||||
# Lightweight validation (no credits consumed)
|
||||
await test_lightweight_health(
|
||||
provider=provider_lower,
|
||||
api_key=api_key,
|
||||
embedding_model=embedding_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
elif llm_model:
|
||||
# Test completion with tool calling
|
||||
await test_completion_with_tools(
|
||||
provider=provider_lower,
|
||||
api_key=api_key,
|
||||
llm_model=llm_model,
|
||||
endpoint=endpoint,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Validation successful for provider: {provider_lower}")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -62,6 +73,26 @@ async def validate_provider_setup(
|
|||
raise Exception("Setup failed, please try again or select a different provider.")
|
||||
|
||||
|
||||
async def test_lightweight_health(
|
||||
provider: str,
|
||||
api_key: str = None,
|
||||
endpoint: str = None,
|
||||
project_id: str = None,
|
||||
) -> None:
|
||||
"""Test provider health with lightweight check (no credits consumed)."""
|
||||
|
||||
if provider == "openai":
|
||||
await _test_openai_lightweight_health(api_key)
|
||||
elif provider == "watsonx":
|
||||
await _test_watsonx_lightweight_health(api_key, endpoint, project_id)
|
||||
elif provider == "ollama":
|
||||
await _test_ollama_lightweight_health(endpoint)
|
||||
elif provider == "anthropic":
|
||||
await _test_anthropic_lightweight_health(api_key)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
|
||||
async def test_completion_with_tools(
|
||||
provider: str,
|
||||
api_key: str = None,
|
||||
|
|
@ -103,6 +134,40 @@ async def test_embedding(
|
|||
|
||||
|
||||
# OpenAI validation functions
|
||||
async def _test_openai_lightweight_health(api_key: str) -> None:
|
||||
"""Test OpenAI API key validity with lightweight check.
|
||||
|
||||
Only checks if the API key is valid without consuming credits.
|
||||
Uses the /v1/models endpoint which doesn't consume credits.
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Use /v1/models endpoint which validates the key without consuming credits
|
||||
response = await client.get(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers=headers,
|
||||
timeout=10.0, # Short timeout for lightweight check
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OpenAI lightweight health check failed: {response.status_code}")
|
||||
raise Exception(f"OpenAI API key validation failed: {response.status_code}")
|
||||
|
||||
logger.info("OpenAI lightweight health check passed")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("OpenAI lightweight health check timed out")
|
||||
raise Exception("OpenAI API request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI lightweight health check failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None:
|
||||
"""Test OpenAI completion with tool calling."""
|
||||
try:
|
||||
|
|
@ -213,6 +278,45 @@ async def _test_openai_embedding(api_key: str, embedding_model: str) -> None:
|
|||
|
||||
|
||||
# IBM Watson validation functions
|
||||
async def _test_watsonx_lightweight_health(
|
||||
api_key: str, endpoint: str, project_id: str
|
||||
) -> None:
|
||||
"""Test WatsonX API key validity with lightweight check.
|
||||
|
||||
Only checks if the API key is valid by getting a bearer token.
|
||||
Does not consume credits by avoiding model inference requests.
|
||||
"""
|
||||
try:
|
||||
# Get bearer token from IBM IAM - this validates the API key without consuming credits
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
"https://iam.cloud.ibm.com/identity/token",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": api_key,
|
||||
},
|
||||
timeout=10.0, # Short timeout for lightweight check
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
logger.error(f"IBM IAM token request failed: {token_response.status_code}")
|
||||
raise Exception("Failed to authenticate with IBM Watson - invalid API key")
|
||||
|
||||
bearer_token = token_response.json().get("access_token")
|
||||
if not bearer_token:
|
||||
raise Exception("No access token received from IBM")
|
||||
|
||||
logger.info("WatsonX lightweight health check passed - API key is valid")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("WatsonX lightweight health check timed out")
|
||||
raise Exception("WatsonX API request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"WatsonX lightweight health check failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _test_watsonx_completion_with_tools(
|
||||
api_key: str, llm_model: str, endpoint: str, project_id: str
|
||||
) -> None:
|
||||
|
|
@ -483,6 +587,48 @@ async def _test_ollama_embedding(embedding_model: str, endpoint: str) -> None:
|
|||
|
||||
|
||||
# Anthropic validation functions
|
||||
async def _test_anthropic_lightweight_health(api_key: str) -> None:
|
||||
"""Test Anthropic API key validity with lightweight check.
|
||||
|
||||
Only checks if the API key is valid without consuming credits.
|
||||
Uses a minimal messages request with max_tokens=1 to validate the key.
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Minimal validation request - uses cheapest model with minimal tokens
|
||||
payload = {
|
||||
"model": "claude-3-5-haiku-latest", # Cheapest model
|
||||
"max_tokens": 1, # Minimum tokens to validate key
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=10.0, # Short timeout for lightweight check
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Anthropic lightweight health check failed: {response.status_code}")
|
||||
raise Exception(f"Anthropic API key validation failed: {response.status_code}")
|
||||
|
||||
logger.info("Anthropic lightweight health check passed")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Anthropic lightweight health check timed out")
|
||||
raise Exception("Anthropic API request timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic lightweight health check failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _test_anthropic_completion_with_tools(api_key: str, llm_model: str) -> None:
|
||||
"""Test Anthropic completion with tool calling."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -897,6 +897,7 @@ async def onboarding(request, flows_service, session_manager=None):
|
|||
)
|
||||
|
||||
# Validate provider setup before initializing OpenSearch index
|
||||
# Use lightweight validation (test_completion=False) to avoid consuming credits during onboarding
|
||||
try:
|
||||
from api.provider_validation import validate_provider_setup
|
||||
|
||||
|
|
@ -905,13 +906,14 @@ async def onboarding(request, flows_service, session_manager=None):
|
|||
llm_provider = current_config.agent.llm_provider.lower()
|
||||
llm_provider_config = current_config.get_llm_provider_config()
|
||||
|
||||
logger.info(f"Validating LLM provider setup for {llm_provider}")
|
||||
logger.info(f"Validating LLM provider setup for {llm_provider} (lightweight)")
|
||||
await validate_provider_setup(
|
||||
provider=llm_provider,
|
||||
api_key=getattr(llm_provider_config, "api_key", None),
|
||||
llm_model=current_config.agent.llm_model,
|
||||
endpoint=getattr(llm_provider_config, "endpoint", None),
|
||||
project_id=getattr(llm_provider_config, "project_id", None),
|
||||
test_completion=False, # Lightweight validation - no credits consumed
|
||||
)
|
||||
logger.info(f"LLM provider setup validation completed successfully for {llm_provider}")
|
||||
|
||||
|
|
@ -920,13 +922,14 @@ async def onboarding(request, flows_service, session_manager=None):
|
|||
embedding_provider = current_config.knowledge.embedding_provider.lower()
|
||||
embedding_provider_config = current_config.get_embedding_provider_config()
|
||||
|
||||
logger.info(f"Validating embedding provider setup for {embedding_provider}")
|
||||
logger.info(f"Validating embedding provider setup for {embedding_provider} (lightweight)")
|
||||
await validate_provider_setup(
|
||||
provider=embedding_provider,
|
||||
api_key=getattr(embedding_provider_config, "api_key", None),
|
||||
embedding_model=current_config.knowledge.embedding_model,
|
||||
endpoint=getattr(embedding_provider_config, "endpoint", None),
|
||||
project_id=getattr(embedding_provider_config, "project_id", None),
|
||||
test_completion=False, # Lightweight validation - no credits consumed
|
||||
)
|
||||
logger.info(f"Embedding provider setup validation completed successfully for {embedding_provider}")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class ModelsService:
|
|||
self.session_manager = None
|
||||
|
||||
async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""Fetch available models from OpenAI API"""
|
||||
"""Fetch available models from OpenAI API with lightweight validation"""
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
|
|
@ -58,6 +58,8 @@ class ModelsService:
|
|||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Lightweight validation: just check if API key is valid
|
||||
# This doesn't consume credits, only validates the key
|
||||
response = await client.get(
|
||||
"https://api.openai.com/v1/models", headers=headers, timeout=10.0
|
||||
)
|
||||
|
|
@ -101,6 +103,7 @@ class ModelsService:
|
|||
key=lambda x: (not x.get("default", False), x["value"])
|
||||
)
|
||||
|
||||
logger.info("OpenAI API key validated successfully without consuming credits")
|
||||
return {
|
||||
"language_models": language_models,
|
||||
"embedding_models": embedding_models,
|
||||
|
|
@ -389,38 +392,12 @@ class ModelsService:
|
|||
}
|
||||
)
|
||||
|
||||
# Validate credentials with the first available LLM model
|
||||
if language_models:
|
||||
first_llm_model = language_models[0]["value"]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
validation_url = f"{watson_endpoint}/ml/v1/text/generation"
|
||||
validation_params = {"version": "2024-09-16"}
|
||||
validation_payload = {
|
||||
"input": "test",
|
||||
"model_id": first_llm_model,
|
||||
"project_id": project_id,
|
||||
"parameters": {
|
||||
"max_new_tokens": 1,
|
||||
},
|
||||
}
|
||||
|
||||
validation_response = await client.post(
|
||||
validation_url,
|
||||
headers=headers,
|
||||
params=validation_params,
|
||||
json=validation_payload,
|
||||
timeout=10.0,
|
||||
)
|
||||
|
||||
if validation_response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Invalid credentials or endpoint: {validation_response.status_code} - {validation_response.text}"
|
||||
)
|
||||
|
||||
logger.info(f"IBM Watson credentials validated successfully using model: {first_llm_model}")
|
||||
# Lightweight validation: API key is already validated by successfully getting bearer token
|
||||
# No need to make a generation request that consumes credits
|
||||
if bearer_token:
|
||||
logger.info("IBM Watson API key validated successfully without consuming credits")
|
||||
else:
|
||||
logger.warning("No language models available to validate credentials")
|
||||
logger.warning("No bearer token available - API key validation may have failed")
|
||||
|
||||
if not language_models and not embedding_models:
|
||||
raise Exception("No IBM models retrieved from API")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue