Merge branch 'main' of github.com:phact/gendb
This commit is contained in:
commit
b223f183ee
3 changed files with 172 additions and 2 deletions
|
|
@ -57,6 +57,7 @@ from config.settings import (
|
|||
is_no_auth_mode,
|
||||
)
|
||||
from services.auth_service import AuthService
|
||||
from services.langflow_mcp_service import LangflowMCPService
|
||||
from services.chat_service import ChatService
|
||||
|
||||
# Services
|
||||
|
|
@ -437,7 +438,11 @@ async def initialize_services():
|
|||
)
|
||||
|
||||
# Initialize auth service
|
||||
auth_service = AuthService(session_manager, connector_service)
|
||||
auth_service = AuthService(
|
||||
session_manager,
|
||||
connector_service,
|
||||
langflow_mcp_service=LangflowMCPService(),
|
||||
)
|
||||
|
||||
# Load persisted connector connections at startup so webhooks and syncs
|
||||
# can resolve existing subscriptions immediately after server boot
|
||||
|
|
|
|||
|
|
@ -5,9 +5,11 @@ import httpx
|
|||
import aiofiles
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
from config.settings import WEBHOOK_BASE_URL, is_no_auth_mode
|
||||
from session_manager import SessionManager
|
||||
from services.langflow_mcp_service import LangflowMCPService
|
||||
from connectors.google_drive.oauth import GoogleDriveOAuth
|
||||
from connectors.onedrive.oauth import OneDriveOAuth
|
||||
from connectors.sharepoint.oauth import SharePointOAuth
|
||||
|
|
@ -17,10 +19,12 @@ from connectors.sharepoint import SharePointConnector
|
|||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, session_manager: SessionManager, connector_service=None):
|
||||
def __init__(self, session_manager: SessionManager, connector_service=None, langflow_mcp_service: LangflowMCPService | None = None):
|
||||
self.session_manager = session_manager
|
||||
self.connector_service = connector_service
|
||||
self.used_auth_codes = set() # Track used authorization codes
|
||||
self.langflow_mcp_service = langflow_mcp_service
|
||||
self._background_tasks = set()
|
||||
|
||||
async def init_oauth(
|
||||
self,
|
||||
|
|
@ -287,6 +291,20 @@ class AuthService:
|
|||
user_info = await self.session_manager.get_user_info_from_token(
|
||||
token_data["access_token"]
|
||||
)
|
||||
|
||||
# Best-effort: update Langflow MCP servers to include user's JWT header
|
||||
try:
|
||||
if self.langflow_mcp_service and isinstance(jwt_token, str) and jwt_token.strip():
|
||||
# Run in background to avoid delaying login flow
|
||||
task = asyncio.create_task(
|
||||
self.langflow_mcp_service.update_mcp_servers_with_jwt(jwt_token)
|
||||
)
|
||||
# Keep reference until done to avoid premature GC
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
except Exception:
|
||||
# Do not block login on MCP update issues
|
||||
pass
|
||||
|
||||
response_data = {
|
||||
"status": "authenticated",
|
||||
|
|
|
|||
147
src/services/langflow_mcp_service.py
Normal file
147
src/services/langflow_mcp_service.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
from typing import List, Dict, Any
|
||||
|
||||
from config.settings import clients
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LangflowMCPService:
|
||||
async def list_mcp_servers(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch list of MCP servers from Langflow (v2 API)."""
|
||||
try:
|
||||
response = await clients.langflow_request(
|
||||
method="GET",
|
||||
endpoint="/api/v2/mcp/servers",
|
||||
params={"action_count": "false"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
logger.warning("Unexpected response format for MCP servers list", data_type=type(data).__name__)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error("Failed to list MCP servers", error=str(e))
|
||||
return []
|
||||
|
||||
async def get_mcp_server(self, server_name: str) -> Dict[str, Any]:
|
||||
"""Get MCP server configuration by name."""
|
||||
response = await clients.langflow_request(
|
||||
method="GET",
|
||||
endpoint=f"/api/v2/mcp/servers/{server_name}",
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _upsert_jwt_header_in_args(self, args: List[str], jwt_token: str) -> List[str]:
|
||||
"""Ensure args contains a header triplet for X-Langflow-Global-Var-JWT with the provided JWT.
|
||||
|
||||
Args are expected in the pattern: [..., "--headers", key, value, ...].
|
||||
If the header exists, update its value; otherwise append the triplet at the end.
|
||||
"""
|
||||
if not isinstance(args, list):
|
||||
return [
|
||||
"mcp-proxy",
|
||||
"--headers",
|
||||
"X-Langflow-Global-Var-JWT",
|
||||
jwt_token,
|
||||
]
|
||||
|
||||
updated_args = list(args)
|
||||
i = 0
|
||||
found_index = -1
|
||||
while i < len(updated_args):
|
||||
token = updated_args[i]
|
||||
if token == "--headers" and i + 2 < len(updated_args):
|
||||
header_key = updated_args[i + 1]
|
||||
if isinstance(header_key, str) and header_key.lower() == "x-langflow-global-var-jwt".lower():
|
||||
found_index = i
|
||||
break
|
||||
i += 3
|
||||
continue
|
||||
i += 1
|
||||
|
||||
if found_index >= 0:
|
||||
# Replace existing value at found_index + 2
|
||||
if found_index + 2 < len(updated_args):
|
||||
updated_args[found_index + 2] = jwt_token
|
||||
else:
|
||||
# Malformed existing header triplet; make sure to append a value
|
||||
updated_args.append(jwt_token)
|
||||
else:
|
||||
updated_args.extend([
|
||||
"--headers",
|
||||
"X-Langflow-Global-Var-JWT",
|
||||
jwt_token,
|
||||
])
|
||||
|
||||
return updated_args
|
||||
|
||||
async def patch_mcp_server_args_with_jwt(self, server_name: str, jwt_token: str) -> bool:
|
||||
"""Patch a single MCP server to include/update the JWT header in args."""
|
||||
try:
|
||||
current = await self.get_mcp_server(server_name)
|
||||
command = current.get("command")
|
||||
args = current.get("args", [])
|
||||
updated_args = self._upsert_jwt_header_in_args(args, jwt_token)
|
||||
|
||||
payload = {"command": command, "args": updated_args}
|
||||
response = await clients.langflow_request(
|
||||
method="PATCH",
|
||||
endpoint=f"/api/v2/mcp/servers/{server_name}",
|
||||
json=payload,
|
||||
)
|
||||
if response.status_code in (200, 201):
|
||||
logger.info(
|
||||
"Patched MCP server with JWT header",
|
||||
server_name=server_name,
|
||||
args_len=len(updated_args),
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to patch MCP server",
|
||||
server_name=server_name,
|
||||
status_code=response.status_code,
|
||||
body=response.text,
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Exception while patching MCP server",
|
||||
server_name=server_name,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
async def update_mcp_servers_with_jwt(self, jwt_token: str) -> Dict[str, Any]:
|
||||
"""Fetch all MCP servers and ensure each includes the JWT header in args.
|
||||
|
||||
Returns a summary dict with counts.
|
||||
"""
|
||||
servers = await self.list_mcp_servers()
|
||||
if not servers:
|
||||
return {"updated": 0, "failed": 0, "total": 0}
|
||||
|
||||
updated = 0
|
||||
failed = 0
|
||||
for server in servers:
|
||||
name = server.get("name") or server.get("server") or server.get("id")
|
||||
if not name:
|
||||
continue
|
||||
ok = await self.patch_mcp_server_args_with_jwt(name, jwt_token)
|
||||
if ok:
|
||||
updated += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
summary = {"updated": updated, "failed": failed, "total": len(servers)}
|
||||
if failed == 0:
|
||||
logger.info("MCP servers updated with JWT header", **summary)
|
||||
else:
|
||||
logger.warning("MCP servers update had failures", **summary)
|
||||
return summary
|
||||
|
||||
|
||||
Loading…
Add table
Reference in a new issue