From 751b12c76a491a3cd969ba4d510057a91862a29a Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Mon, 15 Jan 2024 19:47:25 +0800 Subject: [PATCH] add alot of api (#23) * clean rust version project * clean rust version project * build python version rag-flow * add alot of api --- rag/llm/embedding_model.py | 2 +- rag/nlp/huchunk.py | 9 +- rag/nlp/search.py | 2 +- rag/ppdet/benchmark_utils.py | 3 +- rag/ppdet/det_keypoint_unite_infer.py | 6 +- rag/ppdet/infer.py | 7 +- rag/ppdet/keypoint_infer.py | 23 +- rag/ppdet/keypoint_postprocess.py | 2 + rag/ppdet/mot_centertrack_infer.py | 76 +-- rag/ppdet/mot_jde_infer.py | 77 +-- rag/ppdet/mot_keypoint_unite_infer.py | 32 +- rag/ppdet/mot_sde_infer.py | 522 +++++++++++++++++++++ rag/svr/parse_user_docs.py | 54 ++- rag/utils/__init__.py | 19 + rag/utils/es_conn.py | 1 + web_server/apps/document_app.py | 49 +- web_server/apps/kb_app.py | 16 +- web_server/apps/llm_app.py | 95 ++++ web_server/apps/user_app.py | 38 +- web_server/db/db_models.py | 11 +- web_server/db/services/document_service.py | 40 +- web_server/db/services/kb_service.py | 39 +- web_server/db/services/llm_service.py | 18 + web_server/db/services/user_service.py | 2 +- web_server/utils/file_utils.py | 2 +- 25 files changed, 1001 insertions(+), 144 deletions(-) create mode 100644 rag/ppdet/mot_sde_infer.py create mode 100644 web_server/apps/llm_app.py diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 551f9c60a..e148560c2 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -35,7 +35,7 @@ class Base(ABC): class HuEmbedding(Base): - def __init__(self): + def __init__(self, key="", model_name=""): """ If you have trouble downloading HuggingFace models, -_^ this might help!! diff --git a/rag/nlp/huchunk.py b/rag/nlp/huchunk.py index fb28a915c..cc93f5faf 100644 --- a/rag/nlp/huchunk.py +++ b/rag/nlp/huchunk.py @@ -411,9 +411,12 @@ class TextChunker(HuChunker): flds = self.Fields() if self.is_binary_file(fnm): return flds - with open(fnm, "r") as f: - txt = f.read() - flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] + txt = "" + if isinstance(fnm, str): + with open(fnm, "r") as f: + txt = f.read() + else: txt = fnm.decode("utf-8") + flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] flds.table_chunks = [] return flds diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 05ce6276f..d79640b4c 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -8,7 +8,7 @@ from rag.nlp import huqie, query import numpy as np -def index_name(uid): return f"docgpt_{uid}" +def index_name(uid): return f"ragflow_{uid}" class Dealer: diff --git a/rag/ppdet/benchmark_utils.py b/rag/ppdet/benchmark_utils.py index 5894caf3f..118e5efe1 100644 --- a/rag/ppdet/benchmark_utils.py +++ b/rag/ppdet/benchmark_utils.py @@ -158,8 +158,9 @@ class PaddleInferBenchmark(object): return: config_status(dict): dict style config info """ - config_status = {} + if isinstance(config, paddle_infer.Config): + config_status = {} config_status['runtime_device'] = "gpu" if config.use_gpu( ) else "cpu" config_status['ir_optim'] = config.ir_optim() diff --git a/rag/ppdet/det_keypoint_unite_infer.py b/rag/ppdet/det_keypoint_unite_infer.py index 65f573d60..1ed93cd89 100644 --- a/rag/ppdet/det_keypoint_unite_infer.py +++ b/rag/ppdet/det_keypoint_unite_infer.py @@ -22,9 +22,11 @@ import yaml from det_keypoint_unite_utils import argsparser from preprocess import decode_image -from infer import print_arguments, get_test_images, bench_log -from keypoint_infer import KeyPointDetector +from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log +from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint from visualize import visualize_pose +from benchmark_utils import PaddleInferBenchmark + from utils import get_current_memory_mb from keypoint_postprocess import translate_to_ori_images diff --git a/rag/ppdet/infer.py b/rag/ppdet/infer.py index d132ceb25..fcdc8a2c3 100644 --- a/rag/ppdet/infer.py +++ b/rag/ppdet/infer.py @@ -17,6 +17,8 @@ import yaml import glob import json from pathlib import Path +from functools import reduce + import cv2 import numpy as np @@ -26,12 +28,15 @@ from paddle.inference import Config from paddle.inference import create_predictor import sys +# add deploy path of PaddleDetection to sys.path + parent_path = os.path.abspath(os.path.join(__file__, *(['..']))) sys.path.insert(0, parent_path) from benchmark_utils import PaddleInferBenchmark from picodet_postprocess import PicoDetPostProcess -from preprocess import preprocess +from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize +from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop from clrnet_postprocess import CLRNetPostProcess from visualize import visualize_box_mask, imshow_lanes from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid diff --git a/rag/ppdet/keypoint_infer.py b/rag/ppdet/keypoint_infer.py index eb4dea8ea..2c80ff484 100644 --- a/rag/ppdet/keypoint_infer.py +++ b/rag/ppdet/keypoint_infer.py @@ -11,21 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + import os -# add deploy path of PaddleDetection to sys.path -parent_path = os.path.abspath(os.path.join(__file__, *(['..']))) -import sys -sys.path.insert(0, parent_path) +import time import yaml +import glob +from functools import reduce + +from PIL import Image + import cv2 import math import numpy as np import paddle -from keypoint_preprocess import expand_crop +import sys +# add deploy path of PaddleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..']))) +sys.path.insert(0, parent_path) + +from preprocess import preprocess, NormalizeImage, Permute +from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess from visualize import visualize_pose +from paddle.inference import Config +from paddle.inference import create_predictor + from utils import argsparser, Timer, get_current_memory_mb from benchmark_utils import PaddleInferBenchmark from infer import Detector, get_test_images, print_arguments diff --git a/rag/ppdet/keypoint_postprocess.py b/rag/ppdet/keypoint_postprocess.py index 4a930676b..69f1d3fd9 100644 --- a/rag/ppdet/keypoint_postprocess.py +++ b/rag/ppdet/keypoint_postprocess.py @@ -17,6 +17,8 @@ from collections import abc, defaultdict import cv2 import numpy as np import math +import paddle +import paddle.nn as nn from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform diff --git a/rag/ppdet/mot_centertrack_infer.py b/rag/ppdet/mot_centertrack_infer.py index 9caf8bbd3..af28d6403 100644 --- a/rag/ppdet/mot_centertrack_infer.py +++ b/rag/ppdet/mot_centertrack_infer.py @@ -15,16 +15,18 @@ import os import copy import math +import time +import yaml import cv2 import numpy as np from collections import defaultdict import paddle -from rag.ppdet import MOTTimer -from utils import gaussian_radius, draw_umich_gaussian -from preprocess import preprocess, decode_image +from benchmark_utils import PaddleInferBenchmark +from utils import gaussian_radius, gaussian2D, draw_umich_gaussian +from preprocess import preprocess, decode_image, WarpAffine, NormalizeImage, Permute from utils import argsparser, Timer, get_current_memory_mb -from infer import Detector, get_test_images, print_arguments, bench_log +from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig from keypoint_preprocess import get_affine_transform # add python path @@ -32,6 +34,9 @@ import sys parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) sys.path.insert(0, parent_path) +from pptracking.python.mot import CenterTracker +from pptracking.python.mot.utils import MOTTimer, write_mot_results +from pptracking.python.mot.visualize import plot_tracking def transform_preds_with_trans(coords, trans): @@ -119,12 +124,12 @@ class CenterTrack(Detector): track_thresh = cfg.get('track_thresh', 0.4) pre_thresh = cfg.get('pre_thresh', 0.5) - # self.tracker = CenterTracker( - # num_classes=self.num_classes, - # min_box_area=min_box_area, - # vertical_ratio=vertical_ratio, - # track_thresh=track_thresh, - # pre_thresh=pre_thresh) + self.tracker = CenterTracker( + num_classes=self.num_classes, + min_box_area=min_box_area, + vertical_ratio=vertical_ratio, + track_thresh=track_thresh, + pre_thresh=pre_thresh) self.pre_image = None @@ -359,20 +364,21 @@ class CenterTrack(Detector): print('Tracking frame {}'.format(frame_id)) frame, _ = decode_image(img_file, {}) - # im = plot_tracking( - # frame, - # online_tlwhs, - # online_ids, - # online_scores, - # frame_id=frame_id, - # ids2names=ids2names) + im = plot_tracking( + frame, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + ids2names=ids2names) + if seq_name is None: seq_name = image_list[0].split('/')[-2] save_dir = os.path.join(self.output_dir, seq_name) if not os.path.exists(save_dir): os.makedirs(save_dir) - # cv2.imwrite( - # os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) + cv2.imwrite( + os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) mot_results.append([online_tlwhs, online_scores, online_ids]) return mot_results @@ -422,27 +428,27 @@ class CenterTrack(Detector): online_tlwhs, online_scores, online_ids = mot_results[0] results[0].append( (frame_id + 1, online_tlwhs, online_scores, online_ids)) - # im = plot_tracking( - # frame, - # online_tlwhs, - # online_ids, - # online_scores, - # frame_id=frame_id, - # fps=fps, - # ids2names=ids2names) - # - # writer.write(im) - # if camera_id != -1: - # cv2.imshow('Mask Detection', im) - # if cv2.waitKey(1) & 0xFF == ord('q'): - # break + im = plot_tracking( + frame, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + fps=fps, + ids2names=ids2names) + + writer.write(im) + if camera_id != -1: + cv2.imshow('Mask Detection', im) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + if self.save_mot_txts: result_filename = os.path.join( self.output_dir, video_out_name.split('.')[-2] + '.txt') - #write_mot_results(result_filename, results, data_type, num_classes) - + write_mot_results(result_filename, results, data_type, num_classes) writer.release() diff --git a/rag/ppdet/mot_jde_infer.py b/rag/ppdet/mot_jde_infer.py index 71f76c176..4d1e6fe82 100644 --- a/rag/ppdet/mot_jde_infer.py +++ b/rag/ppdet/mot_jde_infer.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import time +import yaml import cv2 import numpy as np from collections import defaultdict @@ -20,7 +22,6 @@ import paddle from benchmark_utils import PaddleInferBenchmark from preprocess import decode_image -from rag.ppdet import MOTTimer from utils import argsparser, Timer, get_current_memory_mb from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig @@ -29,6 +30,9 @@ import sys parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) sys.path.insert(0, parent_path) +from pptracking.python.mot import JDETracker +from pptracking.python.mot.utils import MOTTimer, write_mot_results +from pptracking.python.mot.visualize import plot_tracking_dict # Global dictionary MOT_JDE_SUPPORT_MODELS = { @@ -102,13 +106,14 @@ class JDE_Detector(Detector): tracked_thresh = cfg.get('tracked_thresh', 0.7) metric_type = cfg.get('metric_type', 'euclidean') - # self.tracker = JDETracker( - # num_classes=self.num_classes, - # min_box_area=min_box_area, - # vertical_ratio=vertical_ratio, - # conf_thres=conf_thres, - # tracked_thresh=tracked_thresh, - # metric_type=metric_type) + self.tracker = JDETracker( + num_classes=self.num_classes, + min_box_area=min_box_area, + vertical_ratio=vertical_ratio, + conf_thres=conf_thres, + tracked_thresh=tracked_thresh, + metric_type=metric_type) + def postprocess(self, inputs, result): # postprocess output of predictor @@ -235,21 +240,23 @@ class JDE_Detector(Detector): print('Tracking frame {}'.format(frame_id)) frame, _ = decode_image(img_file, {}) - # im = plot_tracking_dict( - # frame, - # num_classes, - # online_tlwhs, - # online_ids, - # online_scores, - # frame_id=frame_id, - # ids2names=ids2names) + im = plot_tracking_dict( + frame, + num_classes, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + ids2names=ids2names) + if seq_name is None: seq_name = image_list[0].split('/')[-2] save_dir = os.path.join(self.output_dir, seq_name) if not os.path.exists(save_dir): os.makedirs(save_dir) - # cv2.imwrite( - # os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) + cv2.imwrite( + os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) + mot_results.append([online_tlwhs, online_scores, online_ids]) return mot_results @@ -301,28 +308,28 @@ class JDE_Detector(Detector): (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id], online_ids[cls_id])) - #fps = 1. / timer.duration - # im = plot_tracking_dict( - # frame, - # num_classes, - # online_tlwhs, - # online_ids, - # online_scores, - # frame_id=frame_id, - # fps=fps, - # ids2names=ids2names) - # - # writer.write(im) - # if camera_id != -1: - # cv2.imshow('Mask Detection', im) - # if cv2.waitKey(1) & 0xFF == ord('q'): - # break + fps = 1. / timer.duration + im = plot_tracking_dict( + frame, + num_classes, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + fps=fps, + ids2names=ids2names) + + writer.write(im) + if camera_id != -1: + cv2.imshow('Mask Detection', im) + if cv2.waitKey(1) & 0xFF == ord('q'): + break if self.save_mot_txts: result_filename = os.path.join( self.output_dir, video_out_name.split('.')[-2] + '.txt') - #write_mot_results(result_filename, results, data_type, num_classes) + write_mot_results(result_filename, results, data_type, num_classes) writer.release() diff --git a/rag/ppdet/mot_keypoint_unite_infer.py b/rag/ppdet/mot_keypoint_unite_infer.py index 70b10394d..d129ac73b 100644 --- a/rag/ppdet/mot_keypoint_unite_infer.py +++ b/rag/ppdet/mot_keypoint_unite_infer.py @@ -11,26 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections + import os import json -import time - import cv2 +import math + import numpy as np import paddle import yaml import copy +from collections import defaultdict from mot_keypoint_unite_utils import argsparser from preprocess import decode_image from infer import print_arguments, get_test_images, bench_log -from mot_jde_infer import MOT_JDE_SUPPORT_MODELS +from mot_sde_infer import SDE_Detector +from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS from det_keypoint_unite_infer import predict_with_given_det -from rag.ppdet import MOTTimer from visualize import visualize_pose +from benchmark_utils import PaddleInferBenchmark from utils import get_current_memory_mb +from keypoint_postprocess import translate_to_ori_images + +# add python path +import sys +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) +sys.path.insert(0, parent_path) + +from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict +from pptracking.python.mot.utils import MOTTimer as FPSTimer def convert_mot_to_det(tlwhs, scores): @@ -140,7 +151,7 @@ def mot_topdown_unite_predict_video(mot_detector, fourcc = cv2.VideoWriter_fourcc(* 'mp4v') writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) frame_id = 0 - timer_mot, timer_kp, timer_mot_kp = MOTTimer(), MOTTimer(), MOTTimer() + timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer() num_classes = mot_detector.num_classes assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.' @@ -187,6 +198,15 @@ def mot_topdown_unite_predict_video(mot_detector, returnimg=True, ids=online_ids[0]) + im = plot_tracking_dict( + im, + num_classes, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + fps=mot_kp_fps) + writer.write(im) if camera_id != -1: cv2.imshow('Tracking and keypoint results', im) diff --git a/rag/ppdet/mot_sde_infer.py b/rag/ppdet/mot_sde_infer.py new file mode 100644 index 000000000..acfc940d5 --- /dev/null +++ b/rag/ppdet/mot_sde_infer.py @@ -0,0 +1,522 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import yaml +import cv2 +import numpy as np +from collections import defaultdict +import paddle + +from benchmark_utils import PaddleInferBenchmark +from preprocess import decode_image +from utils import argsparser, Timer, get_current_memory_mb +from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor + +# add python path +import sys +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) +sys.path.insert(0, parent_path) + +from pptracking.python.mot import JDETracker, DeepSORTTracker +from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box +from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict + + +class SDE_Detector(Detector): + """ + Args: + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + tracker_config (str): tracker config path + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + output_dir (string): The path of output, default as 'output' + threshold (float): Score threshold of the detected bbox, default as 0.5 + save_images (bool): Whether to save visualization image results, default as False + save_mot_txts (bool): Whether to save tracking results (txt), default as False + reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT + """ + + def __init__(self, + model_dir, + tracker_config, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + output_dir='output', + threshold=0.5, + save_images=False, + save_mot_txts=False, + reid_model_dir=None): + super(SDE_Detector, self).__init__( + model_dir=model_dir, + device=device, + run_mode=run_mode, + batch_size=batch_size, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn, + output_dir=output_dir, + threshold=threshold, ) + self.save_images = save_images + self.save_mot_txts = save_mot_txts + assert batch_size == 1, "MOT model only supports batch_size=1." + self.det_times = Timer(with_tracker=True) + self.num_classes = len(self.pred_config.labels) + + # reid config + self.use_reid = False if reid_model_dir is None else True + if self.use_reid: + self.reid_pred_config = self.set_config(reid_model_dir) + self.reid_predictor, self.config = load_predictor( + reid_model_dir, + run_mode=run_mode, + batch_size=50, # reid_batch_size + min_subgraph_size=self.reid_pred_config.min_subgraph_size, + device=device, + use_dynamic_shape=self.reid_pred_config.use_dynamic_shape, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) + else: + self.reid_pred_config = None + self.reid_predictor = None + + assert tracker_config is not None, 'Note that tracker_config should be set.' + self.tracker_config = tracker_config + tracker_cfg = yaml.safe_load(open(self.tracker_config)) + cfg = tracker_cfg[tracker_cfg['type']] + + # tracker config + self.use_deepsort_tracker = True if tracker_cfg[ + 'type'] == 'DeepSORTTracker' else False + if self.use_deepsort_tracker: + # use DeepSORTTracker + if self.reid_pred_config is not None and hasattr( + self.reid_pred_config, 'tracker'): + cfg = self.reid_pred_config.tracker + budget = cfg.get('budget', 100) + max_age = cfg.get('max_age', 30) + max_iou_distance = cfg.get('max_iou_distance', 0.7) + matching_threshold = cfg.get('matching_threshold', 0.2) + min_box_area = cfg.get('min_box_area', 0) + vertical_ratio = cfg.get('vertical_ratio', 0) + + self.tracker = DeepSORTTracker( + budget=budget, + max_age=max_age, + max_iou_distance=max_iou_distance, + matching_threshold=matching_threshold, + min_box_area=min_box_area, + vertical_ratio=vertical_ratio, ) + else: + # use ByteTracker + use_byte = cfg.get('use_byte', False) + det_thresh = cfg.get('det_thresh', 0.3) + min_box_area = cfg.get('min_box_area', 0) + vertical_ratio = cfg.get('vertical_ratio', 0) + match_thres = cfg.get('match_thres', 0.9) + conf_thres = cfg.get('conf_thres', 0.6) + low_conf_thres = cfg.get('low_conf_thres', 0.1) + + self.tracker = JDETracker( + use_byte=use_byte, + det_thresh=det_thresh, + num_classes=self.num_classes, + min_box_area=min_box_area, + vertical_ratio=vertical_ratio, + match_thres=match_thres, + conf_thres=conf_thres, + low_conf_thres=low_conf_thres, ) + + def postprocess(self, inputs, result): + # postprocess output of predictor + np_boxes_num = result['boxes_num'] + if np_boxes_num[0] <= 0: + print('[WARNNING] No object detected.') + result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]} + result = {k: v for k, v in result.items() if v is not None} + return result + + def reidprocess(self, det_results, repeats=1): + pred_dets = det_results['boxes'] + pred_xyxys = pred_dets[:, 2:6] + + ori_image = det_results['ori_image'] + ori_image_shape = ori_image.shape[:2] + pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape) + + if len(keep_idx[0]) == 0: + det_results['boxes'] = np.zeros((1, 6), dtype=np.float32) + det_results['embeddings'] = None + return det_results + + pred_dets = pred_dets[keep_idx[0]] + pred_xyxys = pred_dets[:, 2:6] + + w, h = self.tracker.input_size + crops = get_crops(pred_xyxys, ori_image, w, h) + + # to keep fast speed, only use topk crops + crops = crops[:50] # reid_batch_size + det_results['crops'] = np.array(crops).astype('float32') + det_results['boxes'] = pred_dets[:50] + + input_names = self.reid_predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.reid_predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(det_results[input_names[i]]) + + # model prediction + for i in range(repeats): + self.reid_predictor.run() + output_names = self.reid_predictor.get_output_names() + feature_tensor = self.reid_predictor.get_output_handle(output_names[ + 0]) + pred_embs = feature_tensor.copy_to_cpu() + + det_results['embeddings'] = pred_embs + return det_results + + def tracking(self, det_results): + pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1' + pred_embs = det_results.get('embeddings', None) + + if self.use_deepsort_tracker: + # use DeepSORTTracker, only support singe class + self.tracker.predict() + online_targets = self.tracker.update(pred_dets, pred_embs) + online_tlwhs, online_scores, online_ids = [], [], [] + for t in online_targets: + if not t.is_confirmed() or t.time_since_update > 1: + continue + tlwh = t.to_tlwh() + tscore = t.score + tid = t.track_id + if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ + 3] > self.tracker.vertical_ratio: + continue + online_tlwhs.append(tlwh) + online_scores.append(tscore) + online_ids.append(tid) + + tracking_outs = { + 'online_tlwhs': online_tlwhs, + 'online_scores': online_scores, + 'online_ids': online_ids, + } + return tracking_outs + else: + # use ByteTracker, support multiple class + online_tlwhs = defaultdict(list) + online_scores = defaultdict(list) + online_ids = defaultdict(list) + online_targets_dict = self.tracker.update(pred_dets, pred_embs) + for cls_id in range(self.num_classes): + online_targets = online_targets_dict[cls_id] + for t in online_targets: + tlwh = t.tlwh + tid = t.track_id + tscore = t.score + if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: + continue + if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[ + 3] > self.tracker.vertical_ratio: + continue + online_tlwhs[cls_id].append(tlwh) + online_ids[cls_id].append(tid) + online_scores[cls_id].append(tscore) + + tracking_outs = { + 'online_tlwhs': online_tlwhs, + 'online_scores': online_scores, + 'online_ids': online_ids, + } + return tracking_outs + + def predict_image(self, + image_list, + run_benchmark=False, + repeats=1, + visual=True, + seq_name=None): + num_classes = self.num_classes + image_list.sort() + ids2names = self.pred_config.labels + mot_results = [] + for frame_id, img_file in enumerate(image_list): + batch_image_list = [img_file] # bs=1 in MOT model + frame, _ = decode_image(img_file, {}) + if run_benchmark: + # preprocess + inputs = self.preprocess(batch_image_list) # warmup + self.det_times.preprocess_time_s.start() + inputs = self.preprocess(batch_image_list) + self.det_times.preprocess_time_s.end() + + # model prediction + result_warmup = self.predict(repeats=repeats) # warmup + self.det_times.inference_time_s.start() + result = self.predict(repeats=repeats) + self.det_times.inference_time_s.end(repeats=repeats) + + # postprocess + result_warmup = self.postprocess(inputs, result) # warmup + self.det_times.postprocess_time_s.start() + det_result = self.postprocess(inputs, result) + self.det_times.postprocess_time_s.end() + + # tracking + if self.use_reid: + det_result['frame_id'] = frame_id + det_result['seq_name'] = seq_name + det_result['ori_image'] = frame + det_result = self.reidprocess(det_result) + result_warmup = self.tracking(det_result) + self.det_times.tracking_time_s.start() + if self.use_reid: + det_result = self.reidprocess(det_result) + tracking_outs = self.tracking(det_result) + self.det_times.tracking_time_s.end() + self.det_times.img_num += 1 + + cm, gm, gu = get_current_memory_mb() + self.cpu_mem += cm + self.gpu_mem += gm + self.gpu_util += gu + + else: + self.det_times.preprocess_time_s.start() + inputs = self.preprocess(batch_image_list) + self.det_times.preprocess_time_s.end() + + self.det_times.inference_time_s.start() + result = self.predict() + self.det_times.inference_time_s.end() + + self.det_times.postprocess_time_s.start() + det_result = self.postprocess(inputs, result) + self.det_times.postprocess_time_s.end() + + # tracking process + self.det_times.tracking_time_s.start() + if self.use_reid: + det_result['frame_id'] = frame_id + det_result['seq_name'] = seq_name + det_result['ori_image'] = frame + det_result = self.reidprocess(det_result) + tracking_outs = self.tracking(det_result) + self.det_times.tracking_time_s.end() + self.det_times.img_num += 1 + + online_tlwhs = tracking_outs['online_tlwhs'] + online_scores = tracking_outs['online_scores'] + online_ids = tracking_outs['online_ids'] + + mot_results.append([online_tlwhs, online_scores, online_ids]) + + if visual: + if len(image_list) > 1 and frame_id % 10 == 0: + print('Tracking frame {}'.format(frame_id)) + frame, _ = decode_image(img_file, {}) + if isinstance(online_tlwhs, defaultdict): + im = plot_tracking_dict( + frame, + num_classes, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + ids2names=ids2names) + else: + im = plot_tracking( + frame, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + ids2names=ids2names) + save_dir = os.path.join(self.output_dir, seq_name) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + cv2.imwrite( + os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) + + return mot_results + + def predict_video(self, video_file, camera_id): + video_out_name = 'output.mp4' + if camera_id != -1: + capture = cv2.VideoCapture(camera_id) + else: + capture = cv2.VideoCapture(video_file) + video_out_name = os.path.split(video_file)[-1] + # Get Video info : resolution, fps, frame count + width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = int(capture.get(cv2.CAP_PROP_FPS)) + frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) + print("fps: %d, frame_count: %d" % (fps, frame_count)) + + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + out_path = os.path.join(self.output_dir, video_out_name) + video_format = 'mp4v' + fourcc = cv2.VideoWriter_fourcc(*video_format) + writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) + + frame_id = 1 + timer = MOTTimer() + results = defaultdict(list) + num_classes = self.num_classes + data_type = 'mcmot' if num_classes > 1 else 'mot' + ids2names = self.pred_config.labels + + while (1): + ret, frame = capture.read() + if not ret: + break + if frame_id % 10 == 0: + print('Tracking frame: %d' % (frame_id)) + frame_id += 1 + + timer.tic() + seq_name = video_out_name.split('.')[0] + mot_results = self.predict_image( + [frame[:, :, ::-1]], visual=False, seq_name=seq_name) + timer.toc() + + # bs=1 in MOT model + online_tlwhs, online_scores, online_ids = mot_results[0] + + fps = 1. / timer.duration + if self.use_deepsort_tracker: + # use DeepSORTTracker, only support singe class + results[0].append( + (frame_id + 1, online_tlwhs, online_scores, online_ids)) + im = plot_tracking( + frame, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + fps=fps, + ids2names=ids2names) + else: + # use ByteTracker, support multiple class + for cls_id in range(num_classes): + results[cls_id].append( + (frame_id + 1, online_tlwhs[cls_id], + online_scores[cls_id], online_ids[cls_id])) + im = plot_tracking_dict( + frame, + num_classes, + online_tlwhs, + online_ids, + online_scores, + frame_id=frame_id, + fps=fps, + ids2names=ids2names) + + writer.write(im) + if camera_id != -1: + cv2.imshow('Mask Detection', im) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + if self.save_mot_txts: + result_filename = os.path.join( + self.output_dir, video_out_name.split('.')[-2] + '.txt') + write_mot_results(result_filename, results) + + writer.release() + + +def main(): + deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml') + with open(deploy_file) as f: + yml_conf = yaml.safe_load(f) + arch = yml_conf['arch'] + detector = SDE_Detector( + FLAGS.model_dir, + tracker_config=FLAGS.tracker_config, + device=FLAGS.device, + run_mode=FLAGS.run_mode, + batch_size=1, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn, + output_dir=FLAGS.output_dir, + threshold=FLAGS.threshold, + save_images=FLAGS.save_images, + save_mot_txts=FLAGS.save_mot_txts, ) + + # predict from video file or camera video stream + if FLAGS.video_file is not None or FLAGS.camera_id != -1: + detector.predict_video(FLAGS.video_file, FLAGS.camera_id) + else: + # predict from image + if FLAGS.image_dir is None and FLAGS.image_file is not None: + assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models." + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + seq_name = FLAGS.image_dir.split('/')[-1] + detector.predict_image( + img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name) + + if not FLAGS.run_benchmark: + detector.det_times.info(average=True) + else: + mode = FLAGS.run_mode + model_dir = FLAGS.model_dir + model_info = { + 'model_name': model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + bench_log(detector, img_list, model_info, name='MOT') + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + print_arguments(FLAGS) + FLAGS.device = FLAGS.device.upper() + assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU' + ], "device should be CPU, GPU, NPU or XPU" + + main() diff --git a/rag/svr/parse_user_docs.py b/rag/svr/parse_user_docs.py index 29d2c2876..188662e6b 100644 --- a/rag/svr/parse_user_docs.py +++ b/rag/svr/parse_user_docs.py @@ -14,6 +14,7 @@ # limitations under the License. # import json +import logging import os import hashlib import copy @@ -24,9 +25,10 @@ from timeit import default_timer as timer from rag.llm import EmbeddingModel, CvModel from rag.settings import cron_logger, DOC_MAXIMUM_SIZE -from rag.utils import ELASTICSEARCH, num_tokens_from_string +from rag.utils import ELASTICSEARCH from rag.utils import MINIO -from rag.utils import rmSpace, findMaxDt +from rag.utils import rmSpace, findMaxTm + from rag.nlp import huchunk, huqie, search from io import BytesIO import pandas as pd @@ -47,6 +49,7 @@ from rag.nlp.huchunk import ( from web_server.db import LLMType from web_server.db.services.document_service import DocumentService from web_server.db.services.llm_service import TenantLLMService +from web_server.settings import database_logger from web_server.utils import get_format_time from web_server.utils.file_utils import get_project_base_directory @@ -83,7 +86,7 @@ def collect(comm, mod, tm): if len(docs) == 0: return pd.DataFrame() docs = pd.DataFrame(docs) - mtm = str(docs["update_time"].max())[:19] + mtm = docs["update_time"].max() cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) return docs @@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False): cron_logger.error("set_progress:({}), {}".format(docid, str(e))) -def build(row): +def build(row, cvmdl): if row["size"] > DOC_MAXIMUM_SIZE: set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) return [] + res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) if ELASTICSEARCH.getTotal(res) > 0: ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), @@ -120,7 +124,8 @@ def build(row): set_progress(row["id"], random.randint(0, 20) / 100., "Finished preparing! Start to slice file!", True) try: - obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"])) + cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) + obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl) except Exception as e: if re.search("(No such file|not found)", str(e)): set_progress( @@ -131,6 +136,9 @@ def build(row): row["id"], -1, f"Internal server error: %s" % str(e).replace( "'", "")) + + cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e))) + return [] if not obj.text_chunks and not obj.table_chunks: @@ -144,7 +152,7 @@ def build(row): "Finished slicing files. Start to embedding the content.") doc = { - "doc_id": row["did"], + "doc_id": row["id"], "kb_id": [str(row["kb_id"])], "docnm_kwd": os.path.split(row["location"])[-1], "title_tks": huqie.qie(row["name"]), @@ -164,10 +172,10 @@ def build(row): docs.append(d) continue - if isinstance(img, Image): - img.save(output_buffer, format='JPEG') - else: + if isinstance(img, bytes): output_buffer = BytesIO(img) + else: + img.save(output_buffer, format='JPEG') MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) @@ -215,15 +223,16 @@ def embedding(docs, mdl): def model_instance(tenant_id, llm_type): - model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING) - if not model_config:return - model_config = model_config[0] + model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING) + if not model_config: + model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""} + else: model_config = model_config[0].to_dict() if llm_type == LLMType.EMBEDDING: - if model_config.llm_factory not in EmbeddingModel: return - return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) + if model_config["llm_factory"] not in EmbeddingModel: return + return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) if llm_type == LLMType.IMAGE2TEXT: - if model_config.llm_factory not in CvModel: return - return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) + if model_config["llm_factory"] not in CvModel: return + return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"]) def main(comm, mod): @@ -231,7 +240,7 @@ def main(comm, mod): from rag.llm import HuEmbedding model = HuEmbedding() tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") - tm = findMaxDt(tm_fnm) + tm = findMaxTm(tm_fnm) rows = collect(comm, mod, tm) if len(rows) == 0: return @@ -247,7 +256,7 @@ def main(comm, mod): st_tm = timer() cks = build(r, cv_mdl) if not cks: - tmf.write(str(r["updated_at"]) + "\n") + tmf.write(str(r["update_time"]) + "\n") continue # TODO: exception handler ## set_progress(r["did"], -1, "ERROR: ") @@ -268,12 +277,19 @@ def main(comm, mod): cron_logger.error(str(es_r)) else: set_progress(r["id"], 1., "Done!") - DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm}) + DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm) + cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) + tmf.write(str(r["update_time"]) + "\n") tmf.close() if __name__ == "__main__": + peewee_logger = logging.getLogger('peewee') + peewee_logger.propagate = False + peewee_logger.addHandler(database_logger.handlers[0]) + peewee_logger.setLevel(database_logger.level) + from mpi4py import MPI comm = MPI.COMM_WORLD main(comm.Get_size(), comm.Get_rank()) diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index d3f163233..9898d19d5 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -40,6 +40,25 @@ def findMaxDt(fnm): print("WARNING: can't find " + fnm) return m + +def findMaxTm(fnm): + m = 0 + try: + with open(fnm, "r") as f: + while True: + l = f.readline() + if not l: + break + l = l.strip("\n") + if l == 'nan': + continue + if int(l) > m: + m = int(l) + except Exception as e: + print("WARNING: can't find " + fnm) + return m + + def num_tokens_from_string(string: str) -> int: """Returns the number of tokens in a text string.""" encoding = tiktoken.get_encoding('cl100k_base') diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index f8337c01d..632b01d6e 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -294,6 +294,7 @@ class HuEs: except Exception as e: es_logger.error("ES updateByQuery deleteByQuery: " + str(e) + "【Q】:" + str(query.to_dict())) + if str(e).find("NotFoundError") > 0: return True if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue diff --git a/web_server/apps/document_app.py b/web_server/apps/document_app.py index d14d69ab1..9be9cfde9 100644 --- a/web_server/apps/document_app.py +++ b/web_server/apps/document_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import base64 import pathlib from elasticsearch_dsl import Q @@ -195,11 +196,15 @@ def rm(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(retmsg="Document not found!") + if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)): + return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR) + + DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) if not DocumentService.delete_by_id(req["doc_id"]): return get_data_error_result( retmsg="Database error (Document removal)!") - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - MINIO.rm(kb.id, doc.location) + + MINIO.rm(doc.kb_id, doc.location) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -233,3 +238,43 @@ def rename(): return get_json_result(data=True) except Exception as e: return server_error_response(e) + + +@manager.route('/get', methods=['GET']) +@login_required +def get(): + doc_id = request.args["doc_id"] + try: + e, doc = DocumentService.get_by_id(doc_id) + if not e: + return get_data_error_result(retmsg="Document not found!") + + blob = MINIO.get(doc.kb_id, doc.location) + return get_json_result(data={"base64": base64.b64decode(blob)}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/change_parser', methods=['POST']) +@login_required +@validate_request("doc_id", "parser_id") +def change_parser(): + req = request.json + try: + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(retmsg="Document not found!") + if doc.parser_id.lower() == req["parser_id"].lower(): + return get_json_result(data=True) + + e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""}) + if not e: + return get_data_error_result(retmsg="Document not found!") + e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1) + if not e: + return get_data_error_result(retmsg="Document not found!") + + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + diff --git a/web_server/apps/kb_app.py b/web_server/apps/kb_app.py index c035cb637..054f97e00 100644 --- a/web_server/apps/kb_app.py +++ b/web_server/apps/kb_app.py @@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result @manager.route('/create', methods=['post']) @login_required -@validate_request("name", "description", "permission", "embd_id", "parser_id") +@validate_request("name", "description", "permission", "parser_id") def create(): req = request.json req["name"] = req["name"].strip() @@ -46,7 +46,7 @@ def create(): @manager.route('/update', methods=['post']) @login_required -@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id") +@validate_request("kb_id", "name", "description", "permission", "parser_id") def update(): req = request.json req["name"] = req["name"].strip() @@ -72,6 +72,18 @@ def update(): return server_error_response(e) +@manager.route('/detail', methods=['GET']) +@login_required +def detail(): + kb_id = request.args["kb_id"] + try: + kb = KnowledgebaseService.get_detail(kb_id) + if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!") + return get_json_result(data=kb) + except Exception as e: + return server_error_response(e) + + @manager.route('/list', methods=['GET']) @login_required def list(): diff --git a/web_server/apps/llm_app.py b/web_server/apps/llm_app.py new file mode 100644 index 000000000..0877a1977 --- /dev/null +++ b/web_server/apps/llm_app.py @@ -0,0 +1,95 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from flask import request +from flask_login import login_required, current_user + +from web_server.db.services import duplicate_name +from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService +from web_server.db.services.user_service import TenantService, UserTenantService +from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request +from web_server.utils import get_uuid, get_format_time +from web_server.db import StatusEnum, UserTenantRole +from web_server.db.services.kb_service import KnowledgebaseService +from web_server.db.db_models import Knowledgebase, TenantLLM +from web_server.settings import stat_logger, RetCode +from web_server.utils.api_utils import get_json_result + + +@manager.route('/factories', methods=['GET']) +@login_required +def factories(): + try: + fac = LLMFactoriesService.get_all() + return get_json_result(data=fac.to_json()) + except Exception as e: + return server_error_response(e) + + +@manager.route('/set_api_key', methods=['POST']) +@login_required +@validate_request("llm_factory", "api_key") +def set_api_key(): + req = request.json + llm = { + "tenant_id": current_user.id, + "llm_factory": req["llm_factory"], + "api_key": req["api_key"] + } + # TODO: Test api_key + for n in ["model_type", "llm_name"]: + if n in req: llm[n] = req[n] + + TenantLLM.insert(**llm).on_conflict("replace").execute() + return get_json_result(data=True) + + +@manager.route('/my_llms', methods=['GET']) +@login_required +def my_llms(): + try: + objs = TenantLLMService.query(tenant_id=current_user.id) + objs = [o.to_dict() for o in objs] + for o in objs: del o["api_key"] + return get_json_result(data=objs) + except Exception as e: + return server_error_response(e) + + +@manager.route('/list', methods=['GET']) +@login_required +def list(): + try: + objs = TenantLLMService.query(tenant_id=current_user.id) + objs = [o.to_dict() for o in objs if o.api_key] + fct = {} + for o in objs: + if o["llm_factory"] not in fct: fct[o["llm_factory"]] = [] + if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"]) + + llms = LLMService.get_all() + llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] + for m in llms: + m["available"] = False + if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]): + m["available"] = True + res = {} + for m in llms: + if m["fid"] not in res: res[m["fid"]] = [] + res[m["fid"]].append(m) + + return get_json_result(data=res) + except Exception as e: + return server_error_response(e) \ No newline at end of file diff --git a/web_server/apps/user_app.py b/web_server/apps/user_app.py index aa2ba43ce..81946074e 100644 --- a/web_server/apps/user_app.py +++ b/web_server/apps/user_app.py @@ -16,9 +16,12 @@ from flask import request, session, redirect, url_for from werkzeug.security import generate_password_hash, check_password_hash from flask_login import login_required, current_user, login_user, logout_user + +from web_server.db.db_models import TenantLLM +from web_server.db.services.llm_service import TenantLLMService from web_server.utils.api_utils import server_error_response, validate_request from web_server.utils import get_uuid, get_format_time, decrypt, download_img -from web_server.db import UserTenantRole +from web_server.db import UserTenantRole, LLMType from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS from web_server.db.services.user_service import UserService, TenantService, UserTenantService from web_server.settings import stat_logger @@ -47,8 +50,9 @@ def login(): avatar = download_img(userinfo["avatar_url"]) except Exception as e: stat_logger.exception(e) + user_id = get_uuid() try: - users = user_register({ + users = user_register(user_id, { "access_token": session["access_token"], "email": userinfo["email"], "avatar": avatar, @@ -63,6 +67,7 @@ def login(): login_user(user) return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") except Exception as e: + rollback_user_registration(user_id) stat_logger.exception(e) return server_error_response(e) elif not request.json: @@ -162,7 +167,25 @@ def user_info(): return get_json_result(data=current_user.to_dict()) -def user_register(user): +def rollback_user_registration(user_id): + try: + TenantService.delete_by_id(user_id) + except Exception as e: + pass + try: + u = UserTenantService.query(tenant_id=user_id) + if u: + UserTenantService.delete_by_id(u[0].id) + except Exception as e: + pass + try: + TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute() + except Exception as e: + pass + + +def user_register(user_id, user): + user_id = get_uuid() user["id"] = user_id tenant = { @@ -180,10 +203,12 @@ def user_register(user): "invited_by": user_id, "role": UserTenantRole.OWNER } + tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"} if not UserService.save(**user):return TenantService.save(**tenant) UserTenantService.save(**usr_tenant) + TenantLLMService.save(**tenant_llm) return UserService.query(email=user["email"]) @@ -203,14 +228,17 @@ def user_add(): "last_login_time": get_format_time(), "is_superuser": False, } + + user_id = get_uuid() try: - users = user_register(user_dict) + users = user_register(user_id, user_dict) if not users: raise Exception('Register user failure.') if len(users) > 1: raise Exception('Same E-mail exist!') user = users[0] login_user(user) return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") except Exception as e: + rollback_user_registration(user_id) stat_logger.exception(e) return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) @@ -220,7 +248,7 @@ def user_add(): @login_required def tenant_info(): try: - tenants = TenantService.get_by_user_id(current_user.id) + tenants = TenantService.get_by_user_id(current_user.id)[0] return get_json_result(data=tenants) except Exception as e: return server_error_response(e) diff --git a/web_server/db/db_models.py b/web_server/db/db_models.py index b67616803..62d92b475 100644 --- a/web_server/db/db_models.py +++ b/web_server/db/db_models.py @@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel): class LLM(DataBaseModel): # defautlt LLMs for every users llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) + model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") fid = CharField(max_length=128, null=False, help_text="LLM factory id") tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") @@ -442,8 +443,8 @@ class LLM(DataBaseModel): class TenantLLM(DataBaseModel): tenant_id = CharField(max_length=32, null=False) llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") - model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") - llm_name = CharField(max_length=128, null=False, help_text="LLM name") + model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR") + llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="") api_key = CharField(max_length=255, null=True, help_text="API KEY") api_base = CharField(max_length=255, null=True, help_text="API Base") @@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel): class Meta: db_table = "tenant_llm" - primary_key = CompositeKey('tenant_id', 'llm_factory') + primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name') class Knowledgebase(DataBaseModel): @@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel): permission = CharField(max_length=16, null=False, help_text="me|team") created_by = CharField(max_length=32, null=False) doc_num = IntegerField(default=0) - embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID") + token_num = IntegerField(default=0) + chunk_num = IntegerField(default=0) + parser_id = CharField(max_length=32, null=False, help_text="default parser ID") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") diff --git a/web_server/db/services/document_service.py b/web_server/db/services/document_service.py index e8746a4ac..38b1cd559 100644 --- a/web_server/db/services/document_service.py +++ b/web_server/db/services/document_service.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from peewee import Expression + from web_server.db import TenantPermission, FileType -from web_server.db.db_models import DB, Knowledgebase +from web_server.db.db_models import DB, Knowledgebase, Tenant from web_server.db.db_models import Document from web_server.db.services.common_service import CommonService from web_server.db.services.kb_service import KnowledgebaseService -from web_server.utils import get_uuid, get_format_time from web_server.db.db_utils import StatusEnum @@ -61,15 +62,28 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): - fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id] - docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where( - cls.model.status == StatusEnum.VALID.value, - cls.model.type != FileType.VIRTUAL, - cls.model.progress == 0, - cls.model.update_time >= tm, - cls.model.create_time % - comm == mod).order_by( - cls.model.update_time.asc()).paginate( - 1, - items_per_page) + fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time] + docs = cls.model.select(*fields) \ + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ + .where( + cls.model.status == StatusEnum.VALID.value, + ~(cls.model.type == FileType.VIRTUAL.value), + cls.model.progress == 0, + cls.model.update_time >= tm, + (Expression(cls.model.create_time, "%%", comm) == mod))\ + .order_by(cls.model.update_time.asc())\ + .paginate(1, items_per_page) return list(docs.dicts()) + + @classmethod + @DB.connection_context() + def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): + num = cls.model.update(token_num=cls.model.token_num + token_num, + chunk_num=cls.model.chunk_num + chunk_num, + process_duation=cls.model.process_duation+duation).where( + cls.model.id == doc_id).execute() + if num == 0:raise LookupError("Document not found which is supposed to be there") + num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() + return num + diff --git a/web_server/db/services/kb_service.py b/web_server/db/services/kb_service.py index 84b2e4f93..a8ca96a2a 100644 --- a/web_server/db/services/kb_service.py +++ b/web_server/db/services/kb_service.py @@ -17,7 +17,7 @@ import peewee from werkzeug.security import generate_password_hash, check_password_hash from web_server.db import TenantPermission -from web_server.db.db_models import DB, UserTenant +from web_server.db.db_models import DB, UserTenant, Tenant from web_server.db.db_models import Knowledgebase from web_server.db.services.common_service import CommonService from web_server.utils import get_uuid, get_format_time @@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() - def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc): + def get_by_tenant_ids(cls, joined_tenant_ids, user_id, + page_number, items_per_page, orderby, desc): kbs = cls.model.select().where( - ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) - & (cls.model.status==StatusEnum.VALID.value) + ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == + TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) + & (cls.model.status == StatusEnum.VALID.value) ) - if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) - else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) + if desc: + kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) + else: + kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) kbs = kbs.paginate(page_number, items_per_page) return list(kbs.dicts()) + @classmethod + @DB.connection_context() + def get_detail(cls, kb_id): + fields = [ + cls.model.id, + Tenant.embd_id, + cls.model.avatar, + cls.model.name, + cls.model.description, + cls.model.permission, + cls.model.doc_num, + cls.model.token_num, + cls.model.chunk_num, + cls.model.parser_id] + kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( + (cls.model.id == kb_id), + (cls.model.status == StatusEnum.VALID.value) + ) + if not kbs: + return + d = kbs[0].to_dict() + d["embd_id"] = kbs[0].tenant.embd_id + return d diff --git a/web_server/db/services/llm_service.py b/web_server/db/services/llm_service.py index 7d6b575fe..350106e36 100644 --- a/web_server/db/services/llm_service.py +++ b/web_server/db/services/llm_service.py @@ -33,3 +33,21 @@ class LLMService(CommonService): class TenantLLMService(CommonService): model = TenantLLM + + @classmethod + @DB.connection_context() + def get_api_key(cls, tenant_id, model_type): + objs = cls.query(tenant_id=tenant_id, model_type=model_type) + if objs and len(objs)>0 and objs[0].llm_name: + return objs[0] + + fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key] + objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where( + (cls.model.tenant_id == tenant_id), + (cls.model.model_type == model_type), + (LLM.status == StatusEnum.VALID) + ) + + if not objs:return + return objs[0] + diff --git a/web_server/db/services/user_service.py b/web_server/db/services/user_service.py index 42e0b5c11..f4ed4b58c 100644 --- a/web_server/db/services/user_service.py +++ b/web_server/db/services/user_service.py @@ -79,7 +79,7 @@ class TenantService(CommonService): @classmethod @DB.connection_context() def get_by_user_id(cls, user_id): - fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] + fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role] return list(cls.model.select(*fields)\ .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ .where(cls.model.status == StatusEnum.VALID.value).dicts()) diff --git a/web_server/utils/file_utils.py b/web_server/utils/file_utils.py index 54b1514ec..442ab19bf 100644 --- a/web_server/utils/file_utils.py +++ b/web_server/utils/file_utils.py @@ -143,7 +143,7 @@ def filename_type(filename): if re.match(r".*\.pdf$", filename): return FileType.PDF.value - if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename): + if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): return FileType.DOC.value if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):