Implement dynamic Ollama embedding dimension resolution with server probing (#237)

* Initial plan

* Implement dynamic Ollama embedding dimension resolution with probing

Co-authored-by: phact <1313220+phact@users.noreply.github.com>

* Fix Ollama probing

* raise instead of dims 0

* Show better error

* Run embedding probe before saving settings so that user can update

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: phact <1313220+phact@users.noreply.github.com>
Co-authored-by: Lucas Oliveira <lucas.edu.oli@hotmail.com>
Co-authored-by: phact <estevezsebastian@gmail.com>
This commit is contained in:
Copilot 2025-10-09 16:40:58 -03:00 committed by GitHub
parent aee0a20302
commit 140d24603d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 283 additions and 147 deletions

View file

@ -506,7 +506,142 @@ async def onboarding(request, flows_service):
{"error": "No valid fields provided for update"}, status_code=400
)
# Initialize the OpenSearch index now that we have the embedding model configured
try:
# Import here to avoid circular imports
from main import init_index
logger.info(
"Initializing OpenSearch index after onboarding configuration"
)
await init_index()
logger.info("OpenSearch index initialization completed successfully")
except Exception as e:
if isinstance(e, ValueError):
logger.error(
"Failed to initialize OpenSearch index after onboarding",
error=str(e),
)
return JSONResponse(
{
"error": str(e),
"edited": True,
},
status_code=400,
)
logger.error(
"Failed to initialize OpenSearch index after onboarding",
error=str(e),
)
# Don't fail the entire onboarding process if index creation fails
# The application can still work, but document operations may fail
# Save the updated configuration (this will mark it as edited)
# If model_provider was updated, assign the new provider to flows
if "model_provider" in body:
provider = body["model_provider"].strip().lower()
try:
flow_result = await flows_service.assign_model_provider(provider)
if flow_result.get("success"):
logger.info(
f"Successfully assigned {provider} to flows",
flow_result=flow_result,
)
else:
logger.warning(
f"Failed to assign {provider} to flows",
flow_result=flow_result,
)
# Continue even if flow assignment fails - configuration was still saved
except Exception as e:
logger.error(
"Error assigning model provider to flows",
provider=provider,
error=str(e),
)
raise
# Set Langflow global variables based on provider
try:
# Set API key for IBM/Watson providers
if (provider == "watsonx") and "api_key" in body:
api_key = body["api_key"]
await clients._create_langflow_global_variable(
"WATSONX_API_KEY", api_key, modify=True
)
logger.info("Set WATSONX_API_KEY global variable in Langflow")
# Set project ID for IBM/Watson providers
if (provider == "watsonx") and "project_id" in body:
project_id = body["project_id"]
await clients._create_langflow_global_variable(
"WATSONX_PROJECT_ID", project_id, modify=True
)
logger.info(
"Set WATSONX_PROJECT_ID global variable in Langflow"
)
# Set API key for OpenAI provider
if provider == "openai" and "api_key" in body:
api_key = body["api_key"]
await clients._create_langflow_global_variable(
"OPENAI_API_KEY", api_key, modify=True
)
logger.info("Set OPENAI_API_KEY global variable in Langflow")
# Set base URL for Ollama provider
if provider == "ollama" and "endpoint" in body:
endpoint = transform_localhost_url(body["endpoint"])
await clients._create_langflow_global_variable(
"OLLAMA_BASE_URL", endpoint, modify=True
)
logger.info("Set OLLAMA_BASE_URL global variable in Langflow")
await flows_service.change_langflow_model_value(
provider,
body["embedding_model"],
body["llm_model"],
body["endpoint"],
)
except Exception as e:
logger.error(
"Failed to set Langflow global variables",
provider=provider,
error=str(e),
)
raise
# Handle sample data ingestion if requested
if should_ingest_sample_data:
try:
# Import the function here to avoid circular imports
from main import ingest_default_documents_when_ready
# Get services from the current app state
# We need to access the app instance to get services
app = request.scope.get("app")
if app and hasattr(app.state, "services"):
services = app.state.services
logger.info(
"Starting sample data ingestion as requested in onboarding"
)
await ingest_default_documents_when_ready(services)
logger.info("Sample data ingestion completed successfully")
else:
logger.error(
"Could not access services for sample data ingestion"
)
except Exception as e:
logger.error(
"Failed to complete sample data ingestion", error=str(e)
)
# Don't fail the entire onboarding process if sample data fails
if config_manager.save_config_file(current_config):
updated_fields = [
k for k in body.keys() if k != "sample_data"
@ -516,144 +651,19 @@ async def onboarding(request, flows_service):
updated_fields=updated_fields,
)
# If model_provider was updated, assign the new provider to flows
if "model_provider" in body:
provider = body["model_provider"].strip().lower()
try:
flow_result = await flows_service.assign_model_provider(provider)
if flow_result.get("success"):
logger.info(
f"Successfully assigned {provider} to flows",
flow_result=flow_result,
)
else:
logger.warning(
f"Failed to assign {provider} to flows",
flow_result=flow_result,
)
# Continue even if flow assignment fails - configuration was still saved
except Exception as e:
logger.error(
"Error assigning model provider to flows",
provider=provider,
error=str(e),
)
# Continue even if flow assignment fails - configuration was still saved
# Set Langflow global variables based on provider
if "model_provider" in body:
provider = body["model_provider"].strip().lower()
try:
# Set API key for IBM/Watson providers
if (provider == "watsonx") and "api_key" in body:
api_key = body["api_key"]
await clients._create_langflow_global_variable(
"WATSONX_API_KEY", api_key, modify=True
)
logger.info("Set WATSONX_API_KEY global variable in Langflow")
# Set project ID for IBM/Watson providers
if (provider == "watsonx") and "project_id" in body:
project_id = body["project_id"]
await clients._create_langflow_global_variable(
"WATSONX_PROJECT_ID", project_id, modify=True
)
logger.info(
"Set WATSONX_PROJECT_ID global variable in Langflow"
)
# Set API key for OpenAI provider
if provider == "openai" and "api_key" in body:
api_key = body["api_key"]
await clients._create_langflow_global_variable(
"OPENAI_API_KEY", api_key, modify=True
)
logger.info("Set OPENAI_API_KEY global variable in Langflow")
# Set base URL for Ollama provider
if provider == "ollama" and "endpoint" in body:
endpoint = transform_localhost_url(body["endpoint"])
await clients._create_langflow_global_variable(
"OLLAMA_BASE_URL", endpoint, modify=True
)
logger.info("Set OLLAMA_BASE_URL global variable in Langflow")
await flows_service.change_langflow_model_value(
provider,
body["embedding_model"],
body["llm_model"],
body["endpoint"],
)
except Exception as e:
logger.error(
"Failed to set Langflow global variables",
provider=provider,
error=str(e),
)
# Continue even if setting global variables fails
# Initialize the OpenSearch index now that we have the embedding model configured
try:
# Import here to avoid circular imports
from main import init_index
logger.info(
"Initializing OpenSearch index after onboarding configuration"
)
await init_index()
logger.info("OpenSearch index initialization completed successfully")
except Exception as e:
logger.error(
"Failed to initialize OpenSearch index after onboarding",
error=str(e),
)
# Don't fail the entire onboarding process if index creation fails
# The application can still work, but document operations may fail
# Handle sample data ingestion if requested
if should_ingest_sample_data:
try:
# Import the function here to avoid circular imports
from main import ingest_default_documents_when_ready
# Get services from the current app state
# We need to access the app instance to get services
app = request.scope.get("app")
if app and hasattr(app.state, "services"):
services = app.state.services
logger.info(
"Starting sample data ingestion as requested in onboarding"
)
await ingest_default_documents_when_ready(services)
logger.info("Sample data ingestion completed successfully")
else:
logger.error(
"Could not access services for sample data ingestion"
)
except Exception as e:
logger.error(
"Failed to complete sample data ingestion", error=str(e)
)
# Don't fail the entire onboarding process if sample data fails
return JSONResponse(
{
"message": "Onboarding configuration updated successfully",
"edited": True, # Confirm that config is now marked as edited
"sample_data_ingested": should_ingest_sample_data,
}
)
else:
return JSONResponse(
{"error": "Failed to save configuration"}, status_code=500
)
return JSONResponse(
{
"message": "Onboarding configuration updated successfully",
"edited": True, # Confirm that config is now marked as edited
"sample_data_ingested": should_ingest_sample_data,
}
)
except Exception as e:
logger.error("Failed to update onboarding settings", error=str(e))
return JSONResponse(

View file

@ -81,12 +81,6 @@ OPENAI_EMBEDDING_DIMENSIONS = {
"text-embedding-ada-002": 1536,
}
OLLAMA_EMBEDDING_DIMENSIONS = {
"nomic-embed-text": 768,
"all-minilm": 384,
"mxbai-embed-large": 1024,
}
WATSONX_EMBEDDING_DIMENSIONS = {
# IBM Models
"ibm/granite-embedding-107m-multilingual": 384,

View file

@ -168,7 +168,12 @@ async def init_index():
embedding_model = config.knowledge.embedding_model
# Create dynamic index body based on the configured embedding model
dynamic_index_body = create_dynamic_index_body(embedding_model)
# Pass provider and endpoint for dynamic dimension resolution (Ollama probing)
dynamic_index_body = await create_dynamic_index_body(
embedding_model,
provider=config.provider.model_provider,
endpoint=config.provider.endpoint
)
# Create documents index
if not await clients.opensearch.indices.exists(index=INDEX_NAME):

View file

@ -1,14 +1,128 @@
from config.settings import OLLAMA_EMBEDDING_DIMENSIONS, OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
import httpx
from config.settings import OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
from utils.container_utils import transform_localhost_url
from utils.logging_config import get_logger
logger = get_logger(__name__)
def get_embedding_dimensions(model_name: str) -> int:
async def _probe_ollama_embedding_dimension(endpoint: str, model_name: str) -> int:
"""Probe Ollama server to get embedding dimension for a model.
Args:
endpoint: Ollama server endpoint (e.g., "http://localhost:11434")
model_name: Name of the embedding model
Returns:
The embedding dimension.
Raises:
ValueError: If the dimension cannot be determined.
"""
transformed_endpoint = transform_localhost_url(endpoint)
url = f"{transformed_endpoint}/api/embeddings"
test_input = "test"
async with httpx.AsyncClient() as client:
errors: list[str] = []
# Try modern API format first (input parameter)
modern_payload = {
"model": model_name,
"input": test_input,
"prompt": test_input,
}
try:
response = await client.post(url, json=modern_payload, timeout=10.0)
response.raise_for_status()
data = response.json()
# Check for embedding in response
if "embedding" in data:
dimension = len(data["embedding"])
if dimension > 0:
logger.info(
f"Probed Ollama model '{model_name}': dimension={dimension}"
)
return dimension
elif "embeddings" in data and len(data["embeddings"]) > 0:
dimension = len(data["embeddings"][0])
if dimension > 0:
logger.info(
f"Probed Ollama model '{model_name}': dimension={dimension}"
)
return dimension
errors.append("response did not include non-zero embedding vector")
except Exception as modern_error: # noqa: BLE001 - log and fall back to legacy payload
logger.debug(
"Modern Ollama embeddings API probe failed",
model=model_name,
endpoint=transformed_endpoint,
error=str(modern_error),
)
errors.append(str(modern_error))
# Try legacy API format (prompt parameter)
legacy_payload = {
"model": model_name,
"prompt": test_input,
}
try:
response = await client.post(url, json=legacy_payload, timeout=10.0)
response.raise_for_status()
data = response.json()
if "embedding" in data:
dimension = len(data["embedding"])
if dimension > 0:
logger.info(
f"Probed Ollama model '{model_name}' (legacy): dimension={dimension}"
)
return dimension
elif "embeddings" in data and len(data["embeddings"]) > 0:
dimension = len(data["embeddings"][0])
if dimension > 0:
logger.info(
f"Probed Ollama model '{model_name}' (legacy): dimension={dimension}"
)
return dimension
errors.append("legacy response did not include non-zero embedding vector")
except Exception as legacy_error: # noqa: BLE001 - collect and raise a helpful error later
logger.warning(
"Legacy Ollama embeddings API probe failed",
model=model_name,
endpoint=transformed_endpoint,
error=str(legacy_error),
)
errors.append(str(legacy_error))
# remove the first instance of this error to show either it or the actual error from any of the two methods
errors.remove("All connection attempts failed")
raise ValueError(
f"Failed to determine embedding dimensions for Ollama model '{model_name}'. "
f"Verify the Ollama server at '{endpoint}' is reachable and the model is available. "
f"Error: {errors[0]}"
)
async def get_embedding_dimensions(model_name: str, provider: str = None, endpoint: str = None) -> int:
"""Get the embedding dimensions for a given model name."""
if provider and provider.lower() == "ollama":
if not endpoint:
raise ValueError(
"Ollama endpoint is required to determine embedding dimensions. Please provide a valid endpoint."
)
return await _probe_ollama_embedding_dimension(endpoint, model_name)
# Check all model dictionaries
all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
model_name = model_name.lower().strip().split(":")[0]
@ -23,9 +137,22 @@ def get_embedding_dimensions(model_name: str) -> int:
return VECTOR_DIM
def create_dynamic_index_body(embedding_model: str) -> dict:
"""Create a dynamic index body configuration based on the embedding model."""
dimensions = get_embedding_dimensions(embedding_model)
async def create_dynamic_index_body(
embedding_model: str,
provider: str = None,
endpoint: str = None
) -> dict:
"""Create a dynamic index body configuration based on the embedding model.
Args:
embedding_model: Name of the embedding model
provider: Provider name (e.g., "ollama", "openai", "watsonx")
endpoint: Endpoint URL for the provider (used for Ollama probing)
Returns:
OpenSearch index body configuration
"""
dimensions = await get_embedding_dimensions(embedding_model, provider, endpoint)
return {
"settings": {
@ -63,4 +190,4 @@ def create_dynamic_index_body(embedding_model: str) -> dict:
"metadata": {"type": "object"},
}
},
}
}