Implement dynamic Ollama embedding dimension resolution with server probing (#237)

* Initial plan

* Implement dynamic Ollama embedding dimension resolution with probing

Co-authored-by: phact <1313220+phact@users.noreply.github.com>

* Fix Ollama probing

* raise instead of dims 0

* Show better error

* Run embedding probe before saving settings so that user can update

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: phact <1313220+phact@users.noreply.github.com>
Co-authored-by: Lucas Oliveira <lucas.edu.oli@hotmail.com>
Co-authored-by: phact <estevezsebastian@gmail.com>
This commit is contained in:
Copilot 2025-10-09 16:40:58 -03:00 committed by GitHub
parent aee0a20302
commit 140d24603d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 283 additions and 147 deletions

View file

@ -506,15 +506,37 @@ async def onboarding(request, flows_service):
{"error": "No valid fields provided for update"}, status_code=400 {"error": "No valid fields provided for update"}, status_code=400
) )
# Save the updated configuration (this will mark it as edited) # Initialize the OpenSearch index now that we have the embedding model configured
if config_manager.save_config_file(current_config): try:
updated_fields = [ # Import here to avoid circular imports
k for k in body.keys() if k != "sample_data" from main import init_index
] # Exclude sample_data from log
logger.info( logger.info(
"Onboarding configuration updated successfully", "Initializing OpenSearch index after onboarding configuration"
updated_fields=updated_fields,
) )
await init_index()
logger.info("OpenSearch index initialization completed successfully")
except Exception as e:
if isinstance(e, ValueError):
logger.error(
"Failed to initialize OpenSearch index after onboarding",
error=str(e),
)
return JSONResponse(
{
"error": str(e),
"edited": True,
},
status_code=400,
)
logger.error(
"Failed to initialize OpenSearch index after onboarding",
error=str(e),
)
# Don't fail the entire onboarding process if index creation fails
# The application can still work, but document operations may fail
# Save the updated configuration (this will mark it as edited)
# If model_provider was updated, assign the new provider to flows # If model_provider was updated, assign the new provider to flows
if "model_provider" in body: if "model_provider" in body:
@ -540,12 +562,9 @@ async def onboarding(request, flows_service):
provider=provider, provider=provider,
error=str(e), error=str(e),
) )
# Continue even if flow assignment fails - configuration was still saved raise
# Set Langflow global variables based on provider # Set Langflow global variables based on provider
if "model_provider" in body:
provider = body["model_provider"].strip().lower()
try: try:
# Set API key for IBM/Watson providers # Set API key for IBM/Watson providers
if (provider == "watsonx") and "api_key" in body: if (provider == "watsonx") and "api_key" in body:
@ -595,25 +614,7 @@ async def onboarding(request, flows_service):
provider=provider, provider=provider,
error=str(e), error=str(e),
) )
# Continue even if setting global variables fails raise
# Initialize the OpenSearch index now that we have the embedding model configured
try:
# Import here to avoid circular imports
from main import init_index
logger.info(
"Initializing OpenSearch index after onboarding configuration"
)
await init_index()
logger.info("OpenSearch index initialization completed successfully")
except Exception as e:
logger.error(
"Failed to initialize OpenSearch index after onboarding",
error=str(e),
)
# Don't fail the entire onboarding process if index creation fails
# The application can still work, but document operations may fail
# Handle sample data ingestion if requested # Handle sample data ingestion if requested
if should_ingest_sample_data: if should_ingest_sample_data:
@ -641,6 +642,19 @@ async def onboarding(request, flows_service):
"Failed to complete sample data ingestion", error=str(e) "Failed to complete sample data ingestion", error=str(e)
) )
# Don't fail the entire onboarding process if sample data fails # Don't fail the entire onboarding process if sample data fails
if config_manager.save_config_file(current_config):
updated_fields = [
k for k in body.keys() if k != "sample_data"
] # Exclude sample_data from log
logger.info(
"Onboarding configuration updated successfully",
updated_fields=updated_fields,
)
else:
return JSONResponse(
{"error": "Failed to save configuration"}, status_code=500
)
return JSONResponse( return JSONResponse(
{ {
@ -649,10 +663,6 @@ async def onboarding(request, flows_service):
"sample_data_ingested": should_ingest_sample_data, "sample_data_ingested": should_ingest_sample_data,
} }
) )
else:
return JSONResponse(
{"error": "Failed to save configuration"}, status_code=500
)
except Exception as e: except Exception as e:
logger.error("Failed to update onboarding settings", error=str(e)) logger.error("Failed to update onboarding settings", error=str(e))

View file

@ -81,12 +81,6 @@ OPENAI_EMBEDDING_DIMENSIONS = {
"text-embedding-ada-002": 1536, "text-embedding-ada-002": 1536,
} }
OLLAMA_EMBEDDING_DIMENSIONS = {
"nomic-embed-text": 768,
"all-minilm": 384,
"mxbai-embed-large": 1024,
}
WATSONX_EMBEDDING_DIMENSIONS = { WATSONX_EMBEDDING_DIMENSIONS = {
# IBM Models # IBM Models
"ibm/granite-embedding-107m-multilingual": 384, "ibm/granite-embedding-107m-multilingual": 384,

View file

@ -168,7 +168,12 @@ async def init_index():
embedding_model = config.knowledge.embedding_model embedding_model = config.knowledge.embedding_model
# Create dynamic index body based on the configured embedding model # Create dynamic index body based on the configured embedding model
dynamic_index_body = create_dynamic_index_body(embedding_model) # Pass provider and endpoint for dynamic dimension resolution (Ollama probing)
dynamic_index_body = await create_dynamic_index_body(
embedding_model,
provider=config.provider.model_provider,
endpoint=config.provider.endpoint
)
# Create documents index # Create documents index
if not await clients.opensearch.indices.exists(index=INDEX_NAME): if not await clients.opensearch.indices.exists(index=INDEX_NAME):

View file

@ -1,14 +1,128 @@
from config.settings import OLLAMA_EMBEDDING_DIMENSIONS, OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS import httpx
from config.settings import OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
from utils.container_utils import transform_localhost_url
from utils.logging_config import get_logger from utils.logging_config import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def get_embedding_dimensions(model_name: str) -> int:
async def _probe_ollama_embedding_dimension(endpoint: str, model_name: str) -> int:
"""Probe Ollama server to get embedding dimension for a model.
Args:
endpoint: Ollama server endpoint (e.g., "http://localhost:11434")
model_name: Name of the embedding model
Returns:
The embedding dimension.
Raises:
ValueError: If the dimension cannot be determined.
"""
transformed_endpoint = transform_localhost_url(endpoint)
url = f"{transformed_endpoint}/api/embeddings"
test_input = "test"
async with httpx.AsyncClient() as client:
errors: list[str] = []
# Try modern API format first (input parameter)
modern_payload = {
"model": model_name,
"input": test_input,
"prompt": test_input,
}
try:
response = await client.post(url, json=modern_payload, timeout=10.0)
response.raise_for_status()
data = response.json()
# Check for embedding in response
if "embedding" in data:
dimension = len(data["embedding"])
if dimension > 0:
logger.info(
f"Probed Ollama model '{model_name}': dimension={dimension}"
)
return dimension
elif "embeddings" in data and len(data["embeddings"]) > 0:
dimension = len(data["embeddings"][0])
if dimension > 0:
logger.info(
f"Probed Ollama model '{model_name}': dimension={dimension}"
)
return dimension
errors.append("response did not include non-zero embedding vector")
except Exception as modern_error: # noqa: BLE001 - log and fall back to legacy payload
logger.debug(
"Modern Ollama embeddings API probe failed",
model=model_name,
endpoint=transformed_endpoint,
error=str(modern_error),
)
errors.append(str(modern_error))
# Try legacy API format (prompt parameter)
legacy_payload = {
"model": model_name,
"prompt": test_input,
}
try:
response = await client.post(url, json=legacy_payload, timeout=10.0)
response.raise_for_status()
data = response.json()
if "embedding" in data:
dimension = len(data["embedding"])
if dimension > 0:
logger.info(
f"Probed Ollama model '{model_name}' (legacy): dimension={dimension}"
)
return dimension
elif "embeddings" in data and len(data["embeddings"]) > 0:
dimension = len(data["embeddings"][0])
if dimension > 0:
logger.info(
f"Probed Ollama model '{model_name}' (legacy): dimension={dimension}"
)
return dimension
errors.append("legacy response did not include non-zero embedding vector")
except Exception as legacy_error: # noqa: BLE001 - collect and raise a helpful error later
logger.warning(
"Legacy Ollama embeddings API probe failed",
model=model_name,
endpoint=transformed_endpoint,
error=str(legacy_error),
)
errors.append(str(legacy_error))
# remove the first instance of this error to show either it or the actual error from any of the two methods
errors.remove("All connection attempts failed")
raise ValueError(
f"Failed to determine embedding dimensions for Ollama model '{model_name}'. "
f"Verify the Ollama server at '{endpoint}' is reachable and the model is available. "
f"Error: {errors[0]}"
)
async def get_embedding_dimensions(model_name: str, provider: str = None, endpoint: str = None) -> int:
"""Get the embedding dimensions for a given model name.""" """Get the embedding dimensions for a given model name."""
if provider and provider.lower() == "ollama":
if not endpoint:
raise ValueError(
"Ollama endpoint is required to determine embedding dimensions. Please provide a valid endpoint."
)
return await _probe_ollama_embedding_dimension(endpoint, model_name)
# Check all model dictionaries # Check all model dictionaries
all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS} all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
model_name = model_name.lower().strip().split(":")[0] model_name = model_name.lower().strip().split(":")[0]
@ -23,9 +137,22 @@ def get_embedding_dimensions(model_name: str) -> int:
return VECTOR_DIM return VECTOR_DIM
def create_dynamic_index_body(embedding_model: str) -> dict: async def create_dynamic_index_body(
"""Create a dynamic index body configuration based on the embedding model.""" embedding_model: str,
dimensions = get_embedding_dimensions(embedding_model) provider: str = None,
endpoint: str = None
) -> dict:
"""Create a dynamic index body configuration based on the embedding model.
Args:
embedding_model: Name of the embedding model
provider: Provider name (e.g., "ollama", "openai", "watsonx")
endpoint: Endpoint URL for the provider (used for Ollama probing)
Returns:
OpenSearch index body configuration
"""
dimensions = await get_embedding_dimensions(embedding_model, provider, endpoint)
return { return {
"settings": { "settings": {