add async for retrieval_test
This commit is contained in:
parent
4e8a8348f1
commit
d5139b10ec
2 changed files with 59 additions and 46 deletions
|
|
@ -311,25 +311,28 @@ async def retrieval_test():
|
||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
tenant_ids = []
|
user_id = current_user.id
|
||||||
|
|
||||||
if req.get("search_id", ""):
|
def _retrieval_sync():
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
tenant_ids = []
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
|
||||||
|
|
||||||
try:
|
if req.get("search_id", ""):
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
|
if meta_data_filter.get("method") == "auto":
|
||||||
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
|
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||||
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
|
if not local_doc_ids:
|
||||||
|
local_doc_ids = None
|
||||||
|
elif meta_data_filter.get("method") == "manual":
|
||||||
|
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||||
|
if meta_data_filter["manual"] and not local_doc_ids:
|
||||||
|
local_doc_ids = ["-999"]
|
||||||
|
|
||||||
|
tenants = UserTenantService.query(user_id=user_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(
|
if KnowledgebaseService.query(
|
||||||
|
|
@ -345,8 +348,9 @@ async def retrieval_test():
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Knowledgebase not found!")
|
return get_data_error_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
|
_question = question
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
|
|
@ -356,19 +360,19 @@ async def retrieval_test():
|
||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
_question += keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||||
float(req.get("similarity_threshold", 0.0)),
|
float(req.get("similarity_threshold", 0.0)),
|
||||||
float(req.get("vector_similarity_weight", 0.3)),
|
float(req.get("vector_similarity_weight", 0.3)),
|
||||||
top,
|
top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl,
|
local_doc_ids, rerank_mdl=rerank_mdl,
|
||||||
highlight=req.get("highlight", False),
|
highlight=req.get("highlight", False),
|
||||||
rank_feature=labels
|
rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question,
|
ck = settings.kg_retriever.retrieval(_question,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
|
|
@ -381,6 +385,9 @@ async def retrieval_test():
|
||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_retrieval_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
@ -963,28 +964,30 @@ async def retrieval_test_embedded():
|
||||||
use_kg = req.get("use_kg", False)
|
use_kg = req.get("use_kg", False)
|
||||||
top = int(req.get("top_k", 1024))
|
top = int(req.get("top_k", 1024))
|
||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
tenant_ids = []
|
|
||||||
|
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_error_data_result(message="permission denined.")
|
return get_error_data_result(message="permission denined.")
|
||||||
|
|
||||||
if req.get("search_id", ""):
|
def _retrieval_sync():
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
tenant_ids = []
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
_question = question
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
if req.get("search_id", ""):
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
if not doc_ids:
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = None
|
if meta_data_filter.get("method") == "auto":
|
||||||
elif meta_data_filter.get("method") == "manual":
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
doc_ids = ["-999"]
|
if not local_doc_ids:
|
||||||
|
local_doc_ids = None
|
||||||
|
elif meta_data_filter.get("method") == "manual":
|
||||||
|
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||||
|
if meta_data_filter["manual"] and not local_doc_ids:
|
||||||
|
local_doc_ids = ["-999"]
|
||||||
|
|
||||||
try:
|
|
||||||
tenants = UserTenantService.query(user_id=tenant_id)
|
tenants = UserTenantService.query(user_id=tenant_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
|
|
@ -1000,7 +1003,7 @@ async def retrieval_test_embedded():
|
||||||
return get_error_data_result(message="Knowledgebase not found!")
|
return get_error_data_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
|
|
@ -1010,15 +1013,15 @@ async def retrieval_test_embedded():
|
||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
_question += keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl,
|
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
|
@ -1028,6 +1031,9 @@ async def retrieval_test_embedded():
|
||||||
ranks["labels"] = labels
|
ranks["labels"] = labels
|
||||||
|
|
||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_retrieval_sync)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue