From 140d24603d64e76e6fb0086f483039dbc62b60ea Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 9 Oct 2025 16:40:58 -0300 Subject: [PATCH] Implement dynamic Ollama embedding dimension resolution with server probing (#237) * Initial plan * Implement dynamic Ollama embedding dimension resolution with probing Co-authored-by: phact <1313220+phact@users.noreply.github.com> * Fix Ollama probing * raise instead of dims 0 * Show better error * Run embedding probe before saving settings so that user can update --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: phact <1313220+phact@users.noreply.github.com> Co-authored-by: Lucas Oliveira Co-authored-by: phact --- src/api/settings.py | 276 +++++++++++++++++++++------------------- src/config/settings.py | 6 - src/main.py | 7 +- src/utils/embeddings.py | 141 +++++++++++++++++++- 4 files changed, 283 insertions(+), 147 deletions(-) diff --git a/src/api/settings.py b/src/api/settings.py index 846c0fda..f60afe69 100644 --- a/src/api/settings.py +++ b/src/api/settings.py @@ -506,7 +506,142 @@ async def onboarding(request, flows_service): {"error": "No valid fields provided for update"}, status_code=400 ) + # Initialize the OpenSearch index now that we have the embedding model configured + try: + # Import here to avoid circular imports + from main import init_index + + logger.info( + "Initializing OpenSearch index after onboarding configuration" + ) + await init_index() + logger.info("OpenSearch index initialization completed successfully") + except Exception as e: + if isinstance(e, ValueError): + logger.error( + "Failed to initialize OpenSearch index after onboarding", + error=str(e), + ) + return JSONResponse( + { + "error": str(e), + "edited": True, + }, + status_code=400, + ) + logger.error( + "Failed to initialize OpenSearch index after onboarding", + error=str(e), + ) + # Don't fail the entire onboarding process if index creation fails + # The application can still work, but document operations may fail + # Save the updated configuration (this will mark it as edited) + + # If model_provider was updated, assign the new provider to flows + if "model_provider" in body: + provider = body["model_provider"].strip().lower() + try: + flow_result = await flows_service.assign_model_provider(provider) + + if flow_result.get("success"): + logger.info( + f"Successfully assigned {provider} to flows", + flow_result=flow_result, + ) + else: + logger.warning( + f"Failed to assign {provider} to flows", + flow_result=flow_result, + ) + # Continue even if flow assignment fails - configuration was still saved + + except Exception as e: + logger.error( + "Error assigning model provider to flows", + provider=provider, + error=str(e), + ) + raise + + # Set Langflow global variables based on provider + try: + # Set API key for IBM/Watson providers + if (provider == "watsonx") and "api_key" in body: + api_key = body["api_key"] + await clients._create_langflow_global_variable( + "WATSONX_API_KEY", api_key, modify=True + ) + logger.info("Set WATSONX_API_KEY global variable in Langflow") + + # Set project ID for IBM/Watson providers + if (provider == "watsonx") and "project_id" in body: + project_id = body["project_id"] + await clients._create_langflow_global_variable( + "WATSONX_PROJECT_ID", project_id, modify=True + ) + logger.info( + "Set WATSONX_PROJECT_ID global variable in Langflow" + ) + + # Set API key for OpenAI provider + if provider == "openai" and "api_key" in body: + api_key = body["api_key"] + await clients._create_langflow_global_variable( + "OPENAI_API_KEY", api_key, modify=True + ) + logger.info("Set OPENAI_API_KEY global variable in Langflow") + + # Set base URL for Ollama provider + if provider == "ollama" and "endpoint" in body: + endpoint = transform_localhost_url(body["endpoint"]) + + await clients._create_langflow_global_variable( + "OLLAMA_BASE_URL", endpoint, modify=True + ) + logger.info("Set OLLAMA_BASE_URL global variable in Langflow") + + await flows_service.change_langflow_model_value( + provider, + body["embedding_model"], + body["llm_model"], + body["endpoint"], + ) + + except Exception as e: + logger.error( + "Failed to set Langflow global variables", + provider=provider, + error=str(e), + ) + raise + + # Handle sample data ingestion if requested + if should_ingest_sample_data: + try: + # Import the function here to avoid circular imports + from main import ingest_default_documents_when_ready + + # Get services from the current app state + # We need to access the app instance to get services + app = request.scope.get("app") + if app and hasattr(app.state, "services"): + services = app.state.services + logger.info( + "Starting sample data ingestion as requested in onboarding" + ) + await ingest_default_documents_when_ready(services) + logger.info("Sample data ingestion completed successfully") + else: + logger.error( + "Could not access services for sample data ingestion" + ) + + except Exception as e: + logger.error( + "Failed to complete sample data ingestion", error=str(e) + ) + # Don't fail the entire onboarding process if sample data fails if config_manager.save_config_file(current_config): updated_fields = [ k for k in body.keys() if k != "sample_data" @@ -516,144 +651,19 @@ async def onboarding(request, flows_service): updated_fields=updated_fields, ) - # If model_provider was updated, assign the new provider to flows - if "model_provider" in body: - provider = body["model_provider"].strip().lower() - try: - flow_result = await flows_service.assign_model_provider(provider) - - if flow_result.get("success"): - logger.info( - f"Successfully assigned {provider} to flows", - flow_result=flow_result, - ) - else: - logger.warning( - f"Failed to assign {provider} to flows", - flow_result=flow_result, - ) - # Continue even if flow assignment fails - configuration was still saved - - except Exception as e: - logger.error( - "Error assigning model provider to flows", - provider=provider, - error=str(e), - ) - # Continue even if flow assignment fails - configuration was still saved - - # Set Langflow global variables based on provider - if "model_provider" in body: - provider = body["model_provider"].strip().lower() - - try: - # Set API key for IBM/Watson providers - if (provider == "watsonx") and "api_key" in body: - api_key = body["api_key"] - await clients._create_langflow_global_variable( - "WATSONX_API_KEY", api_key, modify=True - ) - logger.info("Set WATSONX_API_KEY global variable in Langflow") - - # Set project ID for IBM/Watson providers - if (provider == "watsonx") and "project_id" in body: - project_id = body["project_id"] - await clients._create_langflow_global_variable( - "WATSONX_PROJECT_ID", project_id, modify=True - ) - logger.info( - "Set WATSONX_PROJECT_ID global variable in Langflow" - ) - - # Set API key for OpenAI provider - if provider == "openai" and "api_key" in body: - api_key = body["api_key"] - await clients._create_langflow_global_variable( - "OPENAI_API_KEY", api_key, modify=True - ) - logger.info("Set OPENAI_API_KEY global variable in Langflow") - - # Set base URL for Ollama provider - if provider == "ollama" and "endpoint" in body: - endpoint = transform_localhost_url(body["endpoint"]) - - await clients._create_langflow_global_variable( - "OLLAMA_BASE_URL", endpoint, modify=True - ) - logger.info("Set OLLAMA_BASE_URL global variable in Langflow") - - await flows_service.change_langflow_model_value( - provider, - body["embedding_model"], - body["llm_model"], - body["endpoint"], - ) - - except Exception as e: - logger.error( - "Failed to set Langflow global variables", - provider=provider, - error=str(e), - ) - # Continue even if setting global variables fails - - # Initialize the OpenSearch index now that we have the embedding model configured - try: - # Import here to avoid circular imports - from main import init_index - - logger.info( - "Initializing OpenSearch index after onboarding configuration" - ) - await init_index() - logger.info("OpenSearch index initialization completed successfully") - except Exception as e: - logger.error( - "Failed to initialize OpenSearch index after onboarding", - error=str(e), - ) - # Don't fail the entire onboarding process if index creation fails - # The application can still work, but document operations may fail - - # Handle sample data ingestion if requested - if should_ingest_sample_data: - try: - # Import the function here to avoid circular imports - from main import ingest_default_documents_when_ready - - # Get services from the current app state - # We need to access the app instance to get services - app = request.scope.get("app") - if app and hasattr(app.state, "services"): - services = app.state.services - logger.info( - "Starting sample data ingestion as requested in onboarding" - ) - await ingest_default_documents_when_ready(services) - logger.info("Sample data ingestion completed successfully") - else: - logger.error( - "Could not access services for sample data ingestion" - ) - - except Exception as e: - logger.error( - "Failed to complete sample data ingestion", error=str(e) - ) - # Don't fail the entire onboarding process if sample data fails - - return JSONResponse( - { - "message": "Onboarding configuration updated successfully", - "edited": True, # Confirm that config is now marked as edited - "sample_data_ingested": should_ingest_sample_data, - } - ) else: return JSONResponse( {"error": "Failed to save configuration"}, status_code=500 ) + return JSONResponse( + { + "message": "Onboarding configuration updated successfully", + "edited": True, # Confirm that config is now marked as edited + "sample_data_ingested": should_ingest_sample_data, + } + ) + except Exception as e: logger.error("Failed to update onboarding settings", error=str(e)) return JSONResponse( diff --git a/src/config/settings.py b/src/config/settings.py index 0672ad68..415516e8 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -81,12 +81,6 @@ OPENAI_EMBEDDING_DIMENSIONS = { "text-embedding-ada-002": 1536, } -OLLAMA_EMBEDDING_DIMENSIONS = { - "nomic-embed-text": 768, - "all-minilm": 384, - "mxbai-embed-large": 1024, -} - WATSONX_EMBEDDING_DIMENSIONS = { # IBM Models "ibm/granite-embedding-107m-multilingual": 384, diff --git a/src/main.py b/src/main.py index 1094f8b5..b16a50d5 100644 --- a/src/main.py +++ b/src/main.py @@ -168,7 +168,12 @@ async def init_index(): embedding_model = config.knowledge.embedding_model # Create dynamic index body based on the configured embedding model - dynamic_index_body = create_dynamic_index_body(embedding_model) + # Pass provider and endpoint for dynamic dimension resolution (Ollama probing) + dynamic_index_body = await create_dynamic_index_body( + embedding_model, + provider=config.provider.model_provider, + endpoint=config.provider.endpoint + ) # Create documents index if not await clients.opensearch.indices.exists(index=INDEX_NAME): diff --git a/src/utils/embeddings.py b/src/utils/embeddings.py index b0ec035f..46c53509 100644 --- a/src/utils/embeddings.py +++ b/src/utils/embeddings.py @@ -1,14 +1,128 @@ -from config.settings import OLLAMA_EMBEDDING_DIMENSIONS, OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS +import httpx +from config.settings import OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS +from utils.container_utils import transform_localhost_url from utils.logging_config import get_logger logger = get_logger(__name__) -def get_embedding_dimensions(model_name: str) -> int: + +async def _probe_ollama_embedding_dimension(endpoint: str, model_name: str) -> int: + """Probe Ollama server to get embedding dimension for a model. + + Args: + endpoint: Ollama server endpoint (e.g., "http://localhost:11434") + model_name: Name of the embedding model + + Returns: + The embedding dimension. + + Raises: + ValueError: If the dimension cannot be determined. + """ + transformed_endpoint = transform_localhost_url(endpoint) + url = f"{transformed_endpoint}/api/embeddings" + test_input = "test" + + async with httpx.AsyncClient() as client: + errors: list[str] = [] + + # Try modern API format first (input parameter) + modern_payload = { + "model": model_name, + "input": test_input, + "prompt": test_input, + } + + try: + response = await client.post(url, json=modern_payload, timeout=10.0) + response.raise_for_status() + data = response.json() + + # Check for embedding in response + if "embedding" in data: + dimension = len(data["embedding"]) + if dimension > 0: + logger.info( + f"Probed Ollama model '{model_name}': dimension={dimension}" + ) + return dimension + elif "embeddings" in data and len(data["embeddings"]) > 0: + dimension = len(data["embeddings"][0]) + if dimension > 0: + logger.info( + f"Probed Ollama model '{model_name}': dimension={dimension}" + ) + return dimension + + errors.append("response did not include non-zero embedding vector") + except Exception as modern_error: # noqa: BLE001 - log and fall back to legacy payload + logger.debug( + "Modern Ollama embeddings API probe failed", + model=model_name, + endpoint=transformed_endpoint, + error=str(modern_error), + ) + errors.append(str(modern_error)) + + # Try legacy API format (prompt parameter) + legacy_payload = { + "model": model_name, + "prompt": test_input, + } + + try: + response = await client.post(url, json=legacy_payload, timeout=10.0) + response.raise_for_status() + data = response.json() + + if "embedding" in data: + dimension = len(data["embedding"]) + if dimension > 0: + logger.info( + f"Probed Ollama model '{model_name}' (legacy): dimension={dimension}" + ) + return dimension + elif "embeddings" in data and len(data["embeddings"]) > 0: + dimension = len(data["embeddings"][0]) + if dimension > 0: + logger.info( + f"Probed Ollama model '{model_name}' (legacy): dimension={dimension}" + ) + return dimension + + errors.append("legacy response did not include non-zero embedding vector") + except Exception as legacy_error: # noqa: BLE001 - collect and raise a helpful error later + logger.warning( + "Legacy Ollama embeddings API probe failed", + model=model_name, + endpoint=transformed_endpoint, + error=str(legacy_error), + ) + errors.append(str(legacy_error)) + + # remove the first instance of this error to show either it or the actual error from any of the two methods + errors.remove("All connection attempts failed") + + raise ValueError( + f"Failed to determine embedding dimensions for Ollama model '{model_name}'. " + f"Verify the Ollama server at '{endpoint}' is reachable and the model is available. " + f"Error: {errors[0]}" + ) + + +async def get_embedding_dimensions(model_name: str, provider: str = None, endpoint: str = None) -> int: """Get the embedding dimensions for a given model name.""" + if provider and provider.lower() == "ollama": + if not endpoint: + raise ValueError( + "Ollama endpoint is required to determine embedding dimensions. Please provide a valid endpoint." + ) + return await _probe_ollama_embedding_dimension(endpoint, model_name) + # Check all model dictionaries - all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS} + all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS} model_name = model_name.lower().strip().split(":")[0] @@ -23,9 +137,22 @@ def get_embedding_dimensions(model_name: str) -> int: return VECTOR_DIM -def create_dynamic_index_body(embedding_model: str) -> dict: - """Create a dynamic index body configuration based on the embedding model.""" - dimensions = get_embedding_dimensions(embedding_model) +async def create_dynamic_index_body( + embedding_model: str, + provider: str = None, + endpoint: str = None +) -> dict: + """Create a dynamic index body configuration based on the embedding model. + + Args: + embedding_model: Name of the embedding model + provider: Provider name (e.g., "ollama", "openai", "watsonx") + endpoint: Endpoint URL for the provider (used for Ollama probing) + + Returns: + OpenSearch index body configuration + """ + dimensions = await get_embedding_dimensions(embedding_model, provider, endpoint) return { "settings": { @@ -63,4 +190,4 @@ def create_dynamic_index_body(embedding_model: str) -> dict: "metadata": {"type": "object"}, } }, - } \ No newline at end of file + }