add alot of api

This commit is contained in:
kevinhu 2024-01-15 19:19:35 +08:00
parent ecb7d40dcc
commit 76fad8bf99
26 changed files with 992 additions and 154 deletions

View file

@ -35,7 +35,7 @@ class Base(ABC):
class HuEmbedding(Base): class HuEmbedding(Base):
def __init__(self): def __init__(self, key="", model_name=""):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!

View file

@ -411,8 +411,11 @@ class TextChunker(HuChunker):
flds = self.Fields() flds = self.Fields()
if self.is_binary_file(fnm): if self.is_binary_file(fnm):
return flds return flds
txt = ""
if isinstance(fnm, str):
with open(fnm, "r") as f: with open(fnm, "r") as f:
txt = f.read() txt = f.read()
else: txt = fnm.decode("utf-8")
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
flds.table_chunks = [] flds.table_chunks = []
return flds return flds

View file

@ -8,7 +8,7 @@ from rag.nlp import huqie, query
import numpy as np import numpy as np
def index_name(uid): return f"docgpt_{uid}" def index_name(uid): return f"ragflow_{uid}"
class Dealer: class Dealer:

View file

@ -158,8 +158,8 @@ class PaddleInferBenchmark(object):
return: return:
config_status(dict): dict style config info config_status(dict): dict style config info
""" """
config_status = {}
if isinstance(config, paddle_infer.Config): if isinstance(config, paddle_infer.Config):
config_status = {}
config_status['runtime_device'] = "gpu" if config.use_gpu( config_status['runtime_device'] = "gpu" if config.use_gpu(
) else "cpu" ) else "cpu"
config_status['ir_optim'] = config.ir_optim() config_status['ir_optim'] = config.ir_optim()

View file

@ -15,6 +15,7 @@
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from scipy.special import softmax
from scipy.interpolate import InterpolatedUnivariateSpline from scipy.interpolate import InterpolatedUnivariateSpline

View file

@ -22,9 +22,10 @@ import yaml
from det_keypoint_unite_utils import argsparser from det_keypoint_unite_utils import argsparser
from preprocess import decode_image from preprocess import decode_image
from infer import print_arguments, get_test_images, bench_log from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
from keypoint_infer import KeyPointDetector from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
from visualize import visualize_pose from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images from keypoint_postprocess import translate_to_ori_images

View file

@ -17,6 +17,7 @@ import yaml
import glob import glob
import json import json
from pathlib import Path from pathlib import Path
from functools import reduce
import cv2 import cv2
import numpy as np import numpy as np
@ -26,12 +27,14 @@ from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
import sys import sys
# add deploy path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..']))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from clrnet_postprocess import CLRNetPostProcess from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid

View file

@ -11,21 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#
import os import os
# add deploy path of PaddleDetection to sys.path import time
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
import sys
sys.path.insert(0, parent_path)
import yaml import yaml
import glob
from functools import reduce
from PIL import Image
import cv2 import cv2
import math import math
import numpy as np import numpy as np
import paddle import paddle
from keypoint_preprocess import expand_crop import sys
# add deploy path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from preprocess import preprocess, NormalizeImage, Permute
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
from visualize import visualize_pose from visualize import visualize_pose
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from infer import Detector, get_test_images, print_arguments from infer import Detector, get_test_images, print_arguments

View file

@ -17,6 +17,8 @@ from collections import abc, defaultdict
import cv2 import cv2
import numpy as np import numpy as np
import math import math
import paddle
import paddle.nn as nn
from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform

View file

@ -15,16 +15,18 @@
import os import os
import copy import copy
import math import math
import time
import yaml
import cv2 import cv2
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
import paddle import paddle
from rag.ppdet import MOTTimer from benchmark_utils import PaddleInferBenchmark
from utils import gaussian_radius, draw_umich_gaussian from utils import gaussian_radius, gaussian2D, draw_umich_gaussian
from preprocess import preprocess, decode_image from preprocess import preprocess, decode_image, WarpAffine, NormalizeImage, Permute
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
from keypoint_preprocess import get_affine_transform from keypoint_preprocess import get_affine_transform
# add python path # add python path
@ -32,6 +34,9 @@ import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from pptracking.python.mot import CenterTracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.mot.visualize import plot_tracking
def transform_preds_with_trans(coords, trans): def transform_preds_with_trans(coords, trans):
@ -119,12 +124,12 @@ class CenterTrack(Detector):
track_thresh = cfg.get('track_thresh', 0.4) track_thresh = cfg.get('track_thresh', 0.4)
pre_thresh = cfg.get('pre_thresh', 0.5) pre_thresh = cfg.get('pre_thresh', 0.5)
# self.tracker = CenterTracker( self.tracker = CenterTracker(
# num_classes=self.num_classes, num_classes=self.num_classes,
# min_box_area=min_box_area, min_box_area=min_box_area,
# vertical_ratio=vertical_ratio, vertical_ratio=vertical_ratio,
# track_thresh=track_thresh, track_thresh=track_thresh,
# pre_thresh=pre_thresh) pre_thresh=pre_thresh)
self.pre_image = None self.pre_image = None
@ -359,20 +364,20 @@ class CenterTrack(Detector):
print('Tracking frame {}'.format(frame_id)) print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {}) frame, _ = decode_image(img_file, {})
# im = plot_tracking( im = plot_tracking(
# frame, frame,
# online_tlwhs, online_tlwhs,
# online_ids, online_ids,
# online_scores, online_scores,
# frame_id=frame_id, frame_id=frame_id,
# ids2names=ids2names) ids2names=ids2names)
if seq_name is None: if seq_name is None:
seq_name = image_list[0].split('/')[-2] seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name) save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
# cv2.imwrite( cv2.imwrite(
# os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids]) mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results return mot_results
@ -422,26 +427,26 @@ class CenterTrack(Detector):
online_tlwhs, online_scores, online_ids = mot_results[0] online_tlwhs, online_scores, online_ids = mot_results[0]
results[0].append( results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids)) (frame_id + 1, online_tlwhs, online_scores, online_ids))
# im = plot_tracking( im = plot_tracking(
# frame, frame,
# online_tlwhs, online_tlwhs,
# online_ids, online_ids,
# online_scores, online_scores,
# frame_id=frame_id, frame_id=frame_id,
# fps=fps, fps=fps,
# ids2names=ids2names) ids2names=ids2names)
#
# writer.write(im) writer.write(im)
# if camera_id != -1: if camera_id != -1:
# cv2.imshow('Mask Detection', im) cv2.imshow('Mask Detection', im)
# if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
# break break
if self.save_mot_txts: if self.save_mot_txts:
result_filename = os.path.join( result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt') self.output_dir, video_out_name.split('.')[-2] + '.txt')
#write_mot_results(result_filename, results, data_type, num_classes) write_mot_results(result_filename, results, data_type, num_classes)
writer.release() writer.release()

View file

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import os import os
import time
import yaml
import cv2 import cv2
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
@ -20,7 +22,6 @@ import paddle
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image from preprocess import decode_image
from rag.ppdet import MOTTimer
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
@ -29,6 +30,9 @@ import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.mot.visualize import plot_tracking_dict
# Global dictionary # Global dictionary
MOT_JDE_SUPPORT_MODELS = { MOT_JDE_SUPPORT_MODELS = {
@ -102,13 +106,13 @@ class JDE_Detector(Detector):
tracked_thresh = cfg.get('tracked_thresh', 0.7) tracked_thresh = cfg.get('tracked_thresh', 0.7)
metric_type = cfg.get('metric_type', 'euclidean') metric_type = cfg.get('metric_type', 'euclidean')
# self.tracker = JDETracker( self.tracker = JDETracker(
# num_classes=self.num_classes, num_classes=self.num_classes,
# min_box_area=min_box_area, min_box_area=min_box_area,
# vertical_ratio=vertical_ratio, vertical_ratio=vertical_ratio,
# conf_thres=conf_thres, conf_thres=conf_thres,
# tracked_thresh=tracked_thresh, tracked_thresh=tracked_thresh,
# metric_type=metric_type) metric_type=metric_type)
def postprocess(self, inputs, result): def postprocess(self, inputs, result):
# postprocess output of predictor # postprocess output of predictor
@ -235,21 +239,21 @@ class JDE_Detector(Detector):
print('Tracking frame {}'.format(frame_id)) print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {}) frame, _ = decode_image(img_file, {})
# im = plot_tracking_dict( im = plot_tracking_dict(
# frame, frame,
# num_classes, num_classes,
# online_tlwhs, online_tlwhs,
# online_ids, online_ids,
# online_scores, online_scores,
# frame_id=frame_id, frame_id=frame_id,
# ids2names=ids2names) ids2names=ids2names)
if seq_name is None: if seq_name is None:
seq_name = image_list[0].split('/')[-2] seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name) save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
# cv2.imwrite( cv2.imwrite(
# os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids]) mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results return mot_results
@ -301,28 +305,28 @@ class JDE_Detector(Detector):
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id], (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id])) online_ids[cls_id]))
#fps = 1. / timer.duration fps = 1. / timer.duration
# im = plot_tracking_dict( im = plot_tracking_dict(
# frame, frame,
# num_classes, num_classes,
# online_tlwhs, online_tlwhs,
# online_ids, online_ids,
# online_scores, online_scores,
# frame_id=frame_id, frame_id=frame_id,
# fps=fps, fps=fps,
# ids2names=ids2names) ids2names=ids2names)
#
# writer.write(im) writer.write(im)
# if camera_id != -1: if camera_id != -1:
# cv2.imshow('Mask Detection', im) cv2.imshow('Mask Detection', im)
# if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
# break break
if self.save_mot_txts: if self.save_mot_txts:
result_filename = os.path.join( result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt') self.output_dir, video_out_name.split('.')[-2] + '.txt')
#write_mot_results(result_filename, results, data_type, num_classes) write_mot_results(result_filename, results, data_type, num_classes)
writer.release() writer.release()

View file

@ -11,26 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import os import os
import json import json
import time
import cv2 import cv2
import math
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
import copy import copy
from collections import defaultdict
from mot_keypoint_unite_utils import argsparser from mot_keypoint_unite_utils import argsparser
from preprocess import decode_image from preprocess import decode_image
from infer import print_arguments, get_test_images, bench_log from infer import print_arguments, get_test_images, bench_log
from mot_jde_infer import MOT_JDE_SUPPORT_MODELS from mot_sde_infer import SDE_Detector
from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS
from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
from det_keypoint_unite_infer import predict_with_given_det from det_keypoint_unite_infer import predict_with_given_det
from rag.ppdet import MOTTimer
from visualize import visualize_pose from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
from pptracking.python.mot.utils import MOTTimer as FPSTimer
def convert_mot_to_det(tlwhs, scores): def convert_mot_to_det(tlwhs, scores):
@ -140,7 +150,7 @@ def mot_topdown_unite_predict_video(mot_detector,
fourcc = cv2.VideoWriter_fourcc(* 'mp4v') fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0 frame_id = 0
timer_mot, timer_kp, timer_mot_kp = MOTTimer(), MOTTimer(), MOTTimer() timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer()
num_classes = mot_detector.num_classes num_classes = mot_detector.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.' assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
@ -187,6 +197,15 @@ def mot_topdown_unite_predict_video(mot_detector,
returnimg=True, returnimg=True,
ids=online_ids[0]) ids=online_ids[0])
im = plot_tracking_dict(
im,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=mot_kp_fps)
writer.write(im) writer.write(im)
if camera_id != -1: if camera_id != -1:
cv2.imshow('Tracking and keypoint results', im) cv2.imshow('Tracking and keypoint results', im)

522
rag/ppdet/mot_sde_infer.py Normal file
View file

@ -0,0 +1,522 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import yaml
import cv2
import numpy as np
from collections import defaultdict
import paddle
from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker, DeepSORTTracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
class SDE_Detector(Detector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
tracker_config (str): tracker config path
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
"""
def __init__(self,
model_dir,
tracker_config,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
save_images=False,
save_mot_txts=False,
reid_model_dir=None):
super(SDE_Detector, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
self.save_images = save_images
self.save_mot_txts = save_mot_txts
assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels)
# reid config
self.use_reid = False if reid_model_dir is None else True
if self.use_reid:
self.reid_pred_config = self.set_config(reid_model_dir)
self.reid_predictor, self.config = load_predictor(
reid_model_dir,
run_mode=run_mode,
batch_size=50, # reid_batch_size
min_subgraph_size=self.reid_pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
else:
self.reid_pred_config = None
self.reid_predictor = None
assert tracker_config is not None, 'Note that tracker_config should be set.'
self.tracker_config = tracker_config
tracker_cfg = yaml.safe_load(open(self.tracker_config))
cfg = tracker_cfg[tracker_cfg['type']]
# tracker config
self.use_deepsort_tracker = True if tracker_cfg[
'type'] == 'DeepSORTTracker' else False
if self.use_deepsort_tracker:
# use DeepSORTTracker
if self.reid_pred_config is not None and hasattr(
self.reid_pred_config, 'tracker'):
cfg = self.reid_pred_config.tracker
budget = cfg.get('budget', 100)
max_age = cfg.get('max_age', 30)
max_iou_distance = cfg.get('max_iou_distance', 0.7)
matching_threshold = cfg.get('matching_threshold', 0.2)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
self.tracker = DeepSORTTracker(
budget=budget,
max_age=max_age,
max_iou_distance=max_iou_distance,
matching_threshold=matching_threshold,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio, )
else:
# use ByteTracker
use_byte = cfg.get('use_byte', False)
det_thresh = cfg.get('det_thresh', 0.3)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
match_thres = cfg.get('match_thres', 0.9)
conf_thres = cfg.get('conf_thres', 0.6)
low_conf_thres = cfg.get('low_conf_thres', 0.1)
self.tracker = JDETracker(
use_byte=use_byte,
det_thresh=det_thresh,
num_classes=self.num_classes,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
match_thres=match_thres,
conf_thres=conf_thres,
low_conf_thres=low_conf_thres, )
def postprocess(self, inputs, result):
# postprocess output of predictor
np_boxes_num = result['boxes_num']
if np_boxes_num[0] <= 0:
print('[WARNNING] No object detected.')
result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
result = {k: v for k, v in result.items() if v is not None}
return result
def reidprocess(self, det_results, repeats=1):
pred_dets = det_results['boxes']
pred_xyxys = pred_dets[:, 2:6]
ori_image = det_results['ori_image']
ori_image_shape = ori_image.shape[:2]
pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)
if len(keep_idx[0]) == 0:
det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
det_results['embeddings'] = None
return det_results
pred_dets = pred_dets[keep_idx[0]]
pred_xyxys = pred_dets[:, 2:6]
w, h = self.tracker.input_size
crops = get_crops(pred_xyxys, ori_image, w, h)
# to keep fast speed, only use topk crops
crops = crops[:50] # reid_batch_size
det_results['crops'] = np.array(crops).astype('float32')
det_results['boxes'] = pred_dets[:50]
input_names = self.reid_predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.reid_predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(det_results[input_names[i]])
# model prediction
for i in range(repeats):
self.reid_predictor.run()
output_names = self.reid_predictor.get_output_names()
feature_tensor = self.reid_predictor.get_output_handle(output_names[
0])
pred_embs = feature_tensor.copy_to_cpu()
det_results['embeddings'] = pred_embs
return det_results
def tracking(self, det_results):
pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1'
pred_embs = det_results.get('embeddings', None)
if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
self.tracker.predict()
online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = [], [], []
for t in online_targets:
if not t.is_confirmed() or t.time_since_update > 1:
continue
tlwh = t.to_tlwh()
tscore = t.score
tid = t.track_id
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_scores.append(tscore)
online_ids.append(tid)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
else:
# use ByteTracker, support multiple class
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id]
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs[cls_id].append(tlwh)
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True,
seq_name=None):
num_classes = self.num_classes
image_list.sort()
ids2names = self.pred_config.labels
mot_results = []
for frame_id, img_file in enumerate(image_list):
batch_image_list = [img_file] # bs=1 in MOT model
frame, _ = decode_image(img_file, {})
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
result_warmup = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats)
# postprocess
result_warmup = self.postprocess(inputs, result) # warmup
self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
# tracking
if self.use_reid:
det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
result_warmup = self.tracking(det_result)
self.det_times.tracking_time_s.start()
if self.use_reid:
det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
else:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
# tracking process
self.det_times.tracking_time_s.start()
if self.use_reid:
det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
online_tlwhs = tracking_outs['online_tlwhs']
online_scores = tracking_outs['online_scores']
online_ids = tracking_outs['online_ids']
mot_results.append([online_tlwhs, online_scores, online_ids])
if visual:
if len(image_list) > 1 and frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
if isinstance(online_tlwhs, defaultdict):
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
ids2names=ids2names)
else:
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
ids2names=ids2names)
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
return mot_results
def predict_video(self, video_file, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(video_file)
video_out_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1
timer = MOTTimer()
results = defaultdict(list)
num_classes = self.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = self.pred_config.labels
while (1):
ret, frame = capture.read()
if not ret:
break
if frame_id % 10 == 0:
print('Tracking frame: %d' % (frame_id))
frame_id += 1
timer.tic()
seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image(
[frame[:, :, ::-1]], visual=False, seq_name=seq_name)
timer.toc()
# bs=1 in MOT model
online_tlwhs, online_scores, online_ids = mot_results[0]
fps = 1. / timer.duration
if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
ids2names=ids2names)
else:
# use ByteTracker, support multiple class
for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
ids2names=ids2names)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release()
def main():
deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
detector = SDE_Detector(
FLAGS.model_dir,
tracker_config=FLAGS.tracker_config,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=1,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts, )
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else:
# predict from image
if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
seq_name = FLAGS.image_dir.split('/')[-1]
detector.predict_image(
img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mode = FLAGS.run_mode
model_dir = FLAGS.model_dir
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(detector, img_list, model_info, name='MOT')
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU'
], "device should be CPU, GPU, NPU or XPU"
main()

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import json import json
import logging
import os import os
import hashlib import hashlib
import copy import copy
@ -24,9 +25,9 @@ from timeit import default_timer as timer
from rag.llm import EmbeddingModel, CvModel from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH, num_tokens_from_string from rag.utils import ELASTICSEARCH
from rag.utils import MINIO from rag.utils import MINIO
from rag.utils import rmSpace, findMaxDt from rag.utils import rmSpace, findMaxTm
from rag.nlp import huchunk, huqie, search from rag.nlp import huchunk, huqie, search
from io import BytesIO from io import BytesIO
import pandas as pd import pandas as pd
@ -47,6 +48,7 @@ from rag.nlp.huchunk import (
from web_server.db import LLMType from web_server.db import LLMType
from web_server.db.services.document_service import DocumentService from web_server.db.services.document_service import DocumentService
from web_server.db.services.llm_service import TenantLLMService from web_server.db.services.llm_service import TenantLLMService
from web_server.settings import database_logger
from web_server.utils import get_format_time from web_server.utils import get_format_time
from web_server.utils.file_utils import get_project_base_directory from web_server.utils.file_utils import get_project_base_directory
@ -83,7 +85,7 @@ def collect(comm, mod, tm):
if len(docs) == 0: if len(docs) == 0:
return pd.DataFrame() return pd.DataFrame()
docs = pd.DataFrame(docs) docs = pd.DataFrame(docs)
mtm = str(docs["update_time"].max())[:19] mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs return docs
@ -99,28 +101,30 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
cron_logger.error("set_progress:({}), {}".format(docid, str(e))) cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
def build(row): def build(row, cvmdl):
if row["size"] > DOC_MAXIMUM_SIZE: if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024))) (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return [] return []
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) # If just change the kb for doc
if ELASTICSEARCH.getTotal(res) > 0: # res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]), idxnm=search.index_name(row["tenant_id"]))
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), # if ELASTICSEARCH.getTotal(res) > 0:
scripts=""" # ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
if(!ctx._source.kb_id.contains('%s')) # scripts="""
ctx._source.kb_id.add('%s'); # if(!ctx._source.kb_id.contains('%s'))
""" % (str(row["kb_id"]), str(row["kb_id"])), # ctx._source.kb_id.add('%s');
idxnm=search.index_name(row["tenant_id"]) # """ % (str(row["kb_id"]), str(row["kb_id"])),
) # idxnm=search.index_name(row["tenant_id"])
set_progress(row["id"], 1, "Done") # )
return [] # set_progress(row["id"], 1, "Done")
# return []
random.seed(time.time()) random.seed(time.time())
set_progress(row["id"], random.randint(0, 20) / set_progress(row["id"], random.randint(0, 20) /
100., "Finished preparing! Start to slice file!", True) 100., "Finished preparing! Start to slice file!", True)
try: try:
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"])) cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
set_progress( set_progress(
@ -131,6 +135,7 @@ def build(row):
row["id"], -1, f"Internal server error: %s" % row["id"], -1, f"Internal server error: %s" %
str(e).replace( str(e).replace(
"'", "")) "'", ""))
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return [] return []
if not obj.text_chunks and not obj.table_chunks: if not obj.text_chunks and not obj.table_chunks:
@ -144,7 +149,7 @@ def build(row):
"Finished slicing files. Start to embedding the content.") "Finished slicing files. Start to embedding the content.")
doc = { doc = {
"doc_id": row["did"], "doc_id": row["id"],
"kb_id": [str(row["kb_id"])], "kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1], "docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(row["name"]), "title_tks": huqie.qie(row["name"]),
@ -164,10 +169,10 @@ def build(row):
docs.append(d) docs.append(d)
continue continue
if isinstance(img, Image): if isinstance(img, bytes):
img.save(output_buffer, format='JPEG')
else:
output_buffer = BytesIO(img) output_buffer = BytesIO(img)
else:
img.save(output_buffer, format='JPEG')
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
@ -215,15 +220,16 @@ def embedding(docs, mdl):
def model_instance(tenant_id, llm_type): def model_instance(tenant_id, llm_type):
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING) model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:return if not model_config:
model_config = model_config[0] model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else: model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING: if llm_type == LLMType.EMBEDDING:
if model_config.llm_factory not in EmbeddingModel: return if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT: if llm_type == LLMType.IMAGE2TEXT:
if model_config.llm_factory not in CvModel: return if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
def main(comm, mod): def main(comm, mod):
@ -231,7 +237,7 @@ def main(comm, mod):
from rag.llm import HuEmbedding from rag.llm import HuEmbedding
model = HuEmbedding() model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxDt(tm_fnm) tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm) rows = collect(comm, mod, tm)
if len(rows) == 0: if len(rows) == 0:
return return
@ -247,7 +253,7 @@ def main(comm, mod):
st_tm = timer() st_tm = timer()
cks = build(r, cv_mdl) cks = build(r, cv_mdl)
if not cks: if not cks:
tmf.write(str(r["updated_at"]) + "\n") tmf.write(str(r["update_time"]) + "\n")
continue continue
# TODO: exception handler # TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ") ## set_progress(r["did"], -1, "ERROR: ")
@ -268,12 +274,19 @@ def main(comm, mod):
cron_logger.error(str(es_r)) cron_logger.error(str(es_r))
else: else:
set_progress(r["id"], 1., "Done!") set_progress(r["id"], 1., "Done!")
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm}) DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n") tmf.write(str(r["update_time"]) + "\n")
tmf.close() tmf.close()
if __name__ == "__main__": if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank()) main(comm.Get_size(), comm.Get_rank())

View file

@ -40,6 +40,24 @@ def findMaxDt(fnm):
print("WARNING: can't find " + fnm) print("WARNING: can't find " + fnm)
return m return m
def findMaxTm(fnm):
m = 0
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
break
l = l.strip("\n")
if l == 'nan':
continue
if int(l) > m:
m = int(l)
except Exception as e:
print("WARNING: can't find " + fnm)
return m
def num_tokens_from_string(string: str) -> int: def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string.""" """Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base') encoding = tiktoken.get_encoding('cl100k_base')

View file

@ -294,6 +294,7 @@ class HuEs:
except Exception as e: except Exception as e:
es_logger.error("ES updateByQuery deleteByQuery: " + es_logger.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】" + str(query.to_dict())) str(e) + "【Q】" + str(query.to_dict()))
if str(e).find("NotFoundError") > 0: return True
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import base64
import pathlib import pathlib
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
@ -195,11 +196,15 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
if not DocumentService.delete_by_id(req["doc_id"]): if not DocumentService.delete_by_id(req["doc_id"]):
return get_data_error_result( return get_data_error_result(
retmsg="Database error (Document removal)!") retmsg="Database error (Document removal)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
MINIO.rm(kb.id, doc.location) MINIO.rm(doc.kb_id, doc.location)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -233,3 +238,42 @@ def rename():
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/get', methods=['GET'])
@login_required
def get():
doc_id = request.args["doc_id"]
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(retmsg="Document not found!")
blob = MINIO.get(doc.kb_id, doc.location)
return get_json_result(data={"base64": base64.b64decode(blob)})
except Exception as e:
return server_error_response(e)
@manager.route('/change_parser', methods=['POST'])
@login_required
@validate_request("doc_id", "parser_id")
def change_parser():
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id.lower() == req["parser_id"].lower():
return get_json_result(data=True)
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
if not e:
return get_data_error_result(retmsg="Document not found!")
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
if not e:
return get_data_error_result(retmsg="Document not found!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View file

@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result
@manager.route('/create', methods=['post']) @manager.route('/create', methods=['post'])
@login_required @login_required
@validate_request("name", "description", "permission", "embd_id", "parser_id") @validate_request("name", "description", "permission", "parser_id")
def create(): def create():
req = request.json req = request.json
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
@ -46,7 +46,7 @@ def create():
@manager.route('/update', methods=['post']) @manager.route('/update', methods=['post'])
@login_required @login_required
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id") @validate_request("kb_id", "name", "description", "permission", "parser_id")
def update(): def update():
req = request.json req = request.json
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
@ -72,6 +72,18 @@ def update():
return server_error_response(e) return server_error_response(e)
@manager.route('/detail', methods=['GET'])
@login_required
def detail():
kb_id = request.args["kb_id"]
try:
kb = KnowledgebaseService.get_detail(kb_id)
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
return get_json_result(data=kb)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET']) @manager.route('/list', methods=['GET'])
@login_required @login_required
def list(): def list():

View file

@ -0,0 +1,95 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from web_server.db.services import duplicate_name
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from web_server.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase, TenantLLM
from web_server.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result
@manager.route('/factories', methods=['GET'])
@login_required
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=fac.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/set_api_key', methods=['POST'])
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
llm = {
"tenant_id": current_user.id,
"llm_factory": req["llm_factory"],
"api_key": req["api_key"]
}
# TODO: Test api_key
for n in ["model_type", "llm_name"]:
if n in req: llm[n] = req[n]
TenantLLM.insert(**llm).on_conflict("replace").execute()
return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET'])
@login_required
def my_llms():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs]
for o in objs: del o["api_key"]
return get_json_result(data=objs)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs if o.api_key]
fct = {}
for o in objs:
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = False
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
m["available"] = True
res = {}
for m in llms:
if m["fid"] not in res: res[m["fid"]] = []
res[m["fid"]].append(m)
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)

View file

@ -16,9 +16,12 @@
from flask import request, session, redirect, url_for from flask import request, session, redirect, url_for
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_required, current_user, login_user, logout_user from flask_login import login_required, current_user, login_user, logout_user
from web_server.db.db_models import TenantLLM
from web_server.db.services.llm_service import TenantLLMService
from web_server.utils.api_utils import server_error_response, validate_request from web_server.utils.api_utils import server_error_response, validate_request
from web_server.utils import get_uuid, get_format_time, decrypt, download_img from web_server.utils import get_uuid, get_format_time, decrypt, download_img
from web_server.db import UserTenantRole from web_server.db import UserTenantRole, LLMType
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
from web_server.db.services.user_service import UserService, TenantService, UserTenantService from web_server.db.services.user_service import UserService, TenantService, UserTenantService
from web_server.settings import stat_logger from web_server.settings import stat_logger
@ -47,8 +50,9 @@ def login():
avatar = download_img(userinfo["avatar_url"]) avatar = download_img(userinfo["avatar_url"])
except Exception as e: except Exception as e:
stat_logger.exception(e) stat_logger.exception(e)
user_id = get_uuid()
try: try:
users = user_register({ users = user_register(user_id, {
"access_token": session["access_token"], "access_token": session["access_token"],
"email": userinfo["email"], "email": userinfo["email"],
"avatar": avatar, "avatar": avatar,
@ -63,6 +67,7 @@ def login():
login_user(user) login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e: except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e) stat_logger.exception(e)
return server_error_response(e) return server_error_response(e)
elif not request.json: elif not request.json:
@ -162,7 +167,23 @@ def user_info():
return get_json_result(data=current_user.to_dict()) return get_json_result(data=current_user.to_dict())
def user_register(user): def rollback_user_registration(user_id):
try:
TenantService.delete_by_id(user_id)
except Exception as e:
pass
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception as e:
pass
try:
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
except Exception as e:
pass
def user_register(user_id, user):
user_id = get_uuid() user_id = get_uuid()
user["id"] = user_id user["id"] = user_id
tenant = { tenant = {
@ -180,10 +201,12 @@ def user_register(user):
"invited_by": user_id, "invited_by": user_id,
"role": UserTenantRole.OWNER "role": UserTenantRole.OWNER
} }
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
if not UserService.save(**user):return if not UserService.save(**user):return
TenantService.save(**tenant) TenantService.save(**tenant)
UserTenantService.save(**usr_tenant) UserTenantService.save(**usr_tenant)
TenantLLMService.save(**tenant_llm)
return UserService.query(email=user["email"]) return UserService.query(email=user["email"])
@ -203,14 +226,17 @@ def user_add():
"last_login_time": get_format_time(), "last_login_time": get_format_time(),
"is_superuser": False, "is_superuser": False,
} }
user_id = get_uuid()
try: try:
users = user_register(user_dict) users = user_register(user_id, user_dict)
if not users: raise Exception('Register user failure.') if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!') if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0] user = users[0]
login_user(user) login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
except Exception as e: except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e) stat_logger.exception(e)
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
@ -220,7 +246,7 @@ def user_add():
@login_required @login_required
def tenant_info(): def tenant_info():
try: try:
tenants = TenantService.get_by_user_id(current_user.id) tenants = TenantService.get_by_user_id(current_user.id)[0]
return get_json_result(data=tenants) return get_json_result(data=tenants)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View file

@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel):
class LLM(DataBaseModel): class LLM(DataBaseModel):
# defautlt LLMs for every users # defautlt LLMs for every users
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id") fid = CharField(max_length=128, null=False, help_text="LLM factory id")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
@ -442,8 +443,8 @@ class LLM(DataBaseModel):
class TenantLLM(DataBaseModel): class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False) tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=False, help_text="LLM name") llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY") api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base") api_base = CharField(max_length=255, null=True, help_text="API Base")
@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel):
class Meta: class Meta:
db_table = "tenant_llm" db_table = "tenant_llm"
primary_key = CompositeKey('tenant_id', 'llm_factory') primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
class Knowledgebase(DataBaseModel): class Knowledgebase(DataBaseModel):
@ -464,7 +465,8 @@ class Knowledgebase(DataBaseModel):
permission = CharField(max_length=16, null=False, help_text="me|team") permission = CharField(max_length=16, null=False, help_text="me|team")
created_by = CharField(max_length=32, null=False) created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0) doc_num = IntegerField(default=0)
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID") token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View file

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from peewee import Expression
from web_server.db import TenantPermission, FileType from web_server.db import TenantPermission, FileType
from web_server.db.db_models import DB, Knowledgebase from web_server.db.db_models import DB, Knowledgebase, Tenant
from web_server.db.db_models import Document from web_server.db.db_models import Document
from web_server.db.services.common_service import CommonService from web_server.db.services.common_service import CommonService
from web_server.db.services.kb_service import KnowledgebaseService from web_server.db.services.kb_service import KnowledgebaseService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum from web_server.db.db_utils import StatusEnum
@ -61,15 +62,27 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id] fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where( docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value, cls.model.status == StatusEnum.VALID.value,
cls.model.type != FileType.VIRTUAL, ~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0, cls.model.progress == 0,
cls.model.update_time >= tm, cls.model.update_time >= tm,
cls.model.create_time % (Expression(cls.model.create_time, "%%", comm) == mod))\
comm == mod).order_by( .order_by(cls.model.update_time.asc())\
cls.model.update_time.asc()).paginate( .paginate(1, items_per_page)
1,
items_per_page)
return list(docs.dicts()) return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where(
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num

View file

@ -17,7 +17,7 @@ import peewee
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import TenantPermission from web_server.db import TenantPermission
from web_server.db.db_models import DB, UserTenant from web_server.db.db_models import DB, UserTenant, Tenant
from web_server.db.db_models import Knowledgebase from web_server.db.db_models import Knowledgebase
from web_server.db.services.common_service import CommonService from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time from web_server.utils import get_uuid, get_format_time
@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc): def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc):
kbs = cls.model.select().where( kbs = cls.model.select().where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
& (cls.model.status==StatusEnum.VALID.value) TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
) )
if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) if desc:
else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
kbs = kbs.paginate(page_number, items_per_page) kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts()) return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
fields = [
cls.model.id,
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
if not kbs:
return
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d

View file

@ -31,5 +31,23 @@ class LLMService(CommonService):
model = LLM model = LLM
class TenantLLMService(CommonService): class TenantLLMService(CommonService):
model = TenantLLM model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
return objs[0]

View file

@ -79,7 +79,7 @@ class TenantService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_user_id(cls, user_id): def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
return list(cls.model.select(*fields)\ return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts()) .where(cls.model.status == StatusEnum.VALID.value).dicts())

View file

@ -143,7 +143,7 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename): if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):