diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 188756167..185788c81 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -1,39 +1,44 @@ from flask import Response, request from flask_login import current_user, login_required + +from api.db import VALID_MCP_SERVER_TYPES from api.db.db_models import MCPServer from api.db.services.mcp_server_service import MCPServerService from api.db.services.user_service import TenantService from api.settings import RetCode from api.utils import get_uuid from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request +from api.utils.web_utils import safe_json_parse -@manager.route("/list", methods=["GET"]) # noqa: F821 +@manager.route("/list", methods=["POST"]) # noqa: F821 @login_required -def get_list() -> Response: +def list_mcp() -> Response: + keywords = request.args.get("keywords", "") + page_number = int(request.args.get("page", 0)) + items_per_page = int(request.args.get("page_size", 0)) + orderby = request.args.get("orderby", "create_time") + if request.args.get("desc", "true").lower() == "false": + desc = False + else: + desc = True + + req = request.get_json() + mcp_ids = req.get("mcp_ids", []) try: - return get_json_result(data=MCPServerService.get_servers(current_user.id) or []) + servers = MCPServerService.get_servers(current_user.id, mcp_ids, page_number, items_per_page, orderby, desc, keywords) or [] + + return get_json_result(data={"mcp_servers": servers, "total": len(servers)}) except Exception as e: return server_error_response(e) -@manager.route("/get_multiple", methods=["POST"]) # noqa: F821 +@manager.route("/detail", methods=["GET"]) # noqa: F821 @login_required -@validate_request("id_list") -def get_multiple() -> Response: - req = request.json - +def detail() -> Response: + mcp_id = request.args["mcp_id"] try: - return get_json_result(data=MCPServerService.get_servers(current_user.id, id_list=req["id_list"]) or []) - except Exception as e: - return server_error_response(e) - - -@manager.route("/get/", methods=["GET"]) # noqa: F821 -@login_required -def get(ms_id: str) -> Response: - try: - mcp_server = MCPServerService.get_or_none(id=ms_id, tenant_id=current_user.id) + mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id) if mcp_server is None: return get_json_result(code=RetCode.NOT_FOUND, data=None) @@ -47,7 +52,18 @@ def get(ms_id: str) -> Response: @login_required @validate_request("name", "url", "server_type") def create() -> Response: - req = request.json + req = request.get_json() + + server_type = req.get("server_type", "") + if server_type not in VALID_MCP_SERVER_TYPES: + return get_data_error_result(message="Unsupported MCP server type.") + + server_name = req.get("name", "") + if not server_name or len(server_name.encode("utf-8")) > 255: + return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.") + + req["headers"] = safe_json_parse(req.get("headers", {})) + req["variables"] = safe_json_parse(req.get("variables", {})) try: req["id"] = get_uuid() @@ -58,9 +74,6 @@ def create() -> Response: if not e: return get_data_error_result(message="Tenant not found.") - if not req.get("headers"): - req["headers"] = {} - if not MCPServerService.insert(**req): return get_data_error_result() @@ -71,37 +84,131 @@ def create() -> Response: @manager.route("/update", methods=["POST"]) # noqa: F821 @login_required -@validate_request("id", "name", "url", "server_type") +@validate_request("id") def update() -> Response: - req = request.json + req = request.get_json() - if not req.get("headers"): - req["headers"] = {} + server_type = req.get("server_type", "") + if server_type and server_type not in VALID_MCP_SERVER_TYPES: + return get_data_error_result(message="Unsupported MCP server type.") + server_name = req.get("name", "") + if server_name and len(server_name.encode("utf-8")) > 255: + return get_data_error_result(message=f"Invaild MCP name or length is {len(server_name)} which is large than 255.") + + req["headers"] = safe_json_parse(req.get("headers", {})) + req["variables"] = safe_json_parse(req.get("variables", {})) try: req["tenant_id"] = current_user.id if not MCPServerService.filter_update([MCPServer.id == req["id"], MCPServer.tenant_id == req["tenant_id"]], req): - return get_data_error_result() + return get_data_error_result(message="Failed to updated MCP server.") - return get_json_result(data={"id": req["id"]}) + e, updated_mcp = MCPServerService.get_by_id(req["id"]) + if not e: + return get_data_error_result(message="Failed to fetch updated MCP server.") + + return get_json_result(data=updated_mcp.to_dict()) except Exception as e: return server_error_response(e) @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required -@validate_request("id") +@validate_request("mcp_ids") def rm() -> Response: - req = request.json - ms_id = req["id"] + req = request.get_json() + mcp_ids = req.get("mcp_ids", []) try: req["tenant_id"] = current_user.id - if not MCPServerService.filter_delete([MCPServer.id == ms_id, MCPServer.tenant_id == req["tenant_id"]]): - return get_data_error_result() + if not MCPServerService.delete_by_ids(mcp_ids): + return get_data_error_result(message=f"Failed to delete MCP servers {mcp_ids}") - return get_json_result(data={"id": req["id"]}) + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + + +@manager.route("/import", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("mcpServers") +def import_multiple() -> Response: + req = request.get_json() + servers = req.get("mcpServers", {}) + + if not servers: + return get_data_error_result(message="No MCP servers provided.") + + results = [] + try: + for server_name, config in servers.items(): + if not all(key in config for key in ["type", "url"]): + results.append({"server": server_name, "success": False, "message": "Missing required fields (type or url)"}) + continue + + base_name = server_name + new_name = base_name + counter = 0 + + while True: + e, _ = MCPServerService.get_by_name_and_tenant(name=new_name, tenant_id=current_user.id) + if not e: + break + new_name = f"{base_name}_{counter}" + counter += 1 + + create_data = { + "id": get_uuid(), + "tenant_id": current_user.id, + "name": new_name, + "url": config["url"], + "server_type": config["type"], + "variables": {"authorization_token": config.get("authorization_token", ""), "tool_configuration": config.get("tool_configuration", {})}, + } + + if MCPServerService.insert(**create_data): + result = {"server": server_name, "success": True, "action": "created", "id": create_data["id"], "new_name": new_name} + if new_name != base_name: + result["message"] = f"Renamed from '{base_name}' to avoid duplication" + + results.append(result) + else: + results.append({"server": server_name, "success": False, "message": "Failed to create MCP server."}) + + return get_json_result(data={"results": results}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/export", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("mcp_ids") +def export_multiple() -> Response: + req = request.get_json() + mcp_ids = req.get("mcp_ids", []) + + if not mcp_ids: + return get_data_error_result(message="No MCP server IDs provided.") + + try: + exported_servers = {} + + for mcp_id in mcp_ids: + e, mcp_server = MCPServerService.get_by_id(mcp_id) + + if e and mcp_server.tenant_id == current_user.id: + server_key = mcp_server.name + + exported_servers[server_key] = { + "type": mcp_server.server_type, + "url": mcp_server.url, + "name": mcp_server.name, + "authorization_token": mcp_server.variables.get("authorization_token", ""), + "tool_configuration": mcp_server.variables.get("tool_configuration", {}), + } + + return get_json_result(data={"mcpServers": exported_servers}) except Exception as e: return server_error_response(e) diff --git a/api/apps/search_app.py b/api/apps/search_app.py index 083e63083..25d11381c 100644 --- a/api/apps/search_app.py +++ b/api/apps/search_app.py @@ -40,8 +40,8 @@ def create(): return get_data_error_result(message="Search name must be string.") if search_name.strip() == "": return get_data_error_result(message="Search name can't be empty.") - if len(search_name.encode("utf-8")) > DATASET_NAME_LIMIT: - return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than {DATASET_NAME_LIMIT}") + if len(search_name.encode("utf-8")) > 255: + return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than 255.") e, _ = TenantService.get_by_id(current_user.id) if not e: return get_data_error_result(message="Authorizationd identity.") diff --git a/api/db/__init__.py b/api/db/__init__.py index 54cc98533..a7fb046f8 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -107,6 +107,8 @@ class CanvasType(StrEnum): class MCPServerType(StrEnum): SSE = "sse" - StreamableHttp = "streamable-http" + STREAMABLE_HTTP = "streamable-http" + +VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP} KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" diff --git a/api/db/db_models.py b/api/db/db_models.py index 839203a5f..5eba272d9 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -806,13 +806,13 @@ class MCPServer(DataBaseModel): url = CharField(max_length=2048, null=False, help_text="MCP Server URL") server_type = CharField(max_length=32, null=False, help_text="MCP Server type") description = TextField(null=True, help_text="MCP Server description") - variables = JSONField(null=True, default=[], help_text="MCP Server variables") - headers = JSONField(null=True, default={}, help_text="MCP Server additional request headers") + variables = JSONField(null=True, default=dict, help_text="MCP Server variables") + headers = JSONField(null=True, default=dict, help_text="MCP Server additional request headers") class Meta: db_table = "mcp_server" - + class Search(DataBaseModel): id = CharField(max_length=32, primary_key=True) avatar = TextField(null=True, help_text="avatar base64 string") @@ -949,6 +949,6 @@ def migrate_db(): except Exception: pass try: - migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=[]))) + migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict))) except Exception: pass diff --git a/api/db/services/mcp_server_service.py b/api/db/services/mcp_server_service.py index 43bc75f6c..869350094 100644 --- a/api/db/services/mcp_server_service.py +++ b/api/db/services/mcp_server_service.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from peewee import fn + from api.db.db_models import DB, MCPServer from api.db.services.common_service import CommonService @@ -31,7 +33,7 @@ class MCPServerService(CommonService): @classmethod @DB.connection_context() - def get_servers(cls, tenant_id: str, id_list: list[str] | None = None): + def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, keywords): """Retrieve all MCP servers associated with a tenant. This method fetches all MCP servers for a given tenant, ordered by creation time. @@ -46,16 +48,39 @@ class MCPServerService(CommonService): Returns None if no MCP servers are found. """ fields = [ - cls.model.id, cls.model.name, cls.model.server_type, cls.model.url, cls.model.description, - cls.model.variables, cls.model.update_date + cls.model.id, + cls.model.name, + cls.model.server_type, + cls.model.url, + cls.model.description, + cls.model.variables, + cls.model.create_date, + cls.model.update_date, ] - servers = cls.model.select(*fields).order_by(cls.model.create_time.desc()).where(cls.model.tenant_id == tenant_id) + query = cls.model.select(*fields).order_by(cls.model.create_time.desc()).where(cls.model.tenant_id == tenant_id) - if id_list is not None: - servers = servers.where(cls.model.id.in_(id_list)) + if id_list: + query = query.where(cls.model.id.in_(id_list)) + if keywords: + query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower())) + if desc: + query = query.order_by(cls.model.getter_by(orderby).desc()) + else: + query = query.order_by(cls.model.getter_by(orderby).asc()) + if page_number and items_per_page: + query = query.paginate(page_number, items_per_page) - servers = list(servers.dicts()) + servers = list(query.dicts()) if not servers: return None return servers + + @classmethod + @DB.connection_context() + def get_by_name_and_tenant(cls, name: str, tenant_id: str): + try: + mcp_server = cls.model.query(name=name, tenant_id=tenant_id) + return bool(mcp_server), mcp_server + except Exception: + return False, None diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index 084b7a6f7..de3d692dd 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -116,4 +116,14 @@ def is_valid_url(url: str) -> bool: return False except socket.gaierror: return False - return True \ No newline at end of file + return True + + +def safe_json_parse(data: str | dict) -> dict: + if isinstance(data, dict): + return data + try: + return json.loads(data) if data else {} + except (json.JSONDecodeError, TypeError): + return {} + diff --git a/mcp/server/simple_tools_server.py b/mcp/server/simple_tools_server.py deleted file mode 100644 index f5f9a5257..000000000 --- a/mcp/server/simple_tools_server.py +++ /dev/null @@ -1,23 +0,0 @@ -from mcp.server import FastMCP - - -app = FastMCP("simple-tools", port=8080) - - -@app.tool() -async def bad_calculator(a: int, b: int) -> str: - """ - A calculator to sum up two numbers (will give wrong answer) - - Args: - a: The first number - b: The second number - - Returns: - Sum of a and b - """ - return str(a + b + 200) - - -if __name__ == "__main__": - app.run(transport="sse")