Merge branch 'main' into main

This commit is contained in:
KevinHuSh 2024-01-15 19:47:15 +08:00 committed by GitHub
commit f5fa08540e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 37 additions and 18 deletions

View file

@ -158,6 +158,7 @@ class PaddleInferBenchmark(object):
return:
config_status(dict): dict style config info
"""
if isinstance(config, paddle_infer.Config):
config_status = {}
config_status['runtime_device'] = "gpu" if config.use_gpu(

View file

@ -15,7 +15,6 @@
import numpy as np
import paddle
import paddle.nn as nn
from scipy.special import softmax
from scipy.interpolate import InterpolatedUnivariateSpline

View file

@ -26,6 +26,7 @@ from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get
from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images

View file

@ -19,6 +19,7 @@ import json
from pathlib import Path
from functools import reduce
import cv2
import numpy as np
import math
@ -28,6 +29,7 @@ from paddle.inference import create_predictor
import sys
# add deploy path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)

View file

@ -19,6 +19,7 @@ import glob
from functools import reduce
from PIL import Image
import cv2
import math
import numpy as np
@ -35,6 +36,7 @@ from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
from visualize import visualize_pose
from paddle.inference import Config
from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb
from benchmark_utils import PaddleInferBenchmark
from infer import Detector, get_test_images, print_arguments

View file

@ -371,6 +371,7 @@ class CenterTrack(Detector):
online_scores,
frame_id=frame_id,
ids2names=ids2names)
if seq_name is None:
seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name)
@ -442,12 +443,12 @@ class CenterTrack(Detector):
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if self.save_mot_txts:
result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes)
writer.release()

View file

@ -114,6 +114,7 @@ class JDE_Detector(Detector):
tracked_thresh=tracked_thresh,
metric_type=metric_type)
def postprocess(self, inputs, result):
# postprocess output of predictor
np_boxes = result['pred_dets']
@ -247,6 +248,7 @@ class JDE_Detector(Detector):
online_scores,
frame_id=frame_id,
ids2names=ids2names)
if seq_name is None:
seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name)
@ -255,6 +257,7 @@ class JDE_Detector(Detector):
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results

View file

@ -16,6 +16,7 @@ import os
import json
import cv2
import math
import numpy as np
import paddle
import yaml

View file

@ -28,6 +28,7 @@ from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH
from rag.utils import MINIO
from rag.utils import rmSpace, findMaxTm
from rag.nlp import huchunk, huqie, search
from io import BytesIO
import pandas as pd
@ -106,18 +107,18 @@ def build(row, cvmdl):
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
# If just change the kb for doc
# res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]), idxnm=search.index_name(row["tenant_id"]))
# if ELASTICSEARCH.getTotal(res) > 0:
# ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
# scripts="""
# if(!ctx._source.kb_id.contains('%s'))
# ctx._source.kb_id.add('%s');
# """ % (str(row["kb_id"]), str(row["kb_id"])),
# idxnm=search.index_name(row["tenant_id"])
# )
# set_progress(row["id"], 1, "Done")
# return []
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
if ELASTICSEARCH.getTotal(res) > 0:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
scripts="""
if(!ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.add('%s');
""" % (str(row["kb_id"]), str(row["kb_id"])),
idxnm=search.index_name(row["tenant_id"])
)
set_progress(row["id"], 1, "Done")
return []
random.seed(time.time())
set_progress(row["id"], random.randint(0, 20) /
@ -135,7 +136,9 @@ def build(row, cvmdl):
row["id"], -1, f"Internal server error: %s" %
str(e).replace(
"'", ""))
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return []
if not obj.text_chunks and not obj.table_chunks:

View file

@ -40,7 +40,7 @@ def findMaxDt(fnm):
print("WARNING: can't find " + fnm)
return m
def findMaxTm(fnm):
m = 0
try:
@ -58,6 +58,7 @@ def findMaxTm(fnm):
print("WARNING: can't find " + fnm)
return m
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')

View file

@ -276,4 +276,5 @@ def change_parser():
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
return server_error_response(e)

View file

@ -183,7 +183,9 @@ def rollback_user_registration(user_id):
except Exception as e:
pass
def user_register(user_id, user):
user_id = get_uuid()
user["id"] = user_id
tenant = {

View file

@ -467,6 +467,7 @@ class Knowledgebase(DataBaseModel):
doc_num = IntegerField(default=0)
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View file

@ -85,4 +85,5 @@ class DocumentService(CommonService):
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num
return num

View file

@ -31,7 +31,6 @@ class LLMService(CommonService):
model = LLM
class TenantLLMService(CommonService):
model = TenantLLM
@ -51,3 +50,4 @@ class TenantLLMService(CommonService):
if not objs:return
return objs[0]