add alot of api
This commit is contained in:
parent
ecb7d40dcc
commit
76fad8bf99
26 changed files with 992 additions and 154 deletions
|
|
@ -35,7 +35,7 @@ class Base(ABC):
|
||||||
|
|
||||||
|
|
||||||
class HuEmbedding(Base):
|
class HuEmbedding(Base):
|
||||||
def __init__(self):
|
def __init__(self, key="", model_name=""):
|
||||||
"""
|
"""
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -411,9 +411,12 @@ class TextChunker(HuChunker):
|
||||||
flds = self.Fields()
|
flds = self.Fields()
|
||||||
if self.is_binary_file(fnm):
|
if self.is_binary_file(fnm):
|
||||||
return flds
|
return flds
|
||||||
with open(fnm, "r") as f:
|
txt = ""
|
||||||
txt = f.read()
|
if isinstance(fnm, str):
|
||||||
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
with open(fnm, "r") as f:
|
||||||
|
txt = f.read()
|
||||||
|
else: txt = fnm.decode("utf-8")
|
||||||
|
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
||||||
flds.table_chunks = []
|
flds.table_chunks = []
|
||||||
return flds
|
return flds
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from rag.nlp import huqie, query
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def index_name(uid): return f"docgpt_{uid}"
|
def index_name(uid): return f"ragflow_{uid}"
|
||||||
|
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
|
|
|
||||||
|
|
@ -158,8 +158,8 @@ class PaddleInferBenchmark(object):
|
||||||
return:
|
return:
|
||||||
config_status(dict): dict style config info
|
config_status(dict): dict style config info
|
||||||
"""
|
"""
|
||||||
config_status = {}
|
|
||||||
if isinstance(config, paddle_infer.Config):
|
if isinstance(config, paddle_infer.Config):
|
||||||
|
config_status = {}
|
||||||
config_status['runtime_device'] = "gpu" if config.use_gpu(
|
config_status['runtime_device'] = "gpu" if config.use_gpu(
|
||||||
) else "cpu"
|
) else "cpu"
|
||||||
config_status['ir_optim'] = config.ir_optim()
|
config_status['ir_optim'] = config.ir_optim()
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.nn as nn
|
import paddle.nn as nn
|
||||||
|
from scipy.special import softmax
|
||||||
from scipy.interpolate import InterpolatedUnivariateSpline
|
from scipy.interpolate import InterpolatedUnivariateSpline
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,10 @@ import yaml
|
||||||
|
|
||||||
from det_keypoint_unite_utils import argsparser
|
from det_keypoint_unite_utils import argsparser
|
||||||
from preprocess import decode_image
|
from preprocess import decode_image
|
||||||
from infer import print_arguments, get_test_images, bench_log
|
from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
|
||||||
from keypoint_infer import KeyPointDetector
|
from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
|
||||||
from visualize import visualize_pose
|
from visualize import visualize_pose
|
||||||
|
from benchmark_utils import PaddleInferBenchmark
|
||||||
from utils import get_current_memory_mb
|
from utils import get_current_memory_mb
|
||||||
from keypoint_postprocess import translate_to_ori_images
|
from keypoint_postprocess import translate_to_ori_images
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ import yaml
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -26,12 +27,14 @@ from paddle.inference import Config
|
||||||
from paddle.inference import create_predictor
|
from paddle.inference import create_predictor
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
# add deploy path of PaddleDetection to sys.path
|
||||||
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
|
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
|
||||||
sys.path.insert(0, parent_path)
|
sys.path.insert(0, parent_path)
|
||||||
|
|
||||||
from benchmark_utils import PaddleInferBenchmark
|
from benchmark_utils import PaddleInferBenchmark
|
||||||
from picodet_postprocess import PicoDetPostProcess
|
from picodet_postprocess import PicoDetPostProcess
|
||||||
from preprocess import preprocess
|
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
|
||||||
|
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
|
||||||
from clrnet_postprocess import CLRNetPostProcess
|
from clrnet_postprocess import CLRNetPostProcess
|
||||||
from visualize import visualize_box_mask, imshow_lanes
|
from visualize import visualize_box_mask, imshow_lanes
|
||||||
from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid
|
from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid
|
||||||
|
|
|
||||||
|
|
@ -11,21 +11,30 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
|
||||||
import os
|
import os
|
||||||
# add deploy path of PaddleDetection to sys.path
|
import time
|
||||||
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
|
|
||||||
import sys
|
|
||||||
sys.path.insert(0, parent_path)
|
|
||||||
import yaml
|
import yaml
|
||||||
|
import glob
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
import cv2
|
import cv2
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from keypoint_preprocess import expand_crop
|
import sys
|
||||||
|
# add deploy path of PaddleDetection to sys.path
|
||||||
|
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
|
||||||
|
sys.path.insert(0, parent_path)
|
||||||
|
|
||||||
|
from preprocess import preprocess, NormalizeImage, Permute
|
||||||
|
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
|
||||||
from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
|
from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
|
||||||
from visualize import visualize_pose
|
from visualize import visualize_pose
|
||||||
|
from paddle.inference import Config
|
||||||
|
from paddle.inference import create_predictor
|
||||||
from utils import argsparser, Timer, get_current_memory_mb
|
from utils import argsparser, Timer, get_current_memory_mb
|
||||||
from benchmark_utils import PaddleInferBenchmark
|
from benchmark_utils import PaddleInferBenchmark
|
||||||
from infer import Detector, get_test_images, print_arguments
|
from infer import Detector, get_test_images, print_arguments
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,8 @@ from collections import abc, defaultdict
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform
|
from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,16 +15,18 @@
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from rag.ppdet import MOTTimer
|
from benchmark_utils import PaddleInferBenchmark
|
||||||
from utils import gaussian_radius, draw_umich_gaussian
|
from utils import gaussian_radius, gaussian2D, draw_umich_gaussian
|
||||||
from preprocess import preprocess, decode_image
|
from preprocess import preprocess, decode_image, WarpAffine, NormalizeImage, Permute
|
||||||
from utils import argsparser, Timer, get_current_memory_mb
|
from utils import argsparser, Timer, get_current_memory_mb
|
||||||
from infer import Detector, get_test_images, print_arguments, bench_log
|
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
|
||||||
from keypoint_preprocess import get_affine_transform
|
from keypoint_preprocess import get_affine_transform
|
||||||
|
|
||||||
# add python path
|
# add python path
|
||||||
|
|
@ -32,6 +34,9 @@ import sys
|
||||||
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
|
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
|
||||||
sys.path.insert(0, parent_path)
|
sys.path.insert(0, parent_path)
|
||||||
|
|
||||||
|
from pptracking.python.mot import CenterTracker
|
||||||
|
from pptracking.python.mot.utils import MOTTimer, write_mot_results
|
||||||
|
from pptracking.python.mot.visualize import plot_tracking
|
||||||
|
|
||||||
|
|
||||||
def transform_preds_with_trans(coords, trans):
|
def transform_preds_with_trans(coords, trans):
|
||||||
|
|
@ -119,12 +124,12 @@ class CenterTrack(Detector):
|
||||||
track_thresh = cfg.get('track_thresh', 0.4)
|
track_thresh = cfg.get('track_thresh', 0.4)
|
||||||
pre_thresh = cfg.get('pre_thresh', 0.5)
|
pre_thresh = cfg.get('pre_thresh', 0.5)
|
||||||
|
|
||||||
# self.tracker = CenterTracker(
|
self.tracker = CenterTracker(
|
||||||
# num_classes=self.num_classes,
|
num_classes=self.num_classes,
|
||||||
# min_box_area=min_box_area,
|
min_box_area=min_box_area,
|
||||||
# vertical_ratio=vertical_ratio,
|
vertical_ratio=vertical_ratio,
|
||||||
# track_thresh=track_thresh,
|
track_thresh=track_thresh,
|
||||||
# pre_thresh=pre_thresh)
|
pre_thresh=pre_thresh)
|
||||||
|
|
||||||
self.pre_image = None
|
self.pre_image = None
|
||||||
|
|
||||||
|
|
@ -359,20 +364,20 @@ class CenterTrack(Detector):
|
||||||
print('Tracking frame {}'.format(frame_id))
|
print('Tracking frame {}'.format(frame_id))
|
||||||
frame, _ = decode_image(img_file, {})
|
frame, _ = decode_image(img_file, {})
|
||||||
|
|
||||||
# im = plot_tracking(
|
im = plot_tracking(
|
||||||
# frame,
|
frame,
|
||||||
# online_tlwhs,
|
online_tlwhs,
|
||||||
# online_ids,
|
online_ids,
|
||||||
# online_scores,
|
online_scores,
|
||||||
# frame_id=frame_id,
|
frame_id=frame_id,
|
||||||
# ids2names=ids2names)
|
ids2names=ids2names)
|
||||||
if seq_name is None:
|
if seq_name is None:
|
||||||
seq_name = image_list[0].split('/')[-2]
|
seq_name = image_list[0].split('/')[-2]
|
||||||
save_dir = os.path.join(self.output_dir, seq_name)
|
save_dir = os.path.join(self.output_dir, seq_name)
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir)
|
os.makedirs(save_dir)
|
||||||
# cv2.imwrite(
|
cv2.imwrite(
|
||||||
# os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
|
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
|
||||||
|
|
||||||
mot_results.append([online_tlwhs, online_scores, online_ids])
|
mot_results.append([online_tlwhs, online_scores, online_ids])
|
||||||
return mot_results
|
return mot_results
|
||||||
|
|
@ -422,26 +427,26 @@ class CenterTrack(Detector):
|
||||||
online_tlwhs, online_scores, online_ids = mot_results[0]
|
online_tlwhs, online_scores, online_ids = mot_results[0]
|
||||||
results[0].append(
|
results[0].append(
|
||||||
(frame_id + 1, online_tlwhs, online_scores, online_ids))
|
(frame_id + 1, online_tlwhs, online_scores, online_ids))
|
||||||
# im = plot_tracking(
|
im = plot_tracking(
|
||||||
# frame,
|
frame,
|
||||||
# online_tlwhs,
|
online_tlwhs,
|
||||||
# online_ids,
|
online_ids,
|
||||||
# online_scores,
|
online_scores,
|
||||||
# frame_id=frame_id,
|
frame_id=frame_id,
|
||||||
# fps=fps,
|
fps=fps,
|
||||||
# ids2names=ids2names)
|
ids2names=ids2names)
|
||||||
#
|
|
||||||
# writer.write(im)
|
writer.write(im)
|
||||||
# if camera_id != -1:
|
if camera_id != -1:
|
||||||
# cv2.imshow('Mask Detection', im)
|
cv2.imshow('Mask Detection', im)
|
||||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
# break
|
break
|
||||||
|
|
||||||
if self.save_mot_txts:
|
if self.save_mot_txts:
|
||||||
result_filename = os.path.join(
|
result_filename = os.path.join(
|
||||||
self.output_dir, video_out_name.split('.')[-2] + '.txt')
|
self.output_dir, video_out_name.split('.')[-2] + '.txt')
|
||||||
|
|
||||||
#write_mot_results(result_filename, results, data_type, num_classes)
|
write_mot_results(result_filename, results, data_type, num_classes)
|
||||||
|
|
||||||
writer.release()
|
writer.release()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
@ -20,7 +22,6 @@ import paddle
|
||||||
|
|
||||||
from benchmark_utils import PaddleInferBenchmark
|
from benchmark_utils import PaddleInferBenchmark
|
||||||
from preprocess import decode_image
|
from preprocess import decode_image
|
||||||
from rag.ppdet import MOTTimer
|
|
||||||
from utils import argsparser, Timer, get_current_memory_mb
|
from utils import argsparser, Timer, get_current_memory_mb
|
||||||
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
|
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
|
||||||
|
|
||||||
|
|
@ -29,6 +30,9 @@ import sys
|
||||||
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
|
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
|
||||||
sys.path.insert(0, parent_path)
|
sys.path.insert(0, parent_path)
|
||||||
|
|
||||||
|
from pptracking.python.mot import JDETracker
|
||||||
|
from pptracking.python.mot.utils import MOTTimer, write_mot_results
|
||||||
|
from pptracking.python.mot.visualize import plot_tracking_dict
|
||||||
|
|
||||||
# Global dictionary
|
# Global dictionary
|
||||||
MOT_JDE_SUPPORT_MODELS = {
|
MOT_JDE_SUPPORT_MODELS = {
|
||||||
|
|
@ -102,13 +106,13 @@ class JDE_Detector(Detector):
|
||||||
tracked_thresh = cfg.get('tracked_thresh', 0.7)
|
tracked_thresh = cfg.get('tracked_thresh', 0.7)
|
||||||
metric_type = cfg.get('metric_type', 'euclidean')
|
metric_type = cfg.get('metric_type', 'euclidean')
|
||||||
|
|
||||||
# self.tracker = JDETracker(
|
self.tracker = JDETracker(
|
||||||
# num_classes=self.num_classes,
|
num_classes=self.num_classes,
|
||||||
# min_box_area=min_box_area,
|
min_box_area=min_box_area,
|
||||||
# vertical_ratio=vertical_ratio,
|
vertical_ratio=vertical_ratio,
|
||||||
# conf_thres=conf_thres,
|
conf_thres=conf_thres,
|
||||||
# tracked_thresh=tracked_thresh,
|
tracked_thresh=tracked_thresh,
|
||||||
# metric_type=metric_type)
|
metric_type=metric_type)
|
||||||
|
|
||||||
def postprocess(self, inputs, result):
|
def postprocess(self, inputs, result):
|
||||||
# postprocess output of predictor
|
# postprocess output of predictor
|
||||||
|
|
@ -235,21 +239,21 @@ class JDE_Detector(Detector):
|
||||||
print('Tracking frame {}'.format(frame_id))
|
print('Tracking frame {}'.format(frame_id))
|
||||||
frame, _ = decode_image(img_file, {})
|
frame, _ = decode_image(img_file, {})
|
||||||
|
|
||||||
# im = plot_tracking_dict(
|
im = plot_tracking_dict(
|
||||||
# frame,
|
frame,
|
||||||
# num_classes,
|
num_classes,
|
||||||
# online_tlwhs,
|
online_tlwhs,
|
||||||
# online_ids,
|
online_ids,
|
||||||
# online_scores,
|
online_scores,
|
||||||
# frame_id=frame_id,
|
frame_id=frame_id,
|
||||||
# ids2names=ids2names)
|
ids2names=ids2names)
|
||||||
if seq_name is None:
|
if seq_name is None:
|
||||||
seq_name = image_list[0].split('/')[-2]
|
seq_name = image_list[0].split('/')[-2]
|
||||||
save_dir = os.path.join(self.output_dir, seq_name)
|
save_dir = os.path.join(self.output_dir, seq_name)
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir)
|
os.makedirs(save_dir)
|
||||||
# cv2.imwrite(
|
cv2.imwrite(
|
||||||
# os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
|
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
|
||||||
|
|
||||||
mot_results.append([online_tlwhs, online_scores, online_ids])
|
mot_results.append([online_tlwhs, online_scores, online_ids])
|
||||||
return mot_results
|
return mot_results
|
||||||
|
|
@ -301,28 +305,28 @@ class JDE_Detector(Detector):
|
||||||
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
|
(frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
|
||||||
online_ids[cls_id]))
|
online_ids[cls_id]))
|
||||||
|
|
||||||
#fps = 1. / timer.duration
|
fps = 1. / timer.duration
|
||||||
# im = plot_tracking_dict(
|
im = plot_tracking_dict(
|
||||||
# frame,
|
frame,
|
||||||
# num_classes,
|
num_classes,
|
||||||
# online_tlwhs,
|
online_tlwhs,
|
||||||
# online_ids,
|
online_ids,
|
||||||
# online_scores,
|
online_scores,
|
||||||
# frame_id=frame_id,
|
frame_id=frame_id,
|
||||||
# fps=fps,
|
fps=fps,
|
||||||
# ids2names=ids2names)
|
ids2names=ids2names)
|
||||||
#
|
|
||||||
# writer.write(im)
|
writer.write(im)
|
||||||
# if camera_id != -1:
|
if camera_id != -1:
|
||||||
# cv2.imshow('Mask Detection', im)
|
cv2.imshow('Mask Detection', im)
|
||||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
# break
|
break
|
||||||
|
|
||||||
if self.save_mot_txts:
|
if self.save_mot_txts:
|
||||||
result_filename = os.path.join(
|
result_filename = os.path.join(
|
||||||
self.output_dir, video_out_name.split('.')[-2] + '.txt')
|
self.output_dir, video_out_name.split('.')[-2] + '.txt')
|
||||||
|
|
||||||
#write_mot_results(result_filename, results, data_type, num_classes)
|
write_mot_results(result_filename, results, data_type, num_classes)
|
||||||
|
|
||||||
writer.release()
|
writer.release()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,26 +11,36 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
import yaml
|
import yaml
|
||||||
import copy
|
import copy
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from mot_keypoint_unite_utils import argsparser
|
from mot_keypoint_unite_utils import argsparser
|
||||||
from preprocess import decode_image
|
from preprocess import decode_image
|
||||||
from infer import print_arguments, get_test_images, bench_log
|
from infer import print_arguments, get_test_images, bench_log
|
||||||
from mot_jde_infer import MOT_JDE_SUPPORT_MODELS
|
from mot_sde_infer import SDE_Detector
|
||||||
|
from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS
|
||||||
from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
|
from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
|
||||||
from det_keypoint_unite_infer import predict_with_given_det
|
from det_keypoint_unite_infer import predict_with_given_det
|
||||||
from rag.ppdet import MOTTimer
|
|
||||||
from visualize import visualize_pose
|
from visualize import visualize_pose
|
||||||
|
from benchmark_utils import PaddleInferBenchmark
|
||||||
from utils import get_current_memory_mb
|
from utils import get_current_memory_mb
|
||||||
|
from keypoint_postprocess import translate_to_ori_images
|
||||||
|
|
||||||
|
# add python path
|
||||||
|
import sys
|
||||||
|
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
|
||||||
|
sys.path.insert(0, parent_path)
|
||||||
|
|
||||||
|
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
|
||||||
|
from pptracking.python.mot.utils import MOTTimer as FPSTimer
|
||||||
|
|
||||||
|
|
||||||
def convert_mot_to_det(tlwhs, scores):
|
def convert_mot_to_det(tlwhs, scores):
|
||||||
|
|
@ -140,7 +150,7 @@ def mot_topdown_unite_predict_video(mot_detector,
|
||||||
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
|
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
|
||||||
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
|
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
|
||||||
frame_id = 0
|
frame_id = 0
|
||||||
timer_mot, timer_kp, timer_mot_kp = MOTTimer(), MOTTimer(), MOTTimer()
|
timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer()
|
||||||
|
|
||||||
num_classes = mot_detector.num_classes
|
num_classes = mot_detector.num_classes
|
||||||
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
|
assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
|
||||||
|
|
@ -187,6 +197,15 @@ def mot_topdown_unite_predict_video(mot_detector,
|
||||||
returnimg=True,
|
returnimg=True,
|
||||||
ids=online_ids[0])
|
ids=online_ids[0])
|
||||||
|
|
||||||
|
im = plot_tracking_dict(
|
||||||
|
im,
|
||||||
|
num_classes,
|
||||||
|
online_tlwhs,
|
||||||
|
online_ids,
|
||||||
|
online_scores,
|
||||||
|
frame_id=frame_id,
|
||||||
|
fps=mot_kp_fps)
|
||||||
|
|
||||||
writer.write(im)
|
writer.write(im)
|
||||||
if camera_id != -1:
|
if camera_id != -1:
|
||||||
cv2.imshow('Tracking and keypoint results', im)
|
cv2.imshow('Tracking and keypoint results', im)
|
||||||
|
|
|
||||||
522
rag/ppdet/mot_sde_infer.py
Normal file
522
rag/ppdet/mot_sde_infer.py
Normal file
|
|
@ -0,0 +1,522 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from collections import defaultdict
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from benchmark_utils import PaddleInferBenchmark
|
||||||
|
from preprocess import decode_image
|
||||||
|
from utils import argsparser, Timer, get_current_memory_mb
|
||||||
|
from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
|
||||||
|
|
||||||
|
# add python path
|
||||||
|
import sys
|
||||||
|
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
|
||||||
|
sys.path.insert(0, parent_path)
|
||||||
|
|
||||||
|
from pptracking.python.mot import JDETracker, DeepSORTTracker
|
||||||
|
from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box
|
||||||
|
from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
|
||||||
|
|
||||||
|
|
||||||
|
class SDE_Detector(Detector):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
|
||||||
|
tracker_config (str): tracker config path
|
||||||
|
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU
|
||||||
|
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
|
||||||
|
batch_size (int): size of pre batch in inference
|
||||||
|
trt_min_shape (int): min shape for dynamic shape in trt
|
||||||
|
trt_max_shape (int): max shape for dynamic shape in trt
|
||||||
|
trt_opt_shape (int): opt shape for dynamic shape in trt
|
||||||
|
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
|
||||||
|
calibration, trt_calib_mode need to set True
|
||||||
|
cpu_threads (int): cpu threads
|
||||||
|
enable_mkldnn (bool): whether to open MKLDNN
|
||||||
|
output_dir (string): The path of output, default as 'output'
|
||||||
|
threshold (float): Score threshold of the detected bbox, default as 0.5
|
||||||
|
save_images (bool): Whether to save visualization image results, default as False
|
||||||
|
save_mot_txts (bool): Whether to save tracking results (txt), default as False
|
||||||
|
reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_dir,
|
||||||
|
tracker_config,
|
||||||
|
device='CPU',
|
||||||
|
run_mode='paddle',
|
||||||
|
batch_size=1,
|
||||||
|
trt_min_shape=1,
|
||||||
|
trt_max_shape=1280,
|
||||||
|
trt_opt_shape=640,
|
||||||
|
trt_calib_mode=False,
|
||||||
|
cpu_threads=1,
|
||||||
|
enable_mkldnn=False,
|
||||||
|
output_dir='output',
|
||||||
|
threshold=0.5,
|
||||||
|
save_images=False,
|
||||||
|
save_mot_txts=False,
|
||||||
|
reid_model_dir=None):
|
||||||
|
super(SDE_Detector, self).__init__(
|
||||||
|
model_dir=model_dir,
|
||||||
|
device=device,
|
||||||
|
run_mode=run_mode,
|
||||||
|
batch_size=batch_size,
|
||||||
|
trt_min_shape=trt_min_shape,
|
||||||
|
trt_max_shape=trt_max_shape,
|
||||||
|
trt_opt_shape=trt_opt_shape,
|
||||||
|
trt_calib_mode=trt_calib_mode,
|
||||||
|
cpu_threads=cpu_threads,
|
||||||
|
enable_mkldnn=enable_mkldnn,
|
||||||
|
output_dir=output_dir,
|
||||||
|
threshold=threshold, )
|
||||||
|
self.save_images = save_images
|
||||||
|
self.save_mot_txts = save_mot_txts
|
||||||
|
assert batch_size == 1, "MOT model only supports batch_size=1."
|
||||||
|
self.det_times = Timer(with_tracker=True)
|
||||||
|
self.num_classes = len(self.pred_config.labels)
|
||||||
|
|
||||||
|
# reid config
|
||||||
|
self.use_reid = False if reid_model_dir is None else True
|
||||||
|
if self.use_reid:
|
||||||
|
self.reid_pred_config = self.set_config(reid_model_dir)
|
||||||
|
self.reid_predictor, self.config = load_predictor(
|
||||||
|
reid_model_dir,
|
||||||
|
run_mode=run_mode,
|
||||||
|
batch_size=50, # reid_batch_size
|
||||||
|
min_subgraph_size=self.reid_pred_config.min_subgraph_size,
|
||||||
|
device=device,
|
||||||
|
use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
|
||||||
|
trt_min_shape=trt_min_shape,
|
||||||
|
trt_max_shape=trt_max_shape,
|
||||||
|
trt_opt_shape=trt_opt_shape,
|
||||||
|
trt_calib_mode=trt_calib_mode,
|
||||||
|
cpu_threads=cpu_threads,
|
||||||
|
enable_mkldnn=enable_mkldnn)
|
||||||
|
else:
|
||||||
|
self.reid_pred_config = None
|
||||||
|
self.reid_predictor = None
|
||||||
|
|
||||||
|
assert tracker_config is not None, 'Note that tracker_config should be set.'
|
||||||
|
self.tracker_config = tracker_config
|
||||||
|
tracker_cfg = yaml.safe_load(open(self.tracker_config))
|
||||||
|
cfg = tracker_cfg[tracker_cfg['type']]
|
||||||
|
|
||||||
|
# tracker config
|
||||||
|
self.use_deepsort_tracker = True if tracker_cfg[
|
||||||
|
'type'] == 'DeepSORTTracker' else False
|
||||||
|
if self.use_deepsort_tracker:
|
||||||
|
# use DeepSORTTracker
|
||||||
|
if self.reid_pred_config is not None and hasattr(
|
||||||
|
self.reid_pred_config, 'tracker'):
|
||||||
|
cfg = self.reid_pred_config.tracker
|
||||||
|
budget = cfg.get('budget', 100)
|
||||||
|
max_age = cfg.get('max_age', 30)
|
||||||
|
max_iou_distance = cfg.get('max_iou_distance', 0.7)
|
||||||
|
matching_threshold = cfg.get('matching_threshold', 0.2)
|
||||||
|
min_box_area = cfg.get('min_box_area', 0)
|
||||||
|
vertical_ratio = cfg.get('vertical_ratio', 0)
|
||||||
|
|
||||||
|
self.tracker = DeepSORTTracker(
|
||||||
|
budget=budget,
|
||||||
|
max_age=max_age,
|
||||||
|
max_iou_distance=max_iou_distance,
|
||||||
|
matching_threshold=matching_threshold,
|
||||||
|
min_box_area=min_box_area,
|
||||||
|
vertical_ratio=vertical_ratio, )
|
||||||
|
else:
|
||||||
|
# use ByteTracker
|
||||||
|
use_byte = cfg.get('use_byte', False)
|
||||||
|
det_thresh = cfg.get('det_thresh', 0.3)
|
||||||
|
min_box_area = cfg.get('min_box_area', 0)
|
||||||
|
vertical_ratio = cfg.get('vertical_ratio', 0)
|
||||||
|
match_thres = cfg.get('match_thres', 0.9)
|
||||||
|
conf_thres = cfg.get('conf_thres', 0.6)
|
||||||
|
low_conf_thres = cfg.get('low_conf_thres', 0.1)
|
||||||
|
|
||||||
|
self.tracker = JDETracker(
|
||||||
|
use_byte=use_byte,
|
||||||
|
det_thresh=det_thresh,
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
min_box_area=min_box_area,
|
||||||
|
vertical_ratio=vertical_ratio,
|
||||||
|
match_thres=match_thres,
|
||||||
|
conf_thres=conf_thres,
|
||||||
|
low_conf_thres=low_conf_thres, )
|
||||||
|
|
||||||
|
def postprocess(self, inputs, result):
|
||||||
|
# postprocess output of predictor
|
||||||
|
np_boxes_num = result['boxes_num']
|
||||||
|
if np_boxes_num[0] <= 0:
|
||||||
|
print('[WARNNING] No object detected.')
|
||||||
|
result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
|
||||||
|
result = {k: v for k, v in result.items() if v is not None}
|
||||||
|
return result
|
||||||
|
|
||||||
|
def reidprocess(self, det_results, repeats=1):
|
||||||
|
pred_dets = det_results['boxes']
|
||||||
|
pred_xyxys = pred_dets[:, 2:6]
|
||||||
|
|
||||||
|
ori_image = det_results['ori_image']
|
||||||
|
ori_image_shape = ori_image.shape[:2]
|
||||||
|
pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)
|
||||||
|
|
||||||
|
if len(keep_idx[0]) == 0:
|
||||||
|
det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
|
||||||
|
det_results['embeddings'] = None
|
||||||
|
return det_results
|
||||||
|
|
||||||
|
pred_dets = pred_dets[keep_idx[0]]
|
||||||
|
pred_xyxys = pred_dets[:, 2:6]
|
||||||
|
|
||||||
|
w, h = self.tracker.input_size
|
||||||
|
crops = get_crops(pred_xyxys, ori_image, w, h)
|
||||||
|
|
||||||
|
# to keep fast speed, only use topk crops
|
||||||
|
crops = crops[:50] # reid_batch_size
|
||||||
|
det_results['crops'] = np.array(crops).astype('float32')
|
||||||
|
det_results['boxes'] = pred_dets[:50]
|
||||||
|
|
||||||
|
input_names = self.reid_predictor.get_input_names()
|
||||||
|
for i in range(len(input_names)):
|
||||||
|
input_tensor = self.reid_predictor.get_input_handle(input_names[i])
|
||||||
|
input_tensor.copy_from_cpu(det_results[input_names[i]])
|
||||||
|
|
||||||
|
# model prediction
|
||||||
|
for i in range(repeats):
|
||||||
|
self.reid_predictor.run()
|
||||||
|
output_names = self.reid_predictor.get_output_names()
|
||||||
|
feature_tensor = self.reid_predictor.get_output_handle(output_names[
|
||||||
|
0])
|
||||||
|
pred_embs = feature_tensor.copy_to_cpu()
|
||||||
|
|
||||||
|
det_results['embeddings'] = pred_embs
|
||||||
|
return det_results
|
||||||
|
|
||||||
|
def tracking(self, det_results):
|
||||||
|
pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1'
|
||||||
|
pred_embs = det_results.get('embeddings', None)
|
||||||
|
|
||||||
|
if self.use_deepsort_tracker:
|
||||||
|
# use DeepSORTTracker, only support singe class
|
||||||
|
self.tracker.predict()
|
||||||
|
online_targets = self.tracker.update(pred_dets, pred_embs)
|
||||||
|
online_tlwhs, online_scores, online_ids = [], [], []
|
||||||
|
for t in online_targets:
|
||||||
|
if not t.is_confirmed() or t.time_since_update > 1:
|
||||||
|
continue
|
||||||
|
tlwh = t.to_tlwh()
|
||||||
|
tscore = t.score
|
||||||
|
tid = t.track_id
|
||||||
|
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
|
||||||
|
3] > self.tracker.vertical_ratio:
|
||||||
|
continue
|
||||||
|
online_tlwhs.append(tlwh)
|
||||||
|
online_scores.append(tscore)
|
||||||
|
online_ids.append(tid)
|
||||||
|
|
||||||
|
tracking_outs = {
|
||||||
|
'online_tlwhs': online_tlwhs,
|
||||||
|
'online_scores': online_scores,
|
||||||
|
'online_ids': online_ids,
|
||||||
|
}
|
||||||
|
return tracking_outs
|
||||||
|
else:
|
||||||
|
# use ByteTracker, support multiple class
|
||||||
|
online_tlwhs = defaultdict(list)
|
||||||
|
online_scores = defaultdict(list)
|
||||||
|
online_ids = defaultdict(list)
|
||||||
|
online_targets_dict = self.tracker.update(pred_dets, pred_embs)
|
||||||
|
for cls_id in range(self.num_classes):
|
||||||
|
online_targets = online_targets_dict[cls_id]
|
||||||
|
for t in online_targets:
|
||||||
|
tlwh = t.tlwh
|
||||||
|
tid = t.track_id
|
||||||
|
tscore = t.score
|
||||||
|
if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
|
||||||
|
continue
|
||||||
|
if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
|
||||||
|
3] > self.tracker.vertical_ratio:
|
||||||
|
continue
|
||||||
|
online_tlwhs[cls_id].append(tlwh)
|
||||||
|
online_ids[cls_id].append(tid)
|
||||||
|
online_scores[cls_id].append(tscore)
|
||||||
|
|
||||||
|
tracking_outs = {
|
||||||
|
'online_tlwhs': online_tlwhs,
|
||||||
|
'online_scores': online_scores,
|
||||||
|
'online_ids': online_ids,
|
||||||
|
}
|
||||||
|
return tracking_outs
|
||||||
|
|
||||||
|
def predict_image(self,
|
||||||
|
image_list,
|
||||||
|
run_benchmark=False,
|
||||||
|
repeats=1,
|
||||||
|
visual=True,
|
||||||
|
seq_name=None):
|
||||||
|
num_classes = self.num_classes
|
||||||
|
image_list.sort()
|
||||||
|
ids2names = self.pred_config.labels
|
||||||
|
mot_results = []
|
||||||
|
for frame_id, img_file in enumerate(image_list):
|
||||||
|
batch_image_list = [img_file] # bs=1 in MOT model
|
||||||
|
frame, _ = decode_image(img_file, {})
|
||||||
|
if run_benchmark:
|
||||||
|
# preprocess
|
||||||
|
inputs = self.preprocess(batch_image_list) # warmup
|
||||||
|
self.det_times.preprocess_time_s.start()
|
||||||
|
inputs = self.preprocess(batch_image_list)
|
||||||
|
self.det_times.preprocess_time_s.end()
|
||||||
|
|
||||||
|
# model prediction
|
||||||
|
result_warmup = self.predict(repeats=repeats) # warmup
|
||||||
|
self.det_times.inference_time_s.start()
|
||||||
|
result = self.predict(repeats=repeats)
|
||||||
|
self.det_times.inference_time_s.end(repeats=repeats)
|
||||||
|
|
||||||
|
# postprocess
|
||||||
|
result_warmup = self.postprocess(inputs, result) # warmup
|
||||||
|
self.det_times.postprocess_time_s.start()
|
||||||
|
det_result = self.postprocess(inputs, result)
|
||||||
|
self.det_times.postprocess_time_s.end()
|
||||||
|
|
||||||
|
# tracking
|
||||||
|
if self.use_reid:
|
||||||
|
det_result['frame_id'] = frame_id
|
||||||
|
det_result['seq_name'] = seq_name
|
||||||
|
det_result['ori_image'] = frame
|
||||||
|
det_result = self.reidprocess(det_result)
|
||||||
|
result_warmup = self.tracking(det_result)
|
||||||
|
self.det_times.tracking_time_s.start()
|
||||||
|
if self.use_reid:
|
||||||
|
det_result = self.reidprocess(det_result)
|
||||||
|
tracking_outs = self.tracking(det_result)
|
||||||
|
self.det_times.tracking_time_s.end()
|
||||||
|
self.det_times.img_num += 1
|
||||||
|
|
||||||
|
cm, gm, gu = get_current_memory_mb()
|
||||||
|
self.cpu_mem += cm
|
||||||
|
self.gpu_mem += gm
|
||||||
|
self.gpu_util += gu
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.det_times.preprocess_time_s.start()
|
||||||
|
inputs = self.preprocess(batch_image_list)
|
||||||
|
self.det_times.preprocess_time_s.end()
|
||||||
|
|
||||||
|
self.det_times.inference_time_s.start()
|
||||||
|
result = self.predict()
|
||||||
|
self.det_times.inference_time_s.end()
|
||||||
|
|
||||||
|
self.det_times.postprocess_time_s.start()
|
||||||
|
det_result = self.postprocess(inputs, result)
|
||||||
|
self.det_times.postprocess_time_s.end()
|
||||||
|
|
||||||
|
# tracking process
|
||||||
|
self.det_times.tracking_time_s.start()
|
||||||
|
if self.use_reid:
|
||||||
|
det_result['frame_id'] = frame_id
|
||||||
|
det_result['seq_name'] = seq_name
|
||||||
|
det_result['ori_image'] = frame
|
||||||
|
det_result = self.reidprocess(det_result)
|
||||||
|
tracking_outs = self.tracking(det_result)
|
||||||
|
self.det_times.tracking_time_s.end()
|
||||||
|
self.det_times.img_num += 1
|
||||||
|
|
||||||
|
online_tlwhs = tracking_outs['online_tlwhs']
|
||||||
|
online_scores = tracking_outs['online_scores']
|
||||||
|
online_ids = tracking_outs['online_ids']
|
||||||
|
|
||||||
|
mot_results.append([online_tlwhs, online_scores, online_ids])
|
||||||
|
|
||||||
|
if visual:
|
||||||
|
if len(image_list) > 1 and frame_id % 10 == 0:
|
||||||
|
print('Tracking frame {}'.format(frame_id))
|
||||||
|
frame, _ = decode_image(img_file, {})
|
||||||
|
if isinstance(online_tlwhs, defaultdict):
|
||||||
|
im = plot_tracking_dict(
|
||||||
|
frame,
|
||||||
|
num_classes,
|
||||||
|
online_tlwhs,
|
||||||
|
online_ids,
|
||||||
|
online_scores,
|
||||||
|
frame_id=frame_id,
|
||||||
|
ids2names=ids2names)
|
||||||
|
else:
|
||||||
|
im = plot_tracking(
|
||||||
|
frame,
|
||||||
|
online_tlwhs,
|
||||||
|
online_ids,
|
||||||
|
online_scores,
|
||||||
|
frame_id=frame_id,
|
||||||
|
ids2names=ids2names)
|
||||||
|
save_dir = os.path.join(self.output_dir, seq_name)
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
cv2.imwrite(
|
||||||
|
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
|
||||||
|
|
||||||
|
return mot_results
|
||||||
|
|
||||||
|
def predict_video(self, video_file, camera_id):
|
||||||
|
video_out_name = 'output.mp4'
|
||||||
|
if camera_id != -1:
|
||||||
|
capture = cv2.VideoCapture(camera_id)
|
||||||
|
else:
|
||||||
|
capture = cv2.VideoCapture(video_file)
|
||||||
|
video_out_name = os.path.split(video_file)[-1]
|
||||||
|
# Get Video info : resolution, fps, frame count
|
||||||
|
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
fps = int(capture.get(cv2.CAP_PROP_FPS))
|
||||||
|
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
print("fps: %d, frame_count: %d" % (fps, frame_count))
|
||||||
|
|
||||||
|
if not os.path.exists(self.output_dir):
|
||||||
|
os.makedirs(self.output_dir)
|
||||||
|
out_path = os.path.join(self.output_dir, video_out_name)
|
||||||
|
video_format = 'mp4v'
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*video_format)
|
||||||
|
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
|
||||||
|
|
||||||
|
frame_id = 1
|
||||||
|
timer = MOTTimer()
|
||||||
|
results = defaultdict(list)
|
||||||
|
num_classes = self.num_classes
|
||||||
|
data_type = 'mcmot' if num_classes > 1 else 'mot'
|
||||||
|
ids2names = self.pred_config.labels
|
||||||
|
|
||||||
|
while (1):
|
||||||
|
ret, frame = capture.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
if frame_id % 10 == 0:
|
||||||
|
print('Tracking frame: %d' % (frame_id))
|
||||||
|
frame_id += 1
|
||||||
|
|
||||||
|
timer.tic()
|
||||||
|
seq_name = video_out_name.split('.')[0]
|
||||||
|
mot_results = self.predict_image(
|
||||||
|
[frame[:, :, ::-1]], visual=False, seq_name=seq_name)
|
||||||
|
timer.toc()
|
||||||
|
|
||||||
|
# bs=1 in MOT model
|
||||||
|
online_tlwhs, online_scores, online_ids = mot_results[0]
|
||||||
|
|
||||||
|
fps = 1. / timer.duration
|
||||||
|
if self.use_deepsort_tracker:
|
||||||
|
# use DeepSORTTracker, only support singe class
|
||||||
|
results[0].append(
|
||||||
|
(frame_id + 1, online_tlwhs, online_scores, online_ids))
|
||||||
|
im = plot_tracking(
|
||||||
|
frame,
|
||||||
|
online_tlwhs,
|
||||||
|
online_ids,
|
||||||
|
online_scores,
|
||||||
|
frame_id=frame_id,
|
||||||
|
fps=fps,
|
||||||
|
ids2names=ids2names)
|
||||||
|
else:
|
||||||
|
# use ByteTracker, support multiple class
|
||||||
|
for cls_id in range(num_classes):
|
||||||
|
results[cls_id].append(
|
||||||
|
(frame_id + 1, online_tlwhs[cls_id],
|
||||||
|
online_scores[cls_id], online_ids[cls_id]))
|
||||||
|
im = plot_tracking_dict(
|
||||||
|
frame,
|
||||||
|
num_classes,
|
||||||
|
online_tlwhs,
|
||||||
|
online_ids,
|
||||||
|
online_scores,
|
||||||
|
frame_id=frame_id,
|
||||||
|
fps=fps,
|
||||||
|
ids2names=ids2names)
|
||||||
|
|
||||||
|
writer.write(im)
|
||||||
|
if camera_id != -1:
|
||||||
|
cv2.imshow('Mask Detection', im)
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.save_mot_txts:
|
||||||
|
result_filename = os.path.join(
|
||||||
|
self.output_dir, video_out_name.split('.')[-2] + '.txt')
|
||||||
|
write_mot_results(result_filename, results)
|
||||||
|
|
||||||
|
writer.release()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
|
||||||
|
with open(deploy_file) as f:
|
||||||
|
yml_conf = yaml.safe_load(f)
|
||||||
|
arch = yml_conf['arch']
|
||||||
|
detector = SDE_Detector(
|
||||||
|
FLAGS.model_dir,
|
||||||
|
tracker_config=FLAGS.tracker_config,
|
||||||
|
device=FLAGS.device,
|
||||||
|
run_mode=FLAGS.run_mode,
|
||||||
|
batch_size=1,
|
||||||
|
trt_min_shape=FLAGS.trt_min_shape,
|
||||||
|
trt_max_shape=FLAGS.trt_max_shape,
|
||||||
|
trt_opt_shape=FLAGS.trt_opt_shape,
|
||||||
|
trt_calib_mode=FLAGS.trt_calib_mode,
|
||||||
|
cpu_threads=FLAGS.cpu_threads,
|
||||||
|
enable_mkldnn=FLAGS.enable_mkldnn,
|
||||||
|
output_dir=FLAGS.output_dir,
|
||||||
|
threshold=FLAGS.threshold,
|
||||||
|
save_images=FLAGS.save_images,
|
||||||
|
save_mot_txts=FLAGS.save_mot_txts, )
|
||||||
|
|
||||||
|
# predict from video file or camera video stream
|
||||||
|
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
|
||||||
|
detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
|
||||||
|
else:
|
||||||
|
# predict from image
|
||||||
|
if FLAGS.image_dir is None and FLAGS.image_file is not None:
|
||||||
|
assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
|
||||||
|
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
||||||
|
seq_name = FLAGS.image_dir.split('/')[-1]
|
||||||
|
detector.predict_image(
|
||||||
|
img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
|
||||||
|
|
||||||
|
if not FLAGS.run_benchmark:
|
||||||
|
detector.det_times.info(average=True)
|
||||||
|
else:
|
||||||
|
mode = FLAGS.run_mode
|
||||||
|
model_dir = FLAGS.model_dir
|
||||||
|
model_info = {
|
||||||
|
'model_name': model_dir.strip('/').split('/')[-1],
|
||||||
|
'precision': mode.split('_')[-1]
|
||||||
|
}
|
||||||
|
bench_log(detector, img_list, model_info, name='MOT')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
paddle.enable_static()
|
||||||
|
parser = argsparser()
|
||||||
|
FLAGS = parser.parse_args()
|
||||||
|
print_arguments(FLAGS)
|
||||||
|
FLAGS.device = FLAGS.device.upper()
|
||||||
|
assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU'
|
||||||
|
], "device should be CPU, GPU, NPU or XPU"
|
||||||
|
|
||||||
|
main()
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import copy
|
import copy
|
||||||
|
|
@ -24,9 +25,9 @@ from timeit import default_timer as timer
|
||||||
|
|
||||||
from rag.llm import EmbeddingModel, CvModel
|
from rag.llm import EmbeddingModel, CvModel
|
||||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||||
from rag.utils import ELASTICSEARCH, num_tokens_from_string
|
from rag.utils import ELASTICSEARCH
|
||||||
from rag.utils import MINIO
|
from rag.utils import MINIO
|
||||||
from rag.utils import rmSpace, findMaxDt
|
from rag.utils import rmSpace, findMaxTm
|
||||||
from rag.nlp import huchunk, huqie, search
|
from rag.nlp import huchunk, huqie, search
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
@ -47,6 +48,7 @@ from rag.nlp.huchunk import (
|
||||||
from web_server.db import LLMType
|
from web_server.db import LLMType
|
||||||
from web_server.db.services.document_service import DocumentService
|
from web_server.db.services.document_service import DocumentService
|
||||||
from web_server.db.services.llm_service import TenantLLMService
|
from web_server.db.services.llm_service import TenantLLMService
|
||||||
|
from web_server.settings import database_logger
|
||||||
from web_server.utils import get_format_time
|
from web_server.utils import get_format_time
|
||||||
from web_server.utils.file_utils import get_project_base_directory
|
from web_server.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
@ -83,7 +85,7 @@ def collect(comm, mod, tm):
|
||||||
if len(docs) == 0:
|
if len(docs) == 0:
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
docs = pd.DataFrame(docs)
|
docs = pd.DataFrame(docs)
|
||||||
mtm = str(docs["update_time"].max())[:19]
|
mtm = docs["update_time"].max()
|
||||||
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
@ -99,28 +101,30 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
|
||||||
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
||||||
|
|
||||||
|
|
||||||
def build(row):
|
def build(row, cvmdl):
|
||||||
if row["size"] > DOC_MAXIMUM_SIZE:
|
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||||
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
||||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||||
return []
|
return []
|
||||||
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
# If just change the kb for doc
|
||||||
if ELASTICSEARCH.getTotal(res) > 0:
|
# res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]), idxnm=search.index_name(row["tenant_id"]))
|
||||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
# if ELASTICSEARCH.getTotal(res) > 0:
|
||||||
scripts="""
|
# ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
||||||
if(!ctx._source.kb_id.contains('%s'))
|
# scripts="""
|
||||||
ctx._source.kb_id.add('%s');
|
# if(!ctx._source.kb_id.contains('%s'))
|
||||||
""" % (str(row["kb_id"]), str(row["kb_id"])),
|
# ctx._source.kb_id.add('%s');
|
||||||
idxnm=search.index_name(row["tenant_id"])
|
# """ % (str(row["kb_id"]), str(row["kb_id"])),
|
||||||
)
|
# idxnm=search.index_name(row["tenant_id"])
|
||||||
set_progress(row["id"], 1, "Done")
|
# )
|
||||||
return []
|
# set_progress(row["id"], 1, "Done")
|
||||||
|
# return []
|
||||||
|
|
||||||
random.seed(time.time())
|
random.seed(time.time())
|
||||||
set_progress(row["id"], random.randint(0, 20) /
|
set_progress(row["id"], random.randint(0, 20) /
|
||||||
100., "Finished preparing! Start to slice file!", True)
|
100., "Finished preparing! Start to slice file!", True)
|
||||||
try:
|
try:
|
||||||
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
|
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
||||||
|
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if re.search("(No such file|not found)", str(e)):
|
if re.search("(No such file|not found)", str(e)):
|
||||||
set_progress(
|
set_progress(
|
||||||
|
|
@ -131,6 +135,7 @@ def build(row):
|
||||||
row["id"], -1, f"Internal server error: %s" %
|
row["id"], -1, f"Internal server error: %s" %
|
||||||
str(e).replace(
|
str(e).replace(
|
||||||
"'", ""))
|
"'", ""))
|
||||||
|
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if not obj.text_chunks and not obj.table_chunks:
|
if not obj.text_chunks and not obj.table_chunks:
|
||||||
|
|
@ -144,7 +149,7 @@ def build(row):
|
||||||
"Finished slicing files. Start to embedding the content.")
|
"Finished slicing files. Start to embedding the content.")
|
||||||
|
|
||||||
doc = {
|
doc = {
|
||||||
"doc_id": row["did"],
|
"doc_id": row["id"],
|
||||||
"kb_id": [str(row["kb_id"])],
|
"kb_id": [str(row["kb_id"])],
|
||||||
"docnm_kwd": os.path.split(row["location"])[-1],
|
"docnm_kwd": os.path.split(row["location"])[-1],
|
||||||
"title_tks": huqie.qie(row["name"]),
|
"title_tks": huqie.qie(row["name"]),
|
||||||
|
|
@ -164,10 +169,10 @@ def build(row):
|
||||||
docs.append(d)
|
docs.append(d)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(img, Image):
|
if isinstance(img, bytes):
|
||||||
img.save(output_buffer, format='JPEG')
|
|
||||||
else:
|
|
||||||
output_buffer = BytesIO(img)
|
output_buffer = BytesIO(img)
|
||||||
|
else:
|
||||||
|
img.save(output_buffer, format='JPEG')
|
||||||
|
|
||||||
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||||
|
|
@ -215,15 +220,16 @@ def embedding(docs, mdl):
|
||||||
|
|
||||||
|
|
||||||
def model_instance(tenant_id, llm_type):
|
def model_instance(tenant_id, llm_type):
|
||||||
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
|
model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
|
||||||
if not model_config:return
|
if not model_config:
|
||||||
model_config = model_config[0]
|
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
|
||||||
|
else: model_config = model_config[0].to_dict()
|
||||||
if llm_type == LLMType.EMBEDDING:
|
if llm_type == LLMType.EMBEDDING:
|
||||||
if model_config.llm_factory not in EmbeddingModel: return
|
if model_config["llm_factory"] not in EmbeddingModel: return
|
||||||
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
||||||
if llm_type == LLMType.IMAGE2TEXT:
|
if llm_type == LLMType.IMAGE2TEXT:
|
||||||
if model_config.llm_factory not in CvModel: return
|
if model_config["llm_factory"] not in CvModel: return
|
||||||
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
|
||||||
|
|
||||||
|
|
||||||
def main(comm, mod):
|
def main(comm, mod):
|
||||||
|
|
@ -231,7 +237,7 @@ def main(comm, mod):
|
||||||
from rag.llm import HuEmbedding
|
from rag.llm import HuEmbedding
|
||||||
model = HuEmbedding()
|
model = HuEmbedding()
|
||||||
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
||||||
tm = findMaxDt(tm_fnm)
|
tm = findMaxTm(tm_fnm)
|
||||||
rows = collect(comm, mod, tm)
|
rows = collect(comm, mod, tm)
|
||||||
if len(rows) == 0:
|
if len(rows) == 0:
|
||||||
return
|
return
|
||||||
|
|
@ -247,7 +253,7 @@ def main(comm, mod):
|
||||||
st_tm = timer()
|
st_tm = timer()
|
||||||
cks = build(r, cv_mdl)
|
cks = build(r, cv_mdl)
|
||||||
if not cks:
|
if not cks:
|
||||||
tmf.write(str(r["updated_at"]) + "\n")
|
tmf.write(str(r["update_time"]) + "\n")
|
||||||
continue
|
continue
|
||||||
# TODO: exception handler
|
# TODO: exception handler
|
||||||
## set_progress(r["did"], -1, "ERROR: ")
|
## set_progress(r["did"], -1, "ERROR: ")
|
||||||
|
|
@ -268,12 +274,19 @@ def main(comm, mod):
|
||||||
cron_logger.error(str(es_r))
|
cron_logger.error(str(es_r))
|
||||||
else:
|
else:
|
||||||
set_progress(r["id"], 1., "Done!")
|
set_progress(r["id"], 1., "Done!")
|
||||||
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
|
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
|
||||||
|
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
|
||||||
|
|
||||||
tmf.write(str(r["update_time"]) + "\n")
|
tmf.write(str(r["update_time"]) + "\n")
|
||||||
tmf.close()
|
tmf.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
peewee_logger = logging.getLogger('peewee')
|
||||||
|
peewee_logger.propagate = False
|
||||||
|
peewee_logger.addHandler(database_logger.handlers[0])
|
||||||
|
peewee_logger.setLevel(database_logger.level)
|
||||||
|
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
comm = MPI.COMM_WORLD
|
comm = MPI.COMM_WORLD
|
||||||
main(comm.Get_size(), comm.Get_rank())
|
main(comm.Get_size(), comm.Get_rank())
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,24 @@ def findMaxDt(fnm):
|
||||||
print("WARNING: can't find " + fnm)
|
print("WARNING: can't find " + fnm)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def findMaxTm(fnm):
|
||||||
|
m = 0
|
||||||
|
try:
|
||||||
|
with open(fnm, "r") as f:
|
||||||
|
while True:
|
||||||
|
l = f.readline()
|
||||||
|
if not l:
|
||||||
|
break
|
||||||
|
l = l.strip("\n")
|
||||||
|
if l == 'nan':
|
||||||
|
continue
|
||||||
|
if int(l) > m:
|
||||||
|
m = int(l)
|
||||||
|
except Exception as e:
|
||||||
|
print("WARNING: can't find " + fnm)
|
||||||
|
return m
|
||||||
|
|
||||||
def num_tokens_from_string(string: str) -> int:
|
def num_tokens_from_string(string: str) -> int:
|
||||||
"""Returns the number of tokens in a text string."""
|
"""Returns the number of tokens in a text string."""
|
||||||
encoding = tiktoken.get_encoding('cl100k_base')
|
encoding = tiktoken.get_encoding('cl100k_base')
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,7 @@ class HuEs:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
es_logger.error("ES updateByQuery deleteByQuery: " +
|
es_logger.error("ES updateByQuery deleteByQuery: " +
|
||||||
str(e) + "【Q】:" + str(query.to_dict()))
|
str(e) + "【Q】:" + str(query.to_dict()))
|
||||||
|
if str(e).find("NotFoundError") > 0: return True
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import base64
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from elasticsearch_dsl import Q
|
from elasticsearch_dsl import Q
|
||||||
|
|
@ -195,11 +196,15 @@ def rm():
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(retmsg="Document not found!")
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
|
||||||
|
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
|
||||||
if not DocumentService.delete_by_id(req["doc_id"]):
|
if not DocumentService.delete_by_id(req["doc_id"]):
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
retmsg="Database error (Document removal)!")
|
retmsg="Database error (Document removal)!")
|
||||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
|
||||||
MINIO.rm(kb.id, doc.location)
|
MINIO.rm(doc.kb_id, doc.location)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
@ -233,3 +238,42 @@ def rename():
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/get', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def get():
|
||||||
|
doc_id = request.args["doc_id"]
|
||||||
|
try:
|
||||||
|
e, doc = DocumentService.get_by_id(doc_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
|
||||||
|
blob = MINIO.get(doc.kb_id, doc.location)
|
||||||
|
return get_json_result(data={"base64": base64.b64decode(blob)})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/change_parser', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("doc_id", "parser_id")
|
||||||
|
def change_parser():
|
||||||
|
req = request.json
|
||||||
|
try:
|
||||||
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
if doc.parser_id.lower() == req["parser_id"].lower():
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result
|
||||||
|
|
||||||
@manager.route('/create', methods=['post'])
|
@manager.route('/create', methods=['post'])
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("name", "description", "permission", "embd_id", "parser_id")
|
@validate_request("name", "description", "permission", "parser_id")
|
||||||
def create():
|
def create():
|
||||||
req = request.json
|
req = request.json
|
||||||
req["name"] = req["name"].strip()
|
req["name"] = req["name"].strip()
|
||||||
|
|
@ -46,7 +46,7 @@ def create():
|
||||||
|
|
||||||
@manager.route('/update', methods=['post'])
|
@manager.route('/update', methods=['post'])
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id")
|
@validate_request("kb_id", "name", "description", "permission", "parser_id")
|
||||||
def update():
|
def update():
|
||||||
req = request.json
|
req = request.json
|
||||||
req["name"] = req["name"].strip()
|
req["name"] = req["name"].strip()
|
||||||
|
|
@ -72,6 +72,18 @@ def update():
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/detail', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def detail():
|
||||||
|
kb_id = request.args["kb_id"]
|
||||||
|
try:
|
||||||
|
kb = KnowledgebaseService.get_detail(kb_id)
|
||||||
|
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||||
|
return get_json_result(data=kb)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['GET'])
|
@manager.route('/list', methods=['GET'])
|
||||||
@login_required
|
@login_required
|
||||||
def list():
|
def list():
|
||||||
|
|
|
||||||
95
web_server/apps/llm_app.py
Normal file
95
web_server/apps/llm_app.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from flask import request
|
||||||
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
|
from web_server.db.services import duplicate_name
|
||||||
|
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||||
|
from web_server.db.services.user_service import TenantService, UserTenantService
|
||||||
|
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
|
from web_server.utils import get_uuid, get_format_time
|
||||||
|
from web_server.db import StatusEnum, UserTenantRole
|
||||||
|
from web_server.db.services.kb_service import KnowledgebaseService
|
||||||
|
from web_server.db.db_models import Knowledgebase, TenantLLM
|
||||||
|
from web_server.settings import stat_logger, RetCode
|
||||||
|
from web_server.utils.api_utils import get_json_result
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/factories', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def factories():
|
||||||
|
try:
|
||||||
|
fac = LLMFactoriesService.get_all()
|
||||||
|
return get_json_result(data=fac.to_json())
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/set_api_key', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("llm_factory", "api_key")
|
||||||
|
def set_api_key():
|
||||||
|
req = request.json
|
||||||
|
llm = {
|
||||||
|
"tenant_id": current_user.id,
|
||||||
|
"llm_factory": req["llm_factory"],
|
||||||
|
"api_key": req["api_key"]
|
||||||
|
}
|
||||||
|
# TODO: Test api_key
|
||||||
|
for n in ["model_type", "llm_name"]:
|
||||||
|
if n in req: llm[n] = req[n]
|
||||||
|
|
||||||
|
TenantLLM.insert(**llm).on_conflict("replace").execute()
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/my_llms', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def my_llms():
|
||||||
|
try:
|
||||||
|
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||||
|
objs = [o.to_dict() for o in objs]
|
||||||
|
for o in objs: del o["api_key"]
|
||||||
|
return get_json_result(data=objs)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/list', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def list():
|
||||||
|
try:
|
||||||
|
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||||
|
objs = [o.to_dict() for o in objs if o.api_key]
|
||||||
|
fct = {}
|
||||||
|
for o in objs:
|
||||||
|
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
|
||||||
|
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
|
||||||
|
|
||||||
|
llms = LLMService.get_all()
|
||||||
|
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
||||||
|
for m in llms:
|
||||||
|
m["available"] = False
|
||||||
|
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
|
||||||
|
m["available"] = True
|
||||||
|
res = {}
|
||||||
|
for m in llms:
|
||||||
|
if m["fid"] not in res: res[m["fid"]] = []
|
||||||
|
res[m["fid"]].append(m)
|
||||||
|
|
||||||
|
return get_json_result(data=res)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
@ -16,9 +16,12 @@
|
||||||
from flask import request, session, redirect, url_for
|
from flask import request, session, redirect, url_for
|
||||||
from werkzeug.security import generate_password_hash, check_password_hash
|
from werkzeug.security import generate_password_hash, check_password_hash
|
||||||
from flask_login import login_required, current_user, login_user, logout_user
|
from flask_login import login_required, current_user, login_user, logout_user
|
||||||
|
|
||||||
|
from web_server.db.db_models import TenantLLM
|
||||||
|
from web_server.db.services.llm_service import TenantLLMService
|
||||||
from web_server.utils.api_utils import server_error_response, validate_request
|
from web_server.utils.api_utils import server_error_response, validate_request
|
||||||
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
||||||
from web_server.db import UserTenantRole
|
from web_server.db import UserTenantRole, LLMType
|
||||||
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
||||||
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
||||||
from web_server.settings import stat_logger
|
from web_server.settings import stat_logger
|
||||||
|
|
@ -47,8 +50,9 @@ def login():
|
||||||
avatar = download_img(userinfo["avatar_url"])
|
avatar = download_img(userinfo["avatar_url"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
stat_logger.exception(e)
|
stat_logger.exception(e)
|
||||||
|
user_id = get_uuid()
|
||||||
try:
|
try:
|
||||||
users = user_register({
|
users = user_register(user_id, {
|
||||||
"access_token": session["access_token"],
|
"access_token": session["access_token"],
|
||||||
"email": userinfo["email"],
|
"email": userinfo["email"],
|
||||||
"avatar": avatar,
|
"avatar": avatar,
|
||||||
|
|
@ -63,6 +67,7 @@ def login():
|
||||||
login_user(user)
|
login_user(user)
|
||||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
rollback_user_registration(user_id)
|
||||||
stat_logger.exception(e)
|
stat_logger.exception(e)
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
elif not request.json:
|
elif not request.json:
|
||||||
|
|
@ -162,7 +167,23 @@ def user_info():
|
||||||
return get_json_result(data=current_user.to_dict())
|
return get_json_result(data=current_user.to_dict())
|
||||||
|
|
||||||
|
|
||||||
def user_register(user):
|
def rollback_user_registration(user_id):
|
||||||
|
try:
|
||||||
|
TenantService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
u = UserTenantService.query(tenant_id=user_id)
|
||||||
|
if u:
|
||||||
|
UserTenantService.delete_by_id(u[0].id)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def user_register(user_id, user):
|
||||||
user_id = get_uuid()
|
user_id = get_uuid()
|
||||||
user["id"] = user_id
|
user["id"] = user_id
|
||||||
tenant = {
|
tenant = {
|
||||||
|
|
@ -180,10 +201,12 @@ def user_register(user):
|
||||||
"invited_by": user_id,
|
"invited_by": user_id,
|
||||||
"role": UserTenantRole.OWNER
|
"role": UserTenantRole.OWNER
|
||||||
}
|
}
|
||||||
|
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
|
||||||
|
|
||||||
if not UserService.save(**user):return
|
if not UserService.save(**user):return
|
||||||
TenantService.save(**tenant)
|
TenantService.save(**tenant)
|
||||||
UserTenantService.save(**usr_tenant)
|
UserTenantService.save(**usr_tenant)
|
||||||
|
TenantLLMService.save(**tenant_llm)
|
||||||
return UserService.query(email=user["email"])
|
return UserService.query(email=user["email"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -203,14 +226,17 @@ def user_add():
|
||||||
"last_login_time": get_format_time(),
|
"last_login_time": get_format_time(),
|
||||||
"is_superuser": False,
|
"is_superuser": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user_id = get_uuid()
|
||||||
try:
|
try:
|
||||||
users = user_register(user_dict)
|
users = user_register(user_id, user_dict)
|
||||||
if not users: raise Exception('Register user failure.')
|
if not users: raise Exception('Register user failure.')
|
||||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||||
user = users[0]
|
user = users[0]
|
||||||
login_user(user)
|
login_user(user)
|
||||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
rollback_user_registration(user_id)
|
||||||
stat_logger.exception(e)
|
stat_logger.exception(e)
|
||||||
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||||
|
|
||||||
|
|
@ -220,7 +246,7 @@ def user_add():
|
||||||
@login_required
|
@login_required
|
||||||
def tenant_info():
|
def tenant_info():
|
||||||
try:
|
try:
|
||||||
tenants = TenantService.get_by_user_id(current_user.id)
|
tenants = TenantService.get_by_user_id(current_user.id)[0]
|
||||||
return get_json_result(data=tenants)
|
return get_json_result(data=tenants)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
|
||||||
|
|
@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel):
|
||||||
class LLM(DataBaseModel):
|
class LLM(DataBaseModel):
|
||||||
# defautlt LLMs for every users
|
# defautlt LLMs for every users
|
||||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
||||||
|
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||||
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
||||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
@ -442,8 +443,8 @@ class LLM(DataBaseModel):
|
||||||
class TenantLLM(DataBaseModel):
|
class TenantLLM(DataBaseModel):
|
||||||
tenant_id = CharField(max_length=32, null=False)
|
tenant_id = CharField(max_length=32, null=False)
|
||||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
||||||
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name")
|
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
|
||||||
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
||||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||||
|
|
||||||
|
|
@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel):
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "tenant_llm"
|
db_table = "tenant_llm"
|
||||||
primary_key = CompositeKey('tenant_id', 'llm_factory')
|
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
|
||||||
|
|
||||||
|
|
||||||
class Knowledgebase(DataBaseModel):
|
class Knowledgebase(DataBaseModel):
|
||||||
|
|
@ -464,7 +465,8 @@ class Knowledgebase(DataBaseModel):
|
||||||
permission = CharField(max_length=16, null=False, help_text="me|team")
|
permission = CharField(max_length=16, null=False, help_text="me|team")
|
||||||
created_by = CharField(max_length=32, null=False)
|
created_by = CharField(max_length=32, null=False)
|
||||||
doc_num = IntegerField(default=0)
|
doc_num = IntegerField(default=0)
|
||||||
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID")
|
token_num = IntegerField(default=0)
|
||||||
|
chunk_num = IntegerField(default=0)
|
||||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
from peewee import Expression
|
||||||
|
|
||||||
from web_server.db import TenantPermission, FileType
|
from web_server.db import TenantPermission, FileType
|
||||||
from web_server.db.db_models import DB, Knowledgebase
|
from web_server.db.db_models import DB, Knowledgebase, Tenant
|
||||||
from web_server.db.db_models import Document
|
from web_server.db.db_models import Document
|
||||||
from web_server.db.services.common_service import CommonService
|
from web_server.db.services.common_service import CommonService
|
||||||
from web_server.db.services.kb_service import KnowledgebaseService
|
from web_server.db.services.kb_service import KnowledgebaseService
|
||||||
from web_server.utils import get_uuid, get_format_time
|
|
||||||
from web_server.db.db_utils import StatusEnum
|
from web_server.db.db_utils import StatusEnum
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,15 +62,27 @@ class DocumentService(CommonService):
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
||||||
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id]
|
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
|
||||||
docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where(
|
docs = cls.model.select(*fields) \
|
||||||
cls.model.status == StatusEnum.VALID.value,
|
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
||||||
cls.model.type != FileType.VIRTUAL,
|
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
||||||
cls.model.progress == 0,
|
.where(
|
||||||
cls.model.update_time >= tm,
|
cls.model.status == StatusEnum.VALID.value,
|
||||||
cls.model.create_time %
|
~(cls.model.type == FileType.VIRTUAL.value),
|
||||||
comm == mod).order_by(
|
cls.model.progress == 0,
|
||||||
cls.model.update_time.asc()).paginate(
|
cls.model.update_time >= tm,
|
||||||
1,
|
(Expression(cls.model.create_time, "%%", comm) == mod))\
|
||||||
items_per_page)
|
.order_by(cls.model.update_time.asc())\
|
||||||
|
.paginate(1, items_per_page)
|
||||||
return list(docs.dicts())
|
return list(docs.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
||||||
|
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
||||||
|
chunk_num=cls.model.chunk_num + chunk_num,
|
||||||
|
process_duation=cls.model.process_duation+duation).where(
|
||||||
|
cls.model.id == doc_id).execute()
|
||||||
|
if num == 0:raise LookupError("Document not found which is supposed to be there")
|
||||||
|
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
|
||||||
|
return num
|
||||||
|
|
@ -17,7 +17,7 @@ import peewee
|
||||||
from werkzeug.security import generate_password_hash, check_password_hash
|
from werkzeug.security import generate_password_hash, check_password_hash
|
||||||
|
|
||||||
from web_server.db import TenantPermission
|
from web_server.db import TenantPermission
|
||||||
from web_server.db.db_models import DB, UserTenant
|
from web_server.db.db_models import DB, UserTenant, Tenant
|
||||||
from web_server.db.db_models import Knowledgebase
|
from web_server.db.db_models import Knowledgebase
|
||||||
from web_server.db.services.common_service import CommonService
|
from web_server.db.services.common_service import CommonService
|
||||||
from web_server.utils import get_uuid, get_format_time
|
from web_server.utils import get_uuid, get_format_time
|
||||||
|
|
@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc):
|
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||||
|
page_number, items_per_page, orderby, desc):
|
||||||
kbs = cls.model.select().where(
|
kbs = cls.model.select().where(
|
||||||
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
|
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
||||||
& (cls.model.status==StatusEnum.VALID.value)
|
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
|
||||||
|
& (cls.model.status == StatusEnum.VALID.value)
|
||||||
)
|
)
|
||||||
if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
if desc:
|
||||||
else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
||||||
|
else:
|
||||||
|
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
||||||
|
|
||||||
kbs = kbs.paginate(page_number, items_per_page)
|
kbs = kbs.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
return list(kbs.dicts())
|
return list(kbs.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_detail(cls, kb_id):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
Tenant.embd_id,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.description,
|
||||||
|
cls.model.permission,
|
||||||
|
cls.model.doc_num,
|
||||||
|
cls.model.token_num,
|
||||||
|
cls.model.chunk_num,
|
||||||
|
cls.model.parser_id]
|
||||||
|
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
|
||||||
|
(cls.model.id == kb_id),
|
||||||
|
(cls.model.status == StatusEnum.VALID.value)
|
||||||
|
)
|
||||||
|
if not kbs:
|
||||||
|
return
|
||||||
|
d = kbs[0].to_dict()
|
||||||
|
d["embd_id"] = kbs[0].tenant.embd_id
|
||||||
|
return d
|
||||||
|
|
|
||||||
|
|
@ -31,5 +31,23 @@ class LLMService(CommonService):
|
||||||
model = LLM
|
model = LLM
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TenantLLMService(CommonService):
|
class TenantLLMService(CommonService):
|
||||||
model = TenantLLM
|
model = TenantLLM
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_api_key(cls, tenant_id, model_type):
|
||||||
|
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
|
||||||
|
if objs and len(objs)>0 and objs[0].llm_name:
|
||||||
|
return objs[0]
|
||||||
|
|
||||||
|
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
|
||||||
|
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
|
||||||
|
(cls.model.tenant_id == tenant_id),
|
||||||
|
(cls.model.model_type == model_type),
|
||||||
|
(LLM.status == StatusEnum.VALID)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not objs:return
|
||||||
|
return objs[0]
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ class TenantService(CommonService):
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_user_id(cls, user_id):
|
def get_by_user_id(cls, user_id):
|
||||||
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
|
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
|
||||||
return list(cls.model.select(*fields)\
|
return list(cls.model.select(*fields)\
|
||||||
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
||||||
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,7 @@ def filename_type(filename):
|
||||||
if re.match(r".*\.pdf$", filename):
|
if re.match(r".*\.pdf$", filename):
|
||||||
return FileType.PDF.value
|
return FileType.PDF.value
|
||||||
|
|
||||||
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename):
|
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
||||||
return FileType.DOC.value
|
return FileType.DOC.value
|
||||||
|
|
||||||
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue