diff --git a/docker-compose-cpu.yml b/docker-compose-cpu.yml index 50e118b7..fad301e6 100644 --- a/docker-compose-cpu.yml +++ b/docker-compose-cpu.yml @@ -129,7 +129,8 @@ services: - FILENAME=None - MIMETYPE=None - FILESIZE=0 - - LANGFLOW_VARIABLES_TO_GET_FROM_ENVIRONMENT=JWT,OPENRAG-QUERY-FILTER,OPENSEARCH_PASSWORD,OWNER,OWNER_NAME,OWNER_EMAIL,CONNECTOR_TYPE,FILENAME,MIMETYPE,FILESIZE + - SELECTED_EMBEDDING_MODEL=text-embedding-3-small + - LANGFLOW_VARIABLES_TO_GET_FROM_ENVIRONMENT=JWT,OPENRAG-QUERY-FILTER,OPENSEARCH_PASSWORD,OWNER,OWNER_NAME,OWNER_EMAIL,CONNECTOR_TYPE,FILENAME,MIMETYPE,FILESIZE,SELECTED_EMBEDDING_MODEL - LANGFLOW_LOG_LEVEL=DEBUG - LANGFLOW_AUTO_LOGIN=${LANGFLOW_AUTO_LOGIN} - LANGFLOW_SUPERUSER=${LANGFLOW_SUPERUSER} diff --git a/docker-compose.yml b/docker-compose.yml index 7ba0cea8..dae748eb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -130,8 +130,9 @@ services: - FILENAME=None - MIMETYPE=None - FILESIZE=0 + - SELECTED_EMBEDDING_MODEL=text-embedding-3-small - OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD} - - LANGFLOW_VARIABLES_TO_GET_FROM_ENVIRONMENT=JWT,OPENRAG-QUERY-FILTER,OPENSEARCH_PASSWORD,OWNER,OWNER_NAME,OWNER_EMAIL,CONNECTOR_TYPE,FILENAME,MIMETYPE,FILESIZE + - LANGFLOW_VARIABLES_TO_GET_FROM_ENVIRONMENT=JWT,OPENRAG-QUERY-FILTER,OPENSEARCH_PASSWORD,OWNER,OWNER_NAME,OWNER_EMAIL,CONNECTOR_TYPE,FILENAME,MIMETYPE,FILESIZE,SELECTED_EMBEDDING_MODEL - LANGFLOW_LOG_LEVEL=DEBUG - LANGFLOW_AUTO_LOGIN=${LANGFLOW_AUTO_LOGIN} - LANGFLOW_SUPERUSER=${LANGFLOW_SUPERUSER} diff --git a/src/api/settings.py b/src/api/settings.py index 5b1eb22a..8cde2c3b 100644 --- a/src/api/settings.py +++ b/src/api/settings.py @@ -614,7 +614,7 @@ async def update_settings(request, session_manager): ) logger.info("Set OLLAMA_BASE_URL global variable in Langflow") - # Update model values across flows if provider or model changed + # Update LLM model values across flows if provider or model changed if "llm_provider" in body or "llm_model" in body: flows_service = _get_flows_service() llm_provider = current_config.agent.llm_provider.lower() @@ -629,18 +629,13 @@ async def update_settings(request, session_manager): f"Successfully updated Langflow flows for LLM provider {llm_provider}" ) + # Update SELECTED_EMBEDDING_MODEL global variable (no flow updates needed) if "embedding_provider" in body or "embedding_model" in body: - flows_service = _get_flows_service() - embedding_provider = current_config.knowledge.embedding_provider.lower() - embedding_provider_config = current_config.get_embedding_provider_config() - embedding_endpoint = getattr(embedding_provider_config, "endpoint", None) - await flows_service.change_langflow_model_value( - embedding_provider, - embedding_model=current_config.knowledge.embedding_model, - endpoint=embedding_endpoint, + await clients._create_langflow_global_variable( + "SELECTED_EMBEDDING_MODEL", current_config.knowledge.embedding_model, modify=True ) logger.info( - f"Successfully updated Langflow flows for embedding provider {embedding_provider}" + f"Set SELECTED_EMBEDDING_MODEL global variable to {current_config.knowledge.embedding_model}" ) except Exception as e: @@ -928,7 +923,7 @@ async def onboarding(request, flows_service): ) logger.info("Set OLLAMA_BASE_URL global variable in Langflow") - # Update flows with model values + # Update flows with LLM model values if "llm_provider" in body or "llm_model" in body: llm_provider = current_config.agent.llm_provider.lower() llm_provider_config = current_config.get_llm_provider_config() @@ -940,16 +935,14 @@ async def onboarding(request, flows_service): ) logger.info(f"Updated Langflow flows for LLM provider {llm_provider}") + # Set SELECTED_EMBEDDING_MODEL global variable (no flow updates needed) if "embedding_provider" in body or "embedding_model" in body: - embedding_provider = current_config.knowledge.embedding_provider.lower() - embedding_provider_config = current_config.get_embedding_provider_config() - embedding_endpoint = getattr(embedding_provider_config, "endpoint", None) - await flows_service.change_langflow_model_value( - provider=embedding_provider, - embedding_model=current_config.knowledge.embedding_model, - endpoint=embedding_endpoint, + await clients._create_langflow_global_variable( + "SELECTED_EMBEDDING_MODEL", current_config.knowledge.embedding_model, modify=True + ) + logger.info( + f"Set SELECTED_EMBEDDING_MODEL global variable to {current_config.knowledge.embedding_model}" ) - logger.info(f"Updated Langflow flows for embedding provider {embedding_provider}") except Exception as e: logger.error( diff --git a/src/services/chat_service.py b/src/services/chat_service.py index 63a1415b..d8cb320d 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -60,11 +60,17 @@ class ChatService: "LANGFLOW_URL and LANGFLOW_CHAT_FLOW_ID environment variables are required" ) - # Prepare extra headers for JWT authentication + # Prepare extra headers for JWT authentication and embedding model extra_headers = {} if jwt_token: extra_headers["X-LANGFLOW-GLOBAL-VAR-JWT"] = jwt_token + # Pass the selected embedding model as a global variable + from config.config_manager import get_openrag_config + config = get_openrag_config() + embedding_model = config.knowledge.embedding_model + extra_headers["X-LANGFLOW-GLOBAL-VAR-SELECTED_EMBEDDING_MODEL"] = embedding_model + # Get context variables for filters, limit, and threshold from auth_context import ( get_score_threshold, @@ -169,11 +175,17 @@ class ChatService: "LANGFLOW_URL and NUDGES_FLOW_ID environment variables are required" ) - # Prepare extra headers for JWT authentication + # Prepare extra headers for JWT authentication and embedding model extra_headers = {} if jwt_token: extra_headers["X-LANGFLOW-GLOBAL-VAR-JWT"] = jwt_token + # Pass the selected embedding model as a global variable + from config.config_manager import get_openrag_config + config = get_openrag_config() + embedding_model = config.knowledge.embedding_model + extra_headers["X-LANGFLOW-GLOBAL-VAR-SELECTED_EMBEDDING_MODEL"] = embedding_model + # Build the complete filter expression like the chat service does filter_expression = {} has_user_filters = False @@ -287,10 +299,16 @@ class ChatService: document_prompt = f"I'm uploading a document called '{filename}'. Here is its content:\n\n{document_content}\n\nPlease confirm you've received this document and are ready to answer questions about it." if endpoint == "langflow": - # Prepare extra headers for JWT authentication + # Prepare extra headers for JWT authentication and embedding model extra_headers = {} if jwt_token: extra_headers["X-LANGFLOW-GLOBAL-VAR-JWT"] = jwt_token + + # Pass the selected embedding model as a global variable + from config.config_manager import get_openrag_config + config = get_openrag_config() + embedding_model = config.knowledge.embedding_model + extra_headers["X-LANGFLOW-GLOBAL-VAR-SELECTED_EMBEDDING_MODEL"] = embedding_model # Ensure the Langflow client exists; try lazy init if needed langflow_client = await clients.ensure_langflow_client() if not langflow_client: diff --git a/src/services/langflow_file_service.py b/src/services/langflow_file_service.py index 017431bf..103716e1 100644 --- a/src/services/langflow_file_service.py +++ b/src/services/langflow_file_service.py @@ -140,6 +140,11 @@ class LangflowFileService: filename = str(file_tuples[0][0]) if file_tuples and len(file_tuples) > 0 else "" mimetype = str(file_tuples[0][2]) if file_tuples and len(file_tuples) > 0 and len(file_tuples[0]) > 2 else "" + # Get the current embedding model from config + from config.config_manager import get_openrag_config + config = get_openrag_config() + embedding_model = config.knowledge.embedding_model + headers={ "X-Langflow-Global-Var-JWT": str(jwt_token), "X-Langflow-Global-Var-OWNER": str(owner), @@ -149,6 +154,7 @@ class LangflowFileService: "X-Langflow-Global-Var-FILENAME": filename, "X-Langflow-Global-Var-MIMETYPE": mimetype, "X-Langflow-Global-Var-FILESIZE": str(file_size_bytes), + "X-Langflow-Global-Var-SELECTED_EMBEDDING_MODEL": str(embedding_model), } logger.info(f"[LF] Headers {headers}") logger.info(f"[LF] Payload {payload}")