diff --git a/admin/server/admin_server.py b/admin/server/admin_server.py index cfc5c4bee..3c0f56ec0 100644 --- a/admin/server/admin_server.py +++ b/admin/server/admin_server.py @@ -21,7 +21,7 @@ import time import threading import traceback from werkzeug.serving import run_simple -from flask import Flask +from quart import Quart from routes import admin_bp from common.log_utils import init_root_logger from common.constants import SERVICE_CONF @@ -30,7 +30,7 @@ from common import settings from config import load_configurations, SERVICE_CONFIGS from auth import init_default_admin, setup_auth from flask_session import Session -from flask_login import LoginManager +from quart_auth import LoginManager from common.versions import get_ragflow_version stop_event = threading.Event() diff --git a/admin/server/auth.py b/admin/server/auth.py index 564c348e3..2c8cb4e90 100644 --- a/admin/server/auth.py +++ b/admin/server/auth.py @@ -19,8 +19,8 @@ import logging import uuid from functools import wraps from datetime import datetime -from flask import request, jsonify -from flask_login import current_user, login_user +from quart import request, jsonify +from api.apps import current_user, login_user from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from api.common.exceptions import AdminException, UserNotFoundError diff --git a/admin/server/responses.py b/admin/server/responses.py index 54f841a83..de6600795 100644 --- a/admin/server/responses.py +++ b/admin/server/responses.py @@ -15,7 +15,7 @@ # -from flask import jsonify +from quart import jsonify def success_response(data=None, message="Success", code=0): diff --git a/admin/server/routes.py b/admin/server/routes.py index 2c70fbd7a..3b3cd2772 100644 --- a/admin/server/routes.py +++ b/admin/server/routes.py @@ -16,8 +16,8 @@ import secrets -from flask import Blueprint, request -from flask_login import current_user, logout_user, login_required +from quart import Blueprint, request +from api.apps import current_user, logout_user, login_required from auth import login_verify, login_admin, check_admin_auth from responses import success_response, error_response @@ -30,12 +30,12 @@ admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin') @admin_bp.route('/login', methods=['POST']) -def login(): - if not request.json: +async def login(): + if not await request.json: return error_response('Authorize admin failed.' ,400) try: - email = request.json.get("email", "") - password = request.json.get("password", "") + email = await request.json.get("email", "") + password = await request.json.get("password", "") return login_admin(email, password) except Exception as e: return error_response(str(e), 500) @@ -76,9 +76,9 @@ def list_users(): @admin_bp.route('/users', methods=['POST']) @login_required @check_admin_auth -def create_user(): +async def create_user(): try: - data = request.get_json() + data = await request.get_json() if not data or 'username' not in data or 'password' not in data: return error_response("Username and password are required", 400) @@ -120,9 +120,9 @@ def delete_user(username): @admin_bp.route('/users//password', methods=['PUT']) @login_required @check_admin_auth -def change_password(username): +async def change_password(username): try: - data = request.get_json() + data = await request.get_json() if not data or 'new_password' not in data: return error_response("New password is required", 400) @@ -139,9 +139,9 @@ def change_password(username): @admin_bp.route('/users//activate', methods=['PUT']) @login_required @check_admin_auth -def alter_user_activate_status(username): +async def alter_user_activate_status(username): try: - data = request.get_json() + data = await request.get_json() if not data or 'activate_status' not in data: return error_response("Activation status is required", 400) activate_status = data['activate_status'] @@ -253,9 +253,9 @@ def restart_service(service_id): @admin_bp.route('/roles', methods=['POST']) @login_required @check_admin_auth -def create_role(): +async def create_role(): try: - data = request.get_json() + data = await request.get_json() if not data or 'role_name' not in data: return error_response("Role name is required", 400) role_name: str = data['role_name'] @@ -269,9 +269,9 @@ def create_role(): @admin_bp.route('/roles/', methods=['PUT']) @login_required @check_admin_auth -def update_role(role_name: str): +async def update_role(role_name: str): try: - data = request.get_json() + data = await request.get_json() if not data or 'description' not in data: return error_response("Role description is required", 400) description: str = data['description'] @@ -317,9 +317,9 @@ def get_role_permission(role_name: str): @admin_bp.route('/roles//permission', methods=['POST']) @login_required @check_admin_auth -def grant_role_permission(role_name: str): +async def grant_role_permission(role_name: str): try: - data = request.get_json() + data = await request.get_json() if not data or 'actions' not in data or 'resource' not in data: return error_response("Permission is required", 400) actions: list = data['actions'] @@ -333,9 +333,9 @@ def grant_role_permission(role_name: str): @admin_bp.route('/roles//permission', methods=['DELETE']) @login_required @check_admin_auth -def revoke_role_permission(role_name: str): +async def revoke_role_permission(role_name: str): try: - data = request.get_json() + data = await request.get_json() if not data or 'actions' not in data or 'resource' not in data: return error_response("Permission is required", 400) actions: list = data['actions'] @@ -349,9 +349,9 @@ def revoke_role_permission(role_name: str): @admin_bp.route('/users//role', methods=['PUT']) @login_required @check_admin_auth -def update_user_role(user_name: str): +async def update_user_role(user_name: str): try: - data = request.get_json() + data = await request.get_json() if not data or 'role_name' not in data: return error_response("Role name is required", 400) role_name: str = data['role_name'] diff --git a/api/apps/__init__.py b/api/apps/__init__.py index f2009db2c..99dae7f5d 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -18,30 +18,33 @@ import sys import logging from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from flask import Blueprint, Flask +from quart import Blueprint, Quart, request, g, current_app, session from werkzeug.wrappers.request import Request -from flask_cors import CORS from flasgger import Swagger from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer - +from quart_cors import cors from common.constants import StatusEnum -from api.db.db_models import close_connection +from api.db.db_models import close_connection, User from api.db.services import UserService from api.utils.json_encode import CustomJSONEncoder from api.utils import commands from flask_mail import Mail -from flask_session import Session -from flask_login import LoginManager +from quart_session import Session +from quart_auth import QuartAuth, Unauthorized from common import settings from api.utils.api_utils import server_error_response from api.constants import API_VERSION +from common.misc_utils import get_uuid + +settings.init_settings() __all__ = ["app"] Request.json = property(lambda self: self.get_json(force=True, silent=True)) -app = Flask(__name__) +app = Quart(__name__) +app = cors(app, allow_origin="*") smtp_mail_server = Mail() # Add this at the beginning of your file to configure Swagger UI @@ -76,7 +79,6 @@ swagger = Swagger( }, ) -CORS(app, supports_credentials=True, max_age=2592000) app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) @@ -84,17 +86,143 @@ app.errorhandler(Exception)(server_error_response) ## convince for dev and debug # app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False -app.config["SESSION_TYPE"] = "filesystem" +app.config["SESSION_TYPE"] = "redis" +app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis") app.config["MAX_CONTENT_LENGTH"] = int( os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024) ) - -Session(app) -login_manager = LoginManager() -login_manager.init_app(app) - +app.config['SECRET_KEY'] = settings.SECRET_KEY +app.secret_key = settings.SECRET_KEY commands.register_commands(app) +from functools import wraps +from typing import ParamSpec, TypeVar +from collections.abc import Awaitable, Callable +from werkzeug.local import LocalProxy + +T = TypeVar("T") +P = ParamSpec("P") + +def _load_user(): + jwt = Serializer(secret_key=settings.SECRET_KEY) + authorization = request.headers.get("Authorization") + g.user = None + if not authorization: + return + + try: + access_token = str(jwt.loads(authorization)) + + if not access_token or not access_token.strip(): + logging.warning("Authentication attempt with empty access token") + return None + + # Access tokens should be UUIDs (32 hex characters) + if len(access_token.strip()) < 32: + logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") + return None + + user = UserService.query( + access_token=access_token, status=StatusEnum.VALID.value + ) + if user: + if not user[0].access_token or not user[0].access_token.strip(): + logging.warning(f"User {user[0].email} has empty access_token in database") + return None + g.user = user[0] + return user[0] + except Exception as e: + logging.warning(f"load_user got exception {e}") + + +current_user = LocalProxy(_load_user) + + +def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + """A decorator to restrict route access to authenticated users. + + This should be used to wrap a route handler (or view function) to + enforce that only authenticated requests can access it. Note that + it is important that this decorator be wrapped by the route + decorator and not vice, versa, as below. + + .. code-block:: python + + @app.route('/') + @login_required + async def index(): + ... + + If the request is not authenticated a + `quart.exceptions.Unauthorized` exception will be raised. + + """ + + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + if not current_user or not session.get("_user_id"): + raise Unauthorized() + else: + return await current_app.ensure_async(func)(*args, **kwargs) + + return wrapper + + +def login_user(user, remember=False, duration=None, force=False, fresh=True): + """ + Logs a user in. You should pass the actual user object to this. If the + user's `is_active` property is ``False``, they will not be logged in + unless `force` is ``True``. + + This will return ``True`` if the log in attempt succeeds, and ``False`` if + it fails (i.e. because the user is inactive). + + :param user: The user object to log in. + :type user: object + :param remember: Whether to remember the user after their session expires. + Defaults to ``False``. + :type remember: bool + :param duration: The amount of time before the remember cookie expires. If + ``None`` the value set in the settings is used. Defaults to ``None``. + :type duration: :class:`datetime.timedelta` + :param force: If the user is inactive, setting this to ``True`` will log + them in regardless. Defaults to ``False``. + :type force: bool + :param fresh: setting this to ``False`` will log in the user with a session + marked as not "fresh". Defaults to ``True``. + :type fresh: bool + """ + if not force and not user.is_active: + return False + + session["_user_id"] = user.id + session["_fresh"] = fresh + session["_id"] = get_uuid() + return True + + +def logout_user(): + """ + Logs a user out. (You do not need to pass the actual user.) This will + also clean up the remember me cookie if it exists. + """ + if "_user_id" in session: + session.pop("_user_id") + + if "_fresh" in session: + session.pop("_fresh") + + if "_id" in session: + session.pop("_id") + + COOKIE_NAME = "remember_token" + cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME) + if cookie_name in request.cookies: + session["_remember"] = "clear" + if "_remember_seconds" in session: + session.pop("_remember_seconds") + + return True def search_pages_path(pages_dir): app_path_list = [ @@ -142,40 +270,6 @@ client_urls_prefix = [ ] -@login_manager.request_loader -def load_user(web_request): - jwt = Serializer(secret_key=settings.SECRET_KEY) - authorization = web_request.headers.get("Authorization") - if authorization: - try: - access_token = str(jwt.loads(authorization)) - - if not access_token or not access_token.strip(): - logging.warning("Authentication attempt with empty access token") - return None - - # Access tokens should be UUIDs (32 hex characters) - if len(access_token.strip()) < 32: - logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") - return None - - user = UserService.query( - access_token=access_token, status=StatusEnum.VALID.value - ) - if user: - if not user[0].access_token or not user[0].access_token.strip(): - logging.warning(f"User {user[0].email} has empty access_token in database") - return None - return user[0] - else: - return None - except Exception as e: - logging.warning(f"load_user got exception {e}") - return None - else: - return None - - @app.teardown_request def _db_close(exc): close_connection() diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 1c9a78239..fee34cc6e 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -14,20 +14,20 @@ # limitations under the License. # from datetime import datetime, timedelta -from flask import request -from flask_login import login_required, current_user +from quart import request from api.db.db_models import APIToken from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.user_service import UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ generate_confirmation_token from common.time_utils import current_timestamp, datetime_format +from api.apps import login_required, current_user @manager.route('/new_token', methods=['POST']) # noqa: F821 @login_required -def new_token(): - req = request.json +async def new_token(): + req = await request.json try: tenants = UserTenantService.query(user_id=current_user.id) if not tenants: @@ -72,8 +72,8 @@ def token_list(): @manager.route('/rm', methods=['POST']) # noqa: F821 @validate_request("tokens", "tenant_id") @login_required -def rm(): - req = request.json +async def rm(): + req = await request.json try: for token in req["tokens"]: APITokenService.filter_delete( diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 0ac2951ae..3c4c72650 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -18,12 +18,8 @@ import logging import re import sys from functools import partial - -import flask import trio -from flask import request, Response -from flask_login import login_required, current_user - +from quart import request, Response, make_response from agent.component import LLM from api.db import CanvasCategory, FileType from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService @@ -46,6 +42,7 @@ from rag.flow.pipeline import Pipeline from rag.nlp import search from rag.utils.redis_conn import REDIS_CONN from common import settings +from api.apps import login_required, current_user @manager.route('/templates', methods=['GET']) # noqa: F821 @@ -57,8 +54,9 @@ def templates(): @manager.route('/rm', methods=['POST']) # noqa: F821 @validate_request("canvas_ids") @login_required -def rm(): - for i in request.json["canvas_ids"]: +async def rm(): + req = await request.json + for i in req["canvas_ids"]: if not UserCanvasService.accessible(i, current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', @@ -70,8 +68,8 @@ def rm(): @manager.route('/set', methods=['POST']) # noqa: F821 @validate_request("dsl", "title") @login_required -def save(): - req = request.json +async def save(): + req = await request.json if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.loads(req["dsl"]) @@ -129,8 +127,8 @@ def getsse(canvas_id): @manager.route('/completion', methods=['POST']) # noqa: F821 @validate_request("id") @login_required -def run(): - req = request.json +async def run(): + req = await request.json query = req.get("query", "") files = req.get("files", []) inputs = req.get("inputs", {}) @@ -179,15 +177,15 @@ def run(): resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - resp.call_on_close(lambda: canvas.cancel_task()) + #resp.call_on_close(lambda: canvas.cancel_task()) return resp @manager.route('/rerun', methods=['POST']) # noqa: F821 @validate_request("id", "dsl", "component_id") @login_required -def rerun(): - req = request.json +async def rerun(): + req = await request.json doc = PipelineOperationLogService.get_documents_info(req["id"]) if not doc: return get_data_error_result(message="Document not found.") @@ -224,8 +222,8 @@ def cancel(task_id): @manager.route('/reset', methods=['POST']) # noqa: F821 @validate_request("id") @login_required -def reset(): - req = request.json +async def reset(): + req = await request.json if not UserCanvasService.accessible(req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', @@ -342,8 +340,8 @@ def input_form(): @manager.route('/debug', methods=['POST']) # noqa: F821 @validate_request("id", "component_id", "params") @login_required -def debug(): - req = request.json +async def debug(): + req = await request.json if not UserCanvasService.accessible(req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', @@ -374,8 +372,8 @@ def debug(): @manager.route('/test_db_connect', methods=['POST']) # noqa: F821 @validate_request("db_type", "database", "username", "host", "port", "password") @login_required -def test_db_connect(): - req = request.json +async def test_db_connect(): + req = await request.json try: if req["db_type"] in ["mysql", "mariadb"]: db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], @@ -520,8 +518,8 @@ def list_canvas(): @manager.route('/setting', methods=['POST']) # noqa: F821 @validate_request("id", "title", "permission") @login_required -def setting(): - req = request.json +async def setting(): + req = await request.json req["user_id"] = current_user.id if not UserCanvasService.accessible(req["id"], current_user.id): @@ -602,8 +600,8 @@ def prompts(): @manager.route('/download', methods=['GET']) # noqa: F821 -def download(): +async def download(): id = request.args.get("id") created_by = request.args.get("created_by") blob = FileService.get_blob(created_by, id) - return flask.make_response(blob) + return await make_response(blob) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 78a614ddf..e96548041 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -18,8 +18,7 @@ import json import re import xxhash -from flask import request -from flask_login import current_user, login_required +from quart import request from api.db.services.dialog_service import meta_filter from api.db.services.document_service import DocumentService @@ -35,13 +34,14 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr from common.string_utils import remove_redundant_spaces from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD from common import settings +from api.apps import login_required, current_user @manager.route('/list', methods=['POST']) # noqa: F821 @login_required @validate_request("doc_id") -def list_chunk(): - req = request.json +async def list_chunk(): + req = await request.json doc_id = req["doc_id"] page = int(req.get("page", 1)) size = int(req.get("size", 30)) @@ -121,8 +121,8 @@ def get(): @manager.route('/set', methods=['POST']) # noqa: F821 @login_required @validate_request("doc_id", "chunk_id", "content_with_weight") -def set(): - req = request.json +async def set(): + req = await request.json d = { "id": req["chunk_id"], "content_with_weight": req["content_with_weight"]} @@ -178,8 +178,8 @@ def set(): @manager.route('/switch', methods=['POST']) # noqa: F821 @login_required @validate_request("chunk_ids", "available_int", "doc_id") -def switch(): - req = request.json +async def switch(): + req = await request.json try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -198,8 +198,8 @@ def switch(): @manager.route('/rm', methods=['POST']) # noqa: F821 @login_required @validate_request("chunk_ids", "doc_id") -def rm(): - req = request.json +async def rm(): + req = await request.json try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -222,8 +222,8 @@ def rm(): @manager.route('/create', methods=['POST']) # noqa: F821 @login_required @validate_request("doc_id", "content_with_weight") -def create(): - req = request.json +async def create(): + req = await request.json chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]} @@ -280,8 +280,8 @@ def create(): @manager.route('/retrieval_test', methods=['POST']) # noqa: F821 @login_required @validate_request("kb_id", "question") -def retrieval_test(): - req = request.json +async def retrieval_test(): + req = await request.json page = int(req.get("page", 1)) size = int(req.get("size", 30)) question = req["question"] diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index 23965e617..4dc30c8e4 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -20,8 +20,8 @@ import uuid from html import escape from typing import Any -from flask import make_response, request -from flask_login import current_user, login_required +import flask +from quart import make_response, request from google_auth_oauthlib.flow import Flow from api.db import InputType @@ -32,12 +32,13 @@ from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, Docum from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES from common.misc_utils import get_uuid from rag.utils.redis_conn import REDIS_CONN +from api.apps import login_required, current_user @manager.route("/set", methods=["POST"]) # noqa: F821 @login_required -def set_connector(): - req = request.json +async def set_connector(): + req = await request.json if req.get("id"): conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} ConnectorService.update_by_id(req["id"], conn) @@ -89,8 +90,8 @@ def list_logs(connector_id): @manager.route("//resume", methods=["PUT"]) # noqa: F821 @login_required -def resume(connector_id): - req = request.json +async def resume(connector_id): + req = await request.json if req.get("resume"): ConnectorService.resume(connector_id, TaskStatus.SCHEDULE) else: @@ -101,8 +102,8 @@ def resume(connector_id): @manager.route("//rebuild", methods=["PUT"]) # noqa: F821 @login_required @validate_request("kb_id") -def rebuild(connector_id): - req = request.json +async def rebuild(connector_id): + req = await request.json err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id) if err: return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR) @@ -164,7 +165,7 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str): payload_json=payload_json, auto_close=auto_close, ) - response = make_response(html, 200) + response = flask.make_response(html, 200) response.headers["Content-Type"] = "text/html; charset=utf-8" return response @@ -172,14 +173,14 @@ def _render_web_oauth_popup(flow_id: str, success: bool, message: str): @manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821 @login_required @validate_request("credentials") -def start_google_drive_web_oauth(): +async def start_google_drive_web_oauth(): if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI: return get_json_result( code=RetCode.SERVER_ERROR, message="Google Drive OAuth redirect URI is not configured on the server.", ) - req = request.json or {} + req = await request.json or {} raw_credentials = req.get("credentials", "") try: credentials = _load_credentials(raw_credentials) @@ -280,8 +281,8 @@ def google_drive_web_oauth_callback(): @manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") -def poll_google_drive_web_result(): - req = request.json or {} +async def poll_google_drive_web_result(): + req = await request.json or {} flow_id = req.get("flow_id") cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id)) if not cache_raw: diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 984e57cac..a44d97a1a 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -17,8 +17,8 @@ import json import re import logging from copy import deepcopy -from flask import Response, request -from flask_login import current_user, login_required +from quart import Response, request +from api.apps import current_user, login_required from api.db.db_models import APIToken from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap @@ -34,8 +34,8 @@ from common.constants import RetCode, LLMType @manager.route("/set", methods=["POST"]) # noqa: F821 @login_required -def set_conversation(): - req = request.json +async def set_conversation(): + req = await request.json conv_id = req.get("conversation_id") is_new = req.get("is_new") name = req.get("name", "New conversation") @@ -129,8 +129,9 @@ def getsse(dialog_id): @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required -def rm(): - conv_ids = request.json["conversation_ids"] +async def rm(): + req = await request.json + conv_ids = req["conversation_ids"] try: for cid in conv_ids: exist, conv = ConversationService.get_by_id(cid) @@ -166,8 +167,8 @@ def list_conversation(): @manager.route("/completion", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id", "messages") -def completion(): - req = request.json +async def completion(): + req = await request.json msg = [] for m in req["messages"]: if m["role"] == "system": @@ -251,8 +252,8 @@ def completion(): @manager.route("/tts", methods=["POST"]) # noqa: F821 @login_required -def tts(): - req = request.json +async def tts(): + req = await request.json text = req["text"] tenants = TenantService.get_info_by(current_user.id) @@ -284,8 +285,8 @@ def tts(): @manager.route("/delete_msg", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id", "message_id") -def delete_msg(): - req = request.json +async def delete_msg(): + req = await request.json e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(message="Conversation not found!") @@ -307,8 +308,8 @@ def delete_msg(): @manager.route("/thumbup", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id", "message_id") -def thumbup(): - req = request.json +async def thumbup(): + req = await request.json e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(message="Conversation not found!") @@ -334,8 +335,8 @@ def thumbup(): @manager.route("/ask", methods=["POST"]) # noqa: F821 @login_required @validate_request("question", "kb_ids") -def ask_about(): - req = request.json +async def ask_about(): + req = await request.json uid = current_user.id search_id = req.get("search_id", "") @@ -366,8 +367,8 @@ def ask_about(): @manager.route("/mindmap", methods=["POST"]) # noqa: F821 @login_required @validate_request("question", "kb_ids") -def mindmap(): - req = request.json +async def mindmap(): + req = await request.json search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} search_config = search_app.get("search_config", {}) if search_app else {} @@ -384,8 +385,8 @@ def mindmap(): @manager.route("/related_questions", methods=["POST"]) # noqa: F821 @login_required @validate_request("question") -def related_questions(): - req = request.json +async def related_questions(): + req = await request.json search_id = req.get("search_id", "") search_config = {} diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 99f700568..74c72f824 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -14,8 +14,7 @@ # limitations under the License. # -from flask import request -from flask_login import login_required, current_user +from quart import request from api.db.services import duplicate_name from api.db.services.dialog_service import DialogService from common.constants import StatusEnum @@ -26,13 +25,14 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va from common.misc_utils import get_uuid from common.constants import RetCode from api.utils.api_utils import get_json_result +from api.apps import login_required, current_user @manager.route('/set', methods=['POST']) # noqa: F821 @validate_request("prompt_config") @login_required -def set_dialog(): - req = request.json +async def set_dialog(): + req = await request.json dialog_id = req.get("dialog_id", "") is_create = not dialog_id name = req.get("name", "New Dialog") @@ -169,18 +169,19 @@ def list_dialogs(): @manager.route('/next', methods=['POST']) # noqa: F821 @login_required -def list_dialogs_next(): - keywords = request.args.get("keywords", "") - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - parser_id = request.args.get("parser_id") - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": +async def list_dialogs_next(): + args = request.args + keywords = args.get("keywords", "") + page_number = int(args.get("page", 0)) + items_per_page = int(args.get("page_size", 0)) + parser_id = args.get("parser_id") + orderby = args.get("orderby", "create_time") + if args.get("desc", "true").lower() == "false": desc = False else: desc = True - req = request.get_json() + req = await request.get_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -207,8 +208,8 @@ def list_dialogs_next(): @manager.route('/rm', methods=['POST']) # noqa: F821 @login_required @validate_request("dialog_ids") -def rm(): - req = request.json +async def rm(): + req = await request.json dialog_list=[] tenants = UserTenantService.query(user_id=current_user.id) try: diff --git a/api/apps/document_app.py b/api/apps/document_app.py index c2e37598e..7eea8a919 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -18,10 +18,8 @@ import os.path import pathlib import re from pathlib import Path - -import flask -from flask import request -from flask_login import current_user, login_required +from quart import request, make_response +from api.apps import current_user, login_required from api.common.check_team_permission import check_kb_team_permission from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX @@ -152,8 +150,8 @@ def web_crawl(): @manager.route("/create", methods=["POST"]) # noqa: F821 @login_required @validate_request("name", "kb_id") -def create(): - req = request.json +async def create(): + req = await request.json kb_id = req["kb_id"] if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -208,7 +206,7 @@ def create(): @manager.route("/list", methods=["POST"]) # noqa: F821 @login_required -def list_docs(): +async def list_docs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -230,7 +228,7 @@ def list_docs(): create_time_from = int(request.args.get("create_time_from", 0)) create_time_to = int(request.args.get("create_time_to", 0)) - req = request.get_json() + req = await request.get_json() run_status = req.get("run_status", []) if run_status: @@ -270,8 +268,8 @@ def list_docs(): @manager.route("/filter", methods=["POST"]) # noqa: F821 @login_required -def get_filter(): - req = request.get_json() +async def get_filter(): + req = await request.get_json() kb_id = req.get("kb_id") if not kb_id: @@ -308,8 +306,8 @@ def get_filter(): @manager.route("/infos", methods=["POST"]) # noqa: F821 @login_required -def docinfos(): - req = request.json +async def docinfos(): + req = await request.json doc_ids = req["doc_ids"] for doc_id in doc_ids: if not DocumentService.accessible(doc_id, current_user.id): @@ -340,8 +338,8 @@ def thumbnails(): @manager.route("/change_status", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_ids", "status") -def change_status(): - req = request.get_json() +async def change_status(): + req = await request.get_json() doc_ids = req.get("doc_ids", []) status = str(req.get("status", "")) @@ -380,8 +378,8 @@ def change_status(): @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id") -def rm(): - req = request.json +async def rm(): + req = await request.json doc_ids = req["doc_id"] if isinstance(doc_ids, str): doc_ids = [doc_ids] @@ -401,8 +399,8 @@ def rm(): @manager.route("/run", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_ids", "run") -def run(): - req = request.json +async def run(): + req = await request.json for doc_id in req["doc_ids"]: if not DocumentService.accessible(doc_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) @@ -448,8 +446,8 @@ def run(): @manager.route("/rename", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id", "name") -def rename(): - req = request.json +async def rename(): + req = await request.json if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: @@ -495,14 +493,14 @@ def rename(): @manager.route("/get/", methods=["GET"]) # noqa: F821 # @login_required -def get(doc_id): +async def get(doc_id): try: e, doc = DocumentService.get_by_id(doc_id) if not e: return get_data_error_result(message="Document not found!") b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - response = flask.make_response(settings.STORAGE_IMPL.get(b, n)) + response = await make_response(settings.STORAGE_IMPL.get(b, n)) ext = re.search(r"\.([^.]+)$", doc.name.lower()) ext = ext.group(1) if ext else None @@ -520,9 +518,9 @@ def get(doc_id): @manager.route("/change_parser", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id") -def change_parser(): +async def change_parser(): - req = request.json + req = await request.json if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) @@ -572,13 +570,13 @@ def change_parser(): @manager.route("/image/", methods=["GET"]) # noqa: F821 # @login_required -def get_image(image_id): +async def get_image(image_id): try: arr = image_id.split("-") if len(arr) != 2: return get_data_error_result(message="Image not found.") bkt, nm = image_id.split("-") - response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm)) + response = await make_response(settings.STORAGE_IMPL.get(bkt, nm)) response.headers.set("Content-Type", "image/JPEG") return response except Exception as e: @@ -604,8 +602,8 @@ def upload_and_parse(): @manager.route("/parse", methods=["POST"]) # noqa: F821 @login_required -def parse(): - url = request.json.get("url") if request.json else "" +async def parse(): + url = await request.json.get("url") if await request.json else "" if url: if not is_valid_url(url): return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) @@ -658,8 +656,8 @@ def parse(): @manager.route("/set_meta", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id", "meta") -def set_meta(): - req = request.json +async def set_meta(): + req = await request.json if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: diff --git a/api/apps/file2document_app.py b/api/apps/file2document_app.py index ca1e6b096..1f8921e92 100644 --- a/api/apps/file2document_app.py +++ b/api/apps/file2document_app.py @@ -19,8 +19,8 @@ from pathlib import Path from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService -from flask import request -from flask_login import login_required, current_user +from quart import request +from api.apps import login_required, current_user from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from common.misc_utils import get_uuid @@ -33,8 +33,8 @@ from api.utils.api_utils import get_json_result @manager.route('/convert', methods=['POST']) # noqa: F821 @login_required @validate_request("file_ids", "kb_ids") -def convert(): - req = request.json +async def convert(): + req = await request.json kb_ids = req["kb_ids"] file_ids = req["file_ids"] file2documents = [] @@ -103,8 +103,8 @@ def convert(): @manager.route('/rm', methods=['POST']) # noqa: F821 @login_required @validate_request("file_ids") -def rm(): - req = request.json +async def rm(): + req = await request.json file_ids = req["file_ids"] if not file_ids: return get_json_result( diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 279e32525..caad4a767 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -17,10 +17,8 @@ import logging import os import pathlib import re - -import flask -from flask import request -from flask_login import login_required, current_user +from quart import request, make_response +from api.apps import login_required, current_user from api.common.check_team_permission import check_file_team_permission from api.db.services.document_service import DocumentService @@ -123,10 +121,10 @@ def upload(): @manager.route('/create', methods=['POST']) # noqa: F821 @login_required @validate_request("name") -def create(): - req = request.json - pf_id = request.json.get("parent_id") - input_file_type = request.json.get("type") +async def create(): + req = await request.json + pf_id = await request.json.get("parent_id") + input_file_type = await request.json.get("type") if not pf_id: root_folder = FileService.get_root_folder(current_user.id) pf_id = root_folder["id"] @@ -238,8 +236,8 @@ def get_all_parent_folders(): @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required @validate_request("file_ids") -def rm(): - req = request.json +async def rm(): + req = await request.json file_ids = req["file_ids"] def _delete_single_file(file): @@ -299,8 +297,8 @@ def rm(): @manager.route('/rename', methods=['POST']) # noqa: F821 @login_required @validate_request("file_id", "name") -def rename(): - req = request.json +async def rename(): + req = await request.json try: e, file = FileService.get_by_id(req["file_id"]) if not e: @@ -338,7 +336,7 @@ def rename(): @manager.route('/get/', methods=['GET']) # noqa: F821 @login_required -def get(file_id): +async def get(file_id): try: e, file = FileService.get_by_id(file_id) if not e: @@ -351,7 +349,7 @@ def get(file_id): b, n = File2DocumentService.get_storage_address(file_id=file_id) blob = settings.STORAGE_IMPL.get(b, n) - response = flask.make_response(blob) + response = await make_response(blob) ext = re.search(r"\.([^.]+)$", file.name.lower()) ext = ext.group(1) if ext else None if ext: @@ -368,8 +366,8 @@ def get(file_id): @manager.route("/mv", methods=["POST"]) # noqa: F821 @login_required @validate_request("src_file_ids", "dest_file_id") -def move(): - req = request.json +async def move(): + req = await request.json try: file_ids = req["src_file_ids"] dest_parent_id = req["dest_file_id"] diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 4546b2586..b77880626 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -17,11 +17,9 @@ import json import logging import random -from flask import request -from flask_login import login_required, current_user +from quart import request import numpy as np - from api.db.services.connector_service import Connector2KbService from api.db.services.llm_service import LLMBundle from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks @@ -41,12 +39,14 @@ from rag.utils.redis_conn import REDIS_CONN from rag.utils.doc_store_conn import OrderByExpr from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD from common import settings +from api.apps import login_required, current_user + @manager.route('/create', methods=['post']) # noqa: F821 @login_required @validate_request("name") -def create(): - req = request.json +async def create(): + req = await request.json req = KnowledgebaseService.create_with_name( name = req.pop("name", None), tenant_id = current_user.id, @@ -66,8 +66,8 @@ def create(): @login_required @validate_request("kb_id", "name", "description", "parser_id") @not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") -def update(): - req = request.json +async def update(): + req = await request.json if not isinstance(req["name"], str): return get_data_error_result(message="Dataset name must be string.") if req["name"].strip() == "": @@ -165,18 +165,19 @@ def detail(): @manager.route('/list', methods=['POST']) # noqa: F821 @login_required -def list_kbs(): - keywords = request.args.get("keywords", "") - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - parser_id = request.args.get("parser_id") - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": +async def list_kbs(): + args = request.args + keywords = args.get("keywords", "") + page_number = int(args.get("page", 0)) + items_per_page = int(args.get("page_size", 0)) + parser_id = args.get("parser_id") + orderby = args.get("orderby", "create_time") + if args.get("desc", "true").lower() == "false": desc = False else: desc = True - req = request.get_json() + req = await request.get_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -198,11 +199,12 @@ def list_kbs(): except Exception as e: return server_error_response(e) + @manager.route('/rm', methods=['post']) # noqa: F821 @login_required @validate_request("kb_id") -def rm(): - req = request.json +async def rm(): + req = await request.json if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): return get_json_result( data=False, @@ -278,8 +280,8 @@ def list_tags_from_kbs(): @manager.route('//rm_tags', methods=['POST']) # noqa: F821 @login_required -def rm_tags(kb_id): - req = request.json +async def rm_tags(kb_id): + req = await request.json if not KnowledgebaseService.accessible(kb_id, current_user.id): return get_json_result( data=False, @@ -298,8 +300,8 @@ def rm_tags(kb_id): @manager.route('//rename_tag', methods=['POST']) # noqa: F821 @login_required -def rename_tags(kb_id): - req = request.json +async def rename_tags(kb_id): + req = await request.json if not KnowledgebaseService.accessible(kb_id, current_user.id): return get_json_result( data=False, @@ -402,7 +404,7 @@ def get_basic_info(): @manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821 @login_required -def list_pipeline_logs(): +async def list_pipeline_logs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -421,7 +423,7 @@ def list_pipeline_logs(): if create_date_to > create_date_from: return get_data_error_result(message="Create data filter is abnormal.") - req = request.get_json() + req = await request.get_json() operation_status = req.get("operation_status", []) if operation_status: @@ -446,7 +448,7 @@ def list_pipeline_logs(): @manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821 @login_required -def list_pipeline_dataset_logs(): +async def list_pipeline_dataset_logs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -463,7 +465,7 @@ def list_pipeline_dataset_logs(): if create_date_to > create_date_from: return get_data_error_result(message="Create data filter is abnormal.") - req = request.get_json() + req = await request.get_json() operation_status = req.get("operation_status", []) if operation_status: @@ -480,12 +482,12 @@ def list_pipeline_dataset_logs(): @manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821 @login_required -def delete_pipeline_logs(): +async def delete_pipeline_logs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - req = request.get_json() + req = await request.get_json() log_ids = req.get("log_ids", []) PipelineOperationLogService.delete_by_ids(log_ids) @@ -509,8 +511,8 @@ def pipeline_log_detail(): @manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 @login_required -def run_graphrag(): - req = request.json +async def run_graphrag(): + req = await request.json kb_id = req.get("kb_id", "") if not kb_id: @@ -578,8 +580,8 @@ def trace_graphrag(): @manager.route("/run_raptor", methods=["POST"]) # noqa: F821 @login_required -def run_raptor(): - req = request.json +async def run_raptor(): + req = await request.json kb_id = req.get("kb_id", "") if not kb_id: @@ -647,8 +649,8 @@ def trace_raptor(): @manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 @login_required -def run_mindmap(): - req = request.json +async def run_mindmap(): + req = await request.json kb_id = req.get("kb_id", "") if not kb_id: @@ -761,7 +763,7 @@ def delete_kb_task(): @manager.route("/check_embedding", methods=["post"]) # noqa: F821 @login_required -def check_embedding(): +async def check_embedding(): def _guess_vec_field(src: dict) -> str | None: for k in src or {}: @@ -847,7 +849,7 @@ def check_embedding(): "content_with_weight": full_doc.get("content_with_weight") or "", }) return out - req = request.json + req = await request.json kb_id = req.get("kb_id", "") embd_id = req.get("embd_id", "") n = int(req.get("check_num", 5)) diff --git a/api/apps/langfuse_app.py b/api/apps/langfuse_app.py index 151c40fcd..ffdc6a5fd 100644 --- a/api/apps/langfuse_app.py +++ b/api/apps/langfuse_app.py @@ -15,8 +15,8 @@ # -from flask import request -from flask_login import current_user, login_required +from quart import request +from api.apps import current_user, login_required from langfuse import Langfuse from api.db.db_models import DB @@ -27,8 +27,8 @@ from api.utils.api_utils import get_error_data_result, get_json_result, server_e @manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821 @login_required @validate_request("secret_key", "public_key", "host") -def set_api_key(): - req = request.get_json() +async def set_api_key(): + req = await request.get_json() secret_key = req.get("secret_key", "") public_key = req.get("public_key", "") host = req.get("host", "") diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index c34d71cc0..29da88c4f 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -16,8 +16,9 @@ import logging import json import os -from flask import request -from flask_login import login_required, current_user +from quart import request + +from api.apps import login_required, current_user from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.llm_service import LLMService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request @@ -52,8 +53,8 @@ def factories(): @manager.route("/set_api_key", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "api_key") -def set_api_key(): - req = request.json +async def set_api_key(): + req = await request.json # test if api key works chat_passed, embd_passed, rerank_passed = False, False, False factory = req["llm_factory"] @@ -122,8 +123,8 @@ def set_api_key(): @manager.route("/add_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory") -def add_llm(): - req = request.json +async def add_llm(): + req = await request.json factory = req["llm_factory"] api_key = req.get("api_key", "x") llm_name = req.get("llm_name") @@ -142,11 +143,11 @@ def add_llm(): elif factory == "Tencent Hunyuan": req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"]) - return set_api_key() + return await set_api_key() elif factory == "Tencent Cloud": req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]) - return set_api_key() + return await set_api_key() elif factory == "Bedrock": # For Bedrock, due to its special authentication method @@ -267,8 +268,8 @@ def add_llm(): @manager.route("/delete_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "llm_name") -def delete_llm(): - req = request.json +async def delete_llm(): + req = await request.json TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) return get_json_result(data=True) @@ -276,8 +277,8 @@ def delete_llm(): @manager.route("/enable_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "llm_name") -def enable_llm(): - req = request.json +async def enable_llm(): + req = await request.json TenantLLMService.filter_update( [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))} ) @@ -287,8 +288,8 @@ def enable_llm(): @manager.route("/delete_factory", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory") -def delete_factory(): - req = request.json +async def delete_factory(): + req = await request.json TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) return get_json_result(data=True) diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 66d447491..ad78735c5 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from flask import Response, request -from flask_login import current_user, login_required +from quart import Response, request +from api.apps import current_user, login_required from api.db.db_models import MCPServer from api.db.services.mcp_server_service import MCPServerService @@ -30,7 +30,7 @@ from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_ @manager.route("/list", methods=["POST"]) # noqa: F821 @login_required -def list_mcp() -> Response: +async def list_mcp() -> Response: keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 0)) items_per_page = int(request.args.get("page_size", 0)) @@ -40,7 +40,7 @@ def list_mcp() -> Response: else: desc = True - req = request.get_json() + req = await request.get_json() mcp_ids = req.get("mcp_ids", []) try: servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] @@ -72,8 +72,8 @@ def detail() -> Response: @manager.route("/create", methods=["POST"]) # noqa: F821 @login_required @validate_request("name", "url", "server_type") -def create() -> Response: - req = request.get_json() +async def create() -> Response: + req = await request.get_json() server_type = req.get("server_type", "") if server_type not in VALID_MCP_SERVER_TYPES: @@ -127,8 +127,8 @@ def create() -> Response: @manager.route("/update", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_id") -def update() -> Response: - req = request.get_json() +async def update() -> Response: + req = await request.get_json() mcp_id = req.get("mcp_id", "") e, mcp_server = MCPServerService.get_by_id(mcp_id) @@ -183,8 +183,8 @@ def update() -> Response: @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_ids") -def rm() -> Response: - req = request.get_json() +async def rm() -> Response: + req = await request.get_json() mcp_ids = req.get("mcp_ids", []) try: @@ -201,8 +201,8 @@ def rm() -> Response: @manager.route("/import", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcpServers") -def import_multiple() -> Response: - req = request.get_json() +async def import_multiple() -> Response: + req = await request.get_json() servers = req.get("mcpServers", {}) if not servers: return get_data_error_result(message="No MCP servers provided.") @@ -268,8 +268,8 @@ def import_multiple() -> Response: @manager.route("/export", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_ids") -def export_multiple() -> Response: - req = request.get_json() +async def export_multiple() -> Response: + req = await request.get_json() mcp_ids = req.get("mcp_ids", []) if not mcp_ids: @@ -300,8 +300,8 @@ def export_multiple() -> Response: @manager.route("/list_tools", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_ids") -def list_tools() -> Response: - req = request.get_json() +async def list_tools() -> Response: + req = await request.get_json() mcp_ids = req.get("mcp_ids", []) if not mcp_ids: return get_data_error_result(message="No MCP server IDs provided.") @@ -347,8 +347,8 @@ def list_tools() -> Response: @manager.route("/test_tool", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_id", "tool_name", "arguments") -def test_tool() -> Response: - req = request.get_json() +async def test_tool() -> Response: + req = await request.get_json() mcp_id = req.get("mcp_id", "") if not mcp_id: return get_data_error_result(message="No MCP server ID provided.") @@ -380,8 +380,8 @@ def test_tool() -> Response: @manager.route("/cache_tools", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_id", "tools") -def cache_tool() -> Response: - req = request.get_json() +async def cache_tool() -> Response: + req = await request.get_json() mcp_id = req.get("mcp_id", "") if not mcp_id: return get_data_error_result(message="No MCP server ID provided.") @@ -403,8 +403,8 @@ def cache_tool() -> Response: @manager.route("/test_mcp", methods=["POST"]) # noqa: F821 @validate_request("url", "server_type") -def test_mcp() -> Response: - req = request.get_json() +async def test_mcp() -> Response: + req = await request.get_json() url = req.get("url", "") if not url: diff --git a/api/apps/plugin_app.py b/api/apps/plugin_app.py index 9ca04416d..6e7a87690 100644 --- a/api/apps/plugin_app.py +++ b/api/apps/plugin_app.py @@ -15,8 +15,8 @@ # -from flask import Response -from flask_login import login_required +from quart import Response +from api.apps import login_required from api.utils.api_utils import get_json_result from plugin import GlobalPluginManager diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py index 208b7a1be..034915f1f 100644 --- a/api/apps/sdk/agents.py +++ b/api/apps/sdk/agents.py @@ -27,7 +27,7 @@ from common.constants import RetCode from common.misc_utils import get_uuid from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required from api.utils.api_utils import get_result -from flask import request, Response +from quart import request, Response @manager.route('/agents', methods=['GET']) # noqa: F821 @@ -52,8 +52,8 @@ def list_agents(tenant_id): @manager.route("/agents", methods=["POST"]) # noqa: F821 @token_required -def create_agent(tenant_id: str): - req: dict[str, Any] = cast(dict[str, Any], request.json) +async def create_agent(tenant_id: str): + req: dict[str, Any] = cast(dict[str, Any], await request.json) req["user_id"] = tenant_id if req.get("dsl") is not None: @@ -89,8 +89,8 @@ def create_agent(tenant_id: str): @manager.route("/agents/", methods=["PUT"]) # noqa: F821 @token_required -def update_agent(tenant_id: str, agent_id: str): - req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None} +async def update_agent(tenant_id: str, agent_id: str): + req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None} req["user_id"] = tenant_id if req.get("dsl") is not None: @@ -135,8 +135,8 @@ def delete_agent(tenant_id: str, agent_id: str): @manager.route('/webhook/', methods=['POST']) # noqa: F821 @token_required -def webhook(tenant_id: str, agent_id: str): - req = request.json +async def webhook(tenant_id: str, agent_id: str): + req = await request.json if not UserCanvasService.accessible(req["id"], tenant_id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index a3f03b448..a2a70d6e2 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -15,7 +15,7 @@ # import logging -from flask import request +from quart import request from api.db.services.dialog_service import DialogService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -28,8 +28,8 @@ from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_ @manager.route("/chats", methods=["POST"]) # noqa: F821 @token_required -def create(tenant_id): - req = request.json +async def create(tenant_id): + req = await request.json ids = [i for i in req.get("dataset_ids", []) if i] for kb_id in ids: kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) @@ -145,10 +145,10 @@ def create(tenant_id): @manager.route("/chats/", methods=["PUT"]) # noqa: F821 @token_required -def update(tenant_id, chat_id): +async def update(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(message="You do not own the chat") - req = request.json + req = await request.json ids = req.get("dataset_ids", []) if "show_quotation" in req: req["do_refer"] = req.pop("show_quotation") @@ -228,10 +228,10 @@ def update(tenant_id, chat_id): @manager.route("/chats", methods=["DELETE"]) # noqa: F821 @token_required -def delete(tenant_id): +async def delete(tenant_id): errors = [] success_count = 0 - req = request.json + req = await request.json if not req: ids = None else: diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 8a315ce69..86a6eb314 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -18,7 +18,7 @@ import logging import os import json -from flask import request +from quart import request from peewee import OperationalError from api.db.db_models import File from api.db.services.document_service import DocumentService @@ -53,7 +53,7 @@ from common import settings @manager.route("/datasets", methods=["POST"]) # noqa: F821 @token_required -def create(tenant_id): +async def create(tenant_id): """ Create a new dataset. --- @@ -115,7 +115,7 @@ def create(tenant_id): # | embedding_model| embd_id | # | chunk_method | parser_id | - req, err = validate_and_parse_json_request(request, CreateDatasetReq) + req, err = await validate_and_parse_json_request(request, CreateDatasetReq) if err is not None: return get_error_argument_result(err) @@ -153,7 +153,7 @@ def create(tenant_id): @manager.route("/datasets", methods=["DELETE"]) # noqa: F821 @token_required -def delete(tenant_id): +async def delete(tenant_id): """ Delete datasets. --- @@ -191,7 +191,7 @@ def delete(tenant_id): schema: type: object """ - req, err = validate_and_parse_json_request(request, DeleteDatasetReq) + req, err = await validate_and_parse_json_request(request, DeleteDatasetReq) if err is not None: return get_error_argument_result(err) @@ -251,7 +251,7 @@ def delete(tenant_id): @manager.route("/datasets/", methods=["PUT"]) # noqa: F821 @token_required -def update(tenant_id, dataset_id): +async def update(tenant_id, dataset_id): """ Update a dataset. --- @@ -317,7 +317,7 @@ def update(tenant_id, dataset_id): # | embedding_model| embd_id | # | chunk_method | parser_id | extras = {"dataset_id": dataset_id} - req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) + req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) if err is not None: return get_error_argument_result(err) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index d2c3485a9..dd28b2644 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -15,7 +15,7 @@ # import logging -from flask import request, jsonify +from quart import request, jsonify from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -29,7 +29,7 @@ from common import settings @manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 @apikey_required @validate_request("knowledge_id", "query") -def retrieval(tenant_id): +async def retrieval(tenant_id): """ Dify-compatible retrieval API --- @@ -113,7 +113,7 @@ def retrieval(tenant_id): 404: description: Knowledge base or document not found """ - req = request.json + req = await request.json question = req["query"] kb_id = req["knowledge_id"] use_kg = req.get("use_kg", False) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index b54597f89..6e8969ac0 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -20,7 +20,7 @@ import re from io import BytesIO import xxhash -from flask import request, send_file +from quart import request, send_file from peewee import OperationalError from pydantic import BaseModel, Field, validator @@ -179,7 +179,7 @@ def upload(dataset_id, tenant_id): @manager.route("/datasets//documents/", methods=["PUT"]) # noqa: F821 @token_required -def update_doc(tenant_id, dataset_id, document_id): +async def update_doc(tenant_id, dataset_id, document_id): """ Update a document within a dataset. --- @@ -228,7 +228,7 @@ def update_doc(tenant_id, dataset_id, document_id): schema: type: object """ - req = request.json + req = await request.json if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(message="You don't own the dataset.") e, kb = KnowledgebaseService.get_by_id(dataset_id) @@ -589,7 +589,7 @@ def list_docs(dataset_id, tenant_id): @manager.route("/datasets//documents", methods=["DELETE"]) # noqa: F821 @token_required -def delete(tenant_id, dataset_id): +async def delete(tenant_id, dataset_id): """ Delete documents from a dataset. --- @@ -628,7 +628,7 @@ def delete(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") - req = request.json + req = await request.json if not req: doc_ids = None else: @@ -699,7 +699,7 @@ def delete(tenant_id, dataset_id): @manager.route("/datasets//chunks", methods=["POST"]) # noqa: F821 @token_required -def parse(tenant_id, dataset_id): +async def parse(tenant_id, dataset_id): """ Start parsing documents into chunks. --- @@ -738,7 +738,7 @@ def parse(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = request.json + req = await request.json if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") doc_list = req.get("document_ids") @@ -782,7 +782,7 @@ def parse(tenant_id, dataset_id): @manager.route("/datasets//chunks", methods=["DELETE"]) # noqa: F821 @token_required -def stop_parsing(tenant_id, dataset_id): +async def stop_parsing(tenant_id, dataset_id): """ Stop parsing documents into chunks. --- @@ -821,7 +821,7 @@ def stop_parsing(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = request.json + req = await request.json if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") @@ -1023,7 +1023,7 @@ def list_chunks(tenant_id, dataset_id, document_id): "/datasets//documents//chunks", methods=["POST"] ) @token_required -def add_chunk(tenant_id, dataset_id, document_id): +async def add_chunk(tenant_id, dataset_id, document_id): """ Add a chunk to a document. --- @@ -1093,7 +1093,7 @@ def add_chunk(tenant_id, dataset_id, document_id): if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] - req = request.json + req = await request.json if not str(req.get("content", "")).strip(): return get_error_data_result(message="`content` is required") if "important_keywords" in req: @@ -1152,7 +1152,7 @@ def add_chunk(tenant_id, dataset_id, document_id): "datasets//documents//chunks", methods=["DELETE"] ) @token_required -def rm_chunk(tenant_id, dataset_id, document_id): +async def rm_chunk(tenant_id, dataset_id, document_id): """ Remove chunks from a document. --- @@ -1199,7 +1199,7 @@ def rm_chunk(tenant_id, dataset_id, document_id): docs = DocumentService.get_by_ids([document_id]) if not docs: raise LookupError(f"Can't find the document with ID {document_id}!") - req = request.json + req = await request.json condition = {"doc_id": document_id} if "chunk_ids" in req: unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") @@ -1223,7 +1223,7 @@ def rm_chunk(tenant_id, dataset_id, document_id): "/datasets//documents//chunks/", methods=["PUT"] ) @token_required -def update_chunk(tenant_id, dataset_id, document_id, chunk_id): +async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): """ Update a chunk within a document. --- @@ -1285,7 +1285,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] - req = request.json + req = await request.json if "content" in req: content = req["content"] else: @@ -1327,7 +1327,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): @manager.route("/retrieval", methods=["POST"]) # noqa: F821 @token_required -def retrieval_test(tenant_id): +async def retrieval_test(tenant_id): """ Retrieve chunks based on a query. --- @@ -1408,7 +1408,7 @@ def retrieval_test(tenant_id): format: float description: Similarity score. """ - req = request.json + req = await request.json if not req.get("dataset_ids"): return get_error_data_result("`dataset_ids` is required.") kb_ids = req["dataset_ids"] diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 733c894c3..f079e8b2d 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -17,9 +17,7 @@ import pathlib import re - -import flask -from flask import request +from quart import request, make_response from pathlib import Path from api.db.services.document_service import DocumentService @@ -151,7 +149,7 @@ def upload(tenant_id): @manager.route('/file/create', methods=['POST']) # noqa: F821 @token_required -def create(tenant_id): +async def create(tenant_id): """ Create a new file or folder. --- @@ -193,9 +191,9 @@ def create(tenant_id): type: type: string """ - req = request.json - pf_id = request.json.get("parent_id") - input_file_type = request.json.get("type") + req = await request.json + pf_id = await request.json.get("parent_id") + input_file_type = await request.json.get("type") if not pf_id: root_folder = FileService.get_root_folder(tenant_id) pf_id = root_folder["id"] @@ -450,7 +448,7 @@ def get_all_parent_folders(tenant_id): @manager.route('/file/rm', methods=['POST']) # noqa: F821 @token_required -def rm(tenant_id): +async def rm(tenant_id): """ Delete one or multiple files/folders. --- @@ -481,7 +479,7 @@ def rm(tenant_id): type: boolean example: true """ - req = request.json + req = await request.json file_ids = req["file_ids"] try: for file_id in file_ids: @@ -524,7 +522,7 @@ def rm(tenant_id): @manager.route('/file/rename', methods=['POST']) # noqa: F821 @token_required -def rename(tenant_id): +async def rename(tenant_id): """ Rename a file. --- @@ -556,7 +554,7 @@ def rename(tenant_id): type: boolean example: true """ - req = request.json + req = await request.json try: e, file = FileService.get_by_id(req["file_id"]) if not e: @@ -585,7 +583,7 @@ def rename(tenant_id): @manager.route('/file/get/', methods=['GET']) # noqa: F821 @token_required -def get(tenant_id, file_id): +async def get(tenant_id, file_id): """ Download a file. --- @@ -619,7 +617,7 @@ def get(tenant_id, file_id): b, n = File2DocumentService.get_storage_address(file_id=file_id) blob = settings.STORAGE_IMPL.get(b, n) - response = flask.make_response(blob) + response = await make_response(blob) ext = re.search(r"\.([^.]+)$", file.name) if ext: if file.type == FileType.VISUAL.value: @@ -633,7 +631,7 @@ def get(tenant_id, file_id): @manager.route('/file/mv', methods=['POST']) # noqa: F821 @token_required -def move(tenant_id): +async def move(tenant_id): """ Move one or multiple files to another folder. --- @@ -667,7 +665,7 @@ def move(tenant_id): type: boolean example: true """ - req = request.json + req = await request.json try: file_ids = req["src_file_ids"] parent_id = req["dest_file_id"] @@ -693,8 +691,8 @@ def move(tenant_id): @manager.route('/file/convert', methods=['POST']) # noqa: F821 @token_required -def convert(tenant_id): - req = request.json +async def convert(tenant_id): + req = await request.json kb_ids = req["kb_ids"] file_ids = req["file_ids"] file2documents = [] diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 4edb2bb6b..98151a5fe 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -18,7 +18,7 @@ import re import time import tiktoken -from flask import Response, jsonify, request +from quart import Response, jsonify, request from agent.canvas import Canvas from api.db.db_models import APIToken @@ -44,8 +44,8 @@ from common import settings @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required -def create(tenant_id, chat_id): - req = request.json +async def create(tenant_id, chat_id): + req = await request.json req["dialog_id"] = chat_id dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) if not dia: @@ -97,8 +97,8 @@ def create_agent_session(tenant_id, agent_id): @manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821 @token_required -def update(tenant_id, chat_id, session_id): - req = request.json +async def update(tenant_id, chat_id, session_id): + req = await request.json req["dialog_id"] = chat_id conv_id = session_id conv = ConversationService.query(id=conv_id, dialog_id=chat_id) @@ -119,8 +119,8 @@ def update(tenant_id, chat_id, session_id): @manager.route("/chats//completions", methods=["POST"]) # noqa: F821 @token_required -def chat_completion(tenant_id, chat_id): - req = request.json +async def chat_completion(tenant_id, chat_id): + req = await request.json if not req: req = {"question": ""} if not req.get("session_id"): @@ -149,7 +149,7 @@ def chat_completion(tenant_id, chat_id): @manager.route("/chats_openai//chat/completions", methods=["POST"]) # noqa: F821 @validate_request("model", "messages") # noqa: F821 @token_required -def chat_completion_openai_like(tenant_id, chat_id): +async def chat_completion_openai_like(tenant_id, chat_id): """ OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint. @@ -206,7 +206,7 @@ def chat_completion_openai_like(tenant_id, chat_id): if reference: print(completion.choices[0].message.reference) """ - req = request.get_json() + req = await request.get_json() need_reference = bool(req.get("reference", False)) @@ -383,8 +383,8 @@ def chat_completion_openai_like(tenant_id, chat_id): @manager.route("/agents_openai//chat/completions", methods=["POST"]) # noqa: F821 @validate_request("model", "messages") # noqa: F821 @token_required -def agents_completion_openai_compatibility(tenant_id, agent_id): - req = request.json +async def agents_completion_openai_compatibility(tenant_id, agent_id): + req = await request.json tiktokenenc = tiktoken.get_encoding("cl100k_base") messages = req.get("messages", []) if not messages: @@ -443,8 +443,8 @@ def agents_completion_openai_compatibility(tenant_id, agent_id): @manager.route("/agents//completions", methods=["POST"]) # noqa: F821 @token_required -def agent_completions(tenant_id, agent_id): - req = request.json +async def agent_completions(tenant_id, agent_id): + req = await request.json if req.get("stream", True): @@ -610,13 +610,13 @@ def list_agent_session(tenant_id, agent_id): @manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 @token_required -def delete(tenant_id, chat_id): +async def delete(tenant_id, chat_id): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): return get_error_data_result(message="You don't own the chat") errors = [] success_count = 0 - req = request.json + req = await request.json convs = ConversationService.query(dialog_id=chat_id) if not req: ids = None @@ -661,10 +661,10 @@ def delete(tenant_id, chat_id): @manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 @token_required -def delete_agent_session(tenant_id, agent_id): +async def delete_agent_session(tenant_id, agent_id): errors = [] success_count = 0 - req = request.json + req = await request.json cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) if not cvs: return get_error_data_result(f"You don't own the agent {agent_id}") @@ -716,8 +716,8 @@ def delete_agent_session(tenant_id, agent_id): @manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 @token_required -def ask_about(tenant_id): - req = request.json +async def ask_about(tenant_id): + req = await request.json if not req.get("question"): return get_error_data_result("`question` is required.") if not req.get("dataset_ids"): @@ -755,8 +755,8 @@ def ask_about(tenant_id): @manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 @token_required -def related_questions(tenant_id): - req = request.json +async def related_questions(tenant_id): + req = await request.json if not req.get("question"): return get_error_data_result("`question` is required.") question = req["question"] @@ -806,8 +806,8 @@ Related search terms: @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 -def chatbot_completions(dialog_id): - req = request.json +async def chatbot_completions(dialog_id): + req = await request.json token = request.headers.get("Authorization").split() if len(token) != 2: @@ -856,8 +856,8 @@ def chatbots_inputs(dialog_id): @manager.route("/agentbots//completions", methods=["POST"]) # noqa: F821 -def agent_bot_completions(agent_id): - req = request.json +async def agent_bot_completions(agent_id): + req = await request.json token = request.headers.get("Authorization").split() if len(token) != 2: @@ -901,7 +901,7 @@ def begin_inputs(agent_id): @manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821 @validate_request("question", "kb_ids") -def ask_about_embedded(): +async def ask_about_embedded(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -910,7 +910,7 @@ def ask_about_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = request.json + req = await request.json uid = objs[0].tenant_id search_id = req.get("search_id", "") @@ -940,7 +940,7 @@ def ask_about_embedded(): @manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821 @validate_request("kb_id", "question") -def retrieval_test_embedded(): +async def retrieval_test_embedded(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -949,7 +949,7 @@ def retrieval_test_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = request.json + req = await request.json page = int(req.get("page", 1)) size = int(req.get("size", 30)) question = req["question"] @@ -1039,7 +1039,7 @@ def retrieval_test_embedded(): @manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821 @validate_request("question") -def related_questions_embedded(): +async def related_questions_embedded(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -1048,7 +1048,7 @@ def related_questions_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = request.json + req = await request.json tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") @@ -1115,7 +1115,7 @@ def detail_share_embedded(): @manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821 @validate_request("question", "kb_ids") -def mindmap(): +async def mindmap(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -1125,7 +1125,7 @@ def mindmap(): return get_error_data_result(message='Authentication error: API key is invalid!"') tenant_id = objs[0].tenant_id - req = request.json + req = await request.json search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} diff --git a/api/apps/search_app.py b/api/apps/search_app.py index 799223371..d350b93c3 100644 --- a/api/apps/search_app.py +++ b/api/apps/search_app.py @@ -14,8 +14,8 @@ # limitations under the License. # -from flask import request -from flask_login import current_user, login_required +from quart import request +from api.apps import current_user, login_required from api.constants import DATASET_NAME_LIMIT from api.db.db_models import DB @@ -30,8 +30,8 @@ from api.utils.api_utils import get_data_error_result, get_json_result, not_allo @manager.route("/create", methods=["post"]) # noqa: F821 @login_required @validate_request("name") -def create(): - req = request.get_json() +async def create(): + req = await request.get_json() search_name = req["name"] description = req.get("description", "") if not isinstance(search_name, str): @@ -65,8 +65,8 @@ def create(): @login_required @validate_request("search_id", "name", "search_config", "tenant_id") @not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") -def update(): - req = request.get_json() +async def update(): + req = await request.get_json() if not isinstance(req["name"], str): return get_data_error_result(message="Search name must be string.") if req["name"].strip() == "": @@ -140,7 +140,7 @@ def detail(): @manager.route("/list", methods=["POST"]) # noqa: F821 @login_required -def list_search_app(): +async def list_search_app(): keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 0)) items_per_page = int(request.args.get("page_size", 0)) @@ -150,7 +150,7 @@ def list_search_app(): else: desc = True - req = request.get_json() + req = await request.get_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -173,8 +173,8 @@ def list_search_app(): @manager.route("/rm", methods=["post"]) # noqa: F821 @login_required @validate_request("search_id") -def rm(): - req = request.get_json() +async def rm(): + req = await request.get_json() search_id = req["search_id"] if not SearchService.accessible4deletion(search_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) diff --git a/api/apps/system_app.py b/api/apps/system_app.py index b63f80a6a..7e646927e 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -17,7 +17,7 @@ import logging from datetime import datetime import json -from flask_login import login_required, current_user +from api.apps import login_required, current_user from api.db.db_models import APIToken from api.db.services.api_service import APITokenService @@ -34,7 +34,7 @@ from common.time_utils import current_timestamp, datetime_format from timeit import default_timer as timer from rag.utils.redis_conn import REDIS_CONN -from flask import jsonify +from quart import jsonify from api.utils.health_utils import run_health_checks from common import settings diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py index abb096faa..380838bcd 100644 --- a/api/apps/tenant_app.py +++ b/api/apps/tenant_app.py @@ -14,10 +14,7 @@ # limitations under the License. # -from flask import request -from flask_login import login_required, current_user - -from api.apps import smtp_mail_server +from quart import request from api.db import UserTenantRole from api.db.db_models import UserTenant from api.db.services.user_service import UserTenantService, UserService @@ -28,6 +25,7 @@ from common.time_utils import delta_seconds from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result from api.utils.web_utils import send_invite_email from common import settings +from api.apps import smtp_mail_server, login_required, current_user @manager.route("//user/list", methods=["GET"]) # noqa: F821 @@ -51,14 +49,14 @@ def user_list(tenant_id): @manager.route('//user', methods=['POST']) # noqa: F821 @login_required @validate_request("email") -def create(tenant_id): +async def create(tenant_id): if current_user.id != tenant_id: return get_json_result( data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) - req = request.json + req = await request.json invite_user_email = req["email"] invite_users = UserService.query(email=invite_user_email) if not invite_users: diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 06130cce7..c051f7ccc 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -22,8 +22,7 @@ import secrets import time from datetime import datetime -from flask import redirect, request, session, make_response -from flask_login import current_user, login_required, login_user, logout_user +from quart import redirect, request, session, make_response from werkzeug.security import check_password_hash, generate_password_hash from api.apps.auth import get_auth_client @@ -45,7 +44,7 @@ from api.utils.api_utils import ( ) from api.utils.crypt import decrypt from rag.utils.redis_conn import REDIS_CONN -from api.apps import smtp_mail_server +from api.apps import smtp_mail_server, login_required, current_user, login_user, logout_user from api.utils.web_utils import ( send_email_html, OTP_LENGTH, @@ -61,7 +60,7 @@ from common import settings @manager.route("/login", methods=["POST", "GET"]) # noqa: F821 -def login(): +async def login(): """ User login endpoint. --- @@ -91,10 +90,11 @@ def login(): schema: type: object """ - if not request.json: + json_body = await request.json + if not json_body: return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!") - email = request.json.get("email", "") + email = json_body.get("email", "") users = UserService.query(email=email) if not users: return get_json_result( @@ -103,7 +103,7 @@ def login(): message=f"Email: {email} is not registered!", ) - password = request.json.get("password") + password = json_body.get("password") try: password = decrypt(password) except BaseException: @@ -125,7 +125,7 @@ def login(): user.update_date = (datetime_format(datetime.now()),) user.save() msg = "Welcome back!" - return construct_response(data=response_data, auth=user.get_id(), message=msg) + return await construct_response(data=response_data, auth=user.get_id(), message=msg) else: return get_json_result( data=False, @@ -501,7 +501,7 @@ def log_out(): @manager.route("/setting", methods=["POST"]) # noqa: F821 @login_required -def setting_user(): +async def setting_user(): """ Update user settings. --- @@ -530,7 +530,7 @@ def setting_user(): type: object """ update_dict = {} - request_data = request.json + request_data = await request.json if request_data.get("password"): new_password = request_data.get("new_password") if not check_password_hash(current_user.password, decrypt(request_data["password"])): @@ -660,7 +660,7 @@ def user_register(user_id, user): @manager.route("/register", methods=["POST"]) # noqa: F821 @validate_request("nickname", "email", "password") -def user_add(): +async def user_add(): """ Register a new user. --- @@ -697,7 +697,7 @@ def user_add(): code=RetCode.OPERATING_ERROR, ) - req = request.json + req = await request.json email_address = req["email"] # Validate the email address @@ -793,7 +793,7 @@ def tenant_info(): @manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821 @login_required @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id") -def set_tenant_info(): +async def set_tenant_info(): """ Update tenant information. --- @@ -830,7 +830,7 @@ def set_tenant_info(): schema: type: object """ - req = request.json + req = await request.json try: tid = req.pop("tenant_id") TenantService.update_by_id(tid, req) @@ -840,7 +840,7 @@ def set_tenant_info(): @manager.route("/forget/captcha", methods=["GET"]) # noqa: F821 -def forget_get_captcha(): +async def forget_get_captcha(): """ GET /forget/captcha?email= - Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS. @@ -862,19 +862,19 @@ def forget_get_captcha(): from captcha.image import ImageCaptcha image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70]) img_bytes = image.generate(captcha_text).read() - response = make_response(img_bytes) + response = await make_response(img_bytes) response.headers.set("Content-Type", "image/JPEG") return response @manager.route("/forget/otp", methods=["POST"]) # noqa: F821 -def forget_send_otp(): +async def forget_send_otp(): """ POST /forget/otp - Verify the image captcha stored at captcha:{email} (case-insensitive). - On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email. """ - req = request.get_json() + req = await request.get_json() email = req.get("email") or "" captcha = (req.get("captcha") or "").strip() @@ -935,12 +935,12 @@ def forget_send_otp(): @manager.route("/forget", methods=["POST"]) # noqa: F821 -def forget(): +async def forget(): """ POST: Verify email + OTP and reset password, then log the user in. Request JSON: { email, otp, new_password, confirm_new_password } """ - req = request.get_json() + req = await request.get_json() email = req.get("email") or "" otp = (req.get("otp") or "").strip() new_pwd = req.get("new_password") diff --git a/api/db/db_models.py b/api/db/db_models.py index 68bf37ce4..05969a433 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -25,7 +25,7 @@ from datetime import datetime, timezone from enum import Enum from functools import wraps -from flask_login import UserMixin +from quart_auth import AuthUser from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate @@ -594,7 +594,7 @@ def fill_db_model_object(model_object, human_model_dict): return model_object -class User(DataBaseModel, UserMixin): +class User(DataBaseModel, AuthUser): id = CharField(max_length=32, primary_key=True) access_token = CharField(max_length=255, null=True, index=True) nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True) diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 2cf4931d0..8d8c7866d 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -18,7 +18,7 @@ import re from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from flask_login import current_user +from api.apps import current_user from peewee import fn from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 868e054ae..4accbb8a2 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -31,7 +31,6 @@ import traceback import threading import uuid -from werkzeug.serving import run_simple from api.apps import app, smtp_mail_server from api.db.runtime_config import RuntimeConfig from api.db.services.document_service import DocumentService @@ -153,14 +152,7 @@ if __name__ == '__main__': # start http server try: logging.info("RAGFlow HTTP server start...") - run_simple( - hostname=settings.HOST_IP, - port=settings.HOST_PORT, - application=app, - threaded=True, - use_reloader=RuntimeConfig.DEBUG, - use_debugger=RuntimeConfig.DEBUG, - ) + app.run(host=settings.HOST_IP, port=settings.HOST_PORT) except Exception: traceback.print_exc() stop_event.set() diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 4cace9eca..bd35a6c69 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -15,6 +15,7 @@ # import functools +import inspect import json import logging import os @@ -24,14 +25,12 @@ from functools import wraps import requests import trio -from flask import ( +from quart import ( Response, jsonify, + request ) -from flask_login import current_user -from flask import ( - request as flask_request, -) + from peewee import OperationalError from common.constants import ActiveEnum @@ -105,31 +104,37 @@ def server_error_response(e): def validate_request(*args, **kwargs): + def process_args(input_arguments): + no_arguments = [] + error_arguments = [] + for arg in args: + if arg not in input_arguments: + no_arguments.append(arg) + for k, v in kwargs.items(): + config_value = input_arguments.get(k, None) + if config_value is None: + no_arguments.append(k) + elif isinstance(v, (tuple, list)): + if config_value not in v: + error_arguments.append((k, set(v))) + elif config_value != v: + error_arguments.append((k, v)) + if no_arguments or error_arguments: + error_string = "" + if no_arguments: + error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) + if error_arguments: + error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) + return error_string + def wrapper(func): @wraps(func) - def decorated_function(*_args, **_kwargs): - input_arguments = flask_request.json or flask_request.form.to_dict() - no_arguments = [] - error_arguments = [] - for arg in args: - if arg not in input_arguments: - no_arguments.append(arg) - for k, v in kwargs.items(): - config_value = input_arguments.get(k, None) - if config_value is None: - no_arguments.append(k) - elif isinstance(v, (tuple, list)): - if config_value not in v: - error_arguments.append((k, set(v))) - elif config_value != v: - error_arguments.append((k, v)) - if no_arguments or error_arguments: - error_string = "" - if no_arguments: - error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) - if error_arguments: - error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) - return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string) + async def decorated_function(*_args, **_kwargs): + errs = process_args(await request.json or (await request.form).to_dict()) + if errs: + return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs) + if inspect.iscoroutinefunction(func): + return await func(*_args, **_kwargs) return func(*_args, **_kwargs) return decorated_function @@ -138,30 +143,34 @@ def validate_request(*args, **kwargs): def not_allowed_parameters(*params): - def decorator(f): - def wrapper(*args, **kwargs): - input_arguments = flask_request.json or flask_request.form.to_dict() + def decorator(func): + async def wrapper(*args, **kwargs): + input_arguments = await request.json or (await request.form).to_dict() for param in params: if param in input_arguments: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") - return f(*args, **kwargs) - + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + return func(*args, **kwargs) return wrapper return decorator -def active_required(f): - @wraps(f) - def wrapper(*args, **kwargs): +def active_required(func): + @wraps(func) + async def wrapper(*args, **kwargs): from api.db.services import UserService + from api.apps import current_user user_id = current_user.id usr = UserService.filter_by_id(user_id) # check is_active if not usr or not usr.is_active == ActiveEnum.ACTIVE.value: return get_json_result(code=RetCode.FORBIDDEN, message="User isn't active, please activate first.") - return f(*args, **kwargs) + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + return func(*args, **kwargs) return wrapper @@ -173,12 +182,15 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non def apikey_required(func): @wraps(func) - def decorated_function(*args, **kwargs): - token = flask_request.headers.get("Authorization").split()[1] + async def decorated_function(*args, **kwargs): + token = request.headers.get("Authorization").split()[1] objs = APIToken.query(token=token) if not objs: return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN) kwargs["tenant_id"] = objs[0].tenant_id + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + return func(*args, **kwargs) return decorated_function @@ -200,10 +212,10 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da def token_required(func): @wraps(func) - def decorated_function(*args, **kwargs): + async def decorated_function(*args, **kwargs): if os.environ.get("DISABLE_SDK"): return get_json_result(data=False, message="`Authorization` can't be empty") - authorization_str = flask_request.headers.get("Authorization") + authorization_str = request.headers.get("Authorization") if not authorization_str: return get_json_result(data=False, message="`Authorization` can't be empty") authorization_list = authorization_str.split() @@ -214,6 +226,9 @@ def token_required(func): if not objs: return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR) kwargs["tenant_id"] = objs[0].tenant_id + + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) return func(*args, **kwargs) return decorated_function diff --git a/api/utils/commands.py b/api/utils/commands.py index a1a8d025a..a3df7b507 100644 --- a/api/utils/commands.py +++ b/api/utils/commands.py @@ -18,7 +18,7 @@ import base64 import click import re -from flask import Flask +from quart import Quart from werkzeug.security import generate_password_hash from api.db.services import UserService @@ -73,6 +73,7 @@ def reset_email(email, new_email, email_confirm): UserService.update_user(user[0].id,user_dict) click.echo(click.style('Congratulations!, email has been reset.', fg='green')) -def register_commands(app: Flask): + +def register_commands(app: Quart): app.cli.add_command(reset_password) app.cli.add_command(reset_email) diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index caf3f0924..630b64feb 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -17,7 +17,7 @@ from collections import Counter from typing import Annotated, Any, Literal from uuid import UUID -from flask import Request +from quart import Request from pydantic import ( BaseModel, ConfigDict, @@ -32,7 +32,7 @@ from werkzeug.exceptions import BadRequest, UnsupportedMediaType from api.constants import DATASET_NAME_LIMIT -def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: +async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: """ Validates and parses JSON requests through a multi-stage validation pipeline. @@ -81,7 +81,7 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel] from the final output after validation """ try: - payload = request.get_json() or {} + payload = await request.get_json() or {} except UnsupportedMediaType: return None, f"Unsupported content type: Expected application/json, got {request.content_type}" except BadRequest: diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index e0e47f472..0f2484c56 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -23,7 +23,7 @@ from urllib.parse import urlparse from api.apps import smtp_mail_server from flask_mail import Message -from flask import render_template_string +from quart import render_template_string from api.utils.email_templates import EMAIL_TEMPLATES from selenium import webdriver from selenium.common.exceptions import TimeoutException diff --git a/common/connection_utils.py b/common/connection_utils.py index 618584ae9..0f765d57e 100644 --- a/common/connection_utils.py +++ b/common/connection_utils.py @@ -21,7 +21,7 @@ from typing import Any, Callable, Coroutine, Optional, Type, Union import asyncio import trio from functools import wraps -from flask import make_response, jsonify +from quart import make_response, jsonify from common.constants import RetCode TimeoutException = Union[Type[BaseException], BaseException] @@ -103,7 +103,7 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: return decorator -def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): +async def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): result_dict = {"code": code, "message": message, "data": data} response_dict = {} for key, value in result_dict.items(): @@ -111,7 +111,7 @@ def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth= continue else: response_dict[key] = value - response = make_response(jsonify(response_dict)) + response = await make_response(jsonify(response_dict)) if auth: response.headers["Authorization"] = auth response.headers["Access-Control-Allow-Origin"] = "*"