Extend AuthService and LangflowMCPService to support updating MCP servers with multiple global variable headers (JWT, OWNER, OWNER_NAME, OWNER_EMAIL) during login. This enables passing user context to MCP servers via custom headers, improving downstream identification and authorization.
275 lines
10 KiB
Python
275 lines
10 KiB
Python
from typing import List, Dict, Any
|
|
|
|
from config.settings import clients
|
|
from utils.logging_config import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class LangflowMCPService:
|
|
async def list_mcp_servers(self) -> List[Dict[str, Any]]:
|
|
"""Fetch list of MCP servers from Langflow (v2 API)."""
|
|
try:
|
|
response = await clients.langflow_request(
|
|
method="GET",
|
|
endpoint="/api/v2/mcp/servers",
|
|
params={"action_count": "false"},
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
if isinstance(data, list):
|
|
return data
|
|
logger.warning("Unexpected response format for MCP servers list", data_type=type(data).__name__)
|
|
return []
|
|
except Exception as e:
|
|
logger.error("Failed to list MCP servers", error=str(e))
|
|
return []
|
|
|
|
async def get_mcp_server(self, server_name: str) -> Dict[str, Any]:
|
|
"""Get MCP server configuration by name."""
|
|
response = await clients.langflow_request(
|
|
method="GET",
|
|
endpoint=f"/api/v2/mcp/servers/{server_name}",
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def _upsert_jwt_header_in_args(self, args: List[str], jwt_token: str) -> List[str]:
|
|
"""Ensure args contains a header triplet for X-Langflow-Global-Var-JWT with the provided JWT.
|
|
|
|
Args are expected in the pattern: [..., "--headers", key, value, ...].
|
|
If the header exists, update its value; otherwise append the triplet at the end.
|
|
"""
|
|
if not isinstance(args, list):
|
|
return [
|
|
"mcp-proxy",
|
|
"--headers",
|
|
"X-Langflow-Global-Var-JWT",
|
|
jwt_token,
|
|
]
|
|
|
|
updated_args = list(args)
|
|
i = 0
|
|
found_index = -1
|
|
while i < len(updated_args):
|
|
token = updated_args[i]
|
|
if token == "--headers" and i + 2 < len(updated_args):
|
|
header_key = updated_args[i + 1]
|
|
if isinstance(header_key, str) and header_key.lower() == "x-langflow-global-var-jwt".lower():
|
|
found_index = i
|
|
break
|
|
i += 3
|
|
continue
|
|
i += 1
|
|
|
|
if found_index >= 0:
|
|
# Replace existing value at found_index + 2
|
|
if found_index + 2 < len(updated_args):
|
|
updated_args[found_index + 2] = jwt_token
|
|
else:
|
|
# Malformed existing header triplet; make sure to append a value
|
|
updated_args.append(jwt_token)
|
|
else:
|
|
updated_args.extend([
|
|
"--headers",
|
|
"X-Langflow-Global-Var-JWT",
|
|
jwt_token,
|
|
])
|
|
|
|
return updated_args
|
|
|
|
def _upsert_global_var_headers_in_args(self, args: List[str], global_vars: Dict[str, str]) -> List[str]:
|
|
"""Ensure args contains header triplets for X-Langflow-Global-Var-{key} with the provided global variables.
|
|
|
|
Args are expected in the pattern: [..., "--headers", key, value, ...].
|
|
If a header exists, update its value; otherwise append the triplet at the end.
|
|
"""
|
|
if not isinstance(args, list):
|
|
updated_args = ["mcp-proxy"]
|
|
else:
|
|
updated_args = list(args)
|
|
|
|
for var_key, var_value in global_vars.items():
|
|
header_name = f"X-Langflow-Global-Var-{var_key}"
|
|
|
|
i = 0
|
|
found_index = -1
|
|
while i < len(updated_args):
|
|
token = updated_args[i]
|
|
if token == "--headers" and i + 2 < len(updated_args):
|
|
header_key = updated_args[i + 1]
|
|
if isinstance(header_key, str) and header_key.lower() == header_name.lower():
|
|
found_index = i
|
|
break
|
|
i += 3
|
|
continue
|
|
i += 1
|
|
|
|
if found_index >= 0:
|
|
# Replace existing value at found_index + 2
|
|
if found_index + 2 < len(updated_args):
|
|
updated_args[found_index + 2] = var_value
|
|
else:
|
|
# Malformed existing header triplet; make sure to append a value
|
|
updated_args.append(var_value)
|
|
else:
|
|
updated_args.extend([
|
|
"--headers",
|
|
header_name,
|
|
var_value,
|
|
])
|
|
|
|
return updated_args
|
|
|
|
async def patch_mcp_server_args_with_global_vars(self, server_name: str, global_vars: Dict[str, Any]) -> bool:
|
|
"""Patch a single MCP server to include/update multiple X-Langflow-Global-Var-* headers in args.
|
|
|
|
Only non-empty values are applied. Keys are uppercased to match existing conventions (e.g., JWT).
|
|
"""
|
|
try:
|
|
if not isinstance(global_vars, dict) or not global_vars:
|
|
return True # Nothing to do
|
|
|
|
# Sanitize and normalize keys/values
|
|
sanitized: Dict[str, str] = {}
|
|
for k, v in global_vars.items():
|
|
if v is None:
|
|
continue
|
|
v_str = str(v).strip()
|
|
if not v_str:
|
|
continue
|
|
sanitized[k.upper()] = v_str
|
|
|
|
if not sanitized:
|
|
return True
|
|
|
|
current = await self.get_mcp_server(server_name)
|
|
command = current.get("command")
|
|
args = current.get("args", [])
|
|
updated_args = self._upsert_global_var_headers_in_args(args, sanitized)
|
|
|
|
payload = {"command": command, "args": updated_args}
|
|
response = await clients.langflow_request(
|
|
method="PATCH",
|
|
endpoint=f"/api/v2/mcp/servers/{server_name}",
|
|
json=payload,
|
|
)
|
|
if response.status_code in (200, 201):
|
|
logger.info(
|
|
"Patched MCP server with global-var headers",
|
|
server_name=server_name,
|
|
applied_keys=list(sanitized.keys()),
|
|
args_len=len(updated_args),
|
|
)
|
|
return True
|
|
else:
|
|
logger.warning(
|
|
"Failed to patch MCP server with global vars",
|
|
server_name=server_name,
|
|
status_code=response.status_code,
|
|
body=response.text,
|
|
)
|
|
return False
|
|
except Exception as e:
|
|
logger.error(
|
|
"Exception while patching MCP server with global vars",
|
|
server_name=server_name,
|
|
error=str(e),
|
|
)
|
|
return False
|
|
|
|
async def update_mcp_servers_with_global_vars(self, global_vars: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Fetch all MCP servers and ensure each includes provided global-var headers in args.
|
|
|
|
Returns a summary dict with counts.
|
|
"""
|
|
servers = await self.list_mcp_servers()
|
|
if not servers:
|
|
return {"updated": 0, "failed": 0, "total": 0}
|
|
|
|
updated = 0
|
|
failed = 0
|
|
for server in servers:
|
|
name = server.get("name") or server.get("server") or server.get("id")
|
|
if not name:
|
|
continue
|
|
ok = await self.patch_mcp_server_args_with_global_vars(name, global_vars)
|
|
if ok:
|
|
updated += 1
|
|
else:
|
|
failed += 1
|
|
|
|
summary = {"updated": updated, "failed": failed, "total": len(servers)}
|
|
if failed == 0:
|
|
logger.info("MCP servers updated with global-var headers", **summary)
|
|
else:
|
|
logger.warning("MCP servers update (global vars) had failures", **summary)
|
|
return summary
|
|
|
|
async def patch_mcp_server_args_with_jwt(self, server_name: str, jwt_token: str) -> bool:
|
|
"""Patch a single MCP server to include/update the JWT header in args."""
|
|
try:
|
|
current = await self.get_mcp_server(server_name)
|
|
command = current.get("command")
|
|
args = current.get("args", [])
|
|
updated_args = self._upsert_jwt_header_in_args(args, jwt_token)
|
|
|
|
payload = {"command": command, "args": updated_args}
|
|
response = await clients.langflow_request(
|
|
method="PATCH",
|
|
endpoint=f"/api/v2/mcp/servers/{server_name}",
|
|
json=payload,
|
|
)
|
|
if response.status_code in (200, 201):
|
|
logger.info(
|
|
"Patched MCP server with JWT header",
|
|
server_name=server_name,
|
|
args_len=len(updated_args),
|
|
)
|
|
return True
|
|
else:
|
|
logger.warning(
|
|
"Failed to patch MCP server",
|
|
server_name=server_name,
|
|
status_code=response.status_code,
|
|
body=response.text,
|
|
)
|
|
return False
|
|
except Exception as e:
|
|
logger.error(
|
|
"Exception while patching MCP server",
|
|
server_name=server_name,
|
|
error=str(e),
|
|
)
|
|
return False
|
|
|
|
async def update_mcp_servers_with_jwt(self, jwt_token: str) -> Dict[str, Any]:
|
|
"""Fetch all MCP servers and ensure each includes the JWT header in args.
|
|
|
|
Returns a summary dict with counts.
|
|
"""
|
|
servers = await self.list_mcp_servers()
|
|
if not servers:
|
|
return {"updated": 0, "failed": 0, "total": 0}
|
|
|
|
updated = 0
|
|
failed = 0
|
|
for server in servers:
|
|
name = server.get("name") or server.get("server") or server.get("id")
|
|
if not name:
|
|
continue
|
|
ok = await self.patch_mcp_server_args_with_jwt(name, jwt_token)
|
|
if ok:
|
|
updated += 1
|
|
else:
|
|
failed += 1
|
|
|
|
summary = {"updated": updated, "failed": failed, "total": len(servers)}
|
|
if failed == 0:
|
|
logger.info("MCP servers updated with JWT header", **summary)
|
|
else:
|
|
logger.warning("MCP servers update had failures", **summary)
|
|
return summary
|
|
|
|
|