diff --git a/src/api/settings.py b/src/api/settings.py index ad08a3f5..85f584a8 100644 --- a/src/api/settings.py +++ b/src/api/settings.py @@ -233,7 +233,7 @@ async def update_settings(request, session_manager): ) -async def onboarding(request, session_manager): +async def onboarding(request, flows_service): """Handle onboarding configuration setup""" try: # Get current configuration @@ -323,6 +323,23 @@ async def onboarding(request, session_manager): if config_manager.save_config_file(current_config): updated_fields = [k for k in body.keys() if k != "sample_data"] # Exclude sample_data from log logger.info("Onboarding configuration updated successfully", updated_fields=updated_fields) + + # If model_provider was updated, assign the new provider to flows + if "model_provider" in body: + provider = body["model_provider"].strip().lower() + try: + flow_result = await flows_service.assign_model_provider(provider) + + if flow_result.get("success"): + logger.info(f"Successfully assigned {provider} to flows", flow_result=flow_result) + else: + logger.warning(f"Failed to assign {provider} to flows", flow_result=flow_result) + # Continue even if flow assignment fails - configuration was still saved + + except Exception as e: + logger.error(f"Error assigning model provider to flows", provider=provider, error=str(e)) + # Continue even if flow assignment fails - configuration was still saved + return JSONResponse({ "message": "Onboarding configuration updated successfully", "edited": True # Confirm that config is now marked as edited diff --git a/src/config/settings.py b/src/config/settings.py index 3e60b45c..35399bcc 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -399,6 +399,21 @@ class AppClients: ) +# Component template paths +WATSONX_LLM_COMPONENT_PATH = os.getenv("WATSONX_LLM_COMPONENT_PATH", "flows/components/watsonx_llm.json") +WATSONX_EMBEDDING_COMPONENT_PATH = os.getenv("WATSONX_EMBEDDING_COMPONENT_PATH", "flows/components/watsonx_embedding.json") +OLLAMA_LLM_COMPONENT_PATH = os.getenv("OLLAMA_LLM_COMPONENT_PATH", "flows/components/ollama_llm.json") +OLLAMA_EMBEDDING_COMPONENT_PATH = os.getenv("OLLAMA_EMBEDDING_COMPONENT_PATH", "flows/components/ollama_embedding.json") + +# Component IDs in flows +NUDGES_EMBEDDING_COMPONENT_ID = os.getenv("NUDGES_EMBEDDING_COMPONENT_ID", "EmbeddingModel-eZ6bT") +NUDGES_LLM_COMPONENT_ID = os.getenv("NUDGES_LLM_COMPONENT_ID", "LanguageModelComponent-0YME7") + +AGENT_EMBEDDING_COMPONENT_ID = os.getenv("AGENT_EMBEDDING_COMPONENT_ID", "EmbeddingModel-eZ6bT") +AGENT_LLM_COMPONENT_ID = os.getenv("AGENT_LLM_COMPONENT_ID", "LanguageModelComponent-0YME7") + +INGESTION_EMBEDDING_COMPONENT_ID = os.getenv("INGESTION_EMBEDDING_COMPONENT_ID", "OpenAIEmbeddings-joRJ6") + # Global clients instance clients = AppClients() diff --git a/src/main.py b/src/main.py index 502ada46..870d8deb 100644 --- a/src/main.py +++ b/src/main.py @@ -914,7 +914,8 @@ async def create_app(): "/onboarding", require_auth(services["session_manager"])( partial( - settings.onboarding, session_manager=services["session_manager"] + settings.onboarding, + flows_service=services["flows_service"] ) ), methods=["POST"], diff --git a/src/services/flows_service.py b/src/services/flows_service.py index 2df712b7..5435ffe1 100644 --- a/src/services/flows_service.py +++ b/src/services/flows_service.py @@ -1,6 +1,14 @@ -from config.settings import NUDGES_FLOW_ID, LANGFLOW_URL, LANGFLOW_CHAT_FLOW_ID, LANGFLOW_INGEST_FLOW_ID, clients +from config.settings import ( + NUDGES_FLOW_ID, LANGFLOW_URL, LANGFLOW_CHAT_FLOW_ID, LANGFLOW_INGEST_FLOW_ID, clients, + WATSONX_LLM_COMPONENT_PATH, WATSONX_EMBEDDING_COMPONENT_PATH, + OLLAMA_LLM_COMPONENT_PATH, OLLAMA_EMBEDDING_COMPONENT_PATH, + NUDGES_EMBEDDING_COMPONENT_ID, NUDGES_LLM_COMPONENT_ID, + AGENT_EMBEDDING_COMPONENT_ID, AGENT_LLM_COMPONENT_ID, + INGESTION_EMBEDDING_COMPONENT_ID +) import json import os +import re from utils.logging_config import get_logger logger = get_logger(__name__) @@ -25,7 +33,7 @@ class FlowsService: flow_file = "flows/openrag_nudges.json" flow_id = NUDGES_FLOW_ID elif flow_type == "retrieval": - flow_file = "flows/openrag_agent.json" + flow_file = "flows/openrag_agent.json" flow_id = LANGFLOW_CHAT_FLOW_ID elif flow_type == "ingest": flow_file = "flows/ingestion_flow.json" @@ -108,3 +116,227 @@ class FlowsService: "success": False, "error": f"Error: {str(e)}" } + + async def assign_model_provider(self, provider: str): + """ + Replace OpenAI components with the specified provider components in all flows + + Args: + provider: "watsonx", "ollama", or "openai" + + Returns: + dict: Success/error response with details for each flow + """ + if provider not in ["watsonx", "ollama", "openai"]: + raise ValueError("provider must be 'watsonx', 'ollama', or 'openai'") + + if provider == "openai": + logger.info("Provider is already OpenAI, no changes needed") + return {"success": True, "message": "Provider is already OpenAI, no changes needed"} + + try: + # Load component templates based on provider + llm_template, embedding_template = self._load_component_templates(provider) + + logger.info(f"Assigning {provider} components") + + # Define flow configurations + flow_configs = [ + { + "name": "nudges", + "file": "flows/openrag_nudges.json", + "flow_id": NUDGES_FLOW_ID, + "embedding_id": NUDGES_EMBEDDING_COMPONENT_ID, + "llm_id": NUDGES_LLM_COMPONENT_ID + }, + { + "name": "retrieval", + "file": "flows/openrag_agent.json", + "flow_id": LANGFLOW_CHAT_FLOW_ID, + "embedding_id": AGENT_EMBEDDING_COMPONENT_ID, + "llm_id": AGENT_LLM_COMPONENT_ID + }, + { + "name": "ingest", + "file": "flows/ingestion_flow.json", + "flow_id": LANGFLOW_INGEST_FLOW_ID, + "embedding_id": INGESTION_EMBEDDING_COMPONENT_ID, + "llm_id": None # Ingestion flow might not have LLM + } + ] + + results = [] + + # Process each flow sequentially + for config in flow_configs: + try: + result = await self._update_flow_components( + config, llm_template, embedding_template + ) + results.append(result) + logger.info(f"Successfully updated {config['name']} flow") + except Exception as e: + error_msg = f"Failed to update {config['name']} flow: {str(e)}" + logger.error(error_msg) + results.append({ + "flow": config['name'], + "success": False, + "error": error_msg + }) + # Continue with other flows even if one fails + + # Check if all flows were successful + all_success = all(r.get("success", False) for r in results) + + return { + "success": all_success, + "message": f"Model provider assignment to {provider} {'completed' if all_success else 'completed with errors'}", + "provider": provider, + "results": results + } + + except Exception as e: + logger.error(f"Error assigning model provider {provider}", error=str(e)) + return { + "success": False, + "error": f"Failed to assign model provider: {str(e)}" + } + + def _load_component_templates(self, provider: str): + """Load component templates for the specified provider""" + if provider == "watsonx": + llm_path = WATSONX_LLM_COMPONENT_PATH + embedding_path = WATSONX_EMBEDDING_COMPONENT_PATH + elif provider == "ollama": + llm_path = OLLAMA_LLM_COMPONENT_PATH + embedding_path = OLLAMA_EMBEDDING_COMPONENT_PATH + else: + raise ValueError(f"Unsupported provider: {provider}") + + # Get the project root directory (same logic as reset_langflow_flow) + current_file_dir = os.path.dirname(os.path.abspath(__file__)) # src/services/ + src_dir = os.path.dirname(current_file_dir) # src/ + project_root = os.path.dirname(src_dir) # project root + + # Load LLM template + llm_full_path = os.path.join(project_root, llm_path) + if not os.path.exists(llm_full_path): + raise FileNotFoundError(f"LLM component template not found at: {llm_full_path}") + + with open(llm_full_path, 'r') as f: + llm_template = json.load(f) + + # Load embedding template + embedding_full_path = os.path.join(project_root, embedding_path) + if not os.path.exists(embedding_full_path): + raise FileNotFoundError(f"Embedding component template not found at: {embedding_full_path}") + + with open(embedding_full_path, 'r') as f: + embedding_template = json.load(f) + + logger.info(f"Loaded component templates for {provider}") + return llm_template, embedding_template + + async def _update_flow_components(self, config, llm_template, embedding_template): + """Update components in a specific flow""" + flow_name = config["name"] + flow_file = config["file"] + flow_id = config["flow_id"] + old_embedding_id = config["embedding_id"] + old_llm_id = config["llm_id"] + + # Extract IDs from templates + new_llm_id = llm_template["data"]["id"] + new_embedding_id = embedding_template["data"]["id"] + + # Get the project root directory + current_file_dir = os.path.dirname(os.path.abspath(__file__)) + src_dir = os.path.dirname(current_file_dir) + project_root = os.path.dirname(src_dir) + flow_path = os.path.join(project_root, flow_file) + + if not os.path.exists(flow_path): + raise FileNotFoundError(f"Flow file not found at: {flow_path}") + + # Load flow JSON + with open(flow_path, 'r') as f: + flow_data = json.load(f) + + # Find and replace components + components_updated = [] + + # Replace embedding component + embedding_node = self._find_node_by_id(flow_data, old_embedding_id) + if embedding_node: + # Preserve position + original_position = embedding_node.get("position", {}) + + # Replace with new template + new_embedding_node = embedding_template.copy() + new_embedding_node["position"] = original_position + + # Replace in flow + self._replace_node_in_flow(flow_data, old_embedding_id, new_embedding_node) + components_updated.append(f"embedding: {old_embedding_id} -> {new_embedding_id}") + + # Replace LLM component (if exists in this flow) + if old_llm_id: + llm_node = self._find_node_by_id(flow_data, old_llm_id) + if llm_node: + # Preserve position + original_position = llm_node.get("position", {}) + + # Replace with new template + new_llm_node = llm_template.copy() + new_llm_node["position"] = original_position + + # Replace in flow + self._replace_node_in_flow(flow_data, old_llm_id, new_llm_node) + components_updated.append(f"llm: {old_llm_id} -> {new_llm_id}") + + # Update all edge references using regex replacement + flow_json_str = json.dumps(flow_data) + + # Replace embedding ID references + flow_json_str = re.sub(r'\b' + re.escape(old_embedding_id) + r'\b', new_embedding_id, flow_json_str) + + # Replace LLM ID references (if applicable) + if old_llm_id: + flow_json_str = re.sub(r'\b' + re.escape(old_llm_id) + r'\b', new_llm_id, flow_json_str) + + # Convert back to JSON + flow_data = json.loads(flow_json_str) + + # PATCH the updated flow + response = await clients.langflow_request( + "PATCH", + f"/api/v1/flows/{flow_id}", + json=flow_data + ) + + if response.status_code != 200: + raise Exception(f"Failed to update flow: HTTP {response.status_code} - {response.text}") + + return { + "flow": flow_name, + "success": True, + "components_updated": components_updated, + "flow_id": flow_id + } + + def _find_node_by_id(self, flow_data, node_id): + """Find a node by ID in the flow data""" + nodes = flow_data.get("data", {}).get("nodes", []) + for node in nodes: + if node.get("id") == node_id: + return node + return None + + def _replace_node_in_flow(self, flow_data, old_id, new_node): + """Replace a node in the flow data""" + nodes = flow_data.get("data", {}).get("nodes", []) + for i, node in enumerate(nodes): + if node.get("id") == old_id: + nodes[i] = new_node + return True + return False