Implement dynamic Ollama embedding dimension resolution with server probing (#237)
* Initial plan * Implement dynamic Ollama embedding dimension resolution with probing Co-authored-by: phact <1313220+phact@users.noreply.github.com> * Fix Ollama probing * raise instead of dims 0 * Show better error * Run embedding probe before saving settings so that user can update --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: phact <1313220+phact@users.noreply.github.com> Co-authored-by: Lucas Oliveira <lucas.edu.oli@hotmail.com> Co-authored-by: phact <estevezsebastian@gmail.com>
This commit is contained in:
parent
aee0a20302
commit
140d24603d
4 changed files with 283 additions and 147 deletions
|
|
@ -506,7 +506,142 @@ async def onboarding(request, flows_service):
|
|||
{"error": "No valid fields provided for update"}, status_code=400
|
||||
)
|
||||
|
||||
# Initialize the OpenSearch index now that we have the embedding model configured
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from main import init_index
|
||||
|
||||
logger.info(
|
||||
"Initializing OpenSearch index after onboarding configuration"
|
||||
)
|
||||
await init_index()
|
||||
logger.info("OpenSearch index initialization completed successfully")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
logger.error(
|
||||
"Failed to initialize OpenSearch index after onboarding",
|
||||
error=str(e),
|
||||
)
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": str(e),
|
||||
"edited": True,
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
logger.error(
|
||||
"Failed to initialize OpenSearch index after onboarding",
|
||||
error=str(e),
|
||||
)
|
||||
# Don't fail the entire onboarding process if index creation fails
|
||||
# The application can still work, but document operations may fail
|
||||
|
||||
# Save the updated configuration (this will mark it as edited)
|
||||
|
||||
# If model_provider was updated, assign the new provider to flows
|
||||
if "model_provider" in body:
|
||||
provider = body["model_provider"].strip().lower()
|
||||
try:
|
||||
flow_result = await flows_service.assign_model_provider(provider)
|
||||
|
||||
if flow_result.get("success"):
|
||||
logger.info(
|
||||
f"Successfully assigned {provider} to flows",
|
||||
flow_result=flow_result,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to assign {provider} to flows",
|
||||
flow_result=flow_result,
|
||||
)
|
||||
# Continue even if flow assignment fails - configuration was still saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error assigning model provider to flows",
|
||||
provider=provider,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
# Set Langflow global variables based on provider
|
||||
try:
|
||||
# Set API key for IBM/Watson providers
|
||||
if (provider == "watsonx") and "api_key" in body:
|
||||
api_key = body["api_key"]
|
||||
await clients._create_langflow_global_variable(
|
||||
"WATSONX_API_KEY", api_key, modify=True
|
||||
)
|
||||
logger.info("Set WATSONX_API_KEY global variable in Langflow")
|
||||
|
||||
# Set project ID for IBM/Watson providers
|
||||
if (provider == "watsonx") and "project_id" in body:
|
||||
project_id = body["project_id"]
|
||||
await clients._create_langflow_global_variable(
|
||||
"WATSONX_PROJECT_ID", project_id, modify=True
|
||||
)
|
||||
logger.info(
|
||||
"Set WATSONX_PROJECT_ID global variable in Langflow"
|
||||
)
|
||||
|
||||
# Set API key for OpenAI provider
|
||||
if provider == "openai" and "api_key" in body:
|
||||
api_key = body["api_key"]
|
||||
await clients._create_langflow_global_variable(
|
||||
"OPENAI_API_KEY", api_key, modify=True
|
||||
)
|
||||
logger.info("Set OPENAI_API_KEY global variable in Langflow")
|
||||
|
||||
# Set base URL for Ollama provider
|
||||
if provider == "ollama" and "endpoint" in body:
|
||||
endpoint = transform_localhost_url(body["endpoint"])
|
||||
|
||||
await clients._create_langflow_global_variable(
|
||||
"OLLAMA_BASE_URL", endpoint, modify=True
|
||||
)
|
||||
logger.info("Set OLLAMA_BASE_URL global variable in Langflow")
|
||||
|
||||
await flows_service.change_langflow_model_value(
|
||||
provider,
|
||||
body["embedding_model"],
|
||||
body["llm_model"],
|
||||
body["endpoint"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to set Langflow global variables",
|
||||
provider=provider,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
# Handle sample data ingestion if requested
|
||||
if should_ingest_sample_data:
|
||||
try:
|
||||
# Import the function here to avoid circular imports
|
||||
from main import ingest_default_documents_when_ready
|
||||
|
||||
# Get services from the current app state
|
||||
# We need to access the app instance to get services
|
||||
app = request.scope.get("app")
|
||||
if app and hasattr(app.state, "services"):
|
||||
services = app.state.services
|
||||
logger.info(
|
||||
"Starting sample data ingestion as requested in onboarding"
|
||||
)
|
||||
await ingest_default_documents_when_ready(services)
|
||||
logger.info("Sample data ingestion completed successfully")
|
||||
else:
|
||||
logger.error(
|
||||
"Could not access services for sample data ingestion"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to complete sample data ingestion", error=str(e)
|
||||
)
|
||||
# Don't fail the entire onboarding process if sample data fails
|
||||
if config_manager.save_config_file(current_config):
|
||||
updated_fields = [
|
||||
k for k in body.keys() if k != "sample_data"
|
||||
|
|
@ -516,144 +651,19 @@ async def onboarding(request, flows_service):
|
|||
updated_fields=updated_fields,
|
||||
)
|
||||
|
||||
# If model_provider was updated, assign the new provider to flows
|
||||
if "model_provider" in body:
|
||||
provider = body["model_provider"].strip().lower()
|
||||
try:
|
||||
flow_result = await flows_service.assign_model_provider(provider)
|
||||
|
||||
if flow_result.get("success"):
|
||||
logger.info(
|
||||
f"Successfully assigned {provider} to flows",
|
||||
flow_result=flow_result,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to assign {provider} to flows",
|
||||
flow_result=flow_result,
|
||||
)
|
||||
# Continue even if flow assignment fails - configuration was still saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error assigning model provider to flows",
|
||||
provider=provider,
|
||||
error=str(e),
|
||||
)
|
||||
# Continue even if flow assignment fails - configuration was still saved
|
||||
|
||||
# Set Langflow global variables based on provider
|
||||
if "model_provider" in body:
|
||||
provider = body["model_provider"].strip().lower()
|
||||
|
||||
try:
|
||||
# Set API key for IBM/Watson providers
|
||||
if (provider == "watsonx") and "api_key" in body:
|
||||
api_key = body["api_key"]
|
||||
await clients._create_langflow_global_variable(
|
||||
"WATSONX_API_KEY", api_key, modify=True
|
||||
)
|
||||
logger.info("Set WATSONX_API_KEY global variable in Langflow")
|
||||
|
||||
# Set project ID for IBM/Watson providers
|
||||
if (provider == "watsonx") and "project_id" in body:
|
||||
project_id = body["project_id"]
|
||||
await clients._create_langflow_global_variable(
|
||||
"WATSONX_PROJECT_ID", project_id, modify=True
|
||||
)
|
||||
logger.info(
|
||||
"Set WATSONX_PROJECT_ID global variable in Langflow"
|
||||
)
|
||||
|
||||
# Set API key for OpenAI provider
|
||||
if provider == "openai" and "api_key" in body:
|
||||
api_key = body["api_key"]
|
||||
await clients._create_langflow_global_variable(
|
||||
"OPENAI_API_KEY", api_key, modify=True
|
||||
)
|
||||
logger.info("Set OPENAI_API_KEY global variable in Langflow")
|
||||
|
||||
# Set base URL for Ollama provider
|
||||
if provider == "ollama" and "endpoint" in body:
|
||||
endpoint = transform_localhost_url(body["endpoint"])
|
||||
|
||||
await clients._create_langflow_global_variable(
|
||||
"OLLAMA_BASE_URL", endpoint, modify=True
|
||||
)
|
||||
logger.info("Set OLLAMA_BASE_URL global variable in Langflow")
|
||||
|
||||
await flows_service.change_langflow_model_value(
|
||||
provider,
|
||||
body["embedding_model"],
|
||||
body["llm_model"],
|
||||
body["endpoint"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to set Langflow global variables",
|
||||
provider=provider,
|
||||
error=str(e),
|
||||
)
|
||||
# Continue even if setting global variables fails
|
||||
|
||||
# Initialize the OpenSearch index now that we have the embedding model configured
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from main import init_index
|
||||
|
||||
logger.info(
|
||||
"Initializing OpenSearch index after onboarding configuration"
|
||||
)
|
||||
await init_index()
|
||||
logger.info("OpenSearch index initialization completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to initialize OpenSearch index after onboarding",
|
||||
error=str(e),
|
||||
)
|
||||
# Don't fail the entire onboarding process if index creation fails
|
||||
# The application can still work, but document operations may fail
|
||||
|
||||
# Handle sample data ingestion if requested
|
||||
if should_ingest_sample_data:
|
||||
try:
|
||||
# Import the function here to avoid circular imports
|
||||
from main import ingest_default_documents_when_ready
|
||||
|
||||
# Get services from the current app state
|
||||
# We need to access the app instance to get services
|
||||
app = request.scope.get("app")
|
||||
if app and hasattr(app.state, "services"):
|
||||
services = app.state.services
|
||||
logger.info(
|
||||
"Starting sample data ingestion as requested in onboarding"
|
||||
)
|
||||
await ingest_default_documents_when_ready(services)
|
||||
logger.info("Sample data ingestion completed successfully")
|
||||
else:
|
||||
logger.error(
|
||||
"Could not access services for sample data ingestion"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to complete sample data ingestion", error=str(e)
|
||||
)
|
||||
# Don't fail the entire onboarding process if sample data fails
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": "Onboarding configuration updated successfully",
|
||||
"edited": True, # Confirm that config is now marked as edited
|
||||
"sample_data_ingested": should_ingest_sample_data,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
{"error": "Failed to save configuration"}, status_code=500
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": "Onboarding configuration updated successfully",
|
||||
"edited": True, # Confirm that config is now marked as edited
|
||||
"sample_data_ingested": should_ingest_sample_data,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update onboarding settings", error=str(e))
|
||||
return JSONResponse(
|
||||
|
|
|
|||
|
|
@ -81,12 +81,6 @@ OPENAI_EMBEDDING_DIMENSIONS = {
|
|||
"text-embedding-ada-002": 1536,
|
||||
}
|
||||
|
||||
OLLAMA_EMBEDDING_DIMENSIONS = {
|
||||
"nomic-embed-text": 768,
|
||||
"all-minilm": 384,
|
||||
"mxbai-embed-large": 1024,
|
||||
}
|
||||
|
||||
WATSONX_EMBEDDING_DIMENSIONS = {
|
||||
# IBM Models
|
||||
"ibm/granite-embedding-107m-multilingual": 384,
|
||||
|
|
|
|||
|
|
@ -168,7 +168,12 @@ async def init_index():
|
|||
embedding_model = config.knowledge.embedding_model
|
||||
|
||||
# Create dynamic index body based on the configured embedding model
|
||||
dynamic_index_body = create_dynamic_index_body(embedding_model)
|
||||
# Pass provider and endpoint for dynamic dimension resolution (Ollama probing)
|
||||
dynamic_index_body = await create_dynamic_index_body(
|
||||
embedding_model,
|
||||
provider=config.provider.model_provider,
|
||||
endpoint=config.provider.endpoint
|
||||
)
|
||||
|
||||
# Create documents index
|
||||
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
|
||||
|
|
|
|||
|
|
@ -1,14 +1,128 @@
|
|||
from config.settings import OLLAMA_EMBEDDING_DIMENSIONS, OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
|
||||
import httpx
|
||||
from config.settings import OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
|
||||
from utils.container_utils import transform_localhost_url
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def get_embedding_dimensions(model_name: str) -> int:
|
||||
|
||||
async def _probe_ollama_embedding_dimension(endpoint: str, model_name: str) -> int:
|
||||
"""Probe Ollama server to get embedding dimension for a model.
|
||||
|
||||
Args:
|
||||
endpoint: Ollama server endpoint (e.g., "http://localhost:11434")
|
||||
model_name: Name of the embedding model
|
||||
|
||||
Returns:
|
||||
The embedding dimension.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dimension cannot be determined.
|
||||
"""
|
||||
transformed_endpoint = transform_localhost_url(endpoint)
|
||||
url = f"{transformed_endpoint}/api/embeddings"
|
||||
test_input = "test"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
errors: list[str] = []
|
||||
|
||||
# Try modern API format first (input parameter)
|
||||
modern_payload = {
|
||||
"model": model_name,
|
||||
"input": test_input,
|
||||
"prompt": test_input,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=modern_payload, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Check for embedding in response
|
||||
if "embedding" in data:
|
||||
dimension = len(data["embedding"])
|
||||
if dimension > 0:
|
||||
logger.info(
|
||||
f"Probed Ollama model '{model_name}': dimension={dimension}"
|
||||
)
|
||||
return dimension
|
||||
elif "embeddings" in data and len(data["embeddings"]) > 0:
|
||||
dimension = len(data["embeddings"][0])
|
||||
if dimension > 0:
|
||||
logger.info(
|
||||
f"Probed Ollama model '{model_name}': dimension={dimension}"
|
||||
)
|
||||
return dimension
|
||||
|
||||
errors.append("response did not include non-zero embedding vector")
|
||||
except Exception as modern_error: # noqa: BLE001 - log and fall back to legacy payload
|
||||
logger.debug(
|
||||
"Modern Ollama embeddings API probe failed",
|
||||
model=model_name,
|
||||
endpoint=transformed_endpoint,
|
||||
error=str(modern_error),
|
||||
)
|
||||
errors.append(str(modern_error))
|
||||
|
||||
# Try legacy API format (prompt parameter)
|
||||
legacy_payload = {
|
||||
"model": model_name,
|
||||
"prompt": test_input,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post(url, json=legacy_payload, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "embedding" in data:
|
||||
dimension = len(data["embedding"])
|
||||
if dimension > 0:
|
||||
logger.info(
|
||||
f"Probed Ollama model '{model_name}' (legacy): dimension={dimension}"
|
||||
)
|
||||
return dimension
|
||||
elif "embeddings" in data and len(data["embeddings"]) > 0:
|
||||
dimension = len(data["embeddings"][0])
|
||||
if dimension > 0:
|
||||
logger.info(
|
||||
f"Probed Ollama model '{model_name}' (legacy): dimension={dimension}"
|
||||
)
|
||||
return dimension
|
||||
|
||||
errors.append("legacy response did not include non-zero embedding vector")
|
||||
except Exception as legacy_error: # noqa: BLE001 - collect and raise a helpful error later
|
||||
logger.warning(
|
||||
"Legacy Ollama embeddings API probe failed",
|
||||
model=model_name,
|
||||
endpoint=transformed_endpoint,
|
||||
error=str(legacy_error),
|
||||
)
|
||||
errors.append(str(legacy_error))
|
||||
|
||||
# remove the first instance of this error to show either it or the actual error from any of the two methods
|
||||
errors.remove("All connection attempts failed")
|
||||
|
||||
raise ValueError(
|
||||
f"Failed to determine embedding dimensions for Ollama model '{model_name}'. "
|
||||
f"Verify the Ollama server at '{endpoint}' is reachable and the model is available. "
|
||||
f"Error: {errors[0]}"
|
||||
)
|
||||
|
||||
|
||||
async def get_embedding_dimensions(model_name: str, provider: str = None, endpoint: str = None) -> int:
|
||||
"""Get the embedding dimensions for a given model name."""
|
||||
|
||||
if provider and provider.lower() == "ollama":
|
||||
if not endpoint:
|
||||
raise ValueError(
|
||||
"Ollama endpoint is required to determine embedding dimensions. Please provide a valid endpoint."
|
||||
)
|
||||
return await _probe_ollama_embedding_dimension(endpoint, model_name)
|
||||
|
||||
# Check all model dictionaries
|
||||
all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
|
||||
all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
|
||||
|
||||
model_name = model_name.lower().strip().split(":")[0]
|
||||
|
||||
|
|
@ -23,9 +137,22 @@ def get_embedding_dimensions(model_name: str) -> int:
|
|||
return VECTOR_DIM
|
||||
|
||||
|
||||
def create_dynamic_index_body(embedding_model: str) -> dict:
|
||||
"""Create a dynamic index body configuration based on the embedding model."""
|
||||
dimensions = get_embedding_dimensions(embedding_model)
|
||||
async def create_dynamic_index_body(
|
||||
embedding_model: str,
|
||||
provider: str = None,
|
||||
endpoint: str = None
|
||||
) -> dict:
|
||||
"""Create a dynamic index body configuration based on the embedding model.
|
||||
|
||||
Args:
|
||||
embedding_model: Name of the embedding model
|
||||
provider: Provider name (e.g., "ollama", "openai", "watsonx")
|
||||
endpoint: Endpoint URL for the provider (used for Ollama probing)
|
||||
|
||||
Returns:
|
||||
OpenSearch index body configuration
|
||||
"""
|
||||
dimensions = await get_embedding_dimensions(embedding_model, provider, endpoint)
|
||||
|
||||
return {
|
||||
"settings": {
|
||||
|
|
@ -63,4 +190,4 @@ def create_dynamic_index_body(embedding_model: str) -> dict:
|
|||
"metadata": {"type": "object"},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue