fix: update onboarding design, make opensearch index be initialized after onboarding, make flow reset change the models to the provider chosen (#100)

* changed tooltip stype

* added start on label wrapper

* changed switch to checkbox on openai onboarding and changed copies

* made border be red when api key is invalid

* Added embedding configuration after onboarding

* changed openrag ingest docling to have same embedding model component as other flows

* changed flows service to get flow by id, not by path

* modify reset_langflow to also put right embedding model

* added endpoint and project id to provider config

* added replacing the model with the provider model when resetting

* Moved consts to settings.py

* raise when flow_id is not found
This commit is contained in:
Lucas Oliveira 2025-09-26 12:04:17 -03:00 committed by GitHub
parent 1ce9f2923e
commit e0015f35db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 317 additions and 97 deletions

View file

@ -95,7 +95,7 @@
"data": {
"sourceHandle": {
"dataType": "EmbeddingModel",
"id": "EmbeddingModel-cxG9r",
"id": "EmbeddingModel-eZ6bT",
"name": "embeddings",
"output_types": [
"Embeddings"
@ -110,10 +110,10 @@
"type": "other"
}
},
"id": "xy-edge__EmbeddingModel-cxG9r{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-cxG9rœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}-OpenSearchHybrid-XtKoA{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}",
"id": "xy-edge__EmbeddingModel-eZ6bT{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-eZ6bTœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}-OpenSearchHybrid-XtKoA{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}",
"selected": false,
"source": "EmbeddingModel-cxG9r",
"sourceHandle": "{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-cxG9rœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}",
"source": "EmbeddingModel-eZ6bT",
"sourceHandle": "{œdataTypeœ:œEmbeddingModelœ,œidœ:œEmbeddingModel-eZ6bTœ,œnameœ:œembeddingsœ,œoutput_typesœ:[œEmbeddingsœ]}",
"target": "OpenSearchHybrid-XtKoA",
"targetHandle": "{œfieldNameœ:œembeddingœ,œidœ:œOpenSearchHybrid-XtKoAœ,œinputTypesœ:[œEmbeddingsœ],œtypeœ:œotherœ}"
}
@ -1631,7 +1631,7 @@
},
{
"data": {
"id": "EmbeddingModel-cxG9r",
"id": "EmbeddingModel-eZ6bT",
"node": {
"base_classes": [
"Embeddings"
@ -1657,7 +1657,7 @@
],
"frozen": false,
"icon": "binary",
"last_updated": "2025-09-24T16:02:07.998Z",
"last_updated": "2025-09-22T15:54:52.885Z",
"legacy": false,
"metadata": {
"code_hash": "93faf11517da",
@ -1738,7 +1738,7 @@
"show": true,
"title_case": false,
"type": "str",
"value": ""
"value": "OPENAI_API_KEY"
},
"chunk_size": {
"_input_type": "IntInput",
@ -1926,16 +1926,16 @@
"type": "EmbeddingModel"
},
"dragging": false,
"id": "EmbeddingModel-cxG9r",
"id": "EmbeddingModel-eZ6bT",
"measured": {
"height": 366,
"height": 369,
"width": 320
},
"position": {
"x": 1743.8608432729177,
"y": 1808.780792406514
"x": 1726.6943524438122,
"y": 1800.5330404375484
},
"selected": false,
"selected": true,
"type": "genericNode"
}
],

View file

@ -10,18 +10,25 @@ export function LabelWrapper({
id,
required,
flex,
start,
children,
}: {
label: string;
description?: string;
helperText?: string;
helperText?: string | React.ReactNode;
id: string;
required?: boolean;
flex?: boolean;
start?: boolean;
children: React.ReactNode;
}) {
return (
<div className="flex w-full items-center justify-between">
<div
className={cn(
"flex w-full items-center",
start ? "justify-start flex-row-reverse gap-3" : "justify-between",
)}
>
<div
className={cn(
"flex flex-1 flex-col items-start",
@ -30,7 +37,7 @@ export function LabelWrapper({
>
<Label
htmlFor={id}
className="!text-mmd font-medium flex items-center gap-1"
className="!text-mmd font-medium flex items-center gap-1.5"
>
{label}
{required && <span className="text-red-500">*</span>}
@ -39,7 +46,7 @@ export function LabelWrapper({
<TooltipTrigger>
<Info className="w-3.5 h-3.5 text-muted-foreground" />
</TooltipTrigger>
<TooltipContent>{helperText}</TooltipContent>
<TooltipContent side="right">{helperText}</TooltipContent>
</Tooltip>
)}
</Label>
@ -48,7 +55,7 @@ export function LabelWrapper({
<p className="text-mmd text-muted-foreground">{description}</p>
)}
</div>
{flex && <div className="relative">{children}</div>}
{flex && <div className="relative items-center flex">{children}</div>}
</div>
);
}

View file

@ -19,7 +19,7 @@ const TooltipContent = React.forwardRef<
ref={ref}
sideOffset={sideOffset}
className={cn(
"z-50 overflow-hidden rounded-md border bg-popover px-3 py-1.5 text-sm text-popover-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]",
"z-50 overflow-hidden rounded-md border bg-primary py-1 px-1.5 text-xs font-normal text-primary-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]",
className,
)}
{...props}

View file

@ -2,7 +2,7 @@ import { useState } from "react";
import { LabelInput } from "@/components/label-input";
import { LabelWrapper } from "@/components/label-wrapper";
import OpenAILogo from "@/components/logo/openai-logo";
import { Switch } from "@/components/ui/switch";
import { Checkbox } from "@/components/ui/checkbox";
import { useDebouncedValue } from "@/lib/debounce";
import type { OnboardingVariables } from "../../api/mutations/useOnboardingMutation";
import { useGetOpenAIModelsQuery } from "../../api/queries/useGetModelsQuery";
@ -72,11 +72,19 @@ export function OpenAIOnboarding({
<>
<div className="space-y-5">
<LabelWrapper
label="Get API key from environment variable"
label="Use environment OpenAI API key"
id="get-api-key"
helperText={
<>
Reuse the key from your environment config.
<br />
Uncheck to enter a different key.
</>
}
flex
start
>
<Switch
<Checkbox
checked={getFromEnv}
onCheckedChange={handleGetFromEnvChange}
/>
@ -86,6 +94,7 @@ export function OpenAIOnboarding({
<LabelInput
label="OpenAI API key"
helperText="The API key for your OpenAI account."
className={modelsError ? "!border-destructive" : ""}
id="api-key"
type="password"
required
@ -99,7 +108,7 @@ export function OpenAIOnboarding({
</p>
)}
{modelsError && (
<p className="text-mmd text-accent-amber-foreground">
<p className="text-mmd text-destructive">
Invalid OpenAI API key. Verify or replace the key.
</p>
)}

View file

@ -556,6 +556,19 @@ async def onboarding(request, flows_service):
)
# Continue even if setting global variables fails
# Initialize the OpenSearch index now that we have the embedding model configured
try:
# Import here to avoid circular imports
from main import init_index
logger.info("Initializing OpenSearch index after onboarding configuration")
await init_index()
logger.info("OpenSearch index initialization completed successfully")
except Exception as e:
logger.error("Failed to initialize OpenSearch index after onboarding", error=str(e))
# Don't fail the entire onboarding process if index creation fails
# The application can still work, but document operations may fail
# Handle sample data ingestion if requested
if should_ingest_sample_data:
try:

View file

@ -16,6 +16,8 @@ class ProviderConfig:
model_provider: str = "openai" # openai, anthropic, etc.
api_key: str = ""
endpoint: str = "" # For providers like Watson/IBM that need custom endpoints
project_id: str = "" # For providers like Watson/IBM that need project IDs
@dataclass
@ -129,6 +131,10 @@ class ConfigManager:
config_data["provider"]["model_provider"] = os.getenv("MODEL_PROVIDER")
if os.getenv("PROVIDER_API_KEY"):
config_data["provider"]["api_key"] = os.getenv("PROVIDER_API_KEY")
if os.getenv("PROVIDER_ENDPOINT"):
config_data["provider"]["endpoint"] = os.getenv("PROVIDER_ENDPOINT")
if os.getenv("PROVIDER_PROJECT_ID"):
config_data["provider"]["project_id"] = os.getenv("PROVIDER_PROJECT_ID")
# Backward compatibility for OpenAI
if os.getenv("OPENAI_API_KEY"):
config_data["provider"]["api_key"] = os.getenv("OPENAI_API_KEY")

View file

@ -78,6 +78,31 @@ INDEX_NAME = "documents"
VECTOR_DIM = 1536
EMBED_MODEL = "text-embedding-3-small"
OPENAI_EMBEDDING_DIMENSIONS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}
OLLAMA_EMBEDDING_DIMENSIONS = {
"nomic-embed-text": 768,
"all-minilm": 384,
"mxbai-embed-large": 1024,
}
WATSONX_EMBEDDING_DIMENSIONS = {
# IBM Models
"ibm/granite-embedding-107m-multilingual": 384,
"ibm/granite-embedding-278m-multilingual": 1024,
"ibm/slate-125m-english-rtrvr": 768,
"ibm/slate-125m-english-rtrvr-v2": 768,
"ibm/slate-30m-english-rtrvr": 384,
"ibm/slate-30m-english-rtrvr-v2": 384,
# Third Party Models
"intfloat/multilingual-e5-large": 1024,
"sentence-transformers/all-minilm-l6-v2": 384,
}
INDEX_BODY = {
"settings": {
"index": {"knn": True},

View file

@ -2,6 +2,7 @@
from connectors.langflow_connector_service import LangflowConnectorService
from connectors.service import ConnectorService
from services.flows_service import FlowsService
from utils.embeddings import create_dynamic_index_body
from utils.logging_config import configure_from_env, get_logger
configure_from_env()
@ -52,11 +53,11 @@ from auth_middleware import optional_auth, require_auth
from config.settings import (
DISABLE_INGEST_WITH_LANGFLOW,
EMBED_MODEL,
INDEX_BODY,
INDEX_NAME,
SESSION_SECRET,
clients,
is_no_auth_mode,
get_openrag_config,
)
from services.auth_service import AuthService
from services.langflow_mcp_service import LangflowMCPService
@ -81,7 +82,6 @@ logger.info(
cuda_version=torch.version.cuda,
)
async def wait_for_opensearch():
"""Wait for OpenSearch to be ready with retries"""
max_retries = 30
@ -132,12 +132,19 @@ async def init_index():
"""Initialize OpenSearch index and security roles"""
await wait_for_opensearch()
# Get the configured embedding model from user configuration
config = get_openrag_config()
embedding_model = config.knowledge.embedding_model
# Create dynamic index body based on the configured embedding model
dynamic_index_body = create_dynamic_index_body(embedding_model)
# Create documents index
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
logger.info("Created OpenSearch index", index_name=INDEX_NAME)
await clients.opensearch.indices.create(index=INDEX_NAME, body=dynamic_index_body)
logger.info("Created OpenSearch index", index_name=INDEX_NAME, embedding_model=embedding_model)
else:
logger.info("Index already exists, skipping creation", index_name=INDEX_NAME)
logger.info("Index already exists, skipping creation", index_name=INDEX_NAME, embedding_model=embedding_model)
# Create knowledge filters index
knowledge_filter_index_name = "knowledge_filters"
@ -391,7 +398,12 @@ async def _ingest_default_documents_openrag(services, file_paths):
async def startup_tasks(services):
"""Startup tasks"""
logger.info("Starting startup tasks")
await init_index()
# Only initialize basic OpenSearch connection, not the index
# Index will be created after onboarding when we know the embedding model
await wait_for_opensearch()
# Configure alerting security
await configure_alerting_security()
async def initialize_services():

View file

@ -1,3 +1,4 @@
import asyncio
from config.settings import (
NUDGES_FLOW_ID,
LANGFLOW_URL,
@ -19,6 +20,7 @@ from config.settings import (
WATSONX_LLM_COMPONENT_ID,
OLLAMA_EMBEDDING_COMPONENT_ID,
OLLAMA_LLM_COMPONENT_ID,
get_openrag_config,
)
import json
import os
@ -29,6 +31,74 @@ logger = get_logger(__name__)
class FlowsService:
def __init__(self):
# Cache for flow file mappings to avoid repeated filesystem scans
self._flow_file_cache = {}
def _get_flows_directory(self):
"""Get the flows directory path"""
current_file_dir = os.path.dirname(os.path.abspath(__file__)) # src/services/
src_dir = os.path.dirname(current_file_dir) # src/
project_root = os.path.dirname(src_dir) # project root
return os.path.join(project_root, "flows")
def _find_flow_file_by_id(self, flow_id: str):
"""
Scan the flows directory and find the JSON file that contains the specified flow ID.
Args:
flow_id: The flow ID to search for
Returns:
str: The path to the flow file, or None if not found
"""
if not flow_id:
raise ValueError("flow_id is required")
# Check cache first
if flow_id in self._flow_file_cache:
cached_path = self._flow_file_cache[flow_id]
if os.path.exists(cached_path):
return cached_path
else:
# Remove stale cache entry
del self._flow_file_cache[flow_id]
flows_dir = self._get_flows_directory()
if not os.path.exists(flows_dir):
logger.warning(f"Flows directory not found: {flows_dir}")
return None
# Scan all JSON files in the flows directory
try:
for filename in os.listdir(flows_dir):
if not filename.endswith('.json'):
continue
file_path = os.path.join(flows_dir, filename)
try:
with open(file_path, 'r') as f:
flow_data = json.load(f)
# Check if this file contains the flow we're looking for
if flow_data.get('id') == flow_id:
# Cache the result
self._flow_file_cache[flow_id] = file_path
logger.info(f"Found flow {flow_id} in file: {filename}")
return file_path
except (json.JSONDecodeError, FileNotFoundError) as e:
logger.warning(f"Error reading flow file {filename}: {e}")
continue
except Exception as e:
logger.error(f"Error scanning flows directory: {e}")
return None
logger.warning(f"Flow with ID {flow_id} not found in flows directory")
return None
async def reset_langflow_flow(self, flow_type: str):
"""Reset a Langflow flow by uploading the corresponding JSON file
@ -41,59 +111,35 @@ class FlowsService:
if not LANGFLOW_URL:
raise ValueError("LANGFLOW_URL environment variable is required")
# Determine flow file and ID based on type
# Determine flow ID based on type
if flow_type == "nudges":
flow_file = "flows/openrag_nudges.json"
flow_id = NUDGES_FLOW_ID
elif flow_type == "retrieval":
flow_file = "flows/openrag_agent.json"
flow_id = LANGFLOW_CHAT_FLOW_ID
elif flow_type == "ingest":
flow_file = "flows/ingestion_flow.json"
flow_id = LANGFLOW_INGEST_FLOW_ID
else:
raise ValueError(
"flow_type must be either 'nudges', 'retrieval', or 'ingest'"
)
if not flow_id:
raise ValueError(f"Flow ID not configured for flow_type '{flow_type}'")
# Dynamically find the flow file by ID
flow_path = self._find_flow_file_by_id(flow_id)
if not flow_path:
raise FileNotFoundError(f"Flow file not found for flow ID: {flow_id}")
# Load flow JSON file
try:
# Get the project root directory (go up from src/services/ to project root)
# __file__ is src/services/chat_service.py
# os.path.dirname(__file__) is src/services/
# os.path.dirname(os.path.dirname(__file__)) is src/
# os.path.dirname(os.path.dirname(os.path.dirname(__file__))) is project root
current_file_dir = os.path.dirname(
os.path.abspath(__file__)
) # src/services/
src_dir = os.path.dirname(current_file_dir) # src/
project_root = os.path.dirname(src_dir) # project root
flow_path = os.path.join(project_root, flow_file)
if not os.path.exists(flow_path):
# List contents of project root to help debug
try:
contents = os.listdir(project_root)
logger.info(f"Project root contents: {contents}")
flows_dir = os.path.join(project_root, "flows")
if os.path.exists(flows_dir):
flows_contents = os.listdir(flows_dir)
logger.info(f"Flows directory contents: {flows_contents}")
else:
logger.info("Flows directory does not exist")
except Exception as e:
logger.error(f"Error listing directory contents: {e}")
raise FileNotFoundError(f"Flow file not found at: {flow_path}")
with open(flow_path, "r") as f:
flow_data = json.load(f)
logger.info(f"Successfully loaded flow data from {flow_file}")
logger.info(f"Successfully loaded flow data for {flow_type} from {os.path.basename(flow_path)}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in flow file {flow_path}: {e}")
except FileNotFoundError:
raise ValueError(f"Flow file not found: {flow_path}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in flow file {flow_file}: {e}")
# Make PATCH request to Langflow API to update the flow using shared client
try:
@ -106,8 +152,54 @@ class FlowsService:
logger.info(
f"Successfully reset {flow_type} flow",
flow_id=flow_id,
flow_file=flow_file,
flow_file=os.path.basename(flow_path),
)
# Now update the flow with current configuration settings
try:
config = get_openrag_config()
# Check if configuration has been edited (onboarding completed)
if config.edited:
logger.info(f"Updating {flow_type} flow with current configuration settings")
provider = config.provider.model_provider.lower()
# Step 1: Assign model provider (replace components) if not OpenAI
if provider != "openai":
logger.info(f"Assigning {provider} components to {flow_type} flow")
provider_result = await self.assign_model_provider(provider)
if not provider_result.get("success"):
logger.warning(f"Failed to assign {provider} components: {provider_result.get('error', 'Unknown error')}")
# Continue anyway, maybe just value updates will work
# Step 2: Update model values for the specific flow being reset
single_flow_config = [{
"name": flow_type,
"flow_id": flow_id,
}]
logger.info(f"Updating {flow_type} flow model values")
update_result = await self.change_langflow_model_value(
provider=provider,
embedding_model=config.knowledge.embedding_model,
llm_model=config.agent.llm_model,
endpoint=config.provider.endpoint if config.provider.endpoint else None,
flow_configs=single_flow_config
)
if update_result.get("success"):
logger.info(f"Successfully updated {flow_type} flow with current configuration")
else:
logger.warning(f"Failed to update {flow_type} flow with current configuration: {update_result.get('error', 'Unknown error')}")
else:
logger.info(f"Configuration not yet edited (onboarding not completed), skipping model updates for {flow_type} flow")
except Exception as e:
logger.error(f"Error updating {flow_type} flow with current configuration", error=str(e))
# Don't fail the entire reset operation if configuration update fails
return {
"success": True,
"message": f"Successfully reset {flow_type} flow",
@ -155,11 +247,10 @@ class FlowsService:
logger.info(f"Assigning {provider} components")
# Define flow configurations
# Define flow configurations (removed hardcoded file paths)
flow_configs = [
{
"name": "nudges",
"file": "flows/openrag_nudges.json",
"flow_id": NUDGES_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": OPENAI_LLM_COMPONENT_ID,
@ -167,7 +258,6 @@ class FlowsService:
},
{
"name": "retrieval",
"file": "flows/openrag_agent.json",
"flow_id": LANGFLOW_CHAT_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": OPENAI_LLM_COMPONENT_ID,
@ -175,7 +265,6 @@ class FlowsService:
},
{
"name": "ingest",
"file": "flows/ingestion_flow.json",
"flow_id": LANGFLOW_INGEST_FLOW_ID,
"embedding_id": OPENAI_EMBEDDING_COMPONENT_ID,
"llm_id": None, # Ingestion flow might not have LLM
@ -272,7 +361,6 @@ class FlowsService:
async def _update_flow_components(self, config, llm_template, embedding_template, llm_text_template):
"""Update components in a specific flow"""
flow_name = config["name"]
flow_file = config["file"]
flow_id = config["flow_id"]
old_embedding_id = config["embedding_id"]
old_llm_id = config["llm_id"]
@ -281,14 +369,11 @@ class FlowsService:
new_llm_id = llm_template["data"]["id"]
new_embedding_id = embedding_template["data"]["id"]
new_llm_text_id = llm_text_template["data"]["id"]
# Get the project root directory
current_file_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.dirname(current_file_dir)
project_root = os.path.dirname(src_dir)
flow_path = os.path.join(project_root, flow_file)
if not os.path.exists(flow_path):
raise FileNotFoundError(f"Flow file not found at: {flow_path}")
# Dynamically find the flow file by ID
flow_path = self._find_flow_file_by_id(flow_id)
if not flow_path:
raise FileNotFoundError(f"Flow file not found for flow ID: {flow_id}")
# Load flow JSON
with open(flow_path, "r") as f:
@ -527,16 +612,17 @@ class FlowsService:
return False
async def change_langflow_model_value(
self, provider: str, embedding_model: str, llm_model: str, endpoint: str = None
self, provider: str, embedding_model: str, llm_model: str, endpoint: str = None, flow_configs: list = None
):
"""
Change dropdown values for provider-specific components across all flows
Change dropdown values for provider-specific components across flows
Args:
provider: The provider ("watsonx", "ollama", "openai")
embedding_model: The embedding model name to set
llm_model: The LLM model name to set
endpoint: The endpoint URL (required for watsonx/ibm provider)
flow_configs: Optional list of specific flow configs to update. If None, updates all flows.
Returns:
dict: Success/error response with details for each flow
@ -552,24 +638,22 @@ class FlowsService:
f"Changing dropdown values for provider {provider}, embedding: {embedding_model}, llm: {llm_model}, endpoint: {endpoint}"
)
# Define flow configurations with provider-specific component IDs
flow_configs = [
{
"name": "nudges",
"file": "flows/openrag_nudges.json",
"flow_id": NUDGES_FLOW_ID,
},
{
"name": "retrieval",
"file": "flows/openrag_agent.json",
"flow_id": LANGFLOW_CHAT_FLOW_ID,
},
{
"name": "ingest",
"file": "flows/ingestion_flow.json",
"flow_id": LANGFLOW_INGEST_FLOW_ID,
},
]
# Use provided flow_configs or default to all flows
if flow_configs is None:
flow_configs = [
{
"name": "nudges",
"flow_id": NUDGES_FLOW_ID,
},
{
"name": "retrieval",
"flow_id": LANGFLOW_CHAT_FLOW_ID,
},
{
"name": "ingest",
"flow_id": LANGFLOW_INGEST_FLOW_ID,
},
]
# Determine target component IDs based on provider
target_embedding_id, target_llm_id, target_llm_text_id = self._get_provider_component_ids(

64
src/utils/embeddings.py Normal file
View file

@ -0,0 +1,64 @@
from config.settings import OLLAMA_EMBEDDING_DIMENSIONS, OPENAI_EMBEDDING_DIMENSIONS, VECTOR_DIM, WATSONX_EMBEDDING_DIMENSIONS
from utils.logging_config import get_logger
logger = get_logger(__name__)
def get_embedding_dimensions(model_name: str) -> int:
"""Get the embedding dimensions for a given model name."""
# Check all model dictionaries
all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS}
if model_name in all_models:
dimensions = all_models[model_name]
logger.info(f"Found dimensions for model '{model_name}': {dimensions}")
return dimensions
logger.warning(
f"Unknown embedding model '{model_name}', using default dimensions: {VECTOR_DIM}"
)
return VECTOR_DIM
def create_dynamic_index_body(embedding_model: str) -> dict:
"""Create a dynamic index body configuration based on the embedding model."""
dimensions = get_embedding_dimensions(embedding_model)
return {
"settings": {
"index": {"knn": True},
"number_of_shards": 1,
"number_of_replicas": 1,
},
"mappings": {
"properties": {
"document_id": {"type": "keyword"},
"filename": {"type": "keyword"},
"mimetype": {"type": "keyword"},
"page": {"type": "integer"},
"text": {"type": "text"},
"chunk_embedding": {
"type": "knn_vector",
"dimension": dimensions,
"method": {
"name": "disk_ann",
"engine": "jvector",
"space_type": "l2",
"parameters": {"ef_construction": 100, "m": 16},
},
},
"source_url": {"type": "keyword"},
"connector_type": {"type": "keyword"},
"owner": {"type": "keyword"},
"allowed_users": {"type": "keyword"},
"allowed_groups": {"type": "keyword"},
"user_permissions": {"type": "object"},
"group_permissions": {"type": "object"},
"created_time": {"type": "date"},
"modified_time": {"type": "date"},
"indexed_time": {"type": "date"},
"metadata": {"type": "object"},
}
},
}