add alot of api (#23)

* clean rust version project

* clean rust version project

* build python version rag-flow

* add alot of api
This commit is contained in:
KevinHuSh 2024-01-15 19:47:25 +08:00 committed by GitHub
parent 038b36a525
commit 751b12c76a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 1001 additions and 144 deletions

View file

@ -35,7 +35,7 @@ class Base(ABC):
class HuEmbedding(Base):
def __init__(self):
def __init__(self, key="", model_name=""):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!

View file

@ -411,9 +411,12 @@ class TextChunker(HuChunker):
flds = self.Fields()
if self.is_binary_file(fnm):
return flds
with open(fnm, "r") as f:
txt = f.read()
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
txt = ""
if isinstance(fnm, str):
with open(fnm, "r") as f:
txt = f.read()
else: txt = fnm.decode("utf-8")
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
flds.table_chunks = []
return flds

View file

@ -8,7 +8,7 @@ from rag.nlp import huqie, query
import numpy as np
def index_name(uid): return f"docgpt_{uid}"
def index_name(uid): return f"ragflow_{uid}"
class Dealer:

View file

@ -158,8 +158,9 @@ class PaddleInferBenchmark(object):
return:
config_status(dict): dict style config info
"""
config_status = {}
if isinstance(config, paddle_infer.Config):
config_status = {}
config_status['runtime_device'] = "gpu" if config.use_gpu(
) else "cpu"
config_status['ir_optim'] = config.ir_optim()

View file

@ -22,9 +22,11 @@ import yaml
from det_keypoint_unite_utils import argsparser
from preprocess import decode_image
from infer import print_arguments, get_test_images, bench_log
from keypoint_infer import KeyPointDetector
from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images

View file

@ -17,6 +17,8 @@ import yaml
import glob
import json
from pathlib import Path
from functools import reduce
import cv2
import numpy as np
@ -26,12 +28,15 @@ from paddle.inference import Config
from paddle.inference import create_predictor
import sys
# add deploy path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid

View file

@ -11,21 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
# add deploy path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
import sys
sys.path.insert(0, parent_path)
import time
import yaml
import glob
from functools import reduce
from PIL import Image
import cv2
import math
import numpy as np
import paddle
from keypoint_preprocess import expand_crop
import sys
# add deploy path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from preprocess import preprocess, NormalizeImage, Permute
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
from visualize import visualize_pose
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from benchmark_utils import PaddleInferBenchmark
from infer import Detector, get_test_images, print_arguments

View file

@ -17,6 +17,8 @@ from collections import abc, defaultdict
import cv2
import numpy as np
import math
import paddle
import paddle.nn as nn
from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform

View file

@ -15,16 +15,18 @@
import os
import copy
import math
import time
import yaml
import cv2
import numpy as np
from collections import defaultdict
import paddle
from rag.ppdet import MOTTimer
from utils import gaussian_radius, draw_umich_gaussian
from preprocess import preprocess, decode_image
from benchmark_utils import PaddleInferBenchmark
from utils import gaussian_radius, gaussian2D, draw_umich_gaussian
from preprocess import preprocess, decode_image, WarpAffine, NormalizeImage, Permute
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
from keypoint_preprocess import get_affine_transform
# add python path
@ -32,6 +34,9 @@ import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.mot import CenterTracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.mot.visualize import plot_tracking
def transform_preds_with_trans(coords, trans):
@ -119,12 +124,12 @@ class CenterTrack(Detector):
track_thresh = cfg.get('track_thresh', 0.4)
pre_thresh = cfg.get('pre_thresh', 0.5)
# self.tracker = CenterTracker(
# num_classes=self.num_classes,
# min_box_area=min_box_area,
# vertical_ratio=vertical_ratio,
# track_thresh=track_thresh,
# pre_thresh=pre_thresh)
self.tracker = CenterTracker(
num_classes=self.num_classes,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
track_thresh=track_thresh,
pre_thresh=pre_thresh)
self.pre_image = None
@ -359,20 +364,21 @@ class CenterTrack(Detector):
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
# im = plot_tracking(
# frame,
# online_tlwhs,
# online_ids,
# online_scores,
# frame_id=frame_id,
# ids2names=ids2names)
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
ids2names=ids2names)
if seq_name is None:
seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# cv2.imwrite(
# os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results
@ -422,27 +428,27 @@ class CenterTrack(Detector):
online_tlwhs, online_scores, online_ids = mot_results[0]
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
# im = plot_tracking(
# frame,
# online_tlwhs,
# online_ids,
# online_scores,
# frame_id=frame_id,
# fps=fps,
# ids2names=ids2names)
#
# writer.write(im)
# if camera_id != -1:
# cv2.imshow('Mask Detection', im)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
ids2names=ids2names)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
#write_mot_results(result_filename, results, data_type, num_classes)
write_mot_results(result_filename, results, data_type, num_classes)
writer.release()

View file

@ -13,6 +13,8 @@
# limitations under the License.
import os
import time
import yaml
import cv2
import numpy as np
from collections import defaultdict
@ -20,7 +22,6 @@ import paddle
from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from rag.ppdet import MOTTimer
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
@ -29,6 +30,9 @@ import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results
from pptracking.python.mot.visualize import plot_tracking_dict
# Global dictionary
MOT_JDE_SUPPORT_MODELS = {
@ -102,13 +106,14 @@ class JDE_Detector(Detector):
tracked_thresh = cfg.get('tracked_thresh', 0.7)
metric_type = cfg.get('metric_type', 'euclidean')
# self.tracker = JDETracker(
# num_classes=self.num_classes,
# min_box_area=min_box_area,
# vertical_ratio=vertical_ratio,
# conf_thres=conf_thres,
# tracked_thresh=tracked_thresh,
# metric_type=metric_type)
self.tracker = JDETracker(
num_classes=self.num_classes,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
conf_thres=conf_thres,
tracked_thresh=tracked_thresh,
metric_type=metric_type)
def postprocess(self, inputs, result):
# postprocess output of predictor
@ -235,21 +240,23 @@ class JDE_Detector(Detector):
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
# im = plot_tracking_dict(
# frame,
# num_classes,
# online_tlwhs,
# online_ids,
# online_scores,
# frame_id=frame_id,
# ids2names=ids2names)
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
ids2names=ids2names)
if seq_name is None:
seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# cv2.imwrite(
# os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results
@ -301,28 +308,28 @@ class JDE_Detector(Detector):
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
online_ids[cls_id]))
#fps = 1. / timer.duration
# im = plot_tracking_dict(
# frame,
# num_classes,
# online_tlwhs,
# online_ids,
# online_scores,
# frame_id=frame_id,
# fps=fps,
# ids2names=ids2names)
#
# writer.write(im)
# if camera_id != -1:
# cv2.imshow('Mask Detection', im)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
fps = 1. / timer.duration
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
ids2names=ids2names)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
#write_mot_results(result_filename, results, data_type, num_classes)
write_mot_results(result_filename, results, data_type, num_classes)
writer.release()

View file

@ -11,26 +11,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import json
import time
import cv2
import math
import numpy as np
import paddle
import yaml
import copy
from collections import defaultdict
from mot_keypoint_unite_utils import argsparser
from preprocess import decode_image
from infer import print_arguments, get_test_images, bench_log
from mot_jde_infer import MOT_JDE_SUPPORT_MODELS
from mot_sde_infer import SDE_Detector
from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS
from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
from det_keypoint_unite_infer import predict_with_given_det
from rag.ppdet import MOTTimer
from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
from pptracking.python.mot.utils import MOTTimer as FPSTimer
def convert_mot_to_det(tlwhs, scores):
@ -140,7 +151,7 @@ def mot_topdown_unite_predict_video(mot_detector,
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer_mot, timer_kp, timer_mot_kp = MOTTimer(), MOTTimer(), MOTTimer()
timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer()
num_classes = mot_detector.num_classes
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
@ -187,6 +198,15 @@ def mot_topdown_unite_predict_video(mot_detector,
returnimg=True,
ids=online_ids[0])
im = plot_tracking_dict(
im,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=mot_kp_fps)
writer.write(im)
if camera_id != -1:
cv2.imshow('Tracking and keypoint results', im)

522
rag/ppdet/mot_sde_infer.py Normal file
View file

@ -0,0 +1,522 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import yaml
import cv2
import numpy as np
from collections import defaultdict
import paddle
from benchmark_utils import PaddleInferBenchmark
from preprocess import decode_image
from utils import argsparser, Timer, get_current_memory_mb
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
# add python path
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
from pptracking.python.mot import JDETracker, DeepSORTTracker
from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
class SDE_Detector(Detector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
tracker_config (str): tracker config path
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
output_dir (string): The path of output, default as 'output'
threshold (float): Score threshold of the detected bbox, default as 0.5
save_images (bool): Whether to save visualization image results, default as False
save_mot_txts (bool): Whether to save tracking results (txt), default as False
reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
"""
def __init__(self,
model_dir,
tracker_config,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
output_dir='output',
threshold=0.5,
save_images=False,
save_mot_txts=False,
reid_model_dir=None):
super(SDE_Detector, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
output_dir=output_dir,
threshold=threshold, )
self.save_images = save_images
self.save_mot_txts = save_mot_txts
assert batch_size == 1, "MOT model only supports batch_size=1."
self.det_times = Timer(with_tracker=True)
self.num_classes = len(self.pred_config.labels)
# reid config
self.use_reid = False if reid_model_dir is None else True
if self.use_reid:
self.reid_pred_config = self.set_config(reid_model_dir)
self.reid_predictor, self.config = load_predictor(
reid_model_dir,
run_mode=run_mode,
batch_size=50, # reid_batch_size
min_subgraph_size=self.reid_pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn)
else:
self.reid_pred_config = None
self.reid_predictor = None
assert tracker_config is not None, 'Note that tracker_config should be set.'
self.tracker_config = tracker_config
tracker_cfg = yaml.safe_load(open(self.tracker_config))
cfg = tracker_cfg[tracker_cfg['type']]
# tracker config
self.use_deepsort_tracker = True if tracker_cfg[
'type'] == 'DeepSORTTracker' else False
if self.use_deepsort_tracker:
# use DeepSORTTracker
if self.reid_pred_config is not None and hasattr(
self.reid_pred_config, 'tracker'):
cfg = self.reid_pred_config.tracker
budget = cfg.get('budget', 100)
max_age = cfg.get('max_age', 30)
max_iou_distance = cfg.get('max_iou_distance', 0.7)
matching_threshold = cfg.get('matching_threshold', 0.2)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
self.tracker = DeepSORTTracker(
budget=budget,
max_age=max_age,
max_iou_distance=max_iou_distance,
matching_threshold=matching_threshold,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio, )
else:
# use ByteTracker
use_byte = cfg.get('use_byte', False)
det_thresh = cfg.get('det_thresh', 0.3)
min_box_area = cfg.get('min_box_area', 0)
vertical_ratio = cfg.get('vertical_ratio', 0)
match_thres = cfg.get('match_thres', 0.9)
conf_thres = cfg.get('conf_thres', 0.6)
low_conf_thres = cfg.get('low_conf_thres', 0.1)
self.tracker = JDETracker(
use_byte=use_byte,
det_thresh=det_thresh,
num_classes=self.num_classes,
min_box_area=min_box_area,
vertical_ratio=vertical_ratio,
match_thres=match_thres,
conf_thres=conf_thres,
low_conf_thres=low_conf_thres, )
def postprocess(self, inputs, result):
# postprocess output of predictor
np_boxes_num = result['boxes_num']
if np_boxes_num[0] <= 0:
print('[WARNNING] No object detected.')
result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
result = {k: v for k, v in result.items() if v is not None}
return result
def reidprocess(self, det_results, repeats=1):
pred_dets = det_results['boxes']
pred_xyxys = pred_dets[:, 2:6]
ori_image = det_results['ori_image']
ori_image_shape = ori_image.shape[:2]
pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)
if len(keep_idx[0]) == 0:
det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
det_results['embeddings'] = None
return det_results
pred_dets = pred_dets[keep_idx[0]]
pred_xyxys = pred_dets[:, 2:6]
w, h = self.tracker.input_size
crops = get_crops(pred_xyxys, ori_image, w, h)
# to keep fast speed, only use topk crops
crops = crops[:50] # reid_batch_size
det_results['crops'] = np.array(crops).astype('float32')
det_results['boxes'] = pred_dets[:50]
input_names = self.reid_predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.reid_predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(det_results[input_names[i]])
# model prediction
for i in range(repeats):
self.reid_predictor.run()
output_names = self.reid_predictor.get_output_names()
feature_tensor = self.reid_predictor.get_output_handle(output_names[
0])
pred_embs = feature_tensor.copy_to_cpu()
det_results['embeddings'] = pred_embs
return det_results
def tracking(self, det_results):
pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1'
pred_embs = det_results.get('embeddings', None)
if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
self.tracker.predict()
online_targets = self.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_scores, online_ids = [], [], []
for t in online_targets:
if not t.is_confirmed() or t.time_since_update > 1:
continue
tlwh = t.to_tlwh()
tscore = t.score
tid = t.track_id
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs.append(tlwh)
online_scores.append(tscore)
online_ids.append(tid)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
else:
# use ByteTracker, support multiple class
online_tlwhs = defaultdict(list)
online_scores = defaultdict(list)
online_ids = defaultdict(list)
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
for cls_id in range(self.num_classes):
online_targets = online_targets_dict[cls_id]
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
tscore = t.score
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
continue
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
3] > self.tracker.vertical_ratio:
continue
online_tlwhs[cls_id].append(tlwh)
online_ids[cls_id].append(tid)
online_scores[cls_id].append(tscore)
tracking_outs = {
'online_tlwhs': online_tlwhs,
'online_scores': online_scores,
'online_ids': online_ids,
}
return tracking_outs
def predict_image(self,
image_list,
run_benchmark=False,
repeats=1,
visual=True,
seq_name=None):
num_classes = self.num_classes
image_list.sort()
ids2names = self.pred_config.labels
mot_results = []
for frame_id, img_file in enumerate(image_list):
batch_image_list = [img_file] # bs=1 in MOT model
frame, _ = decode_image(img_file, {})
if run_benchmark:
# preprocess
inputs = self.preprocess(batch_image_list) # warmup
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
result_warmup = self.predict(repeats=repeats) # warmup
self.det_times.inference_time_s.start()
result = self.predict(repeats=repeats)
self.det_times.inference_time_s.end(repeats=repeats)
# postprocess
result_warmup = self.postprocess(inputs, result) # warmup
self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
# tracking
if self.use_reid:
det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
result_warmup = self.tracking(det_result)
self.det_times.tracking_time_s.start()
if self.use_reid:
det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
cm, gm, gu = get_current_memory_mb()
self.cpu_mem += cm
self.gpu_mem += gm
self.gpu_util += gu
else:
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
self.det_times.postprocess_time_s.start()
det_result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
# tracking process
self.det_times.tracking_time_s.start()
if self.use_reid:
det_result['frame_id'] = frame_id
det_result['seq_name'] = seq_name
det_result['ori_image'] = frame
det_result = self.reidprocess(det_result)
tracking_outs = self.tracking(det_result)
self.det_times.tracking_time_s.end()
self.det_times.img_num += 1
online_tlwhs = tracking_outs['online_tlwhs']
online_scores = tracking_outs['online_scores']
online_ids = tracking_outs['online_ids']
mot_results.append([online_tlwhs, online_scores, online_ids])
if visual:
if len(image_list) > 1 and frame_id % 10 == 0:
print('Tracking frame {}'.format(frame_id))
frame, _ = decode_image(img_file, {})
if isinstance(online_tlwhs, defaultdict):
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
ids2names=ids2names)
else:
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
ids2names=ids2names)
save_dir = os.path.join(self.output_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
return mot_results
def predict_video(self, video_file, camera_id):
video_out_name = 'output.mp4'
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
else:
capture = cv2.VideoCapture(video_file)
video_out_name = os.path.split(video_file)[-1]
# Get Video info : resolution, fps, frame count
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(capture.get(cv2.CAP_PROP_FPS))
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print("fps: %d, frame_count: %d" % (fps, frame_count))
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name)
video_format = 'mp4v'
fourcc = cv2.VideoWriter_fourcc(*video_format)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 1
timer = MOTTimer()
results = defaultdict(list)
num_classes = self.num_classes
data_type = 'mcmot' if num_classes > 1 else 'mot'
ids2names = self.pred_config.labels
while (1):
ret, frame = capture.read()
if not ret:
break
if frame_id % 10 == 0:
print('Tracking frame: %d' % (frame_id))
frame_id += 1
timer.tic()
seq_name = video_out_name.split('.')[0]
mot_results = self.predict_image(
[frame[:, :, ::-1]], visual=False, seq_name=seq_name)
timer.toc()
# bs=1 in MOT model
online_tlwhs, online_scores, online_ids = mot_results[0]
fps = 1. / timer.duration
if self.use_deepsort_tracker:
# use DeepSORTTracker, only support singe class
results[0].append(
(frame_id + 1, online_tlwhs, online_scores, online_ids))
im = plot_tracking(
frame,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
ids2names=ids2names)
else:
# use ByteTracker, support multiple class
for cls_id in range(num_classes):
results[cls_id].append(
(frame_id + 1, online_tlwhs[cls_id],
online_scores[cls_id], online_ids[cls_id]))
im = plot_tracking_dict(
frame,
num_classes,
online_tlwhs,
online_ids,
online_scores,
frame_id=frame_id,
fps=fps,
ids2names=ids2names)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release()
def main():
deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
detector = SDE_Detector(
FLAGS.model_dir,
tracker_config=FLAGS.tracker_config,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=1,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
output_dir=FLAGS.output_dir,
threshold=FLAGS.threshold,
save_images=FLAGS.save_images,
save_mot_txts=FLAGS.save_mot_txts, )
# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
else:
# predict from image
if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
seq_name = FLAGS.image_dir.split('/')[-1]
detector.predict_image(
img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
if not FLAGS.run_benchmark:
detector.det_times.info(average=True)
else:
mode = FLAGS.run_mode
model_dir = FLAGS.model_dir
model_info = {
'model_name': model_dir.strip('/').split('/')[-1],
'precision': mode.split('_')[-1]
}
bench_log(detector, img_list, model_info, name='MOT')
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU'
], "device should be CPU, GPU, NPU or XPU"
main()

View file

@ -14,6 +14,7 @@
# limitations under the License.
#
import json
import logging
import os
import hashlib
import copy
@ -24,9 +25,10 @@ from timeit import default_timer as timer
from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH, num_tokens_from_string
from rag.utils import ELASTICSEARCH
from rag.utils import MINIO
from rag.utils import rmSpace, findMaxDt
from rag.utils import rmSpace, findMaxTm
from rag.nlp import huchunk, huqie, search
from io import BytesIO
import pandas as pd
@ -47,6 +49,7 @@ from rag.nlp.huchunk import (
from web_server.db import LLMType
from web_server.db.services.document_service import DocumentService
from web_server.db.services.llm_service import TenantLLMService
from web_server.settings import database_logger
from web_server.utils import get_format_time
from web_server.utils.file_utils import get_project_base_directory
@ -83,7 +86,7 @@ def collect(comm, mod, tm):
if len(docs) == 0:
return pd.DataFrame()
docs = pd.DataFrame(docs)
mtm = str(docs["update_time"].max())[:19]
mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs
@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
def build(row):
def build(row, cvmdl):
if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
if ELASTICSEARCH.getTotal(res) > 0:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
@ -120,7 +124,8 @@ def build(row):
set_progress(row["id"], random.randint(0, 20) /
100., "Finished preparing! Start to slice file!", True)
try:
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(
@ -131,6 +136,9 @@ def build(row):
row["id"], -1, f"Internal server error: %s" %
str(e).replace(
"'", ""))
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return []
if not obj.text_chunks and not obj.table_chunks:
@ -144,7 +152,7 @@ def build(row):
"Finished slicing files. Start to embedding the content.")
doc = {
"doc_id": row["did"],
"doc_id": row["id"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(row["name"]),
@ -164,10 +172,10 @@ def build(row):
docs.append(d)
continue
if isinstance(img, Image):
img.save(output_buffer, format='JPEG')
else:
if isinstance(img, bytes):
output_buffer = BytesIO(img)
else:
img.save(output_buffer, format='JPEG')
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
@ -215,15 +223,16 @@ def embedding(docs, mdl):
def model_instance(tenant_id, llm_type):
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:return
model_config = model_config[0]
model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else: model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING:
if model_config.llm_factory not in EmbeddingModel: return
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT:
if model_config.llm_factory not in CvModel: return
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
def main(comm, mod):
@ -231,7 +240,7 @@ def main(comm, mod):
from rag.llm import HuEmbedding
model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxDt(tm_fnm)
tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:
return
@ -247,7 +256,7 @@ def main(comm, mod):
st_tm = timer()
cks = build(r, cv_mdl)
if not cks:
tmf.write(str(r["updated_at"]) + "\n")
tmf.write(str(r["update_time"]) + "\n")
continue
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
@ -268,12 +277,19 @@ def main(comm, mod):
cron_logger.error(str(es_r))
else:
set_progress(r["id"], 1., "Done!")
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n")
tmf.close()
if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())

View file

@ -40,6 +40,25 @@ def findMaxDt(fnm):
print("WARNING: can't find " + fnm)
return m
def findMaxTm(fnm):
m = 0
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
break
l = l.strip("\n")
if l == 'nan':
continue
if int(l) > m:
m = int(l)
except Exception as e:
print("WARNING: can't find " + fnm)
return m
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')

View file

@ -294,6 +294,7 @@ class HuEs:
except Exception as e:
es_logger.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】" + str(query.to_dict()))
if str(e).find("NotFoundError") > 0: return True
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import pathlib
from elasticsearch_dsl import Q
@ -195,11 +196,15 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
if not DocumentService.delete_by_id(req["doc_id"]):
return get_data_error_result(
retmsg="Database error (Document removal)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
MINIO.rm(kb.id, doc.location)
MINIO.rm(doc.kb_id, doc.location)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -233,3 +238,43 @@ def rename():
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/get', methods=['GET'])
@login_required
def get():
doc_id = request.args["doc_id"]
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(retmsg="Document not found!")
blob = MINIO.get(doc.kb_id, doc.location)
return get_json_result(data={"base64": base64.b64decode(blob)})
except Exception as e:
return server_error_response(e)
@manager.route('/change_parser', methods=['POST'])
@login_required
@validate_request("doc_id", "parser_id")
def change_parser():
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id.lower() == req["parser_id"].lower():
return get_json_result(data=True)
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
if not e:
return get_data_error_result(retmsg="Document not found!")
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
if not e:
return get_data_error_result(retmsg="Document not found!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View file

@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result
@manager.route('/create', methods=['post'])
@login_required
@validate_request("name", "description", "permission", "embd_id", "parser_id")
@validate_request("name", "description", "permission", "parser_id")
def create():
req = request.json
req["name"] = req["name"].strip()
@ -46,7 +46,7 @@ def create():
@manager.route('/update', methods=['post'])
@login_required
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id")
@validate_request("kb_id", "name", "description", "permission", "parser_id")
def update():
req = request.json
req["name"] = req["name"].strip()
@ -72,6 +72,18 @@ def update():
return server_error_response(e)
@manager.route('/detail', methods=['GET'])
@login_required
def detail():
kb_id = request.args["kb_id"]
try:
kb = KnowledgebaseService.get_detail(kb_id)
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
return get_json_result(data=kb)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():

View file

@ -0,0 +1,95 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from web_server.db.services import duplicate_name
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from web_server.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase, TenantLLM
from web_server.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result
@manager.route('/factories', methods=['GET'])
@login_required
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=fac.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/set_api_key', methods=['POST'])
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
llm = {
"tenant_id": current_user.id,
"llm_factory": req["llm_factory"],
"api_key": req["api_key"]
}
# TODO: Test api_key
for n in ["model_type", "llm_name"]:
if n in req: llm[n] = req[n]
TenantLLM.insert(**llm).on_conflict("replace").execute()
return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET'])
@login_required
def my_llms():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs]
for o in objs: del o["api_key"]
return get_json_result(data=objs)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs if o.api_key]
fct = {}
for o in objs:
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = False
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
m["available"] = True
res = {}
for m in llms:
if m["fid"] not in res: res[m["fid"]] = []
res[m["fid"]].append(m)
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)

View file

@ -16,9 +16,12 @@
from flask import request, session, redirect, url_for
from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_required, current_user, login_user, logout_user
from web_server.db.db_models import TenantLLM
from web_server.db.services.llm_service import TenantLLMService
from web_server.utils.api_utils import server_error_response, validate_request
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
from web_server.db import UserTenantRole
from web_server.db import UserTenantRole, LLMType
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
from web_server.settings import stat_logger
@ -47,8 +50,9 @@ def login():
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
user_id = get_uuid()
try:
users = user_register({
users = user_register(user_id, {
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
@ -63,6 +67,7 @@ def login():
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return server_error_response(e)
elif not request.json:
@ -162,7 +167,25 @@ def user_info():
return get_json_result(data=current_user.to_dict())
def user_register(user):
def rollback_user_registration(user_id):
try:
TenantService.delete_by_id(user_id)
except Exception as e:
pass
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception as e:
pass
try:
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
except Exception as e:
pass
def user_register(user_id, user):
user_id = get_uuid()
user["id"] = user_id
tenant = {
@ -180,10 +203,12 @@ def user_register(user):
"invited_by": user_id,
"role": UserTenantRole.OWNER
}
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
if not UserService.save(**user):return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
TenantLLMService.save(**tenant_llm)
return UserService.query(email=user["email"])
@ -203,14 +228,17 @@ def user_add():
"last_login_time": get_format_time(),
"is_superuser": False,
}
user_id = get_uuid()
try:
users = user_register(user_dict)
users = user_register(user_id, user_dict)
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
@ -220,7 +248,7 @@ def user_add():
@login_required
def tenant_info():
try:
tenants = TenantService.get_by_user_id(current_user.id)
tenants = TenantService.get_by_user_id(current_user.id)[0]
return get_json_result(data=tenants)
except Exception as e:
return server_error_response(e)

View file

@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel):
class LLM(DataBaseModel):
# defautlt LLMs for every users
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
@ -442,8 +443,8 @@ class LLM(DataBaseModel):
class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=False, help_text="LLM name")
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base")
@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel):
class Meta:
db_table = "tenant_llm"
primary_key = CompositeKey('tenant_id', 'llm_factory')
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
class Knowledgebase(DataBaseModel):
@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel):
permission = CharField(max_length=16, null=False, help_text="me|team")
created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0)
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID")
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View file

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from peewee import Expression
from web_server.db import TenantPermission, FileType
from web_server.db.db_models import DB, Knowledgebase
from web_server.db.db_models import DB, Knowledgebase, Tenant
from web_server.db.db_models import Document
from web_server.db.services.common_service import CommonService
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
@ -61,15 +62,28 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id]
docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where(
cls.model.status == StatusEnum.VALID.value,
cls.model.type != FileType.VIRTUAL,
cls.model.progress == 0,
cls.model.update_time >= tm,
cls.model.create_time %
comm == mod).order_by(
cls.model.update_time.asc()).paginate(
1,
items_per_page)
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where(
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num

View file

@ -17,7 +17,7 @@ import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import TenantPermission
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import DB, UserTenant, Tenant
from web_server.db.db_models import Knowledgebase
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService):
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc):
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc):
kbs = cls.model.select().where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status==StatusEnum.VALID.value)
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value)
)
if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
if desc:
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
fields = [
cls.model.id,
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
if not kbs:
return
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d

View file

@ -33,3 +33,21 @@ class LLMService(CommonService):
class TenantLLMService(CommonService):
model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
return objs[0]

View file

@ -79,7 +79,7 @@ class TenantService(CommonService):
@classmethod
@DB.connection_context()
def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())

View file

@ -143,7 +143,7 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename):
return FileType.PDF.value
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename):
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):