From be01ad5519f7c69d04f45876ee00acd8e5d69149 Mon Sep 17 00:00:00 2001 From: kevinhu Date: Wed, 17 Jan 2024 01:41:46 +0000 Subject: [PATCH] change name to web_server --- web_server/__init__.py | 0 web_server/apps/__init__.py | 147 ----- web_server/apps/document_app.py | 280 -------- web_server/apps/kb_app.py | 114 ---- web_server/apps/llm_app.py | 95 --- web_server/apps/user_app.py | 254 ------- web_server/db/__init__.py | 54 -- web_server/db/db_models.py | 619 ------------------ web_server/db/db_services.py | 157 ----- web_server/db/db_utils.py | 131 ---- web_server/db/init_data.py | 141 ---- web_server/db/operatioins.py | 21 - web_server/db/reload_config_base.py | 27 - web_server/db/runtime_config.py | 54 -- web_server/db/service_registry.py | 164 ----- web_server/db/services/__init__.py | 38 -- web_server/db/services/common_service.py | 153 ----- web_server/db/services/dialog_service.py | 35 - web_server/db/services/document_service.py | 89 --- web_server/db/services/kb_service.py | 70 -- .../db/services/knowledgebase_service.py | 31 - web_server/db/services/llm_service.py | 53 -- web_server/db/services/user_service.py | 105 --- web_server/errors/__init__.py | 10 - web_server/errors/error_services.py | 13 - web_server/errors/general_error.py | 21 - .../2029240f6d1128be89ddc32729463129 | Bin 9 -> 0 bytes web_server/hook/__init__.py | 57 -- web_server/hook/api/client_authentication.py | 29 - web_server/hook/api/permission.py | 25 - web_server/hook/api/site_authentication.py | 49 -- web_server/hook/common/parameters.py | 56 -- web_server/ragflow_server.py | 80 --- web_server/settings.py | 156 ----- web_server/utils/__init__.py | 321 --------- web_server/utils/api_utils.py | 212 ------ web_server/utils/file_utils.py | 153 ----- web_server/utils/log_utils.py | 299 --------- web_server/utils/t_crypt.py | 18 - web_server/versions.py | 30 - 40 files changed, 4361 deletions(-) delete mode 100644 web_server/__init__.py delete mode 100644 web_server/apps/__init__.py delete mode 100644 web_server/apps/document_app.py delete mode 100644 web_server/apps/kb_app.py delete mode 100644 web_server/apps/llm_app.py delete mode 100644 web_server/apps/user_app.py delete mode 100644 web_server/db/__init__.py delete mode 100644 web_server/db/db_models.py delete mode 100644 web_server/db/db_services.py delete mode 100644 web_server/db/db_utils.py delete mode 100644 web_server/db/init_data.py delete mode 100644 web_server/db/operatioins.py delete mode 100644 web_server/db/reload_config_base.py delete mode 100644 web_server/db/runtime_config.py delete mode 100644 web_server/db/service_registry.py delete mode 100644 web_server/db/services/__init__.py delete mode 100644 web_server/db/services/common_service.py delete mode 100644 web_server/db/services/dialog_service.py delete mode 100644 web_server/db/services/document_service.py delete mode 100644 web_server/db/services/kb_service.py delete mode 100644 web_server/db/services/knowledgebase_service.py delete mode 100644 web_server/db/services/llm_service.py delete mode 100644 web_server/db/services/user_service.py delete mode 100644 web_server/errors/__init__.py delete mode 100644 web_server/errors/error_services.py delete mode 100644 web_server/errors/general_error.py delete mode 100644 web_server/flask_session/2029240f6d1128be89ddc32729463129 delete mode 100644 web_server/hook/__init__.py delete mode 100644 web_server/hook/api/client_authentication.py delete mode 100644 web_server/hook/api/permission.py delete mode 100644 web_server/hook/api/site_authentication.py delete mode 100644 web_server/hook/common/parameters.py delete mode 100644 web_server/ragflow_server.py delete mode 100644 web_server/settings.py delete mode 100644 web_server/utils/__init__.py delete mode 100644 web_server/utils/api_utils.py delete mode 100644 web_server/utils/file_utils.py delete mode 100644 web_server/utils/log_utils.py delete mode 100644 web_server/utils/t_crypt.py delete mode 100644 web_server/versions.py diff --git a/web_server/__init__.py b/web_server/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/web_server/apps/__init__.py b/web_server/apps/__init__.py deleted file mode 100644 index 6a7cc9d30..000000000 --- a/web_server/apps/__init__.py +++ /dev/null @@ -1,147 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -import sys -from importlib.util import module_from_spec, spec_from_file_location -from pathlib import Path -from flask import Blueprint, Flask, request -from werkzeug.wrappers.request import Request -from flask_cors import CORS - -from web_server.db import StatusEnum -from web_server.db.services import UserService -from web_server.utils import CustomJSONEncoder - -from flask_session import Session -from flask_login import LoginManager -from web_server.settings import RetCode, SECRET_KEY, stat_logger -from web_server.hook import HookManager -from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters -from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger -from web_server.utils.api_utils import get_json_result, server_error_response -from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer - -__all__ = ['app'] - - -logger = logging.getLogger('flask.app') -for h in access_logger.handlers: - logger.addHandler(h) - -Request.json = property(lambda self: self.get_json(force=True, silent=True)) - -app = Flask(__name__) -CORS(app, supports_credentials=True,max_age = 2592000) -app.url_map.strict_slashes = False -app.json_encoder = CustomJSONEncoder -app.errorhandler(Exception)(server_error_response) - - -## convince for dev and debug -#app.config["LOGIN_DISABLED"] = True -app.config["SESSION_PERMANENT"] = False -app.config["SESSION_TYPE"] = "filesystem" -app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024 - -Session(app) -login_manager = LoginManager() -login_manager.init_app(app) - - - -def search_pages_path(pages_dir): - return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')] - - -def register_page(page_path): - page_name = page_path.stem.rstrip('_app') - module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, )) - - spec = spec_from_file_location(module_name, page_path) - page = module_from_spec(spec) - page.app = app - page.manager = Blueprint(page_name, module_name) - sys.modules[module_name] = page - spec.loader.exec_module(page) - - page_name = getattr(page, 'page_name', page_name) - url_prefix = f'/{API_VERSION}/{page_name}' - - app.register_blueprint(page.manager, url_prefix=url_prefix) - return url_prefix - - -pages_dir = [ - Path(__file__).parent, - Path(__file__).parent.parent / 'web_server' / 'apps', -] - -client_urls_prefix = [ - register_page(path) - for dir in pages_dir - for path in search_pages_path(dir) -] - - -def client_authentication_before_request(): - result = HookManager.client_authentication(ClientAuthenticationParameters( - request.full_path, request.headers, - request.form, request.data, request.json, - )) - - if result.code != RetCode.SUCCESS: - return get_json_result(result.code, result.message) - - -def site_authentication_before_request(): - for url_prefix in client_urls_prefix: - if request.path.startswith(url_prefix): - return - - result = HookManager.site_authentication(AuthenticationParameters( - request.headers.get('site_signature'), - request.json, - )) - - if result.code != RetCode.SUCCESS: - return get_json_result(result.code, result.message) - - -@app.before_request -def authentication_before_request(): - if CLIENT_AUTHENTICATION: - return client_authentication_before_request() - - if SITE_AUTHENTICATION: - return site_authentication_before_request() - -@login_manager.request_loader -def load_user(web_request): - jwt = Serializer(secret_key=SECRET_KEY) - authorization = web_request.headers.get("Authorization") - if authorization: - try: - access_token = str(jwt.loads(authorization)) - user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) - if user: - return user[0] - else: - return None - except Exception as e: - stat_logger.exception(e) - return None - else: - return None \ No newline at end of file diff --git a/web_server/apps/document_app.py b/web_server/apps/document_app.py deleted file mode 100644 index 9be9cfde9..000000000 --- a/web_server/apps/document_app.py +++ /dev/null @@ -1,280 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import base64 -import pathlib - -from elasticsearch_dsl import Q -from flask import request -from flask_login import login_required, current_user - -from rag.nlp import search -from rag.utils import ELASTICSEARCH -from web_server.db.services import duplicate_name -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.db.services.user_service import TenantService -from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request -from web_server.utils import get_uuid, get_format_time -from web_server.db import StatusEnum, FileType -from web_server.db.services.document_service import DocumentService -from web_server.settings import RetCode -from web_server.utils.api_utils import get_json_result -from rag.utils.minio_conn import MINIO -from web_server.utils.file_utils import filename_type - - -@manager.route('/upload', methods=['POST']) -@login_required -@validate_request("kb_id") -def upload(): - kb_id = request.form.get("kb_id") - if not kb_id: - return get_json_result( - data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) - if 'file' not in request.files: - return get_json_result( - data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) - file = request.files['file'] - if file.filename == '': - return get_json_result( - data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) - - try: - e, kb = KnowledgebaseService.get_by_id(kb_id) - if not e: - return get_data_error_result( - retmsg="Can't find this knowledgebase!") - - filename = duplicate_name( - DocumentService.query, - name=file.filename, - kb_id=kb.id) - location = filename - while MINIO.obj_exist(kb_id, location): - location += "_" - blob = request.files['file'].read() - MINIO.put(kb_id, filename, blob) - doc = DocumentService.insert({ - "id": get_uuid(), - "kb_id": kb.id, - "parser_id": kb.parser_id, - "created_by": current_user.id, - "type": filename_type(filename), - "name": filename, - "location": location, - "size": len(blob) - }) - return get_json_result(data=doc.to_json()) - except Exception as e: - return server_error_response(e) - - -@manager.route('/create', methods=['POST']) -@login_required -@validate_request("name", "kb_id") -def create(): - req = request.json - kb_id = req["kb_id"] - if not kb_id: - return get_json_result( - data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) - - try: - e, kb = KnowledgebaseService.get_by_id(kb_id) - if not e: - return get_data_error_result( - retmsg="Can't find this knowledgebase!") - - if DocumentService.query(name=req["name"], kb_id=kb_id): - return get_data_error_result( - retmsg="Duplicated document name in the same knowledgebase.") - - doc = DocumentService.insert({ - "id": get_uuid(), - "kb_id": kb.id, - "parser_id": kb.parser_id, - "created_by": current_user.id, - "type": FileType.VIRTUAL, - "name": req["name"], - "location": "", - "size": 0 - }) - return get_json_result(data=doc.to_json()) - except Exception as e: - return server_error_response(e) - - -@manager.route('/list', methods=['GET']) -@login_required -def list(): - kb_id = request.args.get("kb_id") - if not kb_id: - return get_json_result( - data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) - keywords = request.args.get("keywords", "") - - page_number = request.args.get("page", 1) - items_per_page = request.args.get("page_size", 15) - orderby = request.args.get("orderby", "create_time") - desc = request.args.get("desc", True) - try: - docs = DocumentService.get_by_kb_id( - kb_id, page_number, items_per_page, orderby, desc, keywords) - return get_json_result(data=docs) - except Exception as e: - return server_error_response(e) - - -@manager.route('/change_status', methods=['POST']) -@login_required -@validate_request("doc_id", "status") -def change_status(): - req = request.json - if str(req["status"]) not in ["0", "1"]: - get_json_result( - data=False, - retmsg='"Status" must be either 0 or 1!', - retcode=RetCode.ARGUMENT_ERROR) - - try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(retmsg="Document not found!") - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - if not e: - return get_data_error_result( - retmsg="Can't find this knowledgebase!") - - if not DocumentService.update_by_id( - req["doc_id"], {"status": str(req["status"])}): - return get_data_error_result( - retmsg="Database error (Document update)!") - - if str(req["status"]) == "0": - ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), - scripts=""" - if(ctx._source.kb_id.contains('%s')) - ctx._source.kb_id.remove( - ctx._source.kb_id.indexOf('%s') - ); - """ % (doc.kb_id, doc.kb_id), - idxnm=search.index_name( - kb.tenant_id) - ) - else: - ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), - scripts=""" - if(!ctx._source.kb_id.contains('%s')) - ctx._source.kb_id.add('%s'); - """ % (doc.kb_id, doc.kb_id), - idxnm=search.index_name( - kb.tenant_id) - ) - return get_json_result(data=True) - except Exception as e: - return server_error_response(e) - - -@manager.route('/rm', methods=['POST']) -@login_required -@validate_request("doc_id") -def rm(): - req = request.json - try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(retmsg="Document not found!") - if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)): - return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR) - - DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) - if not DocumentService.delete_by_id(req["doc_id"]): - return get_data_error_result( - retmsg="Database error (Document removal)!") - - MINIO.rm(doc.kb_id, doc.location) - return get_json_result(data=True) - except Exception as e: - return server_error_response(e) - - -@manager.route('/rename', methods=['POST']) -@login_required -@validate_request("doc_id", "name", "old_name") -def rename(): - req = request.json - if pathlib.Path(req["name"].lower()).suffix != pathlib.Path( - req["old_name"].lower()).suffix: - get_json_result( - data=False, - retmsg="The extension of file can't be changed", - retcode=RetCode.ARGUMENT_ERROR) - - try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(retmsg="Document not found!") - if DocumentService.query(name=req["name"], kb_id=doc.kb_id): - return get_data_error_result( - retmsg="Duplicated document name in the same knowledgebase.") - - if not DocumentService.update_by_id( - req["doc_id"], {"name": req["name"]}): - return get_data_error_result( - retmsg="Database error (Document rename)!") - - return get_json_result(data=True) - except Exception as e: - return server_error_response(e) - - -@manager.route('/get', methods=['GET']) -@login_required -def get(): - doc_id = request.args["doc_id"] - try: - e, doc = DocumentService.get_by_id(doc_id) - if not e: - return get_data_error_result(retmsg="Document not found!") - - blob = MINIO.get(doc.kb_id, doc.location) - return get_json_result(data={"base64": base64.b64decode(blob)}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/change_parser', methods=['POST']) -@login_required -@validate_request("doc_id", "parser_id") -def change_parser(): - req = request.json - try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(retmsg="Document not found!") - if doc.parser_id.lower() == req["parser_id"].lower(): - return get_json_result(data=True) - - e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) - if not e: - return get_data_error_result(retmsg="Document not found!") - e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) - if not e: - return get_data_error_result(retmsg="Document not found!") - - return get_json_result(data=True) - except Exception as e: - return server_error_response(e) - diff --git a/web_server/apps/kb_app.py b/web_server/apps/kb_app.py deleted file mode 100644 index 054f97e00..000000000 --- a/web_server/apps/kb_app.py +++ /dev/null @@ -1,114 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from flask import request -from flask_login import login_required, current_user - -from web_server.db.services import duplicate_name -from web_server.db.services.user_service import TenantService, UserTenantService -from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request -from web_server.utils import get_uuid, get_format_time -from web_server.db import StatusEnum, UserTenantRole -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.db.db_models import Knowledgebase -from web_server.settings import stat_logger, RetCode -from web_server.utils.api_utils import get_json_result - - -@manager.route('/create', methods=['post']) -@login_required -@validate_request("name", "description", "permission", "parser_id") -def create(): - req = request.json - req["name"] = req["name"].strip() - req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value) - try: - req["id"] = get_uuid() - req["tenant_id"] = current_user.id - req["created_by"] = current_user.id - if not KnowledgebaseService.save(**req): return get_data_error_result() - return get_json_result(data={"kb_id": req["id"]}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/update', methods=['post']) -@login_required -@validate_request("kb_id", "name", "description", "permission", "parser_id") -def update(): - req = request.json - req["name"] = req["name"].strip() - try: - if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): - return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) - - e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) - if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!") - - if req["name"].lower() != kb.name.lower() \ - and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1: - return get_data_error_result(retmsg="Duplicated knowledgebase name.") - - del req["kb_id"] - if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result() - - e, kb = KnowledgebaseService.get_by_id(kb.id) - if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!") - - return get_json_result(data=kb.to_json()) - except Exception as e: - return server_error_response(e) - - -@manager.route('/detail', methods=['GET']) -@login_required -def detail(): - kb_id = request.args["kb_id"] - try: - kb = KnowledgebaseService.get_detail(kb_id) - if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!") - return get_json_result(data=kb) - except Exception as e: - return server_error_response(e) - - -@manager.route('/list', methods=['GET']) -@login_required -def list(): - page_number = request.args.get("page", 1) - items_per_page = request.args.get("page_size", 15) - orderby = request.args.get("orderby", "create_time") - desc = request.args.get("desc", True) - try: - tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) - kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc) - return get_json_result(data=kbs) - except Exception as e: - return server_error_response(e) - - -@manager.route('/rm', methods=['post']) -@login_required -@validate_request("kb_id") -def rm(): - req = request.json - try: - if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): - return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) - - if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!") - return get_json_result(data=True) - except Exception as e: - return server_error_response(e) \ No newline at end of file diff --git a/web_server/apps/llm_app.py b/web_server/apps/llm_app.py deleted file mode 100644 index 0877a1977..000000000 --- a/web_server/apps/llm_app.py +++ /dev/null @@ -1,95 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from flask import request -from flask_login import login_required, current_user - -from web_server.db.services import duplicate_name -from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService -from web_server.db.services.user_service import TenantService, UserTenantService -from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request -from web_server.utils import get_uuid, get_format_time -from web_server.db import StatusEnum, UserTenantRole -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.db.db_models import Knowledgebase, TenantLLM -from web_server.settings import stat_logger, RetCode -from web_server.utils.api_utils import get_json_result - - -@manager.route('/factories', methods=['GET']) -@login_required -def factories(): - try: - fac = LLMFactoriesService.get_all() - return get_json_result(data=fac.to_json()) - except Exception as e: - return server_error_response(e) - - -@manager.route('/set_api_key', methods=['POST']) -@login_required -@validate_request("llm_factory", "api_key") -def set_api_key(): - req = request.json - llm = { - "tenant_id": current_user.id, - "llm_factory": req["llm_factory"], - "api_key": req["api_key"] - } - # TODO: Test api_key - for n in ["model_type", "llm_name"]: - if n in req: llm[n] = req[n] - - TenantLLM.insert(**llm).on_conflict("replace").execute() - return get_json_result(data=True) - - -@manager.route('/my_llms', methods=['GET']) -@login_required -def my_llms(): - try: - objs = TenantLLMService.query(tenant_id=current_user.id) - objs = [o.to_dict() for o in objs] - for o in objs: del o["api_key"] - return get_json_result(data=objs) - except Exception as e: - return server_error_response(e) - - -@manager.route('/list', methods=['GET']) -@login_required -def list(): - try: - objs = TenantLLMService.query(tenant_id=current_user.id) - objs = [o.to_dict() for o in objs if o.api_key] - fct = {} - for o in objs: - if o["llm_factory"] not in fct: fct[o["llm_factory"]] = [] - if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"]) - - llms = LLMService.get_all() - llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] - for m in llms: - m["available"] = False - if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]): - m["available"] = True - res = {} - for m in llms: - if m["fid"] not in res: res[m["fid"]] = [] - res[m["fid"]].append(m) - - return get_json_result(data=res) - except Exception as e: - return server_error_response(e) \ No newline at end of file diff --git a/web_server/apps/user_app.py b/web_server/apps/user_app.py deleted file mode 100644 index 81946074e..000000000 --- a/web_server/apps/user_app.py +++ /dev/null @@ -1,254 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from flask import request, session, redirect, url_for -from werkzeug.security import generate_password_hash, check_password_hash -from flask_login import login_required, current_user, login_user, logout_user - -from web_server.db.db_models import TenantLLM -from web_server.db.services.llm_service import TenantLLMService -from web_server.utils.api_utils import server_error_response, validate_request -from web_server.utils import get_uuid, get_format_time, decrypt, download_img -from web_server.db import UserTenantRole, LLMType -from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS -from web_server.db.services.user_service import UserService, TenantService, UserTenantService -from web_server.settings import stat_logger -from web_server.utils.api_utils import get_json_result, cors_reponse - - -@manager.route('/login', methods=['POST', 'GET']) -def login(): - userinfo = None - login_channel = "password" - if session.get("access_token"): - login_channel = session["access_token_from"] - if session["access_token_from"] == "github": - userinfo = user_info_from_github(session["access_token"]) - elif not request.json: - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, - retmsg='Unautherized!') - - email = request.json.get('email') if not userinfo else userinfo["email"] - users = UserService.query(email=email) - if not users: - if request.json is not None: - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') - avatar = "" - try: - avatar = download_img(userinfo["avatar_url"]) - except Exception as e: - stat_logger.exception(e) - user_id = get_uuid() - try: - users = user_register(user_id, { - "access_token": session["access_token"], - "email": userinfo["email"], - "avatar": avatar, - "nickname": userinfo["login"], - "login_channel": login_channel, - "last_login_time": get_format_time(), - "is_superuser": False, - }) - if not users: raise Exception('Register user failure.') - if len(users) > 1: raise Exception('Same E-mail exist!') - user = users[0] - login_user(user) - return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") - except Exception as e: - rollback_user_registration(user_id) - stat_logger.exception(e) - return server_error_response(e) - elif not request.json: - login_user(users[0]) - return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!") - - password = request.json.get('password') - try: - password = decrypt(password) - except: - return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') - - user = UserService.query_user(email, password) - if user: - response_data = user.to_json() - user.access_token = get_uuid() - login_user(user) - user.save() - msg = "Welcome back!" - return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg) - else: - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!') - - -@manager.route('/github_callback', methods=['GET']) -def github_callback(): - try: - import requests - res = requests.post(GITHUB_OAUTH.get("url"), data={ - "client_id": GITHUB_OAUTH.get("client_id"), - "client_secret": GITHUB_OAUTH.get("secret_key"), - "code": request.args.get('code') - },headers={"Accept": "application/json"}) - res = res.json() - if "error" in res: - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, - retmsg=res["error_description"]) - - if "user:email" not in res["scope"].split(","): - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope') - - session["access_token"] = res["access_token"] - session["access_token_from"] = "github" - return redirect(url_for("user.login"), code=307) - - except Exception as e: - stat_logger.exception(e) - return server_error_response(e) - - -def user_info_from_github(access_token): - import requests - headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"} - res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) - user_info = res.json() - email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json() - user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"] - return user_info - - -@manager.route("/logout", methods=['GET']) -@login_required -def log_out(): - current_user.access_token = "" - current_user.save() - logout_user() - return get_json_result(data=True) - - -@manager.route("/setting", methods=["POST"]) -@login_required -def setting_user(): - update_dict = {} - request_data = request.json - if request_data.get("password"): - new_password = request_data.get("new_password") - if not check_password_hash(current_user.password, decrypt(request_data["password"])): - return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') - - if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password)) - - for k in request_data.keys(): - if k in ["password", "new_password"]:continue - update_dict[k] = request_data[k] - - try: - UserService.update_by_id(current_user.id, update_dict) - return get_json_result(data=True) - except Exception as e: - stat_logger.exception(e) - return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) - - -@manager.route("/info", methods=["GET"]) -@login_required -def user_info(): - return get_json_result(data=current_user.to_dict()) - - -def rollback_user_registration(user_id): - try: - TenantService.delete_by_id(user_id) - except Exception as e: - pass - try: - u = UserTenantService.query(tenant_id=user_id) - if u: - UserTenantService.delete_by_id(u[0].id) - except Exception as e: - pass - try: - TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute() - except Exception as e: - pass - - -def user_register(user_id, user): - - user_id = get_uuid() - user["id"] = user_id - tenant = { - "id": user_id, - "name": user["nickname"] + "‘s Kingdom", - "llm_id": CHAT_MDL, - "embd_id": EMBEDDING_MDL, - "asr_id": ASR_MDL, - "parser_ids": PARSERS, - "img2txt_id": IMAGE2TEXT_MDL - } - usr_tenant = { - "tenant_id": user_id, - "user_id": user_id, - "invited_by": user_id, - "role": UserTenantRole.OWNER - } - tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"} - - if not UserService.save(**user):return - TenantService.save(**tenant) - UserTenantService.save(**usr_tenant) - TenantLLMService.save(**tenant_llm) - return UserService.query(email=user["email"]) - - -@manager.route("/register", methods=["POST"]) -@validate_request("nickname", "email", "password") -def user_add(): - req = request.json - if UserService.query(email=req["email"]): - return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) - - user_dict = { - "access_token": get_uuid(), - "email": req["email"], - "nickname": req["nickname"], - "password": decrypt(req["password"]), - "login_channel": "password", - "last_login_time": get_format_time(), - "is_superuser": False, - } - - user_id = get_uuid() - try: - users = user_register(user_id, user_dict) - if not users: raise Exception('Register user failure.') - if len(users) > 1: raise Exception('Same E-mail exist!') - user = users[0] - login_user(user) - return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") - except Exception as e: - rollback_user_registration(user_id) - stat_logger.exception(e) - return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) - - - -@manager.route("/tenant_info", methods=["GET"]) -@login_required -def tenant_info(): - try: - tenants = TenantService.get_by_user_id(current_user.id)[0] - return get_json_result(data=tenants) - except Exception as e: - return server_error_response(e) diff --git a/web_server/db/__init__.py b/web_server/db/__init__.py deleted file mode 100644 index 9984299c1..000000000 --- a/web_server/db/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from enum import Enum -from enum import IntEnum -from strenum import StrEnum - - -class StatusEnum(Enum): - VALID = "1" - IN_VALID = "0" - - -class UserTenantRole(StrEnum): - OWNER = 'owner' - ADMIN = 'admin' - NORMAL = 'normal' - - -class TenantPermission(StrEnum): - ME = 'me' - TEAM = 'team' - - -class SerializedType(IntEnum): - PICKLE = 1 - JSON = 2 - - -class FileType(StrEnum): - PDF = 'pdf' - DOC = 'doc' - VISUAL = 'visual' - AURAL = 'aural' - VIRTUAL = 'virtual' - - -class LLMType(StrEnum): - CHAT = 'chat' - EMBEDDING = 'embedding' - SPEECH2TEXT = 'speech2text' - IMAGE2TEXT = 'image2text' \ No newline at end of file diff --git a/web_server/db/db_models.py b/web_server/db/db_models.py deleted file mode 100644 index 62d92b475..000000000 --- a/web_server/db/db_models.py +++ /dev/null @@ -1,619 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import inspect -import os -import sys -import typing -import operator -from functools import wraps -from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer -from flask_login import UserMixin - -from peewee import ( - BigAutoField, BigIntegerField, BooleanField, CharField, - CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField, - Field, Model, Metadata -) -from playhouse.pool import PooledMySQLDatabase - -from web_server.db import SerializedType -from web_server.settings import DATABASE, stat_logger, SECRET_KEY -from web_server.utils.log_utils import getLogger -from web_server import utils - -LOGGER = getLogger() - - -def singleton(cls, *args, **kw): - instances = {} - - def _singleton(): - key = str(cls) + str(os.getpid()) - if key not in instances: - instances[key] = cls(*args, **kw) - return instances[key] - - return _singleton - - -CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField} -AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"} - - -class LongTextField(TextField): - field_type = 'LONGTEXT' - - -class JSONField(LongTextField): - default_value = {} - - def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs): - self._object_hook = object_hook - self._object_pairs_hook = object_pairs_hook - super().__init__(**kwargs) - - def db_value(self, value): - if value is None: - value = self.default_value - return utils.json_dumps(value) - - def python_value(self, value): - if not value: - return self.default_value - return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) - - -class ListField(JSONField): - default_value = [] - - -class SerializedField(LongTextField): - def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs): - self._serialized_type = serialized_type - self._object_hook = object_hook - self._object_pairs_hook = object_pairs_hook - super().__init__(**kwargs) - - def db_value(self, value): - if self._serialized_type == SerializedType.PICKLE: - return utils.serialize_b64(value, to_str=True) - elif self._serialized_type == SerializedType.JSON: - if value is None: - return None - return utils.json_dumps(value, with_type=True) - else: - raise ValueError(f"the serialized type {self._serialized_type} is not supported") - - def python_value(self, value): - if self._serialized_type == SerializedType.PICKLE: - return utils.deserialize_b64(value) - elif self._serialized_type == SerializedType.JSON: - if value is None: - return {} - return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) - else: - raise ValueError(f"the serialized type {self._serialized_type} is not supported") - - -def is_continuous_field(cls: typing.Type) -> bool: - if cls in CONTINUOUS_FIELD_TYPE: - return True - for p in cls.__bases__: - if p in CONTINUOUS_FIELD_TYPE: - return True - elif p != Field and p != object: - if is_continuous_field(p): - return True - else: - return False - - -def auto_date_timestamp_field(): - return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX} - - -def auto_date_timestamp_db_field(): - return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX} - - -def remove_field_name_prefix(field_name): - return field_name[2:] if field_name.startswith('f_') else field_name - - -class BaseModel(Model): - create_time = BigIntegerField(null=True) - create_date = DateTimeField(null=True) - update_time = BigIntegerField(null=True) - update_date = DateTimeField(null=True) - - def to_json(self): - # This function is obsolete - return self.to_dict() - - def to_dict(self): - return self.__dict__['__data__'] - - def to_human_model_dict(self, only_primary_with: list = None): - model_dict = self.__dict__['__data__'] - - if not only_primary_with: - return {remove_field_name_prefix(k): v for k, v in model_dict.items()} - - human_model_dict = {} - for k in self._meta.primary_key.field_names: - human_model_dict[remove_field_name_prefix(k)] = model_dict[k] - for k in only_primary_with: - human_model_dict[k] = model_dict[f'f_{k}'] - return human_model_dict - - @property - def meta(self) -> Metadata: - return self._meta - - @classmethod - def get_primary_keys_name(cls): - return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [ - cls._meta.primary_key.name] - - @classmethod - def getter_by(cls, attr): - return operator.attrgetter(attr)(cls) - - @classmethod - def query(cls, reverse=None, order_by=None, **kwargs): - filters = [] - for f_n, f_v in kwargs.items(): - attr_name = '%s' % f_n - if not hasattr(cls, attr_name) or f_v is None: - continue - if type(f_v) in {list, set}: - f_v = list(f_v) - if is_continuous_field(type(getattr(cls, attr_name))): - if len(f_v) == 2: - for i, v in enumerate(f_v): - if isinstance(v, str) and f_n in auto_date_timestamp_field(): - # time type: %Y-%m-%d %H:%M:%S - f_v[i] = utils.date_string_to_timestamp(v) - lt_value = f_v[0] - gt_value = f_v[1] - if lt_value is not None and gt_value is not None: - filters.append(cls.getter_by(attr_name).between(lt_value, gt_value)) - elif lt_value is not None: - filters.append(operator.attrgetter(attr_name)(cls) >= lt_value) - elif gt_value is not None: - filters.append(operator.attrgetter(attr_name)(cls) <= gt_value) - else: - filters.append(operator.attrgetter(attr_name)(cls) << f_v) - else: - filters.append(operator.attrgetter(attr_name)(cls) == f_v) - if filters: - query_records = cls.select().where(*filters) - if reverse is not None: - if not order_by or not hasattr(cls, f"{order_by}"): - order_by = "create_time" - if reverse is True: - query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc()) - elif reverse is False: - query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc()) - return [query_record for query_record in query_records] - else: - return [] - - @classmethod - def insert(cls, __data=None, **insert): - if isinstance(__data, dict) and __data: - __data[cls._meta.combined["create_time"]] = utils.current_timestamp() - if insert: - insert["create_time"] = utils.current_timestamp() - - return super().insert(__data, **insert) - - # update and insert will call this method - @classmethod - def _normalize_data(cls, data, kwargs): - normalized = super()._normalize_data(data, kwargs) - if not normalized: - return {} - - normalized[cls._meta.combined["update_time"]] = utils.current_timestamp() - - for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX: - if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \ - cls._meta.combined[f"{f_n}_time"] in normalized and \ - normalized[cls._meta.combined[f"{f_n}_time"]] is not None: - normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date( - normalized[cls._meta.combined[f"{f_n}_time"]]) - - return normalized - - -class JsonSerializedField(SerializedField): - def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs): - super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, - object_pairs_hook=object_pairs_hook, **kwargs) - - -@singleton -class BaseDataBase: - def __init__(self): - database_config = DATABASE.copy() - db_name = database_config.pop("name") - self.database_connection = PooledMySQLDatabase(db_name, **database_config) - stat_logger.info('init mysql database on cluster mode successfully') - - -class DatabaseLock: - def __init__(self, lock_name, timeout=10, db=None): - self.lock_name = lock_name - self.timeout = int(timeout) - self.db = db if db else DB - - def lock(self): - # SQL parameters only support %s format placeholders - cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout)) - ret = cursor.fetchone() - if ret[0] == 0: - raise Exception(f'acquire mysql lock {self.lock_name} timeout') - elif ret[0] == 1: - return True - else: - raise Exception(f'failed to acquire lock {self.lock_name}') - - def unlock(self): - cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,)) - ret = cursor.fetchone() - if ret[0] == 0: - raise Exception(f'mysql lock {self.lock_name} was not established by this thread') - elif ret[0] == 1: - return True - else: - raise Exception(f'mysql lock {self.lock_name} does not exist') - - def __enter__(self): - if isinstance(self.db, PooledMySQLDatabase): - self.lock() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if isinstance(self.db, PooledMySQLDatabase): - self.unlock() - - def __call__(self, func): - @wraps(func) - def magic(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return magic - - -DB = BaseDataBase().database_connection -DB.lock = DatabaseLock - - -def close_connection(): - try: - if DB: - DB.close() - except Exception as e: - LOGGER.exception(e) - - -class DataBaseModel(BaseModel): - class Meta: - database = DB - - -@DB.connection_context() -def init_database_tables(): - members = inspect.getmembers(sys.modules[__name__], inspect.isclass) - table_objs = [] - create_failed_list = [] - for name, obj in members: - if obj != DataBaseModel and issubclass(obj, DataBaseModel): - table_objs.append(obj) - LOGGER.info(f"start create table {obj.__name__}") - try: - obj.create_table() - LOGGER.info(f"create table success: {obj.__name__}") - except Exception as e: - LOGGER.exception(e) - create_failed_list.append(obj.__name__) - if create_failed_list: - LOGGER.info(f"create tables failed: {create_failed_list}") - raise Exception(f"create tables failed: {create_failed_list}") - - -def fill_db_model_object(model_object, human_model_dict): - for k, v in human_model_dict.items(): - attr_name = '%s' % k - if hasattr(model_object.__class__, attr_name): - setattr(model_object, attr_name, v) - return model_object - - -class User(DataBaseModel, UserMixin): - id = CharField(max_length=32, primary_key=True) - access_token = CharField(max_length=255, null=True) - nickname = CharField(max_length=100, null=False, help_text="nicky name") - password = CharField(max_length=255, null=True, help_text="password") - email = CharField(max_length=255, null=False, help_text="email", index=True) - avatar = TextField(null=True, help_text="avatar base64 string") - language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese") - color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark") - last_login_time = DateTimeField(null=True) - is_authenticated = CharField(max_length=1, null=False, default="1") - is_active = CharField(max_length=1, null=False, default="1") - is_anonymous = CharField(max_length=1, null=False, default="0") - login_channel = CharField(null=True, help_text="from which user login") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - is_superuser = BooleanField(null=True, help_text="is root", default=False) - - def __str__(self): - return self.email - - def get_id(self): - jwt = Serializer(secret_key=SECRET_KEY) - return jwt.dumps(str(self.access_token)) - - class Meta: - db_table = "user" - - -class Tenant(DataBaseModel): - id = CharField(max_length=32, primary_key=True) - name = CharField(max_length=100, null=True, help_text="Tenant name") - public_key = CharField(max_length=255, null=True) - llm_id = CharField(max_length=128, null=False, help_text="default llm ID") - embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") - asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") - img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") - parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - - class Meta: - db_table = "tenant" - - -class UserTenant(DataBaseModel): - id = CharField(max_length=32, primary_key=True) - user_id = CharField(max_length=32, null=False) - tenant_id = CharField(max_length=32, null=False) - role = CharField(max_length=32, null=False, help_text="UserTenantRole") - invited_by = CharField(max_length=32, null=False) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - - class Meta: - db_table = "user_tenant" - - -class InvitationCode(DataBaseModel): - id = CharField(max_length=32, primary_key=True) - code = CharField(max_length=32, null=False) - visit_time = DateTimeField(null=True) - user_id = CharField(max_length=32, null=True) - tenant_id = CharField(max_length=32, null=True) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - - class Meta: - db_table = "invitation_code" - - -class LLMFactories(DataBaseModel): - name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True) - logo = TextField(null=True, help_text="llm logo base64") - tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - - def __str__(self): - return self.name - - class Meta: - db_table = "llm_factories" - - -class LLM(DataBaseModel): - # defautlt LLMs for every users - llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) - model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") - fid = CharField(max_length=128, null=False, help_text="LLM factory id") - tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - - def __str__(self): - return self.llm_name - - class Meta: - db_table = "llm" - - -class TenantLLM(DataBaseModel): - tenant_id = CharField(max_length=32, null=False) - llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") - model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR") - llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="") - api_key = CharField(max_length=255, null=True, help_text="API KEY") - api_base = CharField(max_length=255, null=True, help_text="API Base") - - def __str__(self): - return self.llm_name - - class Meta: - db_table = "tenant_llm" - primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name') - - -class Knowledgebase(DataBaseModel): - id = CharField(max_length=32, primary_key=True) - avatar = TextField(null=True, help_text="avatar base64 string") - tenant_id = CharField(max_length=32, null=False) - name = CharField(max_length=128, null=False, help_text="KB name", index=True) - description = TextField(null=True, help_text="KB description") - permission = CharField(max_length=16, null=False, help_text="me|team") - created_by = CharField(max_length=32, null=False) - doc_num = IntegerField(default=0) - token_num = IntegerField(default=0) - chunk_num = IntegerField(default=0) - - parser_id = CharField(max_length=32, null=False, help_text="default parser ID") - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - - def __str__(self): - return self.name - - class Meta: - db_table = "knowledgebase" - - -class Document(DataBaseModel): - id = CharField(max_length=32, primary_key=True) - thumbnail = TextField(null=True, help_text="thumbnail base64 string") - kb_id = CharField(max_length=256, null=False, index=True) - parser_id = CharField(max_length=32, null=False, help_text="default parser ID") - source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") - type = CharField(max_length=32, null=False, help_text="file extension") - created_by = CharField(max_length=32, null=False, help_text="who created it") - name = CharField(max_length=255, null=True, help_text="file name", index=True) - location = CharField(max_length=255, null=True, help_text="where dose it store") - size = IntegerField(default=0) - token_num = IntegerField(default=0) - chunk_num = IntegerField(default=0) - progress = FloatField(default=0) - progress_msg = CharField(max_length=255, null=True, help_text="process message", default="") - process_begin_at = DateTimeField(null=True) - process_duation = FloatField(default=0) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - - class Meta: - db_table = "document" - - -class Dialog(DataBaseModel): - id = CharField(max_length=32, primary_key=True) - tenant_id = CharField(max_length=32, null=False) - name = CharField(max_length=255, null=True, help_text="dialog application name") - description = TextField(null=True, help_text="Dialog description") - icon = CharField(max_length=16, null=False, help_text="dialog icon") - language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") - llm_id = CharField(max_length=32, null=False, help_text="default llm ID") - llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom", - default="Creative") - llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, - "presence_penalty": 0.4, "max_tokens": 215}) - prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") - prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", - "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) - status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") - - class Meta: - db_table = "dialog" - - -class DialogKb(DataBaseModel): - dialog_id = CharField(max_length=32, null=False, index=True) - kb_id = CharField(max_length=32, null=False) - - class Meta: - db_table = "dialog_kb" - primary_key = CompositeKey('dialog_id', 'kb_id') - - -class Conversation(DataBaseModel): - id = CharField(max_length=32, primary_key=True) - dialog_id = CharField(max_length=32, null=False, index=True) - name = CharField(max_length=255, null=True, help_text="converastion name") - message = JSONField(null=True) - - class Meta: - db_table = "conversation" - - -""" -class Job(DataBaseModel): - # multi-party common configuration - f_user_id = CharField(max_length=25, null=True) - f_job_id = CharField(max_length=25, index=True) - f_name = CharField(max_length=500, null=True, default='') - f_description = TextField(null=True, default='') - f_tag = CharField(max_length=50, null=True, default='') - f_dsl = JSONField() - f_runtime_conf = JSONField() - f_runtime_conf_on_party = JSONField() - f_train_runtime_conf = JSONField(null=True) - f_roles = JSONField() - f_initiator_role = CharField(max_length=50) - f_initiator_party_id = CharField(max_length=50) - f_status = CharField(max_length=50) - f_status_code = IntegerField(null=True) - f_user = JSONField() - # this party configuration - f_role = CharField(max_length=50, index=True) - f_party_id = CharField(max_length=10, index=True) - f_is_initiator = BooleanField(null=True, default=False) - f_progress = IntegerField(null=True, default=0) - f_ready_signal = BooleanField(default=False) - f_ready_time = BigIntegerField(null=True) - f_cancel_signal = BooleanField(default=False) - f_cancel_time = BigIntegerField(null=True) - f_rerun_signal = BooleanField(default=False) - f_end_scheduling_updates = IntegerField(null=True, default=0) - - f_engine_name = CharField(max_length=50, null=True) - f_engine_type = CharField(max_length=10, null=True) - f_cores = IntegerField(default=0) - f_memory = IntegerField(default=0) # MB - f_remaining_cores = IntegerField(default=0) - f_remaining_memory = IntegerField(default=0) # MB - f_resource_in_use = BooleanField(default=False) - f_apply_resource_time = BigIntegerField(null=True) - f_return_resource_time = BigIntegerField(null=True) - - f_inheritance_info = JSONField(null=True) - f_inheritance_status = CharField(max_length=50, null=True) - - f_start_time = BigIntegerField(null=True) - f_start_date = DateTimeField(null=True) - f_end_time = BigIntegerField(null=True) - f_end_date = DateTimeField(null=True) - f_elapsed = BigIntegerField(null=True) - - class Meta: - db_table = "t_job" - primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id') - - - -class PipelineComponentMeta(DataBaseModel): - f_model_id = CharField(max_length=100, index=True) - f_model_version = CharField(max_length=100, index=True) - f_role = CharField(max_length=50, index=True) - f_party_id = CharField(max_length=10, index=True) - f_component_name = CharField(max_length=100, index=True) - f_component_module_name = CharField(max_length=100) - f_model_alias = CharField(max_length=100, index=True) - f_model_proto_index = JSONField(null=True) - f_run_parameters = JSONField(null=True) - f_archive_sha256 = CharField(max_length=100, null=True) - f_archive_from_ip = CharField(max_length=100, null=True) - - class Meta: - db_table = 't_pipeline_component_meta' - indexes = ( - (('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True), - ) - - -""" diff --git a/web_server/db/db_services.py b/web_server/db/db_services.py deleted file mode 100644 index 9f8a0a02a..000000000 --- a/web_server/db/db_services.py +++ /dev/null @@ -1,157 +0,0 @@ -# -# Copyright 2021 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import abc -import json -import time -from functools import wraps -from shortuuid import ShortUUID - -from web_server.versions import get_fate_version - -from web_server.errors.error_services import * -from web_server.settings import ( - GRPC_PORT, HOST, HTTP_PORT, - RANDOM_INSTANCE_ID, stat_logger, -) - - -instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}' -server_instance = ( - f'{HOST}:{GRPC_PORT}', - json.dumps({ - 'instance_id': instance_id, - 'timestamp': round(time.time() * 1000), - 'version': get_fate_version() or '', - 'host': HOST, - 'grpc_port': GRPC_PORT, - 'http_port': HTTP_PORT, - }), -) - - -def check_service_supported(method): - """Decorator to check if `service_name` is supported. - The attribute `supported_services` MUST be defined in class. - The first and second arguments of `method` MUST be `self` and `service_name`. - - :param Callable method: The class method. - :return: The inner wrapper function. - :rtype: Callable - """ - @wraps(method) - def magic(self, service_name, *args, **kwargs): - if service_name not in self.supported_services: - raise ServiceNotSupported(service_name=service_name) - return method(self, service_name, *args, **kwargs) - return magic - - -class ServicesDB(abc.ABC): - """Database for storage service urls. - Abstract base class for the real backends. - - """ - @property - @abc.abstractmethod - def supported_services(self): - """The names of supported services. - The returned list SHOULD contain `fateflow` (model download) and `servings` (FATE-Serving). - - :return: The service names. - :rtype: list - """ - pass - - @abc.abstractmethod - def _get_serving(self): - pass - - def get_serving(self): - - try: - return self._get_serving() - except ServicesError as e: - stat_logger.exception(e) - return [] - - @abc.abstractmethod - def _insert(self, service_name, service_url, value=''): - pass - - @check_service_supported - def insert(self, service_name, service_url, value=''): - """Insert a service url to database. - - :param str service_name: The service name. - :param str service_url: The service url. - :return: None - """ - try: - self._insert(service_name, service_url, value) - except ServicesError as e: - stat_logger.exception(e) - - @abc.abstractmethod - def _delete(self, service_name, service_url): - pass - - @check_service_supported - def delete(self, service_name, service_url): - """Delete a service url from database. - - :param str service_name: The service name. - :param str service_url: The service url. - :return: None - """ - try: - self._delete(service_name, service_url) - except ServicesError as e: - stat_logger.exception(e) - - def register_flow(self): - """Call `self.insert` for insert the flow server address to databae. - - :return: None - """ - self.insert('flow-server', *server_instance) - - def unregister_flow(self): - """Call `self.delete` for delete the flow server address from databae. - - :return: None - """ - self.delete('flow-server', server_instance[0]) - - @abc.abstractmethod - def _get_urls(self, service_name, with_values=False): - pass - - @check_service_supported - def get_urls(self, service_name, with_values=False): - """Query service urls from database. The urls may belong to other nodes. - Currently, only `fateflow` (model download) urls and `servings` (FATE-Serving) urls are supported. - `fateflow` is a url containing scheme, host, port and path, - while `servings` only contains host and port. - - :param str service_name: The service name. - :return: The service urls. - :rtype: list - """ - try: - return self._get_urls(service_name, with_values) - except ServicesError as e: - stat_logger.exception(e) - return [] diff --git a/web_server/db/db_utils.py b/web_server/db/db_utils.py deleted file mode 100644 index df00ecb48..000000000 --- a/web_server/db/db_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import operator -from functools import reduce -from typing import Dict, Type, Union - -from web_server.utils import current_timestamp, timestamp_to_date - -from web_server.db.db_models import DB, DataBaseModel -from web_server.db.runtime_config import RuntimeConfig -from web_server.utils.log_utils import getLogger -from enum import Enum - - -LOGGER = getLogger() - - -@DB.connection_context() -def bulk_insert_into_db(model, data_source, replace_on_conflict=False): - DB.create_tables([model]) - - current_time = current_timestamp() - current_date = timestamp_to_date(current_time) - - for data in data_source: - if 'f_create_time' not in data: - data['f_create_time'] = current_time - data['f_create_date'] = timestamp_to_date(data['f_create_time']) - data['f_update_time'] = current_time - data['f_update_date'] = current_date - - preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'}) - - batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000 - - for i in range(0, len(data_source), batch_size): - with DB.atomic(): - query = model.insert_many(data_source[i:i + batch_size]) - if replace_on_conflict: - query = query.on_conflict(preserve=preserve) - query.execute() - - -def get_dynamic_db_model(base, job_id): - return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id))) - - -def get_dynamic_tracking_table_index(job_id): - return job_id[:8] - - -def fill_db_model_object(model_object, human_model_dict): - for k, v in human_model_dict.items(): - attr_name = 'f_%s' % k - if hasattr(model_object.__class__, attr_name): - setattr(model_object, attr_name, v) - return model_object - - -# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html -supported_operators = { - '==': operator.eq, - '<': operator.lt, - '<=': operator.le, - '>': operator.gt, - '>=': operator.ge, - '!=': operator.ne, - '<<': operator.lshift, - '>>': operator.rshift, - '%': operator.mod, - '**': operator.pow, - '^': operator.xor, - '~': operator.inv, -} - -def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]): - expression = [] - - for field, value in query.items(): - if not isinstance(value, (list, tuple)): - value = ('==', value) - op, *val = value - - field = getattr(model, f'f_{field}') - value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val) - expression.append(value) - - return reduce(operator.iand, expression) - - -def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0, - query: dict = None, order_by: Union[str, list, tuple] = None): - data = model.select() - if query: - data = data.where(query_dict2expression(model, query)) - count = data.count() - - if not order_by: - order_by = 'create_time' - if not isinstance(order_by, (list, tuple)): - order_by = (order_by, 'asc') - order_by, order = order_by - order_by = getattr(model, f'f_{order_by}') - order_by = getattr(order_by, order)() - data = data.order_by(order_by) - - if limit > 0: - data = data.limit(limit) - if offset > 0: - data = data.offset(offset) - - return list(data), count - - -class StatusEnum(Enum): - # 样本可用状态 - VALID = "1" - IN_VALID = "0" \ No newline at end of file diff --git a/web_server/db/init_data.py b/web_server/db/init_data.py deleted file mode 100644 index 882b62ef9..000000000 --- a/web_server/db/init_data.py +++ /dev/null @@ -1,141 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import time -import uuid - -from web_server.db import LLMType -from web_server.db.db_models import init_database_tables as init_web_db -from web_server.db.services import UserService -from web_server.db.services.llm_service import LLMFactoriesService, LLMService - - -def init_superuser(): - user_info = { - "id": uuid.uuid1().hex, - "password": "admin", - "nickname": "admin", - "is_superuser": True, - "email": "kai.hu@infiniflow.org", - "creator": "system", - "status": "1", - } - UserService.save(**user_info) - - -def init_llm_factory(): - factory_infos = [{ - "name": "OpenAI", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - },{ - "name": "通义千问", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - },{ - "name": "智普AI", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - },{ - "name": "文心一言", - "logo": "", - "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", - "status": "1", - }, - ] - llm_infos = [{ - "fid": factory_infos[0]["name"], - "llm_name": "gpt-3.5-turbo", - "tags": "LLM,CHAT,4K", - "model_type": LLMType.CHAT.value - },{ - "fid": factory_infos[0]["name"], - "llm_name": "gpt-3.5-turbo-16k-0613", - "tags": "LLM,CHAT,16k", - "model_type": LLMType.CHAT.value - },{ - "fid": factory_infos[0]["name"], - "llm_name": "text-embedding-ada-002", - "tags": "TEXT EMBEDDING,8K", - "model_type": LLMType.EMBEDDING.value - },{ - "fid": factory_infos[0]["name"], - "llm_name": "whisper-1", - "tags": "SPEECH2TEXT", - "model_type": LLMType.SPEECH2TEXT.value - },{ - "fid": factory_infos[0]["name"], - "llm_name": "gpt-4", - "tags": "LLM,CHAT,8K", - "model_type": LLMType.CHAT.value - },{ - "fid": factory_infos[0]["name"], - "llm_name": "gpt-4-32k", - "tags": "LLM,CHAT,32K", - "model_type": LLMType.CHAT.value - },{ - "fid": factory_infos[0]["name"], - "llm_name": "gpt-4-vision-preview", - "tags": "LLM,CHAT,IMAGE2TEXT", - "model_type": LLMType.IMAGE2TEXT.value - },{ - "fid": factory_infos[1]["name"], - "llm_name": "qwen-turbo", - "tags": "LLM,CHAT,8K", - "model_type": LLMType.CHAT.value - },{ - "fid": factory_infos[1]["name"], - "llm_name": "qwen-plus", - "tags": "LLM,CHAT,32K", - "model_type": LLMType.CHAT.value - },{ - "fid": factory_infos[1]["name"], - "llm_name": "text-embedding-v2", - "tags": "TEXT EMBEDDING,2K", - "model_type": LLMType.EMBEDDING.value - },{ - "fid": factory_infos[1]["name"], - "llm_name": "paraformer-realtime-8k-v1", - "tags": "SPEECH2TEXT", - "model_type": LLMType.SPEECH2TEXT.value - },{ - "fid": factory_infos[1]["name"], - "llm_name": "qwen_vl_chat_v1", - "tags": "LLM,CHAT,IMAGE2TEXT", - "model_type": LLMType.IMAGE2TEXT.value - }, - ] - for info in factory_infos: - LLMFactoriesService.save(**info) - for info in llm_infos: - LLMService.save(**info) - - -def init_web_data(): - start_time = time.time() - if not UserService.get_all().count(): - init_superuser() - - if not LLMService.get_all().count():init_llm_factory() - - print("init web data success:{}".format(time.time() - start_time)) - - -if __name__ == '__main__': - init_web_db() - init_web_data() \ No newline at end of file diff --git a/web_server/db/operatioins.py b/web_server/db/operatioins.py deleted file mode 100644 index b8e1596fa..000000000 --- a/web_server/db/operatioins.py +++ /dev/null @@ -1,21 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import operator -import time -import typing -from web_server.utils.log_utils import sql_logger -import peewee \ No newline at end of file diff --git a/web_server/db/reload_config_base.py b/web_server/db/reload_config_base.py deleted file mode 100644 index 049acbf76..000000000 --- a/web_server/db/reload_config_base.py +++ /dev/null @@ -1,27 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -class ReloadConfigBase: - @classmethod - def get_all(cls): - configs = {} - for k, v in cls.__dict__.items(): - if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"): - configs[k] = v - return configs - - @classmethod - def get(cls, config_name): - return getattr(cls, config_name) if hasattr(cls, config_name) else None \ No newline at end of file diff --git a/web_server/db/runtime_config.py b/web_server/db/runtime_config.py deleted file mode 100644 index 095559089..000000000 --- a/web_server/db/runtime_config.py +++ /dev/null @@ -1,54 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from web_server.versions import get_versions -from .reload_config_base import ReloadConfigBase - - -class RuntimeConfig(ReloadConfigBase): - DEBUG = None - WORK_MODE = None - HTTP_PORT = None - JOB_SERVER_HOST = None - JOB_SERVER_VIP = None - ENV = dict() - SERVICE_DB = None - LOAD_CONFIG_MANAGER = False - - @classmethod - def init_config(cls, **kwargs): - for k, v in kwargs.items(): - if hasattr(cls, k): - setattr(cls, k, v) - - @classmethod - def init_env(cls): - cls.ENV.update(get_versions()) - - @classmethod - def load_config_manager(cls): - cls.LOAD_CONFIG_MANAGER = True - - @classmethod - def get_env(cls, key): - return cls.ENV.get(key, None) - - @classmethod - def get_all_env(cls): - return cls.ENV - - @classmethod - def set_service_db(cls, service_db): - cls.SERVICE_DB = service_db \ No newline at end of file diff --git a/web_server/db/service_registry.py b/web_server/db/service_registry.py deleted file mode 100644 index dc704be93..000000000 --- a/web_server/db/service_registry.py +++ /dev/null @@ -1,164 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import socket -from pathlib import Path -from web_server import utils -from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo -from .reload_config_base import ReloadConfigBase - - -class ServiceRegistry(ReloadConfigBase): - @classmethod - @DB.connection_context() - def load_service(cls, **kwargs) -> [ServiceRegistryInfo]: - service_registry_list = ServiceRegistryInfo.query(**kwargs) - return [service for service in service_registry_list] - - @classmethod - @DB.connection_context() - def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"): - if not server_info: - server_list = ServerRegistry.query_server_info_from_db(server_name=server_name) - if not server_list: - raise Exception(f"no found server {server_name}") - server_info = server_list[0] - url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}" - else: - url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}" - service_info = { - "f_server_name": server_name, - "f_service_name": service_name, - "f_url": url, - "f_method": method, - "f_params": params if params else {}, - "f_data": data if data else {}, - "f_headers": headers if headers else {} - } - entity_model, status = ServiceRegistryInfo.get_or_create( - f_server_name=server_name, - f_service_name=service_name, - defaults=service_info) - if status is False: - for key in service_info: - setattr(entity_model, key, service_info[key]) - entity_model.save(force_insert=False) - - -class ServerRegistry(ReloadConfigBase): - FATEBOARD = None - FATE_ON_STANDALONE = None - FATE_ON_EGGROLL = None - FATE_ON_SPARK = None - MODEL_STORE_ADDRESS = None - SERVINGS = None - FATEMANAGER = None - STUDIO = None - - @classmethod - def load(cls): - cls.load_server_info_from_conf() - cls.load_server_info_from_db() - - @classmethod - def load_server_info_from_conf(cls): - path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF - conf = utils.file_utils.load_yaml_conf(path) - if not isinstance(conf, dict): - raise ValueError('invalid config file') - - local_path = path.with_name(f'local.{utils.SERVICE_CONF}') - if local_path.exists(): - local_conf = utils.file_utils.load_yaml_conf(local_path) - if not isinstance(local_conf, dict): - raise ValueError('invalid local config file') - conf.update(local_conf) - for k, v in conf.items(): - if isinstance(v, dict): - setattr(cls, k.upper(), v) - - @classmethod - def register(cls, server_name, server_info): - cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http")) - setattr(cls, server_name, server_info) - - @classmethod - def save(cls, service_config): - update_server = {} - for server_name, server_info in service_config.items(): - cls.parameter_check(server_info) - api_info = server_info.pop("api", {}) - for service_name, info in api_info.items(): - ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info) - cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http") - setattr(cls, server_name.upper(), server_info) - return update_server - - @classmethod - def parameter_check(cls, service_info): - if "host" in service_info and "port" in service_info: - cls.connection_test(service_info.get("host"), service_info.get("port")) - - @classmethod - def connection_test(cls, ip, port): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - result = s.connect_ex((ip, port)) - if result != 0: - raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}") - - @classmethod - def query(cls, service_name, default=None): - service_info = getattr(cls, service_name, default) - if not service_info: - service_info = utils.get_base_config(service_name, default) - return service_info - - @classmethod - @DB.connection_context() - def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]: - if server_name: - server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper()) - else: - server_list = ServerRegistryInfo.select() - return [server for server in server_list] - - @classmethod - @DB.connection_context() - def load_server_info_from_db(cls): - for server in cls.query_server_info_from_db(): - server_info = { - "host": server.f_host, - "port": server.f_port, - "protocol": server.f_protocol - } - setattr(cls, server.f_server_name.upper(), server_info) - - - @classmethod - @DB.connection_context() - def save_server_info_to_db(cls, server_name, host, port, protocol="http"): - server_info = { - "f_server_name": server_name, - "f_host": host, - "f_port": port, - "f_protocol": protocol - } - entity_model, status = ServerRegistryInfo.get_or_create( - f_server_name=server_name, - defaults=server_info) - if status is False: - for key in server_info: - setattr(entity_model, key, server_info[key]) - entity_model.save(force_insert=False) \ No newline at end of file diff --git a/web_server/db/services/__init__.py b/web_server/db/services/__init__.py deleted file mode 100644 index 9c9314bcc..000000000 --- a/web_server/db/services/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pathlib -import re -from .user_service import UserService - - -def duplicate_name(query_func, **kwargs): - fnm = kwargs["name"] - objs = query_func(**kwargs) - if not objs: return fnm - ext = pathlib.Path(fnm).suffix #.jpg - nm = re.sub(r"%s$"%ext, "", fnm) - r = re.search(r"\([0-9]+\)$", nm) - c = 0 - if r: - c = int(r.group(1)) - nm = re.sub(r"\([0-9]+\)$", "", nm) - c += 1 - nm = f"{nm}({c})" - if ext: nm += f"{ext}" - - kwargs["name"] = nm - return duplicate_name(query_func, **kwargs) - diff --git a/web_server/db/services/common_service.py b/web_server/db/services/common_service.py deleted file mode 100644 index 027f6f282..000000000 --- a/web_server/db/services/common_service.py +++ /dev/null @@ -1,153 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from datetime import datetime - -import peewee - -from web_server.db.db_models import DB -from web_server.utils import datetime_format - - -class CommonService: - model = None - - @classmethod - @DB.connection_context() - def query(cls, cols=None, reverse=None, order_by=None, **kwargs): - return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) - - @classmethod - @DB.connection_context() - def get_all(cls, cols=None, reverse=None, order_by=None): - if cols: - query_records = cls.model.select(*cols) - else: - query_records = cls.model.select() - if reverse is not None: - if not order_by or not hasattr(cls, order_by): - order_by = "create_time" - if reverse is True: - query_records = query_records.order_by(cls.model.getter_by(order_by).desc()) - elif reverse is False: - query_records = query_records.order_by(cls.model.getter_by(order_by).asc()) - return query_records - - @classmethod - @DB.connection_context() - def get(cls, **kwargs): - return cls.model.get(**kwargs) - - @classmethod - @DB.connection_context() - def get_or_none(cls, **kwargs): - try: - return cls.model.get(**kwargs) - except peewee.DoesNotExist: - return None - - @classmethod - @DB.connection_context() - def save(cls, **kwargs): - #if "id" not in kwargs: - # kwargs["id"] = get_uuid() - sample_obj = cls.model(**kwargs).save(force_insert=True) - return sample_obj - - @classmethod - @DB.connection_context() - def insert_many(cls, data_list, batch_size=100): - with DB.atomic(): - for i in range(0, len(data_list), batch_size): - cls.model.insert_many(data_list[i:i + batch_size]).execute() - - @classmethod - @DB.connection_context() - def update_many_by_id(cls, data_list): - cur = datetime_format(datetime.now()) - with DB.atomic(): - for data in data_list: - data["update_time"] = cur - cls.model.update(data).where(cls.model.id == data["id"]).execute() - - @classmethod - @DB.connection_context() - def update_by_id(cls, pid, data): - data["update_time"] = datetime_format(datetime.now()) - num = cls.model.update(data).where(cls.model.id == pid).execute() - return num - - @classmethod - @DB.connection_context() - def get_by_id(cls, pid): - try: - obj = cls.model.query(id=pid)[0] - return True, obj - except Exception as e: - return False, None - - @classmethod - @DB.connection_context() - def get_by_ids(cls, pids, cols=None): - if cols: - objs = cls.model.select(*cols) - else: - objs = cls.model.select() - return objs.where(cls.model.id.in_(pids)) - - @classmethod - @DB.connection_context() - def delete_by_id(cls, pid): - return cls.model.delete().where(cls.model.id == pid).execute() - - - @classmethod - @DB.connection_context() - def filter_delete(cls, filters): - with DB.atomic(): - num = cls.model.delete().where(*filters).execute() - return num - - @classmethod - @DB.connection_context() - def filter_update(cls, filters, update_data): - with DB.atomic(): - cls.model.update(update_data).where(*filters).execute() - - @staticmethod - def cut_list(tar_list, n): - length = len(tar_list) - arr = range(length) - result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]] - return result - - @classmethod - @DB.connection_context() - def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None): - in_filters_tuple_list = cls.cut_list(in_filters_list, 20) - if not filters: - filters = [] - res_list = [] - if cols: - for i in in_filters_tuple_list: - query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters) - if query_records: - res_list.extend([query_record for query_record in query_records]) - else: - for i in in_filters_tuple_list: - query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters) - if query_records: - res_list.extend([query_record for query_record in query_records]) - return res_list \ No newline at end of file diff --git a/web_server/db/services/dialog_service.py b/web_server/db/services/dialog_service.py deleted file mode 100644 index e73217a7d..000000000 --- a/web_server/db/services/dialog_service.py +++ /dev/null @@ -1,35 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import peewee -from werkzeug.security import generate_password_hash, check_password_hash - -from web_server.db.db_models import DB, UserTenant -from web_server.db.db_models import Dialog, Conversation, DialogKb -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum - - -class DialogService(CommonService): - model = Dialog - - -class ConversationService(CommonService): - model = Conversation - - -class DialogKbService(CommonService): - model = DialogKb diff --git a/web_server/db/services/document_service.py b/web_server/db/services/document_service.py deleted file mode 100644 index 38b1cd559..000000000 --- a/web_server/db/services/document_service.py +++ /dev/null @@ -1,89 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from peewee import Expression - -from web_server.db import TenantPermission, FileType -from web_server.db.db_models import DB, Knowledgebase, Tenant -from web_server.db.db_models import Document -from web_server.db.services.common_service import CommonService -from web_server.db.services.kb_service import KnowledgebaseService -from web_server.db.db_utils import StatusEnum - - -class DocumentService(CommonService): - model = Document - - @classmethod - @DB.connection_context() - def get_by_kb_id(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords): - if keywords: - docs = cls.model.select().where( - cls.model.kb_id == kb_id, - cls.model.name.like(f"%%{keywords}%%")) - else: - docs = cls.model.select().where(cls.model.kb_id == kb_id) - if desc: - docs = docs.order_by(cls.model.getter_by(orderby).desc()) - else: - docs = docs.order_by(cls.model.getter_by(orderby).asc()) - - docs = docs.paginate(page_number, items_per_page) - - return list(docs.dicts()) - - @classmethod - @DB.connection_context() - def insert(cls, doc): - if not cls.save(**doc): - raise RuntimeError("Database error (Document)!") - e, doc = cls.get_by_id(doc["id"]) - if not e: - raise RuntimeError("Database error (Document retrieval)!") - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - if not KnowledgebaseService.update_by_id( - kb.id, {"doc_num": kb.doc_num + 1}): - raise RuntimeError("Database error (Knowledgebase)!") - return doc - - @classmethod - @DB.connection_context() - def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): - fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time] - docs = cls.model.select(*fields) \ - .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ - .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ - .where( - cls.model.status == StatusEnum.VALID.value, - ~(cls.model.type == FileType.VIRTUAL.value), - cls.model.progress == 0, - cls.model.update_time >= tm, - (Expression(cls.model.create_time, "%%", comm) == mod))\ - .order_by(cls.model.update_time.asc())\ - .paginate(1, items_per_page) - return list(docs.dicts()) - - @classmethod - @DB.connection_context() - def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): - num = cls.model.update(token_num=cls.model.token_num + token_num, - chunk_num=cls.model.chunk_num + chunk_num, - process_duation=cls.model.process_duation+duation).where( - cls.model.id == doc_id).execute() - if num == 0:raise LookupError("Document not found which is supposed to be there") - num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() - return num - diff --git a/web_server/db/services/kb_service.py b/web_server/db/services/kb_service.py deleted file mode 100644 index a8ca96a2a..000000000 --- a/web_server/db/services/kb_service.py +++ /dev/null @@ -1,70 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import peewee -from werkzeug.security import generate_password_hash, check_password_hash - -from web_server.db import TenantPermission -from web_server.db.db_models import DB, UserTenant, Tenant -from web_server.db.db_models import Knowledgebase -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum - - -class KnowledgebaseService(CommonService): - model = Knowledgebase - - @classmethod - @DB.connection_context() - def get_by_tenant_ids(cls, joined_tenant_ids, user_id, - page_number, items_per_page, orderby, desc): - kbs = cls.model.select().where( - ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == - TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) - & (cls.model.status == StatusEnum.VALID.value) - ) - if desc: - kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) - else: - kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) - - kbs = kbs.paginate(page_number, items_per_page) - - return list(kbs.dicts()) - - @classmethod - @DB.connection_context() - def get_detail(cls, kb_id): - fields = [ - cls.model.id, - Tenant.embd_id, - cls.model.avatar, - cls.model.name, - cls.model.description, - cls.model.permission, - cls.model.doc_num, - cls.model.token_num, - cls.model.chunk_num, - cls.model.parser_id] - kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( - (cls.model.id == kb_id), - (cls.model.status == StatusEnum.VALID.value) - ) - if not kbs: - return - d = kbs[0].to_dict() - d["embd_id"] = kbs[0].tenant.embd_id - return d diff --git a/web_server/db/services/knowledgebase_service.py b/web_server/db/services/knowledgebase_service.py deleted file mode 100644 index d5e8b34fe..000000000 --- a/web_server/db/services/knowledgebase_service.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import peewee -from werkzeug.security import generate_password_hash, check_password_hash - -from web_server.db.db_models import DB, UserTenant -from web_server.db.db_models import Knowledgebase, Document -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum - - -class KnowledgebaseService(CommonService): - model = Knowledgebase - - -class DocumentService(CommonService): - model = Document diff --git a/web_server/db/services/llm_service.py b/web_server/db/services/llm_service.py deleted file mode 100644 index 350106e36..000000000 --- a/web_server/db/services/llm_service.py +++ /dev/null @@ -1,53 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import peewee -from werkzeug.security import generate_password_hash, check_password_hash - -from web_server.db.db_models import DB, UserTenant -from web_server.db.db_models import LLMFactories, LLM, TenantLLM -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum - - -class LLMFactoriesService(CommonService): - model = LLMFactories - - -class LLMService(CommonService): - model = LLM - - -class TenantLLMService(CommonService): - model = TenantLLM - - @classmethod - @DB.connection_context() - def get_api_key(cls, tenant_id, model_type): - objs = cls.query(tenant_id=tenant_id, model_type=model_type) - if objs and len(objs)>0 and objs[0].llm_name: - return objs[0] - - fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key] - objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where( - (cls.model.tenant_id == tenant_id), - (cls.model.model_type == model_type), - (LLM.status == StatusEnum.VALID) - ) - - if not objs:return - return objs[0] - diff --git a/web_server/db/services/user_service.py b/web_server/db/services/user_service.py deleted file mode 100644 index f4ed4b58c..000000000 --- a/web_server/db/services/user_service.py +++ /dev/null @@ -1,105 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import peewee -from werkzeug.security import generate_password_hash, check_password_hash - -from web_server.db import UserTenantRole -from web_server.db.db_models import DB, UserTenant -from web_server.db.db_models import User, Tenant -from web_server.db.services.common_service import CommonService -from web_server.utils import get_uuid, get_format_time -from web_server.db.db_utils import StatusEnum - - -class UserService(CommonService): - model = User - - @classmethod - @DB.connection_context() - def filter_by_id(cls, user_id): - try: - user = cls.model.select().where(cls.model.id == user_id).get() - return user - except peewee.DoesNotExist: - return None - - @classmethod - @DB.connection_context() - def query_user(cls, email, password): - user = cls.model.select().where((cls.model.email == email), - (cls.model.status == StatusEnum.VALID.value)).first() - if user and check_password_hash(str(user.password), password): - return user - else: - return None - - @classmethod - @DB.connection_context() - def save(cls, **kwargs): - if "id" not in kwargs: - kwargs["id"] = get_uuid() - if "password" in kwargs: - kwargs["password"] = generate_password_hash(str(kwargs["password"])) - obj = cls.model(**kwargs).save(force_insert=True) - return obj - - - @classmethod - @DB.connection_context() - def delete_user(cls, user_ids, update_user_dict): - with DB.atomic(): - cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute() - - @classmethod - @DB.connection_context() - def update_user(cls, user_id, user_dict): - date_time = get_format_time() - with DB.atomic(): - if user_dict: - user_dict["update_time"] = date_time - cls.model.update(user_dict).where(cls.model.id == user_id).execute() - - -class TenantService(CommonService): - model = Tenant - - @classmethod - @DB.connection_context() - def get_by_user_id(cls, user_id): - fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role] - return list(cls.model.select(*fields)\ - .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ - .where(cls.model.status == StatusEnum.VALID.value).dicts()) - - @classmethod - @DB.connection_context() - def get_joined_tenants_by_user_id(cls, user_id): - fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] - return list(cls.model.select(*fields)\ - .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ - .where(cls.model.status == StatusEnum.VALID.value).dicts()) - - -class UserTenantService(CommonService): - model = UserTenant - - @classmethod - @DB.connection_context() - def save(cls, **kwargs): - if "id" not in kwargs: - kwargs["id"] = get_uuid() - obj = cls.model(**kwargs).save(force_insert=True) - return obj diff --git a/web_server/errors/__init__.py b/web_server/errors/__init__.py deleted file mode 100644 index 358eb012b..000000000 --- a/web_server/errors/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .general_error import * - - -class FateFlowError(Exception): - message = 'Unknown Fate Flow Error' - - def __init__(self, message=None, *args, **kwargs): - message = str(message) if message is not None else self.message - message = message.format(*args, **kwargs) - super().__init__(message) \ No newline at end of file diff --git a/web_server/errors/error_services.py b/web_server/errors/error_services.py deleted file mode 100644 index f391a9188..000000000 --- a/web_server/errors/error_services.py +++ /dev/null @@ -1,13 +0,0 @@ -from web_server.errors import FateFlowError - -__all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', - 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] - - -class ServicesError(FateFlowError): - message = 'Unknown services error' - - -class ServiceNotSupported(ServicesError): - message = 'The service {service_name} is not supported' - diff --git a/web_server/errors/general_error.py b/web_server/errors/general_error.py deleted file mode 100644 index f4fd3fb88..000000000 --- a/web_server/errors/general_error.py +++ /dev/null @@ -1,21 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -class ParameterError(Exception): - pass - - -class PassError(Exception): - pass \ No newline at end of file diff --git a/web_server/flask_session/2029240f6d1128be89ddc32729463129 b/web_server/flask_session/2029240f6d1128be89ddc32729463129 deleted file mode 100644 index 60b84f8bf0af235343c89653c31a85c904ebfc66..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9 QcmZQzU|?uq^=8lm00XQ5{{R30 diff --git a/web_server/hook/__init__.py b/web_server/hook/__init__.py deleted file mode 100644 index 3c21c0718..000000000 --- a/web_server/hook/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -import importlib - -from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ - SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters -from web_server.settings import HOOK_MODULE, stat_logger,RetCode - - -class HookManager: - SITE_SIGNATURE = [] - SITE_AUTHENTICATION = [] - CLIENT_AUTHENTICATION = [] - PERMISSION_CHECK = [] - - @staticmethod - def init(): - if HOOK_MODULE is not None: - for modules in HOOK_MODULE.values(): - for module in modules.split(";"): - try: - importlib.import_module(module) - except Exception as e: - stat_logger.exception(e) - - @staticmethod - def register_site_signature_hook(func): - HookManager.SITE_SIGNATURE.append(func) - - @staticmethod - def register_site_authentication_hook(func): - HookManager.SITE_AUTHENTICATION.append(func) - - @staticmethod - def register_client_authentication_hook(func): - HookManager.CLIENT_AUTHENTICATION.append(func) - - @staticmethod - def register_permission_check_hook(func): - HookManager.PERMISSION_CHECK.append(func) - - @staticmethod - def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: - if HookManager.CLIENT_AUTHENTICATION: - return HookManager.CLIENT_AUTHENTICATION[0](parm) - return ClientAuthenticationReturn() - - @staticmethod - def site_signature(parm: SignatureParameters) -> SignatureReturn: - if HookManager.SITE_SIGNATURE: - return HookManager.SITE_SIGNATURE[0](parm) - return SignatureReturn() - - @staticmethod - def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn: - if HookManager.SITE_AUTHENTICATION: - return HookManager.SITE_AUTHENTICATION[0](parm) - return AuthenticationReturn() - diff --git a/web_server/hook/api/client_authentication.py b/web_server/hook/api/client_authentication.py deleted file mode 100644 index 99e93892d..000000000 --- a/web_server/hook/api/client_authentication.py +++ /dev/null @@ -1,29 +0,0 @@ -import requests - -from web_server.db.service_registry import ServiceRegistry -from web_server.settings import RegistryServiceName -from web_server.hook import HookManager -from web_server.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn -from web_server.settings import HOOK_SERVER_NAME - - -@HookManager.register_client_authentication_hook -def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: - service_list = ServiceRegistry.load_service( - server_name=HOOK_SERVER_NAME, - service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value - ) - if not service_list: - raise Exception(f"client authentication error: no found server" - f" {HOOK_SERVER_NAME} service client_authentication") - service = service_list[0] - response = getattr(requests, service.f_method.lower(), None)( - url=service.f_url, - json=parm.to_dict() - ) - if response.status_code != 200: - raise Exception( - f"client authentication error: request authentication url failed, status code {response.status_code}") - elif response.json().get("code") != 0: - return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) - return ClientAuthenticationReturn() \ No newline at end of file diff --git a/web_server/hook/api/permission.py b/web_server/hook/api/permission.py deleted file mode 100644 index 318173d0e..000000000 --- a/web_server/hook/api/permission.py +++ /dev/null @@ -1,25 +0,0 @@ -import requests - -from web_server.db.service_registry import ServiceRegistry -from web_server.settings import RegistryServiceName -from web_server.hook import HookManager -from web_server.hook.common.parameters import PermissionCheckParameters, PermissionReturn -from web_server.settings import HOOK_SERVER_NAME - - -@HookManager.register_permission_check_hook -def permission(parm: PermissionCheckParameters) -> PermissionReturn: - service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value) - if not service_list: - raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission") - service = service_list[0] - response = getattr(requests, service.f_method.lower(), None)( - url=service.f_url, - json=parm.to_dict() - ) - if response.status_code != 200: - raise Exception( - f"permission check error: request permission url failed, status code {response.status_code}") - elif response.json().get("code") != 0: - return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg")) - return PermissionReturn() diff --git a/web_server/hook/api/site_authentication.py b/web_server/hook/api/site_authentication.py deleted file mode 100644 index bea3b7788..000000000 --- a/web_server/hook/api/site_authentication.py +++ /dev/null @@ -1,49 +0,0 @@ -import requests - -from web_server.db.service_registry import ServiceRegistry -from web_server.settings import RegistryServiceName -from web_server.hook import HookManager -from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ - SignatureReturn -from web_server.settings import HOOK_SERVER_NAME, PARTY_ID - - -@HookManager.register_site_signature_hook -def signature(parm: SignatureParameters) -> SignatureReturn: - service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value) - if not service_list: - raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature") - service = service_list[0] - response = getattr(requests, service.f_method.lower(), None)( - url=service.f_url, - json=parm.to_dict() - ) - if response.status_code == 200: - if response.json().get("code") == 0: - return SignatureReturn(site_signature=response.json().get("data")) - else: - raise Exception(f"signature error: request signature url failed, result: {response.json()}") - else: - raise Exception(f"signature error: request signature url failed, status code {response.status_code}") - - -@HookManager.register_site_authentication_hook -def authentication(parm: AuthenticationParameters) -> AuthenticationReturn: - if not parm.src_party_id or str(parm.src_party_id) == "0": - parm.src_party_id = PARTY_ID - service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, - service_name=RegistryServiceName.SITE_AUTHENTICATION.value) - if not service_list: - raise Exception( - f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication") - service = service_list[0] - response = getattr(requests, service.f_method.lower(), None)( - url=service.f_url, - json=parm.to_dict() - ) - if response.status_code != 200: - raise Exception( - f"site authentication error: request site_authentication url failed, status code {response.status_code}") - elif response.json().get("code") != 0: - return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) - return AuthenticationReturn() \ No newline at end of file diff --git a/web_server/hook/common/parameters.py b/web_server/hook/common/parameters.py deleted file mode 100644 index 40ce4ef19..000000000 --- a/web_server/hook/common/parameters.py +++ /dev/null @@ -1,56 +0,0 @@ -from web_server.settings import RetCode - - -class ParametersBase: - def to_dict(self): - d = {} - for k, v in self.__dict__.items(): - d[k] = v - return d - - -class ClientAuthenticationParameters(ParametersBase): - def __init__(self, full_path, headers, form, data, json): - self.full_path = full_path - self.headers = headers - self.form = form - self.data = data - self.json = json - - -class ClientAuthenticationReturn(ParametersBase): - def __init__(self, code=RetCode.SUCCESS, message="success"): - self.code = code - self.message = message - - -class SignatureParameters(ParametersBase): - def __init__(self, party_id, body): - self.party_id = party_id - self.body = body - - -class SignatureReturn(ParametersBase): - def __init__(self, code=RetCode.SUCCESS, site_signature=None): - self.code = code - self.site_signature = site_signature - - -class AuthenticationParameters(ParametersBase): - def __init__(self, site_signature, body): - self.site_signature = site_signature - self.body = body - - -class AuthenticationReturn(ParametersBase): - def __init__(self, code=RetCode.SUCCESS, message="success"): - self.code = code - self.message = message - - -class PermissionReturn(ParametersBase): - def __init__(self, code=RetCode.SUCCESS, message="success"): - self.code = code - self.message = message - - diff --git a/web_server/ragflow_server.py b/web_server/ragflow_server.py deleted file mode 100644 index 3bd9181d6..000000000 --- a/web_server/ragflow_server.py +++ /dev/null @@ -1,80 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# init env. must be the first import - -import logging -import os -import signal -import sys -import traceback - -from werkzeug.serving import run_simple - -from web_server.apps import app -from web_server.db.runtime_config import RuntimeConfig -from web_server.hook import HookManager -from web_server.settings import ( - HOST, HTTP_PORT, access_logger, database_logger, stat_logger, -) -from web_server import utils - -from web_server.db.db_models import init_database_tables as init_web_db -from web_server.db.init_data import init_web_data -from web_server.versions import get_versions - -if __name__ == '__main__': - stat_logger.info( - f'project base: {utils.file_utils.get_project_base_directory()}' - ) - - # init db - init_web_db() - init_web_data() - # init runtime config - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('--version', default=False, help="fate flow version", action='store_true') - parser.add_argument('--debug', default=False, help="debug mode", action='store_true') - args = parser.parse_args() - if args.version: - print(get_versions()) - sys.exit(0) - - RuntimeConfig.DEBUG = args.debug - if RuntimeConfig.DEBUG: - stat_logger.info("run on debug mode") - - RuntimeConfig.init_env() - RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) - - HookManager.init() - - peewee_logger = logging.getLogger('peewee') - peewee_logger.propagate = False - # fate_arch.common.log.ROpenHandler - peewee_logger.addHandler(database_logger.handlers[0]) - peewee_logger.setLevel(database_logger.level) - - # start http server - try: - stat_logger.info("FATE Flow http server start...") - werkzeug_logger = logging.getLogger("werkzeug") - for h in access_logger.handlers: - werkzeug_logger.addHandler(h) - run_simple(hostname=HOST, port=HTTP_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, use_debugger=RuntimeConfig.DEBUG) - except Exception: - traceback.print_exc() - os.kill(os.getpid(), signal.SIGKILL) \ No newline at end of file diff --git a/web_server/settings.py b/web_server/settings.py deleted file mode 100644 index a93efaa93..000000000 --- a/web_server/settings.py +++ /dev/null @@ -1,156 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os - -from enum import IntEnum, Enum - -from web_server.utils import get_base_config,decrypt_database_config -from web_server.utils.file_utils import get_project_base_directory -from web_server.utils.log_utils import LoggerFactory, getLogger - - -# Server -API_VERSION = "v1" -FATE_FLOW_SERVICE_NAME = "ragflow" -SERVER_MODULE = "rag_flow_server.py" -TEMP_DIRECTORY = os.path.join(get_project_base_directory(), "temp") -FATE_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf") - -SUBPROCESS_STD_LOG_NAME = "std.log" - -ERROR_REPORT = True -ERROR_REPORT_WITH_PATH = False - -MAX_TIMESTAMP_INTERVAL = 60 -SESSION_VALID_PERIOD = 7 * 24 * 60 * 60 * 1000 - -REQUEST_TRY_TIMES = 3 -REQUEST_WAIT_SEC = 2 -REQUEST_MAX_WAIT_SEC = 300 - -USE_REGISTRY = get_base_config("use_registry") - -LLM = get_base_config("llm", {}) -CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo") -EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002") -ASR_MDL = LLM.get("asr_model", "whisper-1") -PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report") -IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview") - -# distribution -DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) -FATE_FLOW_UPDATE_CHECK = False - -HOST = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") -HTTP_PORT = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("http_port") - -SECRET_KEY = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") -TOKEN_EXPIRE_IN = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) - -NGINX_HOST = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST -NGINX_HTTP_PORT = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT - -RANDOM_INSTANCE_ID = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) - -PROXY = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("proxy") -PROXY_PROTOCOL = get_base_config(FATE_FLOW_SERVICE_NAME, {}).get("protocol") - -DATABASE = decrypt_database_config() - -# Logger -LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "web_server")) -# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} -LoggerFactory.LEVEL = 10 - -stat_logger = getLogger("stat") -access_logger = getLogger("access") -database_logger = getLogger("database") - -# Switch -# upload -UPLOAD_DATA_FROM_CLIENT = True - -# authentication -AUTHENTICATION_CONF = get_base_config("authentication", {}) - -# client -CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False) -HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") -GITHUB_OAUTH = get_base_config("oauth", {}).get("github") -WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat") - -# site -SITE_AUTHENTICATION = AUTHENTICATION_CONF.get("site", {}).get("switch", False) - -# permission -PERMISSION_CONF = get_base_config("permission", {}) -PERMISSION_SWITCH = PERMISSION_CONF.get("switch") -COMPONENT_PERMISSION = PERMISSION_CONF.get("component") -DATASET_PERMISSION = PERMISSION_CONF.get("dataset") - -HOOK_MODULE = get_base_config("hook_module") -HOOK_SERVER_NAME = get_base_config("hook_server_name") - -ENABLE_MODEL_STORE = get_base_config('enable_model_store', False) -# authentication -USE_AUTHENTICATION = False -USE_DATA_AUTHENTICATION = False -AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True -USE_DEFAULT_TIMEOUT = False -AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s -PRIVILEGE_COMMAND_WHITELIST = [] -CHECK_NODES_IDENTITY = False - -class CustomEnum(Enum): - @classmethod - def valid(cls, value): - try: - cls(value) - return True - except: - return False - - @classmethod - def values(cls): - return [member.value for member in cls.__members__.values()] - - @classmethod - def names(cls): - return [member.name for member in cls.__members__.values()] - - -class PythonDependenceName(CustomEnum): - Fate_Source_Code = "python" - Python_Env = "miniconda" - - -class ModelStorage(CustomEnum): - REDIS = "redis" - MYSQL = "mysql" - - -class RetCode(IntEnum, CustomEnum): - SUCCESS = 0 - NOT_EFFECTIVE = 10 - EXCEPTION_ERROR = 100 - ARGUMENT_ERROR = 101 - DATA_ERROR = 102 - OPERATING_ERROR = 103 - CONNECTION_ERROR = 105 - RUNNING = 106 - PERMISSION_ERROR = 108 - AUTHENTICATION_ERROR = 109 - SERVER_ERROR = 500 diff --git a/web_server/utils/__init__.py b/web_server/utils/__init__.py deleted file mode 100644 index 57c11ba1c..000000000 --- a/web_server/utils/__init__.py +++ /dev/null @@ -1,321 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import base64 -from datetime import datetime -import io -import json -import os -import pickle -import socket -import time -import uuid -import requests -from enum import Enum, IntEnum -import importlib -from Cryptodome.PublicKey import RSA -from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 - -from filelock import FileLock - -from . import file_utils - -SERVICE_CONF = "service_conf.yaml" - -def conf_realpath(conf_name): - conf_path = f"conf/{conf_name}" - return os.path.join(file_utils.get_project_base_directory(), conf_path) - -def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: - local_config = {} - local_path = conf_realpath(f'local.{conf_name}') - if default is None: - default = os.environ.get(key.upper()) - - if os.path.exists(local_path): - local_config = file_utils.load_yaml_conf(local_path) - if not isinstance(local_config, dict): - raise ValueError(f'Invalid config file: "{local_path}".') - - if key is not None and key in local_config: - return local_config[key] - - config_path = conf_realpath(conf_name) - config = file_utils.load_yaml_conf(config_path) - - if not isinstance(config, dict): - raise ValueError(f'Invalid config file: "{config_path}".') - - config.update(local_config) - return config.get(key, default) if key is not None else config - - -use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False) - - -class CoordinationCommunicationProtocol(object): - HTTP = "http" - GRPC = "grpc" - - -class BaseType: - def to_dict(self): - return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()]) - - def to_dict_with_type(self): - def _dict(obj): - module = None - if issubclass(obj.__class__, BaseType): - data = {} - for attr, v in obj.__dict__.items(): - k = attr.lstrip("_") - data[k] = _dict(v) - module = obj.__module__ - elif isinstance(obj, (list, tuple)): - data = [] - for i, vv in enumerate(obj): - data.append(_dict(vv)) - elif isinstance(obj, dict): - data = {} - for _k, vv in obj.items(): - data[_k] = _dict(vv) - else: - data = obj - return {"type": obj.__class__.__name__, "data": data, "module": module} - return _dict(self) - - -class CustomJSONEncoder(json.JSONEncoder): - def __init__(self, **kwargs): - self._with_type = kwargs.pop("with_type", False) - super().__init__(**kwargs) - - def default(self, obj): - if isinstance(obj, datetime.datetime): - return obj.strftime('%Y-%m-%d %H:%M:%S') - elif isinstance(obj, datetime.date): - return obj.strftime('%Y-%m-%d') - elif isinstance(obj, datetime.timedelta): - return str(obj) - elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): - return obj.value - elif isinstance(obj, set): - return list(obj) - elif issubclass(type(obj), BaseType): - if not self._with_type: - return obj.to_dict() - else: - return obj.to_dict_with_type() - elif isinstance(obj, type): - return obj.__name__ - else: - return json.JSONEncoder.default(self, obj) - - -def rag_uuid(): - return uuid.uuid1().hex - - -def string_to_bytes(string): - return string if isinstance(string, bytes) else string.encode(encoding="utf-8") - - -def bytes_to_string(byte): - return byte.decode(encoding="utf-8") - - -def json_dumps(src, byte=False, indent=None, with_type=False): - dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type) - if byte: - dest = string_to_bytes(dest) - return dest - - -def json_loads(src, object_hook=None, object_pairs_hook=None): - if isinstance(src, bytes): - src = bytes_to_string(src) - return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook) - - -def current_timestamp(): - return int(time.time() * 1000) - - -def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"): - if not timestamp: - timestamp = time.time() - timestamp = int(timestamp) / 1000 - time_array = time.localtime(timestamp) - str_date = time.strftime(format_string, time_array) - return str_date - - -def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"): - time_array = time.strptime(time_str, format_string) - time_stamp = int(time.mktime(time_array) * 1000) - return time_stamp - - -def serialize_b64(src, to_str=False): - dest = base64.b64encode(pickle.dumps(src)) - if not to_str: - return dest - else: - return bytes_to_string(dest) - - -def deserialize_b64(src): - src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src) - if use_deserialize_safe_module: - return restricted_loads(src) - return pickle.loads(src) - - -safe_module = { - 'numpy', - 'fate_flow' -} - - -class RestrictedUnpickler(pickle.Unpickler): - def find_class(self, module, name): - import importlib - if module.split('.')[0] in safe_module: - _module = importlib.import_module(module) - return getattr(_module, name) - # Forbid everything else. - raise pickle.UnpicklingError("global '%s.%s' is forbidden" % - (module, name)) - - -def restricted_loads(src): - """Helper function analogous to pickle.loads().""" - return RestrictedUnpickler(io.BytesIO(src)).load() - - -def get_lan_ip(): - if os.name != "nt": - import fcntl - import struct - - def get_interface_ip(ifname): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - return socket.inet_ntoa( - fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24]) - - ip = socket.gethostbyname(socket.getfqdn()) - if ip.startswith("127.") and os.name != "nt": - interfaces = [ - "bond1", - "eth0", - "eth1", - "eth2", - "wlan0", - "wlan1", - "wifi0", - "ath0", - "ath1", - "ppp0", - ] - for ifname in interfaces: - try: - ip = get_interface_ip(ifname) - break - except IOError as e: - pass - return ip or '' - -def from_dict_hook(in_dict: dict): - if "type" in in_dict and "data" in in_dict: - if in_dict["module"] is None: - return in_dict["data"] - else: - return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"]) - else: - return in_dict - - -def decrypt_database_password(password): - encrypt_password = get_base_config("encrypt_password", False) - encrypt_module = get_base_config("encrypt_module", False) - private_key = get_base_config("private_key", None) - - if not password or not encrypt_password: - return password - - if not private_key: - raise ValueError("No private key") - - module_fun = encrypt_module.split("#") - pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1]) - - return pwdecrypt_fun(private_key, password) - - -def decrypt_database_config(database=None, passwd_key="passwd", name="database"): - if not database: - database = get_base_config(name, {}) - - database[passwd_key] = decrypt_database_password(database[passwd_key]) - return database - - -def update_config(key, value, conf_name=SERVICE_CONF): - conf_path = conf_realpath(conf_name=conf_name) - if not os.path.isabs(conf_path): - conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path) - - with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): - config = file_utils.load_yaml_conf(conf_path=conf_path) or {} - config[key] = value - file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) - - -def get_uuid(): - return uuid.uuid1().hex - - -def datetime_format(date_time: datetime) -> datetime: - return datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) - - -def get_format_time() -> datetime: - return datetime_format(datetime.now()) - - -def str2date(date_time: str): - return datetime.strptime(date_time, '%Y-%m-%d') - - -def elapsed2time(elapsed): - seconds = elapsed / 1000 - minuter, second = divmod(seconds, 60) - hour, minuter = divmod(minuter, 60) - return '%02d:%02d:%02d' % (hour, minuter, second) - - -def decrypt(line): - file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem") - rsa_key = RSA.importKey(open(file_path).read(), "Welcome") - cipher = Cipher_pkcs1_v1_5.new(rsa_key) - return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8') - - -def download_img(url): - if not url: return "" - response = requests.get(url) - return "data:" + \ - response.headers.get('Content-Type', 'image/jpg') + ";" + \ - "base64," + base64.b64encode(response.content).decode("utf-8") diff --git a/web_server/utils/api_utils.py b/web_server/utils/api_utils.py deleted file mode 100644 index 2933a0621..000000000 --- a/web_server/utils/api_utils.py +++ /dev/null @@ -1,212 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import json -import random -import time -from functools import wraps -from io import BytesIO -from flask import ( - Response, jsonify, send_file,make_response, - request as flask_request, -) -from werkzeug.http import HTTP_STATUS_CODES - -from web_server.utils import json_dumps -from web_server.versions import get_fate_version -from web_server.settings import RetCode -from web_server.settings import ( - REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, - stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY -) -import requests -import functools -from web_server.utils import CustomJSONEncoder -from uuid import uuid1 -from base64 import b64encode -from hmac import HMAC -from urllib.parse import quote, urlencode - - -requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) - - -def request(**kwargs): - sess = requests.Session() - stream = kwargs.pop('stream', sess.stream) - timeout = kwargs.pop('timeout', None) - kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()} - prepped = requests.Request(**kwargs).prepare() - - if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: - timestamp = str(round(time() * 1000)) - nonce = str(uuid1()) - signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([ - timestamp.encode('ascii'), - nonce.encode('ascii'), - HTTP_APP_KEY.encode('ascii'), - prepped.path_url.encode('ascii'), - prepped.body if kwargs.get('json') else b'', - urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii') - if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'', - ]), 'sha1').digest()).decode('ascii') - - prepped.headers.update({ - 'TIMESTAMP': timestamp, - 'NONCE': nonce, - 'APP-KEY': HTTP_APP_KEY, - 'SIGNATURE': signature, - }) - - return sess.send(prepped, stream=stream, timeout=timeout) - - -fate_version = get_fate_version() or '' - - -def get_exponential_backoff_interval(retries, full_jitter=False): - """Calculate the exponential backoff wait time.""" - # Will be zero if factor equals 0 - countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) - # Full jitter according to - # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ - if full_jitter: - countdown = random.randrange(countdown + 1) - # Adjust according to maximum wait time and account for negative values. - return max(0, countdown) - - -def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None): - import re - result_dict = { - "retcode": retcode, - "retmsg":retmsg, - # "retmsg": re.sub(r"fate", "seceum", retmsg, flags=re.IGNORECASE), - "data": data, - "jobId": job_id, - "meta": meta, - } - - response = {} - for key, value in result_dict.items(): - if value is None and key != "retcode": - continue - else: - response[key] = value - return jsonify(response) - -def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'): - import re - result_dict = {"retcode": retcode, "retmsg": re.sub(r"fate", "seceum", retmsg, flags=re.IGNORECASE)} - response = {} - for key, value in result_dict.items(): - if value is None and key != "retcode": - continue - else: - response[key] = value - return jsonify(response) - -def server_error_response(e): - stat_logger.exception(e) - try: - if e.code==401: - return get_json_result(retcode=401, retmsg=repr(e)) - except: - pass - if len(e.args) > 1: - return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) - return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e)) - - -def error_response(response_code, retmsg=None): - if retmsg is None: - retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error') - - return Response(json.dumps({ - 'retmsg': retmsg, - 'retcode': response_code, - }), status=response_code, mimetype='application/json') - - -def validate_request(*args, **kwargs): - def wrapper(func): - @wraps(func) - def decorated_function(*_args, **_kwargs): - input_arguments = flask_request.json or flask_request.form.to_dict() - no_arguments = [] - error_arguments = [] - for arg in args: - if arg not in input_arguments: - no_arguments.append(arg) - for k, v in kwargs.items(): - config_value = input_arguments.get(k, None) - if config_value is None: - no_arguments.append(k) - elif isinstance(v, (tuple, list)): - if config_value not in v: - error_arguments.append((k, set(v))) - elif config_value != v: - error_arguments.append((k, v)) - if no_arguments or error_arguments: - error_string = "" - if no_arguments: - error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) - if error_arguments: - error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) - return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) - return func(*_args, **_kwargs) - return decorated_function - return wrapper - - -def is_localhost(ip): - return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'} - - -def send_file_in_mem(data, filename): - if not isinstance(data, (str, bytes)): - data = json_dumps(data) - if isinstance(data, str): - data = data.encode('utf-8') - - f = BytesIO() - f.write(data) - f.seek(0) - - return send_file(f, as_attachment=True, attachment_filename=filename) - - -def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None): - response = {"retcode": retcode, "retmsg": retmsg, "data": data} - return jsonify(response) - - -def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None): - result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data} - response_dict = {} - for key, value in result_dict.items(): - if value is None and key != "retcode": - continue - else: - response_dict[key] = value - response = make_response(jsonify(response_dict)) - if auth: - response.headers["Authorization"] = auth - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Method"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" - response.headers["Access-Control-Expose-Headers"] = "Authorization" - return response \ No newline at end of file diff --git a/web_server/utils/file_utils.py b/web_server/utils/file_utils.py deleted file mode 100644 index 442ab19bf..000000000 --- a/web_server/utils/file_utils.py +++ /dev/null @@ -1,153 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import json -import os -import re - -from cachetools import LRUCache, cached -from ruamel.yaml import YAML - -from web_server.db import FileType - -PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") -FATE_BASE = os.getenv("RAG_BASE") - -def get_project_base_directory(*args): - global PROJECT_BASE - if PROJECT_BASE is None: - PROJECT_BASE = os.path.abspath( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.pardir, - os.pardir, - ) - ) - - if args: - return os.path.join(PROJECT_BASE, *args) - return PROJECT_BASE - - -def get_fate_directory(*args): - global FATE_BASE - if FATE_BASE is None: - FATE_BASE = os.path.abspath( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.pardir, - os.pardir, - os.pardir, - ) - ) - if args: - return os.path.join(FATE_BASE, *args) - return FATE_BASE - - -def get_fate_python_directory(*args): - return get_fate_directory("python", *args) - - - -@cached(cache=LRUCache(maxsize=10)) -def load_json_conf(conf_path): - if os.path.isabs(conf_path): - json_conf_path = conf_path - else: - json_conf_path = os.path.join(get_project_base_directory(), conf_path) - try: - with open(json_conf_path) as f: - return json.load(f) - except BaseException: - raise EnvironmentError( - "loading json file config from '{}' failed!".format(json_conf_path) - ) - - -def dump_json_conf(config_data, conf_path): - if os.path.isabs(conf_path): - json_conf_path = conf_path - else: - json_conf_path = os.path.join(get_project_base_directory(), conf_path) - try: - with open(json_conf_path, "w") as f: - json.dump(config_data, f, indent=4) - except BaseException: - raise EnvironmentError( - "loading json file config from '{}' failed!".format(json_conf_path) - ) - - -def load_json_conf_real_time(conf_path): - if os.path.isabs(conf_path): - json_conf_path = conf_path - else: - json_conf_path = os.path.join(get_project_base_directory(), conf_path) - try: - with open(json_conf_path) as f: - return json.load(f) - except BaseException: - raise EnvironmentError( - "loading json file config from '{}' failed!".format(json_conf_path) - ) - - -def load_yaml_conf(conf_path): - if not os.path.isabs(conf_path): - conf_path = os.path.join(get_project_base_directory(), conf_path) - try: - with open(conf_path) as f: - yaml = YAML(typ='safe', pure=True) - return yaml.load(f) - except Exception as e: - raise EnvironmentError( - "loading yaml file config from {} failed:".format(conf_path), e - ) - - -def rewrite_yaml_conf(conf_path, config): - if not os.path.isabs(conf_path): - conf_path = os.path.join(get_project_base_directory(), conf_path) - try: - with open(conf_path, "w") as f: - yaml = YAML(typ="safe") - yaml.dump(config, f) - except Exception as e: - raise EnvironmentError( - "rewrite yaml file config {} failed:".format(conf_path), e - ) - - -def rewrite_json_file(filepath, json_data): - with open(filepath, "w") as f: - json.dump(json_data, f, indent=4, separators=(",", ": ")) - f.close() - - -def filename_type(filename): - filename = filename.lower() - if re.match(r".*\.pdf$", filename): - return FileType.PDF.value - - if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): - return FileType.DOC.value - - if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): - return FileType.AURAL.value - - if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): - return FileType.VISUAL \ No newline at end of file diff --git a/web_server/utils/log_utils.py b/web_server/utils/log_utils.py deleted file mode 100644 index 5efe3c817..000000000 --- a/web_server/utils/log_utils.py +++ /dev/null @@ -1,299 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -import typing -import traceback -import logging -import inspect -from logging.handlers import TimedRotatingFileHandler -from threading import RLock - -from web_server.utils import file_utils - -class LoggerFactory(object): - TYPE = "FILE" - LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s" - LEVEL = logging.DEBUG - logger_dict = {} - global_handler_dict = {} - - LOG_DIR = None - PARENT_LOG_DIR = None - log_share = True - - append_to_parent_log = None - - lock = RLock() - # CRITICAL = 50 - # FATAL = CRITICAL - # ERROR = 40 - # WARNING = 30 - # WARN = WARNING - # INFO = 20 - # DEBUG = 10 - # NOTSET = 0 - levels = (10, 20, 30, 40) - schedule_logger_dict = {} - - @staticmethod - def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False): - if parent_log_dir: - LoggerFactory.PARENT_LOG_DIR = parent_log_dir - if append_to_parent_log: - LoggerFactory.append_to_parent_log = append_to_parent_log - with LoggerFactory.lock: - if not directory: - directory = file_utils.get_project_base_directory("logs") - if not LoggerFactory.LOG_DIR or force: - LoggerFactory.LOG_DIR = directory - if LoggerFactory.log_share: - oldmask = os.umask(000) - os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True) - os.umask(oldmask) - else: - os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True) - for loggerName, ghandler in LoggerFactory.global_handler_dict.items(): - for className, (logger, handler) in LoggerFactory.logger_dict.items(): - logger.removeHandler(ghandler) - ghandler.close() - LoggerFactory.global_handler_dict = {} - for className, (logger, handler) in LoggerFactory.logger_dict.items(): - logger.removeHandler(handler) - _handler = None - if handler: - handler.close() - if className != "default": - _handler = LoggerFactory.get_handler(className) - logger.addHandler(_handler) - LoggerFactory.assemble_global_handler(logger) - LoggerFactory.logger_dict[className] = logger, _handler - - @staticmethod - def new_logger(name): - logger = logging.getLogger(name) - logger.propagate = False - logger.setLevel(LoggerFactory.LEVEL) - return logger - - @staticmethod - def get_logger(class_name=None): - with LoggerFactory.lock: - if class_name in LoggerFactory.logger_dict.keys(): - logger, handler = LoggerFactory.logger_dict[class_name] - if not logger: - logger, handler = LoggerFactory.init_logger(class_name) - else: - logger, handler = LoggerFactory.init_logger(class_name) - return logger - - @staticmethod - def get_global_handler(logger_name, level=None, log_dir=None): - if not LoggerFactory.LOG_DIR: - return logging.StreamHandler() - if log_dir: - logger_name_key = logger_name + "_" + log_dir - else: - logger_name_key = logger_name + "_" + LoggerFactory.LOG_DIR - # if loggerName not in LoggerFactory.globalHandlerDict: - if logger_name_key not in LoggerFactory.global_handler_dict: - with LoggerFactory.lock: - if logger_name_key not in LoggerFactory.global_handler_dict: - handler = LoggerFactory.get_handler(logger_name, level, log_dir) - LoggerFactory.global_handler_dict[logger_name_key] = handler - return LoggerFactory.global_handler_dict[logger_name_key] - - @staticmethod - def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None): - if not log_type: - if not LoggerFactory.LOG_DIR or not class_name: - return logging.StreamHandler() - # return Diy_StreamHandler() - - if not log_dir: - log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name)) - else: - log_file = os.path.join(log_dir, "{}.log".format(class_name)) - else: - log_file = os.path.join(log_dir, "fate_flow_{}.log".format( - log_type) if level == LoggerFactory.LEVEL else 'fate_flow_{}_error.log'.format(log_type)) - job_id = job_id or os.getenv("FATE_JOB_ID") - if job_id: - formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", job_id)) - else: - formatter = logging.Formatter(LoggerFactory.LOG_FORMAT.replace("jobId", "Server")) - os.makedirs(os.path.dirname(log_file), exist_ok=True) - if LoggerFactory.log_share: - handler = ROpenHandler(log_file, - when='D', - interval=1, - backupCount=14, - delay=True) - else: - handler = TimedRotatingFileHandler(log_file, - when='D', - interval=1, - backupCount=14, - delay=True) - if level: - handler.level = level - - handler.setFormatter(formatter) - return handler - - @staticmethod - def init_logger(class_name): - with LoggerFactory.lock: - logger = LoggerFactory.new_logger(class_name) - handler = None - if class_name: - handler = LoggerFactory.get_handler(class_name) - logger.addHandler(handler) - LoggerFactory.logger_dict[class_name] = logger, handler - - else: - LoggerFactory.logger_dict["default"] = logger, handler - - LoggerFactory.assemble_global_handler(logger) - return logger, handler - - @staticmethod - def assemble_global_handler(logger): - if LoggerFactory.LOG_DIR: - for level in LoggerFactory.levels: - if level >= LoggerFactory.LEVEL: - level_logger_name = logging._levelToName[level] - logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level)) - if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR: - for level in LoggerFactory.levels: - if level >= LoggerFactory.LEVEL: - level_logger_name = logging._levelToName[level] - logger.addHandler( - LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR)) - - -def setDirectory(directory=None): - LoggerFactory.set_directory(directory) - - -def setLevel(level): - LoggerFactory.LEVEL = level - - -def getLogger(className=None, useLevelFile=False): - if className is None: - frame = inspect.stack()[1] - module = inspect.getmodule(frame[0]) - className = 'stat' - return LoggerFactory.get_logger(className) - - -def exception_to_trace_string(ex): - return "".join(traceback.TracebackException.from_exception(ex).format()) - - -class ROpenHandler(TimedRotatingFileHandler): - def _open(self): - prevumask = os.umask(000) - rtv = TimedRotatingFileHandler._open(self) - os.umask(prevumask) - return rtv - - -def sql_logger(job_id='', log_type='sql'): - key = job_id + log_type - if key in LoggerFactory.schedule_logger_dict.keys(): - return LoggerFactory.schedule_logger_dict[key] - return get_job_logger(job_id=job_id, log_type=log_type) - - -def ready_log(msg, job=None, task=None, role=None, party_id=None, detail=None): - prefix, suffix = base_msg(job, task, role, party_id, detail) - return f"{prefix}{msg} ready{suffix}" - - -def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None): - prefix, suffix = base_msg(job, task, role, party_id, detail) - return f"{prefix}start to {msg}{suffix}" - - -def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None): - prefix, suffix = base_msg(job, task, role, party_id, detail) - return f"{prefix}{msg} successfully{suffix}" - - -def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None): - prefix, suffix = base_msg(job, task, role, party_id, detail) - return f"{prefix}{msg} is not effective{suffix}" - - -def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None): - prefix, suffix = base_msg(job, task, role, party_id, detail) - return f"{prefix}failed to {msg}{suffix}" - - -def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None): - if detail: - detail_msg = f" detail: \n{detail}" - else: - detail_msg = "" - if task is not None: - return f"task {task.f_task_id} {task.f_task_version} ", f" on {task.f_role} {task.f_party_id}{detail_msg}" - elif job is not None: - return "", f" on {job.f_role} {job.f_party_id}{detail_msg}" - elif role and party_id: - return "", f" on {role} {party_id}{detail_msg}" - else: - return "", f"{detail_msg}" - - -def exception_to_trace_string(ex): - return "".join(traceback.TracebackException.from_exception(ex).format()) - - -def get_logger_base_dir(): - job_log_dir = file_utils.get_fate_flow_directory('logs') - return job_log_dir - - -def get_job_logger(job_id, log_type): - fate_flow_log_dir = file_utils.get_fate_flow_directory('logs', 'fate_flow') - job_log_dir = file_utils.get_fate_flow_directory('logs', job_id) - if not job_id: - log_dirs = [fate_flow_log_dir] - else: - if log_type == 'audit': - log_dirs = [job_log_dir, fate_flow_log_dir] - else: - log_dirs = [job_log_dir] - if LoggerFactory.log_share: - oldmask = os.umask(000) - os.makedirs(job_log_dir, exist_ok=True) - os.makedirs(fate_flow_log_dir, exist_ok=True) - os.umask(oldmask) - else: - os.makedirs(job_log_dir, exist_ok=True) - os.makedirs(fate_flow_log_dir, exist_ok=True) - logger = LoggerFactory.new_logger(f"{job_id}_{log_type}") - for job_log_dir in log_dirs: - handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL, - log_dir=job_log_dir, log_type=log_type, job_id=job_id) - error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id) - logger.addHandler(handler) - logger.addHandler(error_handler) - with LoggerFactory.lock: - LoggerFactory.schedule_logger_dict[job_id + log_type] = logger - return logger - diff --git a/web_server/utils/t_crypt.py b/web_server/utils/t_crypt.py deleted file mode 100644 index 1d007f49c..000000000 --- a/web_server/utils/t_crypt.py +++ /dev/null @@ -1,18 +0,0 @@ -import base64, os, sys -from Cryptodome.PublicKey import RSA -from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 -from web_server.utils import decrypt, file_utils - -def crypt(line): - file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") - rsa_key = RSA.importKey(open(file_path).read()) - cipher = Cipher_pkcs1_v1_5.new(rsa_key) - return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8") - - - -if __name__ == "__main__": - pswd = crypt(sys.argv[1]) - print(pswd) - print(decrypt(pswd)) - diff --git a/web_server/versions.py b/web_server/versions.py deleted file mode 100644 index e6e3cc135..000000000 --- a/web_server/versions.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os - -import dotenv -import typing - -from web_server.utils.file_utils import get_project_base_directory - - -def get_versions() -> typing.Mapping[str, typing.Any]: - return dotenv.dotenv_values( - dotenv_path=os.path.join(get_project_base_directory(), "rag.env") - ) - -def get_fate_version() -> typing.Optional[str]: - return get_versions().get("RAG") \ No newline at end of file