From 713f90c3c407e06b83e8e5f95e5953ba2461e26d Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Wed, 17 Sep 2025 02:01:03 -0400 Subject: [PATCH] Integrate Langflow MCP JWT header update in AuthService Added LangflowMCPService to update MCP servers with the user's JWT header after authentication. AuthService now triggers a background update to MCP servers on successful login, ensuring JWT propagation for downstream services. --- src/main.py | 7 +- src/services/auth_service.py | 20 +++- src/services/langflow_mcp_service.py | 147 +++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 2 deletions(-) create mode 100644 src/services/langflow_mcp_service.py diff --git a/src/main.py b/src/main.py index 1c0dc09f..1b50dfce 100644 --- a/src/main.py +++ b/src/main.py @@ -57,6 +57,7 @@ from config.settings import ( is_no_auth_mode, ) from services.auth_service import AuthService +from services.langflow_mcp_service import LangflowMCPService from services.chat_service import ChatService # Services @@ -437,7 +438,11 @@ async def initialize_services(): ) # Initialize auth service - auth_service = AuthService(session_manager, connector_service) + auth_service = AuthService( + session_manager, + connector_service, + langflow_mcp_service=LangflowMCPService(), + ) # Load persisted connector connections at startup so webhooks and syncs # can resolve existing subscriptions immediately after server boot diff --git a/src/services/auth_service.py b/src/services/auth_service.py index a29c197f..1c5afdac 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -5,9 +5,11 @@ import httpx import aiofiles from datetime import datetime, timedelta from typing import Optional +import asyncio from config.settings import WEBHOOK_BASE_URL, is_no_auth_mode from session_manager import SessionManager +from services.langflow_mcp_service import LangflowMCPService from connectors.google_drive.oauth import GoogleDriveOAuth from connectors.onedrive.oauth import OneDriveOAuth from connectors.sharepoint.oauth import SharePointOAuth @@ -17,10 +19,12 @@ from connectors.sharepoint import SharePointConnector class AuthService: - def __init__(self, session_manager: SessionManager, connector_service=None): + def __init__(self, session_manager: SessionManager, connector_service=None, langflow_mcp_service: LangflowMCPService | None = None): self.session_manager = session_manager self.connector_service = connector_service self.used_auth_codes = set() # Track used authorization codes + self.langflow_mcp_service = langflow_mcp_service + self._background_tasks = set() async def init_oauth( self, @@ -287,6 +291,20 @@ class AuthService: user_info = await self.session_manager.get_user_info_from_token( token_data["access_token"] ) + + # Best-effort: update Langflow MCP servers to include user's JWT header + try: + if self.langflow_mcp_service and isinstance(jwt_token, str) and jwt_token.strip(): + # Run in background to avoid delaying login flow + task = asyncio.create_task( + self.langflow_mcp_service.update_mcp_servers_with_jwt(jwt_token) + ) + # Keep reference until done to avoid premature GC + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + except Exception: + # Do not block login on MCP update issues + pass response_data = { "status": "authenticated", diff --git a/src/services/langflow_mcp_service.py b/src/services/langflow_mcp_service.py new file mode 100644 index 00000000..3e98a219 --- /dev/null +++ b/src/services/langflow_mcp_service.py @@ -0,0 +1,147 @@ +from typing import List, Dict, Any + +from config.settings import clients +from utils.logging_config import get_logger + + +logger = get_logger(__name__) + + +class LangflowMCPService: + async def list_mcp_servers(self) -> List[Dict[str, Any]]: + """Fetch list of MCP servers from Langflow (v2 API).""" + try: + response = await clients.langflow_request( + method="GET", + endpoint="/api/v2/mcp/servers", + params={"action_count": "false"}, + ) + response.raise_for_status() + data = response.json() + if isinstance(data, list): + return data + logger.warning("Unexpected response format for MCP servers list", data_type=type(data).__name__) + return [] + except Exception as e: + logger.error("Failed to list MCP servers", error=str(e)) + return [] + + async def get_mcp_server(self, server_name: str) -> Dict[str, Any]: + """Get MCP server configuration by name.""" + response = await clients.langflow_request( + method="GET", + endpoint=f"/api/v2/mcp/servers/{server_name}", + ) + response.raise_for_status() + return response.json() + + def _upsert_jwt_header_in_args(self, args: List[str], jwt_token: str) -> List[str]: + """Ensure args contains a header triplet for X-Langflow-Global-Var-JWT with the provided JWT. + + Args are expected in the pattern: [..., "--headers", key, value, ...]. + If the header exists, update its value; otherwise append the triplet at the end. + """ + if not isinstance(args, list): + return [ + "mcp-proxy", + "--headers", + "X-Langflow-Global-Var-JWT", + jwt_token, + ] + + updated_args = list(args) + i = 0 + found_index = -1 + while i < len(updated_args): + token = updated_args[i] + if token == "--headers" and i + 2 < len(updated_args): + header_key = updated_args[i + 1] + if isinstance(header_key, str) and header_key.lower() == "x-langflow-global-var-jwt".lower(): + found_index = i + break + i += 3 + continue + i += 1 + + if found_index >= 0: + # Replace existing value at found_index + 2 + if found_index + 2 < len(updated_args): + updated_args[found_index + 2] = jwt_token + else: + # Malformed existing header triplet; make sure to append a value + updated_args.append(jwt_token) + else: + updated_args.extend([ + "--headers", + "X-Langflow-Global-Var-JWT", + jwt_token, + ]) + + return updated_args + + async def patch_mcp_server_args_with_jwt(self, server_name: str, jwt_token: str) -> bool: + """Patch a single MCP server to include/update the JWT header in args.""" + try: + current = await self.get_mcp_server(server_name) + command = current.get("command") + args = current.get("args", []) + updated_args = self._upsert_jwt_header_in_args(args, jwt_token) + + payload = {"command": command, "args": updated_args} + response = await clients.langflow_request( + method="PATCH", + endpoint=f"/api/v2/mcp/servers/{server_name}", + json=payload, + ) + if response.status_code in (200, 201): + logger.info( + "Patched MCP server with JWT header", + server_name=server_name, + args_len=len(updated_args), + ) + return True + else: + logger.warning( + "Failed to patch MCP server", + server_name=server_name, + status_code=response.status_code, + body=response.text, + ) + return False + except Exception as e: + logger.error( + "Exception while patching MCP server", + server_name=server_name, + error=str(e), + ) + return False + + async def update_mcp_servers_with_jwt(self, jwt_token: str) -> Dict[str, Any]: + """Fetch all MCP servers and ensure each includes the JWT header in args. + + Returns a summary dict with counts. + """ + servers = await self.list_mcp_servers() + if not servers: + return {"updated": 0, "failed": 0, "total": 0} + + updated = 0 + failed = 0 + for server in servers: + name = server.get("name") or server.get("server") or server.get("id") + if not name: + continue + ok = await self.patch_mcp_server_args_with_jwt(name, jwt_token) + if ok: + updated += 1 + else: + failed += 1 + + summary = {"updated": updated, "failed": failed, "total": len(servers)} + if failed == 0: + logger.info("MCP servers updated with JWT header", **summary) + else: + logger.warning("MCP servers update had failures", **summary) + return summary + +