add async for retrieval_test

This commit is contained in:
yongtenglei 2025-12-02 17:13:20 +08:00
parent 4e8a8348f1
commit d5139b10ec
2 changed files with 59 additions and 46 deletions

View file

@ -311,25 +311,28 @@ async def retrieval_test():
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
tenant_ids = [] user_id = current_user.id
if req.get("search_id", ""): def _retrieval_sync():
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) local_doc_ids = list(doc_ids) if doc_ids else []
meta_data_filter = search_config.get("meta_data_filter", {}) tenant_ids = []
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not doc_ids:
doc_ids = ["-999"]
try: if req.get("search_id", ""):
tenants = UserTenantService.query(user_id=current_user.id) search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids:
local_doc_ids = None
elif meta_data_filter.get("method") == "manual":
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not local_doc_ids:
local_doc_ids = ["-999"]
tenants = UserTenantService.query(user_id=user_id)
for kb_id in kb_ids: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(
@ -345,8 +348,9 @@ async def retrieval_test():
if not e: if not e:
return get_data_error_result(message="Knowledgebase not found!") return get_data_error_result(message="Knowledgebase not found!")
_question = question
if langs: if langs:
question = cross_languages(kb.tenant_id, None, question, langs) _question = cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -356,19 +360,19 @@ async def retrieval_test():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) _question += keyword_extraction(chat_mdl, _question)
labels = label_question(question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)), float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)), float(req.get("vector_similarity_weight", 0.3)),
top, top,
doc_ids, rerank_mdl=rerank_mdl, local_doc_ids, rerank_mdl=rerank_mdl,
highlight=req.get("highlight", False), highlight=req.get("highlight", False),
rank_feature=labels rank_feature=labels
) )
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(question, ck = settings.kg_retriever.retrieval(_question,
tenant_ids, tenant_ids,
kb_ids, kb_ids,
embd_mdl, embd_mdl,
@ -381,6 +385,9 @@ async def retrieval_test():
ranks["labels"] = labels ranks["labels"] = labels
return get_json_result(data=ranks) return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message='No chunk found! Check the chunk status please!',

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import re import re
import time import time
@ -963,28 +964,30 @@ async def retrieval_test_embedded():
use_kg = req.get("use_kg", False) use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
tenant_ids = []
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
if not tenant_id: if not tenant_id:
return get_error_data_result(message="permission denined.") return get_error_data_result(message="permission denined.")
if req.get("search_id", ""): def _retrieval_sync():
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) local_doc_ids = list(doc_ids) if doc_ids else []
meta_data_filter = search_config.get("meta_data_filter", {}) tenant_ids = []
metas = DocumentService.get_meta_by_kbs(kb_ids) _question = question
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) if req.get("search_id", ""):
filters: dict = gen_meta_filter(chat_mdl, metas, question) search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) meta_data_filter = search_config.get("meta_data_filter", {})
if not doc_ids: metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = None if meta_data_filter.get("method") == "auto":
elif meta_data_filter.get("method") == "manual": chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) filters: dict = gen_meta_filter(chat_mdl, metas, _question)
if meta_data_filter["manual"] and not doc_ids: local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
doc_ids = ["-999"] if not local_doc_ids:
local_doc_ids = None
elif meta_data_filter.get("method") == "manual":
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
if meta_data_filter["manual"] and not local_doc_ids:
local_doc_ids = ["-999"]
try:
tenants = UserTenantService.query(user_id=tenant_id) tenants = UserTenantService.query(user_id=tenant_id)
for kb_id in kb_ids: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
@ -1000,7 +1003,7 @@ async def retrieval_test_embedded():
return get_error_data_result(message="Knowledgebase not found!") return get_error_data_result(message="Knowledgebase not found!")
if langs: if langs:
question = cross_languages(kb.tenant_id, None, question, langs) _question = cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -1010,15 +1013,15 @@ async def retrieval_test_embedded():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) _question += keyword_extraction(chat_mdl, _question)
labels = label_question(question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval( ranks = settings.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
) )
if use_kg: if use_kg:
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl, ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT)) LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck) ranks["chunks"].insert(0, ck)
@ -1028,6 +1031,9 @@ async def retrieval_test_embedded():
ranks["labels"] = labels ranks["labels"] = labels
return get_json_result(data=ranks) return get_json_result(data=ranks)
try:
return await asyncio.to_thread(_retrieval_sync)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", return get_json_result(data=False, message="No chunk found! Check the chunk status please!",