Merge branch 'main' into main

This commit is contained in:
KevinHuSh 2024-01-15 19:47:15 +08:00 committed by GitHub
commit f5fa08540e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 37 additions and 18 deletions

View file

@ -158,6 +158,7 @@ class PaddleInferBenchmark(object):
return: return:
config_status(dict): dict style config info config_status(dict): dict style config info
""" """
if isinstance(config, paddle_infer.Config): if isinstance(config, paddle_infer.Config):
config_status = {} config_status = {}
config_status['runtime_device'] = "gpu" if config.use_gpu( config_status['runtime_device'] = "gpu" if config.use_gpu(

View file

@ -15,7 +15,6 @@
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from scipy.special import softmax
from scipy.interpolate import InterpolatedUnivariateSpline from scipy.interpolate import InterpolatedUnivariateSpline

View file

@ -26,6 +26,7 @@ from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get
from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
from visualize import visualize_pose from visualize import visualize_pose
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images from keypoint_postprocess import translate_to_ori_images

View file

@ -19,6 +19,7 @@ import json
from pathlib import Path from pathlib import Path
from functools import reduce from functools import reduce
import cv2 import cv2
import numpy as np import numpy as np
import math import math
@ -28,6 +29,7 @@ from paddle.inference import create_predictor
import sys import sys
# add deploy path of PaddleDetection to sys.path # add deploy path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..']))) parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)

View file

@ -19,6 +19,7 @@ import glob
from functools import reduce from functools import reduce
from PIL import Image from PIL import Image
import cv2 import cv2
import math import math
import numpy as np import numpy as np
@ -35,6 +36,7 @@ from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
from visualize import visualize_pose from visualize import visualize_pose
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from infer import Detector, get_test_images, print_arguments from infer import Detector, get_test_images, print_arguments

View file

@ -371,6 +371,7 @@ class CenterTrack(Detector):
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
ids2names=ids2names) ids2names=ids2names)
if seq_name is None: if seq_name is None:
seq_name = image_list[0].split('/')[-2] seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name) save_dir = os.path.join(self.output_dir, seq_name)
@ -442,12 +443,12 @@ class CenterTrack(Detector):
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
if self.save_mot_txts: if self.save_mot_txts:
result_filename = os.path.join( result_filename = os.path.join(
self.output_dir, video_out_name.split('.')[-2] + '.txt') self.output_dir, video_out_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results, data_type, num_classes) write_mot_results(result_filename, results, data_type, num_classes)
writer.release() writer.release()

View file

@ -114,6 +114,7 @@ class JDE_Detector(Detector):
tracked_thresh=tracked_thresh, tracked_thresh=tracked_thresh,
metric_type=metric_type) metric_type=metric_type)
def postprocess(self, inputs, result): def postprocess(self, inputs, result):
# postprocess output of predictor # postprocess output of predictor
np_boxes = result['pred_dets'] np_boxes = result['pred_dets']
@ -247,6 +248,7 @@ class JDE_Detector(Detector):
online_scores, online_scores,
frame_id=frame_id, frame_id=frame_id,
ids2names=ids2names) ids2names=ids2names)
if seq_name is None: if seq_name is None:
seq_name = image_list[0].split('/')[-2] seq_name = image_list[0].split('/')[-2]
save_dir = os.path.join(self.output_dir, seq_name) save_dir = os.path.join(self.output_dir, seq_name)
@ -255,6 +257,7 @@ class JDE_Detector(Detector):
cv2.imwrite( cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im) os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
mot_results.append([online_tlwhs, online_scores, online_ids]) mot_results.append([online_tlwhs, online_scores, online_ids])
return mot_results return mot_results

View file

@ -16,6 +16,7 @@ import os
import json import json
import cv2 import cv2
import math import math
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml

View file

@ -28,6 +28,7 @@ from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH from rag.utils import ELASTICSEARCH
from rag.utils import MINIO from rag.utils import MINIO
from rag.utils import rmSpace, findMaxTm from rag.utils import rmSpace, findMaxTm
from rag.nlp import huchunk, huqie, search from rag.nlp import huchunk, huqie, search
from io import BytesIO from io import BytesIO
import pandas as pd import pandas as pd
@ -106,18 +107,18 @@ def build(row, cvmdl):
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024))) (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return [] return []
# If just change the kb for doc
# res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]), idxnm=search.index_name(row["tenant_id"])) res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
# if ELASTICSEARCH.getTotal(res) > 0: if ELASTICSEARCH.getTotal(res) > 0:
# ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
# scripts=""" scripts="""
# if(!ctx._source.kb_id.contains('%s')) if(!ctx._source.kb_id.contains('%s'))
# ctx._source.kb_id.add('%s'); ctx._source.kb_id.add('%s');
# """ % (str(row["kb_id"]), str(row["kb_id"])), """ % (str(row["kb_id"]), str(row["kb_id"])),
# idxnm=search.index_name(row["tenant_id"]) idxnm=search.index_name(row["tenant_id"])
# ) )
# set_progress(row["id"], 1, "Done") set_progress(row["id"], 1, "Done")
# return [] return []
random.seed(time.time()) random.seed(time.time())
set_progress(row["id"], random.randint(0, 20) / set_progress(row["id"], random.randint(0, 20) /
@ -135,7 +136,9 @@ def build(row, cvmdl):
row["id"], -1, f"Internal server error: %s" % row["id"], -1, f"Internal server error: %s" %
str(e).replace( str(e).replace(
"'", "")) "'", ""))
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e))) cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return [] return []
if not obj.text_chunks and not obj.table_chunks: if not obj.text_chunks and not obj.table_chunks:

View file

@ -58,6 +58,7 @@ def findMaxTm(fnm):
print("WARNING: can't find " + fnm) print("WARNING: can't find " + fnm)
return m return m
def num_tokens_from_string(string: str) -> int: def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string.""" """Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base') encoding = tiktoken.get_encoding('cl100k_base')

View file

@ -277,3 +277,4 @@ def change_parser():
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View file

@ -183,7 +183,9 @@ def rollback_user_registration(user_id):
except Exception as e: except Exception as e:
pass pass
def user_register(user_id, user): def user_register(user_id, user):
user_id = get_uuid() user_id = get_uuid()
user["id"] = user_id user["id"] = user_id
tenant = { tenant = {

View file

@ -467,6 +467,7 @@ class Knowledgebase(DataBaseModel):
doc_num = IntegerField(default=0) doc_num = IntegerField(default=0)
token_num = IntegerField(default=0) token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0) chunk_num = IntegerField(default=0)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View file

@ -86,3 +86,4 @@ class DocumentService(CommonService):
if num == 0:raise LookupError("Document not found which is supposed to be there") if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num return num

View file

@ -31,7 +31,6 @@ class LLMService(CommonService):
model = LLM model = LLM
class TenantLLMService(CommonService): class TenantLLMService(CommonService):
model = TenantLLM model = TenantLLM
@ -51,3 +50,4 @@ class TenantLLMService(CommonService):
if not objs:return if not objs:return
return objs[0] return objs[0]