This commit is contained in:
yongtenglei 2025-12-02 13:20:50 +08:00
parent d1916b0a03
commit ff16f4b3ab

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
# #
import asyncio
import json import json
import os.path import os.path
import pathlib import pathlib
@ -72,7 +73,7 @@ async def upload():
if not check_kb_team_permission(kb, current_user.id): if not check_kb_team_permission(kb, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
err, files = FileService.upload_document(kb, file_objs, current_user.id) err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id)
if err: if err:
return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR) return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR)
@ -390,7 +391,7 @@ async def rm():
if not DocumentService.accessible4deletion(doc_id, current_user.id): if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
errors = FileService.delete_docs(doc_ids, current_user.id) errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id)
if errors: if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
@ -403,44 +404,48 @@ async def rm():
@validate_request("doc_ids", "run") @validate_request("doc_ids", "run")
async def run(): async def run():
req = await get_request_json() req = await get_request_json()
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try: try:
kb_table_num_map = {} def _run_sync():
for id in req["doc_ids"]: for doc_id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0} if not DocumentService.accessible(doc_id, current_user.id):
if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
info["progress_msg"] = ""
info["chunk_num"] = 0
info["token_num"] = 0
tenant_id = DocumentService.get_tenant_id(id) kb_table_num_map = {}
if not tenant_id: for id in req["doc_ids"]:
return get_data_error_result(message="Tenant not found!") info = {"run": str(req["run"]), "progress": 0}
e, doc = DocumentService.get_by_id(id) if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
if not e: info["progress_msg"] = ""
return get_data_error_result(message="Document not found!") info["chunk_num"] = 0
info["token_num"] = 0
if str(req["run"]) == TaskStatus.CANCEL.value: tenant_id = DocumentService.get_tenant_id(id)
if str(doc.run) == TaskStatus.RUNNING.value: if not tenant_id:
cancel_all_task_of(id) return get_data_error_result(message="Tenant not found!")
else: e, doc = DocumentService.get_by_id(id)
return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status") if not e:
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]): return get_data_error_result(message="Document not found!")
DocumentService.clear_chunk_num_when_rerun(doc.id)
DocumentService.update_by_id(id, info) if str(req["run"]) == TaskStatus.CANCEL.value:
if req.get("delete", False): if str(doc.run) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id]) cancel_all_task_of(id)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): else:
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
DocumentService.clear_chunk_num_when_rerun(doc.id)
if str(req["run"]) == TaskStatus.RUNNING.value: DocumentService.update_by_id(id, info)
doc = doc.to_dict() if req.get("delete", False):
DocumentService.run(tenant_id, doc, kb_table_num_map) TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) if str(req["run"]) == TaskStatus.RUNNING.value:
doc_dict = doc.to_dict()
DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
return get_json_result(data=True)
return await asyncio.to_thread(_run_sync)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -450,45 +455,49 @@ async def run():
@validate_request("doc_id", "name") @validate_request("doc_id", "name")
async def rename(): async def rename():
req = await get_request_json() req = await get_request_json()
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) def _rename_sync():
if not e: if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_data_error_result(message="Document not found!") return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): e, doc = DocumentService.get_by_id(req["doc_id"])
if d.name == req["name"]: if not e:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.") return get_data_error_result(message="Document not found!")
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix:
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR)
if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
return get_data_error_result(message="Database error (Document rename)!") if d.name == req["name"]:
return get_data_error_result(message="Duplicated document name in the same knowledgebase.")
informs = File2DocumentService.get_by_document_id(req["doc_id"]) if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}):
if informs: return get_data_error_result(message="Database error (Document rename)!")
e, file = FileService.get_by_id(informs[0].file_id)
FileService.update_by_id(file.id, {"name": req["name"]})
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) informs = File2DocumentService.get_by_document_id(req["doc_id"])
title_tks = rag_tokenizer.tokenize(req["name"]) if informs:
es_body = { e, file = FileService.get_by_id(informs[0].file_id)
"docnm_kwd": req["name"], FileService.update_by_id(file.id, {"name": req["name"]})
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), tenant_id = DocumentService.get_tenant_id(req["doc_id"])
} title_tks = rag_tokenizer.tokenize(req["name"])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): es_body = {
settings.docStoreConn.update( "docnm_kwd": req["name"],
{"doc_id": req["doc_id"]}, "title_tks": title_tks,
es_body, "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
search.index_name(tenant_id), }
doc.kb_id, if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
) settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
search.index_name(tenant_id),
doc.kb_id,
)
return get_json_result(data=True)
return await asyncio.to_thread(_rename_sync)
return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -502,7 +511,8 @@ async def get(doc_id):
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id) b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
response = await make_response(settings.STORAGE_IMPL.get(b, n)) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n)
response = await make_response(data)
ext = re.search(r"\.([^.]+)$", doc.name.lower()) ext = re.search(r"\.([^.]+)$", doc.name.lower())
ext = ext.group(1) if ext else None ext = ext.group(1) if ext else None
@ -523,8 +533,7 @@ async def get(doc_id):
async def download_attachment(attachment_id): async def download_attachment(attachment_id):
try: try:
ext = request.args.get("ext", "markdown") ext = request.args.get("ext", "markdown")
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
response = await make_response(data) response = await make_response(data)
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
@ -596,7 +605,8 @@ async def get_image(image_id):
if len(arr) != 2: if len(arr) != 2:
return get_data_error_result(message="Image not found.") return get_data_error_result(message="Image not found.")
bkt, nm = image_id.split("-") bkt, nm = image_id.split("-")
response = await make_response(settings.STORAGE_IMPL.get(bkt, nm)) data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm)
response = await make_response(data)
response.headers.set("Content-Type", "image/JPEG") response.headers.set("Content-Type", "image/JPEG")
return response return response
except Exception as e: except Exception as e: