From d5139b10ec82abc9c3665922e8a08b32740495e6 Mon Sep 17 00:00:00 2001 From: yongtenglei Date: Tue, 2 Dec 2025 17:13:20 +0800 Subject: [PATCH] add async for retrieval_test --- api/apps/chunk_app.py | 53 +++++++++++++++++++++++------------------ api/apps/sdk/session.py | 52 ++++++++++++++++++++++------------------ 2 files changed, 59 insertions(+), 46 deletions(-) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index c2b384224..d96de64d0 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -311,25 +311,28 @@ async def retrieval_test(): use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) langs = req.get("cross_languages", []) - tenant_ids = [] + user_id = current_user.id - if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not doc_ids: - doc_ids = ["-999"] + def _retrieval_sync(): + local_doc_ids = list(doc_ids) if doc_ids else [] + tenant_ids = [] - try: - tenants = UserTenantService.query(user_id=current_user.id) + if req.get("search_id", ""): + search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + metas = DocumentService.get_meta_by_kbs(kb_ids) + if meta_data_filter.get("method") == "auto": + chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + filters: dict = gen_meta_filter(chat_mdl, metas, question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None + elif meta_data_filter.get("method") == "manual": + local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) + if meta_data_filter["manual"] and not local_doc_ids: + local_doc_ids = ["-999"] + + tenants = UserTenantService.query(user_id=user_id) for kb_id in kb_ids: for tenant in tenants: if KnowledgebaseService.query( @@ -345,8 +348,9 @@ async def retrieval_test(): if not e: return get_data_error_result(message="Knowledgebase not found!") + _question = question if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + _question = cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -356,19 +360,19 @@ async def retrieval_test(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + _question += keyword_extraction(chat_mdl, _question) - labels = label_question(question, [kb]) - ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, + labels = label_question(_question, [kb]) + ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size, float(req.get("similarity_threshold", 0.0)), float(req.get("vector_similarity_weight", 0.3)), top, - doc_ids, rerank_mdl=rerank_mdl, + local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight", False), rank_feature=labels ) if use_kg: - ck = settings.kg_retriever.retrieval(question, + ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, @@ -381,6 +385,9 @@ async def retrieval_test(): ranks["labels"] = labels return get_json_result(data=ranks) + + try: + return await asyncio.to_thread(_retrieval_sync) except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found! Check the chunk status please!', diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 6276877a2..ee81be7b7 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import re import time @@ -963,28 +964,30 @@ async def retrieval_test_embedded(): use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) langs = req.get("cross_languages", []) - tenant_ids = [] - tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") - if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not doc_ids: - doc_ids = ["-999"] + def _retrieval_sync(): + local_doc_ids = list(doc_ids) if doc_ids else [] + tenant_ids = [] + _question = question + + if req.get("search_id", ""): + search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + metas = DocumentService.get_meta_by_kbs(kb_ids) + if meta_data_filter.get("method") == "auto": + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + filters: dict = gen_meta_filter(chat_mdl, metas, _question) + local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not local_doc_ids: + local_doc_ids = None + elif meta_data_filter.get("method") == "manual": + local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) + if meta_data_filter["manual"] and not local_doc_ids: + local_doc_ids = ["-999"] - try: tenants = UserTenantService.query(user_id=tenant_id) for kb_id in kb_ids: for tenant in tenants: @@ -1000,7 +1003,7 @@ async def retrieval_test_embedded(): return get_error_data_result(message="Knowledgebase not found!") if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + _question = cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -1010,15 +1013,15 @@ async def retrieval_test_embedded(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + _question += keyword_extraction(chat_mdl, _question) - labels = label_question(question, [kb]) + labels = label_question(_question, [kb]) ranks = settings.retriever.retrieval( - question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels + _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, + local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) if use_kg: - ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl, + ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -1028,6 +1031,9 @@ async def retrieval_test_embedded(): ranks["labels"] = labels return get_json_result(data=ranks) + + try: + return await asyncio.to_thread(_retrieval_sync) except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message="No chunk found! Check the chunk status please!",