diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4357bf982..2d0804c12 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -95,6 +95,38 @@ jobs: version: ">=0.11.x" args: "check" + - name: Check comments of changed Python files + if: ${{ !cancelled() && !failure() }} + run: | + if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then + CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \ + | grep -E '\.(py)$' || true) + + if [ -n "$CHANGED_FILES" ]; then + echo "Check comments of changed Python files with check_comment_ascii.py" + + readarray -t files <<< "$CHANGED_FILES" + HAS_ERROR=0 + + for file in "${files[@]}"; do + if [ -f "$file" ]; then + if python3 check_comment_ascii.py $file"; then + echo "✅ $file" + else + echo "❌ $file" + HAS_ERROR=1 + fi + fi + done + + if [ $HAS_ERROR -ne 0 ]; then + exit 1 + fi + else + echo "No Python files changed" + fi + fi + - name: Build ragflow:nightly run: | RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}} diff --git a/README.md b/README.md index 299bd67fd..297595c58 100644 --- a/README.md +++ b/README.md @@ -192,9 +192,10 @@ releases! 🌟 ```bash $ cd ragflow/docker - + # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0 - + # This steps ensures the **entrypoint.sh** file in the code matches the Docker image version. + # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d diff --git a/README_id.md b/README_id.md index c9017ddd1..b5230c8bc 100644 --- a/README_id.md +++ b/README_id.md @@ -192,6 +192,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io). $ cd ragflow/docker # Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0 + # This steps ensures the **entrypoint.sh** file in the code matches the Docker image version. # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d diff --git a/README_ja.md b/README_ja.md index 24bce0874..96dc661e3 100644 --- a/README_ja.md +++ b/README_ja.md @@ -172,6 +172,7 @@ $ cd ragflow/docker # 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0 + # この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。 # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d diff --git a/README_ko.md b/README_ko.md index bd5acf82d..51f4169ff 100644 --- a/README_ko.md +++ b/README_ko.md @@ -174,6 +174,7 @@ $ cd ragflow/docker # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0 + # 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다. # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d diff --git a/README_pt_br.md b/README_pt_br.md index 0769ea5e5..5d4d39e5e 100644 --- a/README_pt_br.md +++ b/README_pt_br.md @@ -192,6 +192,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io). $ cd ragflow/docker # Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0 + # Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker. # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d diff --git a/README_tzh.md b/README_tzh.md index a78896453..57f2e8196 100644 --- a/README_tzh.md +++ b/README_tzh.md @@ -191,6 +191,7 @@ $ cd ragflow/docker # 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases),例:git checkout v0.22.0 + # 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。 # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d diff --git a/README_zh.md b/README_zh.md index c70073a3e..91004f549 100644 --- a/README_zh.md +++ b/README_zh.md @@ -192,6 +192,7 @@ $ cd ragflow/docker # 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases),例如:git checkout v0.22.0 + # 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。 # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 98dfbc92f..a27504139 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -368,11 +368,19 @@ Respond immediately with your final comprehensive answer. return "Error occurred." - def reset(self, temp=False): + def reset(self, only_output=False): """ Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession. """ + for k in self._param.outputs.keys(): + self._param.outputs[k]["value"] = None + for k, cpn in self.tools.items(): if hasattr(cpn, "reset") and callable(cpn.reset): cpn.reset() + if only_output: + return + for k in self._param.inputs.keys(): + self._param.inputs[k]["value"] = None + self._param.debug_inputs = {} diff --git a/agent/component/llm.py b/agent/component/llm.py index 6ce0f65a5..b08e0591e 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -222,7 +222,7 @@ class LLM(ComponentBase): output_structure = self._param.outputs['structured'] except Exception: pass - if output_structure: + if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"): schema=json.dumps(output_structure, ensure_ascii=False, indent=2) prompt += structured_output_prompt(schema) for _ in range(self._param.max_retries+1): diff --git a/api/apps/__init__.py b/api/apps/__init__.py index f2009db2c..e6249a443 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -96,12 +96,12 @@ login_manager.init_app(app) commands.register_commands(app) -def search_pages_path(pages_dir): +def search_pages_path(page_path): app_path_list = [ - path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".") + path for path in page_path.glob("*_app.py") if not path.name.startswith(".") ] api_path_list = [ - path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".") + path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".") ] app_path_list.extend(api_path_list) return app_path_list @@ -138,7 +138,7 @@ pages_dir = [ ] client_urls_prefix = [ - register_page(path) for dir in pages_dir for path in search_pages_path(dir) + register_page(path) for directory in pages_dir for path in search_pages_path(directory) ] @@ -177,5 +177,7 @@ def load_user(web_request): @app.teardown_request -def _db_close(exc): +def _db_close(exception): + if exception: + logging.exception(f"Request failed: {exception}") close_connection() diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 1ab1c462a..1c9a78239 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -13,41 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json -import os -import re from datetime import datetime, timedelta -from flask import request, Response -from api.db.services.llm_service import LLMBundle +from flask import request from flask_login import login_required, current_user - -from api.db import VALID_FILE_TYPES, FileType -from api.db.db_models import APIToken, Task, File -from api.db.services import duplicate_name +from api.db.db_models import APIToken from api.db.services.api_service import APITokenService, API4ConversationService -from api.db.services.dialog_service import DialogService, chat -from api.db.services.document_service import DocumentService, doc_upload_and_parse -from api.db.services.file2document_service import File2DocumentService -from api.db.services.file_service import FileService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService -from common.misc_utils import get_uuid -from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ generate_confirmation_token - -from api.utils.file_utils import filename_type, thumbnail -from rag.app.tag import label_question -from rag.prompts.generator import keyword_extraction from common.time_utils import current_timestamp, datetime_format -from api.db.services.canvas_service import UserCanvasService -from agent.canvas import Canvas -from functools import partial -from pathlib import Path -from common import settings - @manager.route('/new_token', methods=['POST']) # noqa: F821 @login_required @@ -138,758 +113,3 @@ def stats(): except Exception as e: return server_error_response(e) - -@manager.route('/new_conversation', methods=['GET']) # noqa: F821 -def set_conversation(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - try: - if objs[0].source == "agent": - e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id) - if not e: - return server_error_response("canvas not found.") - if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - canvas = Canvas(cvs.dsl, objs[0].tenant_id) - conv = { - "id": get_uuid(), - "dialog_id": cvs.id, - "user_id": request.args.get("user_id", ""), - "message": [{"role": "assistant", "content": canvas.get_prologue()}], - "source": "agent" - } - API4ConversationService.save(**conv) - return get_json_result(data=conv) - else: - e, dia = DialogService.get_by_id(objs[0].dialog_id) - if not e: - return get_data_error_result(message="Dialog not found") - conv = { - "id": get_uuid(), - "dialog_id": dia.id, - "user_id": request.args.get("user_id", ""), - "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] - } - API4ConversationService.save(**conv) - return get_json_result(data=conv) - except Exception as e: - return server_error_response(e) - - -@manager.route('/completion', methods=['POST']) # noqa: F821 -@validate_request("conversation_id", "messages") -def completion(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - req = request.json - e, conv = API4ConversationService.get_by_id(req["conversation_id"]) - if not e: - return get_data_error_result(message="Conversation not found!") - if "quote" not in req: - req["quote"] = False - - msg = [] - for m in req["messages"]: - if m["role"] == "system": - continue - if m["role"] == "assistant" and not msg: - continue - msg.append(m) - if not msg[-1].get("id"): - msg[-1]["id"] = get_uuid() - message_id = msg[-1]["id"] - - def fillin_conv(ans): - nonlocal conv, message_id - if not conv.reference: - conv.reference.append(ans["reference"]) - else: - conv.reference[-1] = ans["reference"] - conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} - ans["id"] = message_id - - def rename_field(ans): - reference = ans['reference'] - if not isinstance(reference, dict): - return - for chunk_i in reference.get('chunks', []): - if 'docnm_kwd' in chunk_i: - chunk_i['doc_name'] = chunk_i['docnm_kwd'] - chunk_i.pop('docnm_kwd') - - try: - if conv.source == "agent": - stream = req.get("stream", True) - conv.message.append(msg[-1]) - e, cvs = UserCanvasService.get_by_id(conv.dialog_id) - if not e: - return server_error_response("canvas not found.") - del req["conversation_id"] - del req["messages"] - - if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - final_ans = {"reference": [], "content": ""} - canvas = Canvas(cvs.dsl, objs[0].tenant_id) - - canvas.messages.append(msg[-1]) - canvas.add_user_input(msg[-1]["content"]) - answer = canvas.run(stream=stream) - - assert answer is not None, "Nothing. Is it over?" - - if stream: - assert isinstance(answer, partial), "Nothing. Is it over?" - - def sse(): - nonlocal answer, cvs, conv - try: - for ans in answer(): - for k in ans.keys(): - final_ans[k] = ans[k] - ans = {"answer": ans["content"], "reference": ans.get("reference", [])} - fillin_conv(ans) - rename_field(ans) - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, - ensure_ascii=False) + "\n\n" - - canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) - canvas.history.append(("assistant", final_ans["content"])) - if final_ans.get("reference"): - canvas.reference.append(final_ans["reference"]) - cvs.dsl = json.loads(str(canvas)) - API4ConversationService.append_message(conv.id, conv.to_dict()) - except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - resp = Response(sse(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" - canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) - if final_ans.get("reference"): - canvas.reference.append(final_ans["reference"]) - cvs.dsl = json.loads(str(canvas)) - - result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} - fillin_conv(result) - API4ConversationService.append_message(conv.id, conv.to_dict()) - rename_field(result) - return get_json_result(data=result) - - # ******************For dialog****************** - conv.message.append(msg[-1]) - e, dia = DialogService.get_by_id(conv.dialog_id) - if not e: - return get_data_error_result(message="Dialog not found!") - del req["conversation_id"] - del req["messages"] - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - def stream(): - nonlocal dia, msg, req, conv - try: - for ans in chat(dia, msg, True, **req): - fillin_conv(ans) - rename_field(ans) - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, - ensure_ascii=False) + "\n\n" - API4ConversationService.append_message(conv.id, conv.to_dict()) - except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - if req.get("stream", True): - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - answer = None - for ans in chat(dia, msg, **req): - answer = ans - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - break - rename_field(answer) - return get_json_result(data=answer) - - except Exception as e: - return server_error_response(e) - - -@manager.route('/conversation/', methods=['GET']) # noqa: F821 -# @login_required -def get_conversation(conversation_id): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - try: - e, conv = API4ConversationService.get_by_id(conversation_id) - if not e: - return get_data_error_result(message="Conversation not found!") - - conv = conv.to_dict() - if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token: - return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"', - code=RetCode.AUTHENTICATION_ERROR) - - for referenct_i in conv['reference']: - if referenct_i is None or len(referenct_i) == 0: - continue - for chunk_i in referenct_i['chunks']: - if 'docnm_kwd' in chunk_i.keys(): - chunk_i['doc_name'] = chunk_i['docnm_kwd'] - chunk_i.pop('docnm_kwd') - return get_json_result(data=conv) - except Exception as e: - return server_error_response(e) - - -@manager.route('/document/upload', methods=['POST']) # noqa: F821 -@validate_request("kb_name") -def upload(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - kb_name = request.form.get("kb_name").strip() - tenant_id = objs[0].tenant_id - - try: - e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id) - if not e: - return get_data_error_result( - message="Can't find this knowledgebase!") - kb_id = kb.id - except Exception as e: - return server_error_response(e) - - if 'file' not in request.files: - return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) - - file = request.files['file'] - if file.filename == '': - return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) - - root_folder = FileService.get_root_folder(tenant_id) - pf_id = root_folder["id"] - FileService.init_knowledgebase_docs(pf_id, tenant_id) - kb_root_folder = FileService.get_kb_folder(tenant_id) - kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) - - try: - if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)): - return get_data_error_result( - message="Exceed the maximum file number of a free user!") - - filename = duplicate_name( - DocumentService.query, - name=file.filename, - kb_id=kb_id) - filetype = filename_type(filename) - if not filetype: - return get_data_error_result( - message="This type of file has not been supported yet!") - - location = filename - while settings.STORAGE_IMPL.obj_exist(kb_id, location): - location += "_" - blob = request.files['file'].read() - settings.STORAGE_IMPL.put(kb_id, location, blob) - doc = { - "id": get_uuid(), - "kb_id": kb.id, - "parser_id": kb.parser_id, - "parser_config": kb.parser_config, - "created_by": kb.tenant_id, - "type": filetype, - "name": filename, - "location": location, - "size": len(blob), - "thumbnail": thumbnail(filename, blob), - "suffix": Path(filename).suffix.lstrip("."), - } - - form_data = request.form - if "parser_id" in form_data.keys(): - if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]: - doc["parser_id"] = request.form.get("parser_id").strip() - if doc["type"] == FileType.VISUAL: - doc["parser_id"] = ParserType.PICTURE.value - if doc["type"] == FileType.AURAL: - doc["parser_id"] = ParserType.AUDIO.value - if re.search(r"\.(ppt|pptx|pages)$", filename): - doc["parser_id"] = ParserType.PRESENTATION.value - if re.search(r"\.(eml)$", filename): - doc["parser_id"] = ParserType.EMAIL.value - - doc_result = DocumentService.insert(doc) - FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id) - except Exception as e: - return server_error_response(e) - - if "run" in form_data.keys(): - if request.form.get("run").strip() == "1": - try: - info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} - DocumentService.update_by_id(doc["id"], info) - # if str(req["run"]) == TaskStatus.CANCEL.value: - tenant_id = DocumentService.get_tenant_id(doc["id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - - # e, doc = DocumentService.get_by_id(doc["id"]) - TaskService.filter_delete([Task.doc_id == doc["id"]]) - e, doc = DocumentService.get_by_id(doc["id"]) - doc = doc.to_dict() - doc["tenant_id"] = tenant_id - bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) - queue_tasks(doc, bucket, name, 0) - except Exception as e: - return server_error_response(e) - - return get_json_result(data=doc_result.to_json()) - - -@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821 -@validate_request("conversation_id") -def upload_parse(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - if 'file' not in request.files: - return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) - - file_objs = request.files.getlist('file') - for file_obj in file_objs: - if file_obj.filename == '': - return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) - - doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) - return get_json_result(data=doc_ids) - - -@manager.route('/list_chunks', methods=['POST']) # noqa: F821 -# @login_required -def list_chunks(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - req = request.json - - try: - if "doc_name" in req.keys(): - tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name']) - doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name']) - - elif "doc_id" in req.keys(): - tenant_id = DocumentService.get_tenant_id(req['doc_id']) - doc_id = req['doc_id'] - else: - return get_json_result( - data=False, message="Can't find doc_name or doc_id" - ) - kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - - res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids) - res = [ - { - "content": res_item["content_with_weight"], - "doc_name": res_item["docnm_kwd"], - "image_id": res_item["img_id"] - } for res_item in res - ] - - except Exception as e: - return server_error_response(e) - - return get_json_result(data=res) - -@manager.route('/get_chunk/', methods=['GET']) # noqa: F821 -# @login_required -def get_chunk(chunk_id): - from rag.nlp import search - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - try: - tenant_id = objs[0].tenant_id - kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) - if chunk is None: - return server_error_response(Exception("Chunk not found")) - k = [] - for n in chunk.keys(): - if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): - k.append(n) - for n in k: - del chunk[n] - - return get_json_result(data=chunk) - except Exception as e: - return server_error_response(e) - -@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821 -# @login_required -def list_kb_docs(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - req = request.json - tenant_id = objs[0].tenant_id - kb_name = req.get("kb_name", "").strip() - - try: - e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id) - if not e: - return get_data_error_result( - message="Can't find this knowledgebase!") - kb_id = kb.id - - except Exception as e: - return server_error_response(e) - - page_number = int(req.get("page", 1)) - items_per_page = int(req.get("page_size", 15)) - orderby = req.get("orderby", "create_time") - desc = req.get("desc", True) - keywords = req.get("keywords", "") - status = req.get("status", []) - if status: - invalid_status = {s for s in status if s not in VALID_TASK_STATUS} - if invalid_status: - return get_data_error_result( - message=f"Invalid filter status conditions: {', '.join(invalid_status)}" - ) - types = req.get("types", []) - if types: - invalid_types = {t for t in types if t not in VALID_FILE_TYPES} - if invalid_types: - return get_data_error_result( - message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}" - ) - try: - docs, tol = DocumentService.get_by_kb_id( - kb_id, page_number, items_per_page, orderby, desc, keywords, status, types) - docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs] - - return get_json_result(data={"total": tol, "docs": docs}) - - except Exception as e: - return server_error_response(e) - - -@manager.route('/document/infos', methods=['POST']) # noqa: F821 -@validate_request("doc_ids") -def docinfos(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - req = request.json - doc_ids = req["doc_ids"] - docs = DocumentService.get_by_ids(doc_ids) - return get_json_result(data=list(docs.dicts())) - - -@manager.route('/document', methods=['DELETE']) # noqa: F821 -# @login_required -def document_rm(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - tenant_id = objs[0].tenant_id - req = request.json - try: - doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", [])) - for doc_id in req.get("doc_ids", []): - if doc_id not in doc_ids: - doc_ids.append(doc_id) - - if not doc_ids: - return get_json_result( - data=False, message="Can't find doc_names or doc_ids" - ) - - except Exception as e: - return server_error_response(e) - - root_folder = FileService.get_root_folder(tenant_id) - pf_id = root_folder["id"] - FileService.init_knowledgebase_docs(pf_id, tenant_id) - - errors = "" - docs = DocumentService.get_by_ids(doc_ids) - doc_dic = {} - for doc in docs: - doc_dic[doc.id] = doc - - for doc_id in doc_ids: - try: - if doc_id not in doc_dic: - return get_data_error_result(message="Document not found!") - doc = doc_dic[doc_id] - tenant_id = DocumentService.get_tenant_id(doc_id) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - - b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - - if not DocumentService.remove_document(doc, tenant_id): - return get_data_error_result( - message="Database error (Document removal)!") - - f2d = File2DocumentService.get_by_document_id(doc_id) - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) - File2DocumentService.delete_by_document_id(doc_id) - - settings.STORAGE_IMPL.rm(b, n) - except Exception as e: - errors += str(e) - - if errors: - return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) - - return get_json_result(data=True) - - -@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821 -@validate_request("Authorization", "conversation_id", "word") -def completion_faq(): - import base64 - req = request.json - - token = req["Authorization"] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - e, conv = API4ConversationService.get_by_id(req["conversation_id"]) - if not e: - return get_data_error_result(message="Conversation not found!") - if "quote" not in req: - req["quote"] = True - - msg = [{"role": "user", "content": req["word"]}] - if not msg[-1].get("id"): - msg[-1]["id"] = get_uuid() - message_id = msg[-1]["id"] - - def fillin_conv(ans): - nonlocal conv, message_id - if not conv.reference: - conv.reference.append(ans["reference"]) - else: - conv.reference[-1] = ans["reference"] - conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} - ans["id"] = message_id - - try: - if conv.source == "agent": - conv.message.append(msg[-1]) - e, cvs = UserCanvasService.get_by_id(conv.dialog_id) - if not e: - return server_error_response("canvas not found.") - - if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - final_ans = {"reference": [], "doc_aggs": []} - canvas = Canvas(cvs.dsl, objs[0].tenant_id) - - canvas.messages.append(msg[-1]) - canvas.add_user_input(msg[-1]["content"]) - answer = canvas.run(stream=False) - - assert answer is not None, "Nothing. Is it over?" - - data_type_picture = { - "type": 3, - "url": "base64 content" - } - data = [ - { - "type": 1, - "content": "" - } - ] - final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" - canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) - if final_ans.get("reference"): - canvas.reference.append(final_ans["reference"]) - cvs.dsl = json.loads(str(canvas)) - - ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} - data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - - chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] - for chunk_idx in chunk_idxs[:1]: - if ans["reference"]["chunks"][chunk_idx]["img_id"]: - try: - bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") - response = settings.STORAGE_IMPL.get(bkt, nm) - data_type_picture["url"] = base64.b64encode(response).decode('utf-8') - data.append(data_type_picture) - break - except Exception as e: - return server_error_response(e) - - response = {"code": 200, "msg": "success", "data": data} - return response - - # ******************For dialog****************** - conv.message.append(msg[-1]) - e, dia = DialogService.get_by_id(conv.dialog_id) - if not e: - return get_data_error_result(message="Dialog not found!") - del req["conversation_id"] - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - data_type_picture = { - "type": 3, - "url": "base64 content" - } - data = [ - { - "type": 1, - "content": "" - } - ] - ans = "" - for a in chat(dia, msg, stream=False, **req): - ans = a - break - data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - - chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] - for chunk_idx in chunk_idxs[:1]: - if ans["reference"]["chunks"][chunk_idx]["img_id"]: - try: - bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") - response = settings.STORAGE_IMPL.get(bkt, nm) - data_type_picture["url"] = base64.b64encode(response).decode('utf-8') - data.append(data_type_picture) - break - except Exception as e: - return server_error_response(e) - - response = {"code": 200, "msg": "success", "data": data} - return response - - except Exception as e: - return server_error_response(e) - - -@manager.route('/retrieval', methods=['POST']) # noqa: F821 -@validate_request("kb_id", "question") -def retrieval(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - req = request.json - kb_ids = req.get("kb_id", []) - doc_ids = req.get("doc_ids", []) - question = req.get("question") - page = int(req.get("page", 1)) - size = int(req.get("page_size", 30)) - similarity_threshold = float(req.get("similarity_threshold", 0.2)) - vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) - top = int(req.get("top_k", 1024)) - highlight = bool(req.get("highlight", False)) - - try: - kbs = KnowledgebaseService.get_by_ids(kb_ids) - embd_nms = list(set([kb.embd_id for kb in kbs])) - if len(embd_nms) != 1: - return get_json_result( - data=False, message='Knowledge bases use different embedding models or does not exist."', - code=RetCode.AUTHENTICATION_ERROR) - - embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id) - rerank_mdl = None - if req.get("rerank_id"): - rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) - if req.get("keyword", False): - chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) - ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, - similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl, highlight= highlight, - rank_feature=label_question(question, kbs)) - for c in ranks["chunks"]: - c.pop("vector", None) - return get_json_result(data=ranks) - except Exception as e: - if str(e).find("not_found") > 0: - return get_json_result(data=False, message='No chunk found! Check the chunk status please!', - code=RetCode.DATA_ERROR) - return server_error_response(e) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 0ac2951ae..bc0ea8b80 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -426,7 +426,6 @@ def test_db_connect(): try: import trino import os - from trino.auth import BasicAuthentication except Exception as e: return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}") @@ -438,7 +437,7 @@ def test_db_connect(): auth = None if http_scheme == "https" and req.get("password"): - auth = BasicAuthentication(req.get("username") or "ragflow", req["password"]) + auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"]) conn = trino.dbapi.connect( host=req["host"], @@ -471,8 +470,8 @@ def test_db_connect(): @login_required def getlistversion(canvas_id): try: - list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1) - return get_json_result(data=list) + versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1) + return get_json_result(data=versions) except Exception as e: return get_data_error_result(message=f"Error getting history files: {e}") diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index 23965e617..80f791c93 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -55,7 +55,6 @@ def set_connector(): "timeout_secs": int(req.get("timeout_secs", 60 * 29)), "status": TaskStatus.SCHEDULE, } - conn["status"] = TaskStatus.SCHEDULE ConnectorService.save(**conn) time.sleep(1) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 984e57cac..d0465252a 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -85,7 +85,6 @@ def get(): if not e: return get_data_error_result(message="Conversation not found!") tenants = UserTenantService.query(user_id=current_user.id) - avatar = None for tenant in tenants: dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id) if dialog and len(dialog) > 0: diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 99f700568..82c78ffed 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -154,15 +154,15 @@ def get_kb_names(kb_ids): @login_required def list_dialogs(): try: - diags = DialogService.query( + conversations = DialogService.query( tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time) - diags = [d.to_dict() for d in diags] - for d in diags: - d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) - return get_json_result(data=diags) + conversations = [d.to_dict() for d in conversations] + for conversation in conversations: + conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"]) + return get_json_result(data=conversations) except Exception as e: return server_error_response(e) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index c2e37598e..12c19f978 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -308,7 +308,7 @@ def get_filter(): @manager.route("/infos", methods=["POST"]) # noqa: F821 @login_required -def docinfos(): +def doc_infos(): req = request.json doc_ids = req["doc_ids"] for doc_id in doc_ids: @@ -544,6 +544,7 @@ def change_parser(): return get_data_error_result(message="Tenant not found!") if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + return None try: if "pipeline_id" in req and req["pipeline_id"] != "": diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 279e32525..7daff6ed7 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -246,8 +246,8 @@ def rm(): try: if file.location: settings.STORAGE_IMPL.rm(file.parent_id, file.location) - except Exception: - logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}") + except Exception as e: + logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}") informs = File2DocumentService.get_by_file_id(file.id) for inform in informs: diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 7094c28d7..b7cf58a20 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -16,6 +16,7 @@ import json import logging import random +import re from flask import request from flask_login import login_required, current_user @@ -731,6 +732,8 @@ def delete_kb_task(): def cancel_task(task_id): REDIS_CONN.set(f"{task_id}-cancel", "x") + kb_task_id_field: str = "" + kb_task_finish_at: str = "" match pipeline_task_type: case PipelineTaskType.GRAPH_RAG: kb_task_id_field = "graphrag_task_id" @@ -807,7 +810,7 @@ def check_embedding(): offset=0, limit=1, indexNames=index_nm, knowledgebaseIds=[kb_id] ) - total = docStoreConn.getTotal(res0) + total = docStoreConn.get_total(res0) if total <= 0: return [] @@ -824,7 +827,7 @@ def check_embedding(): offset=off, limit=1, indexNames=index_nm, knowledgebaseIds=[kb_id] ) - ids = docStoreConn.getChunkIds(res1) + ids = docStoreConn.get_chunk_ids(res1) if not ids: continue @@ -845,8 +848,13 @@ def check_embedding(): "position_int": full_doc.get("position_int"), "top_int": full_doc.get("top_int"), "content_with_weight": full_doc.get("content_with_weight") or "", + "question_kwd": full_doc.get("question_kwd") or [] }) return out + + def _clean(s: str) -> str: + s = re.sub(r"]{0,12})?>", " ", s or "") + return s if s else "None" req = request.json kb_id = req.get("kb_id", "") embd_id = req.get("embd_id", "") @@ -859,8 +867,10 @@ def check_embedding(): results, eff_sims = [], [] for ck in samples: - txt = (ck.get("content_with_weight") or "").strip() - if not txt: + title = ck.get("doc_name") or "Title" + txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or "" + txt_in = _clean(txt_in) + if not txt_in: results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"}) continue @@ -869,8 +879,16 @@ def check_embedding(): continue try: - qv, _ = emb_mdl.encode_queries(txt) - sim = _cos_sim(qv, ck["vector"]) + v, _ = emb_mdl.encode([title, txt_in]) + sim_content = _cos_sim(v[1], ck["vector"]) + title_w = 0.1 + qv_mix = title_w * v[0] + (1 - title_w) * v[1] + sim_mix = _cos_sim(qv_mix, ck["vector"]) + sim = sim_content + mode = "content_only" + if sim_mix > sim: + sim = sim_mix + mode = "title+content" except Exception: return get_error_data_result(message="embedding failure") @@ -892,8 +910,9 @@ def check_embedding(): "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6), "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6), "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6), + "match_mode": mode, } - if summary["avg_cos_sim"] > 0.99: + if summary["avg_cos_sim"] > 0.9: return get_json_result(data={"summary": summary, "results": results}) return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results}) diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 8a315ce69..de5434de7 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -21,10 +21,11 @@ import json from flask import request from peewee import OperationalError from api.db.db_models import File -from api.db.services.document_service import DocumentService +from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService from api.db.services.user_service import TenantService from common.constants import RetCode, FileSource, StatusEnum from api.utils.api_utils import ( @@ -118,7 +119,6 @@ def create(tenant_id): req, err = validate_and_parse_json_request(request, CreateDatasetReq) if err is not None: return get_error_argument_result(err) - req = KnowledgebaseService.create_with_name( name = req.pop("name", None), tenant_id = tenant_id, @@ -144,7 +144,6 @@ def create(tenant_id): ok, k = KnowledgebaseService.get_by_id(req["id"]) if not ok: return get_error_data_result(message="Dataset created failed") - response_data = remap_dictionary_keys(k.to_dict()) return get_result(data=response_data) except Exception as e: @@ -532,3 +531,157 @@ def delete_knowledge_graph(tenant_id, dataset_id): search.index_name(kb.tenant_id), dataset_id) return get_result(data=True) + + +@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 +@token_required +def run_graphrag(tenant_id,dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.graphrag_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return get_error_data_result(message=f"No documents in Dataset {dataset_id}") + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): + logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}") + + return get_result(data={"graphrag_task_id": task_id}) + + +@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 +@token_required +def trace_graphrag(tenant_id,dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.graphrag_task_id + if not task_id: + return get_result(data={}) + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_result(data={}) + + return get_result(data=task.to_dict()) + + +@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 +@token_required +def run_raptor(tenant_id,dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.raptor_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return get_error_data_result(message=f"No documents in Dataset {dataset_id}") + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): + logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}") + + return get_result(data={"raptor_task_id": task_id}) + + +@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 +@token_required +def trace_raptor(tenant_id,dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.raptor_task_id + if not task_id: + return get_result(data={}) + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") + + return get_result(data=task.to_dict()) \ No newline at end of file diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 4caf2cc8d..b54597f89 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -93,6 +93,10 @@ def upload(dataset_id, tenant_id): type: file required: true description: Document files to upload. + - in: formData + name: parent_path + type: string + description: Optional nested path under the parent folder. Uses '/' separators. responses: 200: description: Successfully uploaded documents. @@ -151,7 +155,7 @@ def upload(dataset_id, tenant_id): e, kb = KnowledgebaseService.get_by_id(dataset_id) if not e: raise LookupError(f"Can't find the dataset with ID {dataset_id}!") - err, files = FileService.upload_document(kb, file_objs, tenant_id) + err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=request.form.get("parent_path")) if err: return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR) # rename key's name diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 2ff16669d..3e65c87da 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -242,7 +242,7 @@ class Connector2KbService(CommonService): "id": get_uuid(), "connector_id": conn_id, "kb_id": kb_id, - "auto_parse": conn.get("auto_parse", "1") + "auto_parse": conn.get("auto_parse", "1") }) SyncLogsService.schedule(conn_id, kb_id, reindex=True) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index a64ae16de..530133164 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -309,7 +309,7 @@ class DocumentService(CommonService): chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id]) - chunk_ids = settings.docStoreConn.getChunkIds(chunks) + chunk_ids = settings.docStoreConn.get_chunk_ids(chunks) if not chunk_ids: break all_chunk_ids.extend(chunk_ids) @@ -322,7 +322,7 @@ class DocumentService(CommonService): settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) - graph_source = settings.docStoreConn.getFields( + graph_source = settings.docStoreConn.get_fields( settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] ) if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 5a3632e97..2cf4931d0 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -31,7 +31,7 @@ from common.misc_utils import get_uuid from common.constants import TaskStatus, FileSource, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import TaskService -from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img +from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path from rag.llm.cv_model import GptV4 from common import settings @@ -329,7 +329,7 @@ class FileService(CommonService): current_id = start_id while current_id: e, file = cls.get_by_id(current_id) - if file.parent_id != file.id and e: + if e and file.parent_id != file.id: parent_folders.append(file) current_id = file.parent_id else: @@ -423,13 +423,15 @@ class FileService(CommonService): @classmethod @DB.connection_context() - def upload_document(self, kb, file_objs, user_id, src="local"): + def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None): root_folder = self.get_root_folder(user_id) pf_id = root_folder["id"] self.init_knowledgebase_docs(pf_id, user_id) kb_root_folder = self.get_kb_folder(user_id) kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) + safe_parent_path = sanitize_path(parent_path) + err, files = [], [] for file in file_objs: try: @@ -439,7 +441,7 @@ class FileService(CommonService): if filetype == FileType.OTHER.value: raise RuntimeError("This type of file has not been supported yet!") - location = filename + location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}" while settings.STORAGE_IMPL.obj_exist(kb.id, location): location += "_" diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 5f0fa70f4..e67ddd82d 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -164,3 +164,23 @@ def read_potential_broken_pdf(blob): return repaired return blob + + +def sanitize_path(raw_path: str | None) -> str: + """Normalize and sanitize a user-provided path segment. + + - Converts backslashes to forward slashes + - Strips leading/trailing slashes + - Removes '.' and '..' segments + - Restricts characters to A-Za-z0-9, underscore, dash, and '/' + """ + if not raw_path: + return "" + backslash_re = re.compile(r"[\\]+") + unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]") + normalized = backslash_re.sub("/", raw_path) + normalized = normalized.strip("/") + parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")] + sanitized = "/".join(parts) + sanitized = unsafe_re.sub("", sanitized) + return sanitized diff --git a/check_comment_ascii.py b/check_comment_ascii.py new file mode 100644 index 000000000..49cac90d7 --- /dev/null +++ b/check_comment_ascii.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +import sys +import tokenize +import ast +import pathlib +import re + +ASCII = re.compile(r"^[ -~]*\Z") # Only printable ASCII + + +def check(src: str, name: str) -> int: + """ + I'm a docstring + """ + ok = 1 + # A common comment begins with `#` + with tokenize.open(src) as fp: + for tk in tokenize.generate_tokens(fp.readline): + if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string): + print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}") + ok = 0 + # A docstring begins and ends with `'''` + for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc): + print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}") + ok = 0 + return ok + + +if __name__ == "__main__": + status = 0 + for file in sys.argv[1:]: + if not check(file, file): + status = 1 + sys.exit(status) diff --git a/common/data_source/google_util/oauth_flow.py b/common/data_source/google_util/oauth_flow.py index 7e39e5283..e6ba58274 100644 --- a/common/data_source/google_util/oauth_flow.py +++ b/common/data_source/google_util/oauth_flow.py @@ -3,15 +3,9 @@ import os import threading from typing import Any, Callable -import requests - from common.data_source.config import DocumentSource from common.data_source.google_util.constant import GOOGLE_SCOPES -GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code" -GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token" -DEFAULT_DEVICE_INTERVAL = 5 - def _get_requested_scopes(source: DocumentSource) -> list[str]: """Return the scopes to request, honoring an optional override env var.""" @@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag return result.get("value") -def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]: - if "client_id" in credentials: - return credentials["client_id"], credentials.get("client_secret") - for key in ("installed", "web"): - if key in credentials and isinstance(credentials[key], dict): - nested = credentials[key] - if "client_id" not in nested: - break - return nested["client_id"], nested.get("client_secret") - raise ValueError("Provided Google OAuth credentials are missing client_id.") - - -def start_device_authorization_flow( - credentials: dict[str, Any], - source: DocumentSource, -) -> tuple[dict[str, Any], dict[str, Any]]: - client_id, client_secret = _extract_client_info(credentials) - data = { - "client_id": client_id, - "scope": " ".join(_get_requested_scopes(source)), - } - if client_secret: - data["client_secret"] = client_secret - resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15) - resp.raise_for_status() - payload = resp.json() - state = { - "client_id": client_id, - "client_secret": client_secret, - "device_code": payload.get("device_code"), - "interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL), - } - response_data = { - "user_code": payload.get("user_code"), - "verification_url": payload.get("verification_url") or payload.get("verification_uri"), - "verification_url_complete": payload.get("verification_url_complete") - or payload.get("verification_uri_complete"), - "expires_in": payload.get("expires_in"), - "interval": state["interval"], - } - return state, response_data - - -def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]: - data = { - "client_id": state["client_id"], - "device_code": state["device_code"], - "grant_type": "urn:ietf:params:oauth:grant-type:device_code", - } - if state.get("client_secret"): - data["client_secret"] = state["client_secret"] - resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20) - resp.raise_for_status() - return resp.json() - - def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]: """Launch the standard Google OAuth local-server flow to mint user tokens.""" from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore @@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT") port = int(preferred_port) if preferred_port else 0 timeout_secs = _get_oauth_timeout_secs() - timeout_message = ( - f"Google OAuth verification timed out after {timeout_secs} seconds. " - "Close any pending consent windows and rerun the connector configuration to try again." - ) + timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again." print("Launching Google OAuth flow. A browser window should open shortly.") print("If it does not, copy the URL shown in the console into your browser manually.") @@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource instructions = [ "Google rejected one or more of the requested OAuth scopes.", "Fix options:", - " 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes " - " (Drive metadata + Admin Directory read scopes), then re-run the flow.", + " 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.", " 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.", - " 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes " - " (be aware the connector may lose functionality).", ] raise RuntimeError("\n".join(instructions)) from warning raise @@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource) client_config = {"web": credentials["web"]} if client_config is None: - raise ValueError( - "Provided Google OAuth credentials are missing both tokens and a client configuration." - ) + raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.") return _run_local_server_flow(client_config, source) diff --git a/deepdoc/parser/docling_parser.py b/deepdoc/parser/docling_parser.py index dd0f57ea4..9d67478c8 100644 --- a/deepdoc/parser/docling_parser.py +++ b/deepdoc/parser/docling_parser.py @@ -186,9 +186,6 @@ class DoclingParser(RAGFlowPdfParser): yield (DoclingContentType.EQUATION.value, text, bbox) def _transfer_to_sections(self, doc) -> list[tuple[str, str]]: - """ - 和 MinerUParser 保持一致:返回 [(section_text, line_tag), ...] - """ sections: list[tuple[str, str]] = [] for typ, payload, bbox in self._iter_doc_items(doc): if typ == DoclingContentType.TEXT.value: diff --git a/deepdoc/parser/figure_parser.py b/deepdoc/parser/figure_parser.py index a913822c3..f659b3847 100644 --- a/deepdoc/parser/figure_parser.py +++ b/deepdoc/parser/figure_parser.py @@ -34,6 +34,7 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions): if isinstance(figure_data[1], Image.Image) ] + def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs): try: vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT) @@ -50,7 +51,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs): callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.") return tbls -def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs): + +def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs): try: vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT) callback(0.7, "Visual model detected. Attempting to enhance figure extraction...") @@ -72,6 +74,7 @@ def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs): callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.") return tbls + shared_executor = ThreadPoolExecutor(max_workers=10) diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index f9bea6903..207fb0e84 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -117,7 +117,6 @@ def load_model(model_dir, nm, device_id: int | None = None): providers=['CUDAExecutionProvider'], provider_options=[cuda_provider_options] ) - run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id)) logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})") else: sess = ort.InferenceSession( diff --git a/docs/guides/dataset/add_data_source/_category_.json b/docs/guides/dataset/add_data_source/_category_.json new file mode 100644 index 000000000..42f2b164a --- /dev/null +++ b/docs/guides/dataset/add_data_source/_category_.json @@ -0,0 +1,8 @@ +{ + "label": "Add data source", + "position": 18, + "link": { + "type": "generated-index", + "description": "Add various data sources" + } +} diff --git a/docs/guides/dataset/add_data_source/add_google_drive.md b/docs/guides/dataset/add_data_source/add_google_drive.md new file mode 100644 index 000000000..b4fdf14f4 --- /dev/null +++ b/docs/guides/dataset/add_data_source/add_google_drive.md @@ -0,0 +1,137 @@ +--- +sidebar_position: 3 +slug: /add_google_drive +--- + +# Add Google Drive + +## 1. Create a Google Cloud Project + +You can either create a dedicated project for RAGFlow or use an existing +Google Cloud external project. + +**Steps:** +1. Open the project creation page\ +`https://console.cloud.google.com/projectcreate` +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image1.jpeg?raw=true) +2. Select **External** as the Audience +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image2.png?raw=true) +3. Click **Create** +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image3.jpeg?raw=true) + +------------------------------------------------------------------------ + +## 2. Configure OAuth Consent Screen + +1. Go to **APIs & Services → OAuth consent screen** +2. Ensure **User Type = External** +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image4.jpeg?raw=true) +3. Add your test users under **Test Users** by entering email addresses +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image5.jpeg?raw=true) +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image6.jpeg?raw=true) + +------------------------------------------------------------------------ + +## 3. Create OAuth Client Credentials + +1. Navigate to:\ + `https://console.cloud.google.com/auth/clients` +2. Create a **Web Application** +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image7.png?raw=true) +3. Enter a name for the client +4. Add the following **Authorized Redirect URIs**: + +``` +http://localhost:9380/v1/connector/google-drive/oauth/web/callback +``` + +### If using Docker deployment: + +**Authorized JavaScript origin:** +``` +http://localhost:80 +``` + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image8.png?raw=true) +### If running from source: +**Authorized JavaScript origin:** +``` +http://localhost:9222 +``` + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image9.png?raw=true) +5. After saving, click **Download JSON**. This file will later be + uploaded into RAGFlow. + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image10.png?raw=true) + +------------------------------------------------------------------------ + +## 4. Add Scopes + +1. Open **Data Access → Add or remove scopes** + +2. Paste and add the following entries: + +``` +https://www.googleapis.com/auth/drive.readonly +https://www.googleapis.com/auth/drive.metadata.readonly +https://www.googleapis.com/auth/admin.directory.group.readonly +https://www.googleapis.com/auth/admin.directory.user.readonly +``` + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image11.jpeg?raw=true) +3. Update and Save changes + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image12.jpeg?raw=true) +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image13.jpeg?raw=true) + +------------------------------------------------------------------------ + +## 5. Enable Required APIs +Navigate to the Google API Library:\ +`https://console.cloud.google.com/apis/library` +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image14.png?raw=true) + +Enable the following APIs: + +- Google Drive API +- Admin SDK API +- Google Sheets API +- Google Docs API + + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image15.png?raw=true) + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image16.png?raw=true) + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image17.png?raw=true) + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image18.png?raw=true) + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image19.png?raw=true) + +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image21.png?raw=true) + +------------------------------------------------------------------------ + +## 6. Add Google Drive As a Data Source in RAGFlow + +1. Go to **Data Sources** inside RAGFlow +2. Select **Google Drive** +3. Upload the previously downloaded JSON credentials +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image22.jpeg?raw=true) +4. Enter the shared Google Drive folder link (https://drive.google.com/drive), such as: +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image23.png?raw=true) + +5. Click **Authorize with Google** +A browser window will appear. +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image25.jpeg?raw=true) +Click: - **Continue** - **Select All → Continue** - Authorization should +succeed - Select **OK** to add the data source +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image26.jpeg?raw=true) +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image27.jpeg?raw=true) +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image28.png?raw=true) +![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image29.png?raw=true) + + diff --git a/docs/guides/dataset/best_practices/_category_.json b/docs/guides/dataset/best_practices/_category_.json index f55fe009b..79a1103d5 100644 --- a/docs/guides/dataset/best_practices/_category_.json +++ b/docs/guides/dataset/best_practices/_category_.json @@ -1,6 +1,6 @@ { "label": "Best practices", - "position": 11, + "position": 19, "link": { "type": "generated-index", "description": "Best practices on configuring a dataset." diff --git a/docs/guides/manage_users_and_services.md b/docs/guides/manage_users_and_services.md index 1d7f0fa64..a6e8a3314 100644 --- a/docs/guides/manage_users_and_services.md +++ b/docs/guides/manage_users_and_services.md @@ -64,7 +64,10 @@ The Admin CLI and Admin Service form a client-server architectural suite for RAG - -p: RAGFlow admin server port +## Default administrative account +- Username: admin@ragflow.io +- Password: admin ## Supported Commands diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index f2b86a735..481614d13 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -974,6 +974,237 @@ Failure: --- +### Construct knowledge graph + +**POST** `/api/v1/datasets/{dataset_id}/run_graphrag` + +Constructs a knowledge graph from a specified dataset. + +#### Request + +- Method: POST +- URL: `/api/v1/datasets/{dataset_id}/run_graphrag` +- Headers: + - `'Authorization: Bearer '` + +##### Request example + +```bash +curl --request POST \ + --url http://{address}/api/v1/datasets/{dataset_id}/run_graphrag \ + --header 'Authorization: Bearer ' +``` + +##### Request parameters + +- `dataset_id`: (*Path parameter*) + The ID of the target dataset. + +#### Response + +Success: + +```json +{ + "code":0, + "data":{ + "graphrag_task_id":"e498de54bfbb11f0ba028f704583b57b" + } +} +``` + +Failure: + +```json +{ + "code": 102, + "message": "Invalid Dataset ID" +} +``` + +--- + +### Get knowledge graph construction status + +**GET** `/api/v1/datasets/{dataset_id}/trace_graphrag` + +Retrieves the knowledge graph construction status for a specified dataset. + +#### Request + +- Method: GET +- URL: `/api/v1/datasets/{dataset_id}/trace_graphrag` +- Headers: + - `'Authorization: Bearer '` + +##### Request example + +```bash +curl --request GET \ + --url http://{address}/api/v1/datasets/{dataset_id}/trace_graphrag \ + --header 'Authorization: Bearer ' +``` + +##### Request parameters + +- `dataset_id`: (*Path parameter*) + The ID of the target dataset. + +#### Response + +Success: + +```json +{ + "code":0, + "data":{ + "begin_at":"Wed, 12 Nov 2025 19:36:56 GMT", + "chunk_ids":"", + "create_date":"Wed, 12 Nov 2025 19:36:56 GMT", + "create_time":1762947416350, + "digest":"39e43572e3dcd84f", + "doc_id":"44661c10bde211f0bc93c164a47ffc40", + "from_page":100000000, + "id":"e498de54bfbb11f0ba028f704583b57b", + "priority":0, + "process_duration":2.45419, + "progress":1.0, + "progress_msg":"19:36:56 created task graphrag\n19:36:57 Task has been received.\n19:36:58 [GraphRAG] doc:083661febe2411f0bc79456921e5745f has no available chunks, skip generation.\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 start (chunks=1, timeout=10000000000s)\n19:36:58 Graph already contains 44661c10bde211f0bc93c164a47ffc40\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 empty\n19:36:58 [GraphRAG] kb:33137ed0bde211f0bc93c164a47ffc40 no subgraphs generated successfully, end.\n19:36:58 Knowledge Graph done (0.72s)","retry_count":1, + "task_type":"graphrag", + "to_page":100000000, + "update_date":"Wed, 12 Nov 2025 19:36:58 GMT", + "update_time":1762947418454 + } +} +``` + +Failure: + +```json +{ + "code": 102, + "message": "Invalid Dataset ID" +} +``` + +--- + +### Construct RAPTOR + +**POST** `/api/v1/datasets/{dataset_id}/run_raptor` + +Construct a RAPTOR from a specified dataset. + +#### Request + +- Method: POST +- URL: `/api/v1/datasets/{dataset_id}/run_raptor` +- Headers: + - `'Authorization: Bearer '` + +##### Request example + +```bash +curl --request POST \ + --url http://{address}/api/v1/datasets/{dataset_id}/run_raptor \ + --header 'Authorization: Bearer ' +``` + +##### Request parameters + +- `dataset_id`: (*Path parameter*) + The ID of the target dataset. + +#### Response + +Success: + +```json +{ + "code":0, + "data":{ + "raptor_task_id":"50d3c31cbfbd11f0ba028f704583b57b" + } +} +``` + +Failure: + +```json +{ + "code": 102, + "message": "Invalid Dataset ID" +} +``` + +--- + +### Get RAPTOR construction status + +**GET** `/api/v1/datasets/{dataset_id}/trace_raptor` + +Retrieves the RAPTOR construction status for a specified dataset. + +#### Request + +- Method: GET +- URL: `/api/v1/datasets/{dataset_id}/trace_raptor` +- Headers: + - `'Authorization: Bearer '` + +##### Request example + +```bash +curl --request GET \ + --url http://{address}/api/v1/datasets/{dataset_id}/trace_raptor \ + --header 'Authorization: Bearer ' +``` + +##### Request parameters + +- `dataset_id`: (*Path parameter*) + The ID of the target dataset. + +#### Response + +Success: + +```json +{ + "code":0, + "data":{ + "begin_at":"Wed, 12 Nov 2025 19:47:07 GMT", + "chunk_ids":"", + "create_date":"Wed, 12 Nov 2025 19:47:07 GMT", + "create_time":1762948027427, + "digest":"8b279a6248cb8fc6", + "doc_id":"44661c10bde211f0bc93c164a47ffc40", + "from_page":100000000, + "id":"50d3c31cbfbd11f0ba028f704583b57b", + "priority":0, + "process_duration":0.948244, + "progress":1.0, + "progress_msg":"19:47:07 created task raptor\n19:47:07 Task has been received.\n19:47:07 Processing...\n19:47:07 Processing...\n19:47:07 Indexing done (0.01s).\n19:47:07 Task done (0.29s)", + "retry_count":1, + "task_type":"raptor", + "to_page":100000000, + "update_date":"Wed, 12 Nov 2025 19:47:07 GMT", + "update_time":1762948027948 + } +} +``` + +Failure: + +```json +{ + "code": 102, + "message": "Invalid Dataset ID" +} +``` + +--- + ## FILE MANAGEMENT WITHIN DATASET --- diff --git a/graphrag/general/extractor.py b/graphrag/general/extractor.py index 1df38ed1c..495e562ed 100644 --- a/graphrag/general/extractor.py +++ b/graphrag/general/extractor.py @@ -114,7 +114,7 @@ class Extractor: async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""): out_results = [] error_count = 0 - max_errors = 3 + max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3)) limiter = trio.Semaphore(max_concurrency) diff --git a/graphrag/search.py b/graphrag/search.py index 860f14bcb..b3a0104e1 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -69,7 +69,7 @@ class KGSearch(Dealer): def _ent_info_from_(self, es_res, sim_thr=0.3): res = {} flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"] - es_res = self.dataStore.getFields(es_res, flds) + es_res = self.dataStore.get_fields(es_res, flds) for _, ent in es_res.items(): for f in flds: if f in ent and ent[f] is None: @@ -88,7 +88,7 @@ class KGSearch(Dealer): def _relation_info_from_(self, es_res, sim_thr=0.3): res = {} - es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", + es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"]) for _, ent in es_res.items(): if get_float(ent["_score"]) < sim_thr: @@ -300,7 +300,7 @@ class KGSearch(Dealer): fltr["entities_kwd"] = entities comm_res = self.dataStore.search(fields, [], fltr, [], OrderByExpr(), 0, topn, idxnms, kb_ids) - comm_res_fields = self.dataStore.getFields(comm_res, fields) + comm_res_fields = self.dataStore.get_fields(comm_res, fields) txts = [] for ii, (_, row) in enumerate(comm_res_fields.items()): obj = json.loads(row["content_with_weight"]) diff --git a/graphrag/utils.py b/graphrag/utils.py index 6a8df1e40..51a9c1abc 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): "removed_kwd": "N", } res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) - fields2 = settings.docStoreConn.getFields(res, fields) + fields2 = settings.docStoreConn.get_fields(res, fields) graph_doc_ids = set() for chunk_id in fields2.keys(): graph_doc_ids = set(fields2[chunk_id]["source_id"]) @@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): es_res = await trio.to_thread.run_sync( lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) ) - # tot = settings.docStoreConn.getTotal(es_res) - es_res = settings.docStoreConn.getFields(es_res, flds) + # tot = settings.docStoreConn.get_total(es_res) + es_res = settings.docStoreConn.get_fields(es_res, flds) if len(es_res) == 0: break diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 61a3b6f3a..80acf1d8f 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -482,7 +482,7 @@ def tree_merge(bull, sections, depth): root = Node(level=0, depth=target_level, texts=[]) root.build_tree(lines) - return [("\n").join(element) for element in root.get_tree() if element] + return [element for element in root.get_tree() if element] def hierarchical_merge(bull, sections, depth): diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 68d2d2979..ec3628525 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -38,11 +38,11 @@ class FulltextQueryer: ] @staticmethod - def subSpecialChar(line): + def sub_special_char(line): return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip() @staticmethod - def isChinese(line): + def is_chinese(line): arr = re.split(r"[ \t]+", line) if len(arr) <= 3: return True @@ -92,7 +92,7 @@ class FulltextQueryer: otxt = txt txt = FulltextQueryer.rmWWW(txt) - if not self.isChinese(txt): + if not self.is_chinese(txt): txt = FulltextQueryer.rmWWW(txt) tks = rag_tokenizer.tokenize(txt).split() keywords = [t for t in tks if t] @@ -163,7 +163,7 @@ class FulltextQueryer: ) for m in sm ] - sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1] + sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1] sm = [m for m in sm if len(m) > 1] if len(keywords) < 32: @@ -171,7 +171,7 @@ class FulltextQueryer: keywords.extend(sm) tk_syns = self.syn.lookup(tk) - tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] + tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns] if len(keywords) < 32: keywords.extend([s for s in tk_syns if s]) tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] @@ -180,7 +180,7 @@ class FulltextQueryer: if len(keywords) >= 32: break - tk = FulltextQueryer.subSpecialChar(tk) + tk = FulltextQueryer.sub_special_char(tk) if tk.find(" ") > 0: tk = '"%s"' % tk if tk_syns: @@ -198,7 +198,7 @@ class FulltextQueryer: syns = " OR ".join( [ '"%s"' - % rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s)) + % rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s)) for s in syns ] ) @@ -217,17 +217,17 @@ class FulltextQueryer: return None, keywords def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7): - from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity + from sklearn.metrics.pairwise import cosine_similarity import numpy as np - sims = CosineSimilarity([avec], bvecs) + sims = cosine_similarity([avec], bvecs) tksim = self.token_similarity(atks, btkss) if np.sum(sims[0]) == 0: return np.array(tksim), tksim, sims[0] return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] def token_similarity(self, atks, btkss): - def toDict(tks): + def to_dict(tks): if isinstance(tks, str): tks = tks.split() d = defaultdict(int) @@ -236,8 +236,8 @@ class FulltextQueryer: d[t] += c return d - atks = toDict(atks) - btkss = [toDict(tks) for tks in btkss] + atks = to_dict(atks) + btkss = [to_dict(tks) for tks in btkss] return [self.similarity(atks, btks) for btks in btkss] def similarity(self, qtwt, dtwt): @@ -262,10 +262,10 @@ class FulltextQueryer: keywords = [f'"{k.strip()}"' for k in keywords] for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]: tk_syns = self.syn.lookup(tk) - tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] + tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] - tk = FulltextQueryer.subSpecialChar(tk) + tk = FulltextQueryer.sub_special_char(tk) if tk.find(" ") > 0: tk = '"%s"' % tk if tk_syns: diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index 3c4b97833..c95c18e74 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -35,7 +35,7 @@ class RagTokenizer: def rkey_(self, line): return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1] - def loadDict_(self, fnm): + def _load_dict(self, fnm): logging.info(f"[HUQIE]:Build trie from {fnm}") try: of = open(fnm, "r", encoding='utf-8') @@ -85,18 +85,18 @@ class RagTokenizer: self.trie_ = datrie.Trie(string.printable) # load data from dict file and save to trie file - self.loadDict_(self.DIR_ + ".txt") + self._load_dict(self.DIR_ + ".txt") - def loadUserDict(self, fnm): + def load_user_dict(self, fnm): try: self.trie_ = datrie.Trie.load(fnm + ".trie") return except Exception: self.trie_ = datrie.Trie(string.printable) - self.loadDict_(fnm) + self._load_dict(fnm) - def addUserDict(self, fnm): - self.loadDict_(fnm) + def add_user_dict(self, fnm): + self._load_dict(fnm) def _strQ2B(self, ustring): """Convert full-width characters to half-width characters""" @@ -221,7 +221,7 @@ class RagTokenizer: logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F)) return tks, B / len(tks) + L + F - def sortTks_(self, tkslist): + def _sort_tokens(self, tkslist): res = [] for tfts in tkslist: tks, s = self.score_(tfts) @@ -246,7 +246,7 @@ class RagTokenizer: return " ".join(res) - def maxForward_(self, line): + def _max_forward(self, line): res = [] s = 0 while s < len(line): @@ -270,7 +270,7 @@ class RagTokenizer: return self.score_(res) - def maxBackward_(self, line): + def _max_backward(self, line): res = [] s = len(line) - 1 while s >= 0: @@ -336,8 +336,8 @@ class RagTokenizer: continue # use maxforward for the first time - tks, s = self.maxForward_(L) - tks1, s1 = self.maxBackward_(L) + tks, s = self._max_forward(L) + tks1, s1 = self._max_backward(L) if self.DEBUG: logging.debug("[FW] {} {}".format(tks, s)) logging.debug("[BW] {} {}".format(tks1, s1)) @@ -369,7 +369,7 @@ class RagTokenizer: # backward tokens from_i to i are different from forward tokens from _j to j. tkslist = [] self.dfs_("".join(tks[_j:j]), 0, [], tkslist) - res.append(" ".join(self.sortTks_(tkslist)[0][0])) + res.append(" ".join(self._sort_tokens(tkslist)[0][0])) same = 1 while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: @@ -385,7 +385,7 @@ class RagTokenizer: assert "".join(tks1[_i:]) == "".join(tks[_j:]) tkslist = [] self.dfs_("".join(tks[_j:]), 0, [], tkslist) - res.append(" ".join(self.sortTks_(tkslist)[0][0])) + res.append(" ".join(self._sort_tokens(tkslist)[0][0])) res = " ".join(res) logging.debug("[TKS] {}".format(self.merge_(res))) @@ -413,7 +413,7 @@ class RagTokenizer: if len(tkslist) < 2: res.append(tk) continue - stk = self.sortTks_(tkslist)[1][0] + stk = self._sort_tokens(tkslist)[1][0] if len(stk) == len(tk): stk = tk else: @@ -447,14 +447,13 @@ def is_number(s): def is_alphabet(s): - if (s >= u'\u0041' and s <= u'\u005a') or ( - s >= u'\u0061' and s <= u'\u007a'): + if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'): return True else: return False -def naiveQie(txt): +def naive_qie(txt): tks = [] for t in txt.split(): if tks and re.match(r".*[a-zA-Z]$", tks[-1] @@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize fine_grained_tokenize = tokenizer.fine_grained_tokenize tag = tokenizer.tag freq = tokenizer.freq -loadUserDict = tokenizer.loadUserDict -addUserDict = tokenizer.addUserDict +load_user_dict = tokenizer.load_user_dict +add_user_dict = tokenizer.add_user_dict tradi2simp = tokenizer._tradi2simp strQ2B = tokenizer._strQ2B if __name__ == '__main__': tknzr = RagTokenizer(debug=True) - # huqie.addUserDict("/tmp/tmp.new.tks.dict") + # huqie.add_user_dict("/tmp/tmp.new.tks.dict") tks = tknzr.tokenize( "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈") logging.info(tknzr.fine_grained_tokenize(tks)) @@ -506,7 +505,7 @@ if __name__ == '__main__': if len(sys.argv) < 2: sys.exit() tknzr.DEBUG = False - tknzr.loadUserDict(sys.argv[1]) + tknzr.load_user_dict(sys.argv[1]) of = open(sys.argv[2], "r") while True: line = of.readline() diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 1bf0abe04..f8b3d513f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -102,7 +102,7 @@ class Dealer: orderBy.asc("top_int") orderBy.desc("create_timestamp_flt") res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) else: highlightFields = ["content_ltks", "title_tks"] @@ -115,7 +115,7 @@ class Dealer: matchExprs = [matchText] res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) else: matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) @@ -127,20 +127,20 @@ class Dealer: res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) # If result is empty, try again with lower min_match if total == 0: if filters.get("doc_id"): res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) else: matchText, _ = self.qryr.question(qst, min_match=0.1) matchDense.extra_options["similarity"] = 0.17 res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) logging.debug("Dealer.search 2 TOTAL: {}".format(total)) for k in keywords: @@ -153,17 +153,17 @@ class Dealer: kwds.add(kk) logging.debug(f"TOTAL: {total}") - ids = self.dataStore.getChunkIds(res) + ids = self.dataStore.get_chunk_ids(res) keywords = list(kwds) - highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") - aggs = self.dataStore.getAggregation(res, "docnm_kwd") + highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight") + aggs = self.dataStore.get_aggregation(res, "docnm_kwd") return self.SearchResult( total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, - field=self.dataStore.getFields(res, src + ["_score"]), + field=self.dataStore.get_fields(res, src + ["_score"]), keywords=keywords ) @@ -488,7 +488,7 @@ class Dealer: for p in range(offset, max_count, bs): es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id), kb_ids) - dict_chunks = self.dataStore.getFields(es_res, fields) + dict_chunks = self.dataStore.get_fields(es_res, fields) for id, doc in dict_chunks.items(): doc["id"] = id if dict_chunks: @@ -501,11 +501,11 @@ class Dealer: if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]): return [] res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) - return self.dataStore.getAggregation(res, "tag_kwd") + return self.dataStore.get_aggregation(res, "tag_kwd") def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000): res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) - res = self.dataStore.getAggregation(res, "tag_kwd") + res = self.dataStore.get_aggregation(res, "tag_kwd") total = np.sum([c for _, c in res]) return {t: (c + 1) / (total + S) for t, c in res} @@ -513,7 +513,7 @@ class Dealer: idx_nm = index_name(tenant_id) match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"]) - aggs = self.dataStore.getAggregation(res, "tag_kwd") + aggs = self.dataStore.get_aggregation(res, "tag_kwd") if not aggs: return False cnt = np.sum([c for _, c in aggs]) @@ -529,7 +529,7 @@ class Dealer: idx_nms = [index_name(tid) for tid in tenant_ids] match_txt, _ = self.qryr.question(question, min_match=0.0) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"]) - aggs = self.dataStore.getAggregation(res, "tag_kwd") + aggs = self.dataStore.get_aggregation(res, "tag_kwd") if not aggs: return {} cnt = np.sum([c for _, c in aggs]) @@ -552,7 +552,7 @@ class Dealer: es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms, kb_ids) toc = [] - dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"]) + dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"]) for _, doc in dict_chunks.items(): try: toc.extend(json.loads(doc["content_with_weight"])) diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py index 392117c18..28ed585ee 100644 --- a/rag/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -113,20 +113,20 @@ class Dealer: res.append(tk) return res - def tokenMerge(self, tks): - def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) + def token_merge(self, tks): + def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) res, i = [], 0 while i < len(tks): j = i - if i == 0 and oneTerm(tks[i]) and len( + if i == 0 and one_term(tks[i]) and len( tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位 res.append(" ".join(tks[0:2])) i = 2 continue while j < len( - tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]): + tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]): j += 1 if j - i > 1: if j - i < 5: @@ -232,7 +232,7 @@ class Dealer: tw = list(zip(tks, wts)) else: for tk in tks: - tt = self.tokenMerge(self.pretoken(tk, True)) + tt = self.token_merge(self.pretoken(tk, True)) idf1 = np.array([idf(freq(t), 10000000) for t in tt]) idf2 = np.array([idf(df(t), 1000000000) for t in tt]) wts = (0.3 * idf1 + 0.7 * idf2) * \ diff --git a/rag/raptor.py b/rag/raptor.py index e6efe3504..a455d0127 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -15,27 +15,35 @@ # import logging import re -import umap + import numpy as np -from sklearn.mixture import GaussianMixture import trio +import umap +from sklearn.mixture import GaussianMixture from api.db.services.task_service import has_canceled from common.connection_utils import timeout from common.exceptions import TaskCanceledException +from common.token_utils import truncate from graphrag.utils import ( - get_llm_cache, + chat_limiter, get_embed_cache, + get_llm_cache, set_embed_cache, set_llm_cache, - chat_limiter, ) -from common.token_utils import truncate class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: def __init__( - self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 + self, + max_cluster, + llm_model, + embd_model, + prompt, + max_token=512, + threshold=0.1, + max_errors=3, ): self._max_cluster = max_cluster self._llm_model = llm_model @@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: self._threshold = threshold self._prompt = prompt self._max_token = max_token + self._max_errors = max(1, max_errors) + self._error_count = 0 - @timeout(60*20) + @timeout(60 * 20) async def _chat(self, system, history, gen_conf): - response = await trio.to_thread.run_sync( - lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) - ) + cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)) + if cached: + return cached - if response: - return response - response = await trio.to_thread.run_sync( - lambda: self._llm_model.chat(system, history, gen_conf) - ) - response = re.sub(r"^.*", "", response, flags=re.DOTALL) - if response.find("**ERROR**") >= 0: - raise Exception(response) - await trio.to_thread.run_sync( - lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) - ) - return response + last_exc = None + for attempt in range(3): + try: + response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf)) + response = re.sub(r"^.*", "", response, flags=re.DOTALL) + if response.find("**ERROR**") >= 0: + raise Exception(response) + await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)) + return response + except Exception as exc: + last_exc = exc + logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc) + if attempt < 2: + await trio.sleep(1 + attempt) + + raise last_exc if last_exc else Exception("LLM chat failed without exception") @timeout(20) async def _embedding_encode(self, txt): - response = await trio.to_thread.run_sync( - lambda: get_embed_cache(self._embd_model.llm_name, txt) - ) + response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt)) if response is not None: return response embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) @@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: n_clusters = np.arange(1, max_clusters) bics = [] for n in n_clusters: - if task_id: if has_canceled(task_id): logging.info(f"Task {task_id} cancelled during get optimal clusters.") @@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: layers = [(0, len(chunks))] start, end = 0, len(chunks) - @timeout(60*20) + @timeout(60 * 20) async def summarize(ck_idx: list[int]): nonlocal chunks @@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: raise TaskCanceledException(f"Task {task_id} was cancelled") texts = [chunks[i][0] for i in ck_idx] - len_per_chunk = int( - (self._llm_model.max_length - self._max_token) / len(texts) - ) - cluster_content = "\n".join( - [truncate(t, max(1, len_per_chunk)) for t in texts] - ) - async with chat_limiter: + len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) + cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) + try: + async with chat_limiter: + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") + raise TaskCanceledException(f"Task {task_id} was cancelled") - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.") - raise TaskCanceledException(f"Task {task_id} was cancelled") + cnt = await self._chat( + "You're a helpful assistant.", + [ + { + "role": "user", + "content": self._prompt.format(cluster_content=cluster_content), + } + ], + {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 + ) + cnt = re.sub( + "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", + "", + cnt, + ) + logging.debug(f"SUM: {cnt}") - cnt = await self._chat( - "You're a helpful assistant.", - [ - { - "role": "user", - "content": self._prompt.format( - cluster_content=cluster_content - ), - } - ], - {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 - ) - cnt = re.sub( - "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", - "", - cnt, - ) - logging.debug(f"SUM: {cnt}") + if task_id and has_canceled(task_id): + logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") + raise TaskCanceledException(f"Task {task_id} was cancelled") - if task_id and has_canceled(task_id): - logging.info(f"Task {task_id} cancelled before RAPTOR embedding.") - raise TaskCanceledException(f"Task {task_id} was cancelled") - - embds = await self._embedding_encode(cnt) - chunks.append((cnt, embds)) + embds = await self._embedding_encode(cnt) + chunks.append((cnt, embds)) + except TaskCanceledException: + raise + except Exception as exc: + self._error_count += 1 + warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}" + logging.warning(warn_msg) + if callback: + callback(msg=warn_msg) + if self._error_count >= self._max_errors: + raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc labels = [] while end - start > 1: - if task_id: if has_canceled(task_id): logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.") @@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: if len(embeddings) == 2: await summarize([start, start + 1]) if callback: - callback( - msg="Cluster one layer: {} -> {}".format( - end - start, len(chunks) - end - ) - ) + callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) labels.extend([0, 0]) layers.append((end, len(chunks))) start = end @@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: nursery.start_soon(summarize, ck_idx) - assert len(chunks) - end == n_clusters, "{} vs. {}".format( - len(chunks) - end, n_clusters - ) + assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) labels.extend(lbls) layers.append((end, len(chunks))) if callback: - callback( - msg="Cluster one layer: {} -> {}".format( - end - start, len(chunks) - end - ) - ) + callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) start = end end = len(chunks) diff --git a/rag/svr/cache_file_svr.py b/rag/svr/cache_file_svr.py index 89ab8b75f..3744c04ea 100644 --- a/rag/svr/cache_file_svr.py +++ b/rag/svr/cache_file_svr.py @@ -28,7 +28,7 @@ def collect(): logging.debug(doc_locations) if len(doc_locations) == 0: time.sleep(1) - return + return None return doc_locations diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index af8dfc186..370bd2a10 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback): task_canceled = has_canceled(task["id"]) if task_canceled: progress_callback(-1, msg="Task has been canceled.") - return + return None if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) else: @@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback): d["page_num_int"] = [100000000] d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() return d + return None def init_kb(row, vector_size: int): @@ -441,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): tk_count = 0 if len(tts) == len(cnts): vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1])) - tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0) + tts = np.tile(vts[0], (len(cnts), 1)) tk_count += c @timeout(60) @@ -464,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None): if not filename_embd_weight: filename_embd_weight = 0.1 title_w = float(filename_embd_weight) - vects = (title_w * tts + (1 - title_w) * - cnts) if len(tts) == len(cnts) else cnts + if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape: + vects = title_w * tts + (1 - title_w) * cnts + else: + vects = cnts assert len(vects) == len(docs) vector_size = 0 @@ -648,6 +651,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si res = [] tk_count = 0 + max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3)) + async def generate(chunks, did): nonlocal tk_count, res raptor = Raptor( @@ -657,6 +662,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si raptor_config["prompt"], raptor_config["max_token"], raptor_config["threshold"], + max_errors=max_errors, ) original_length = len(chunks) chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) @@ -719,7 +725,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") - return + return False if b % 128 == 0: progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") if doc_store_result: @@ -737,7 +743,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c for chunk_id in chunk_ids: nursery.start_soon(delete_image, task_dataset_id, chunk_id) progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") - return + return False return True diff --git a/rag/utils/azure_spn_conn.py b/rag/utils/azure_spn_conn.py index f47470d67..005d3ba6b 100644 --- a/rag/utils/azure_spn_conn.py +++ b/rag/utils/azure_spn_conn.py @@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob: logging.exception(f"Fail put {bucket}/{fnm}") self.__open__() time.sleep(1) + return None + return None def rm(self, bucket, fnm): try: @@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob: logging.exception(f"fail get {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None def obj_exist(self, bucket, fnm): try: @@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob: logging.exception(f"fail get {bucket}/{fnm}") self.__open__() time.sleep(1) - return \ No newline at end of file + return None \ No newline at end of file diff --git a/rag/utils/doc_store_conn.py b/rag/utils/doc_store_conn.py index c3fa61b0c..33f030011 100644 --- a/rag/utils/doc_store_conn.py +++ b/rag/utils/doc_store_conn.py @@ -241,23 +241,23 @@ class DocStoreConnection(ABC): """ @abstractmethod - def getTotal(self, res): + def get_total(self, res): raise NotImplementedError("Not implemented") @abstractmethod - def getChunkIds(self, res): + def get_chunk_ids(self, res): raise NotImplementedError("Not implemented") @abstractmethod - def getFields(self, res, fields: list[str]) -> dict[str, dict]: + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: raise NotImplementedError("Not implemented") @abstractmethod - def getHighlight(self, res, keywords: list[str], fieldnm: str): + def get_highlight(self, res, keywords: list[str], fieldnm: str): raise NotImplementedError("Not implemented") @abstractmethod - def getAggregation(self, res, fieldnm: str): + def get_aggregation(self, res, fieldnm: str): raise NotImplementedError("Not implemented") """ diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index e99ee1375..5971950cf 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection): Helper functions for search result """ - def getTotal(self, res): + def get_total(self, res): if isinstance(res["hits"]["total"], type({})): return res["hits"]["total"]["value"] return res["hits"]["total"] - def getChunkIds(self, res): + def get_chunk_ids(self, res): return [d["_id"] for d in res["hits"]["hits"]] def __getSource(self, res): @@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection): rr.append(d["_source"]) return rr - def getFields(self, res, fields: list[str]) -> dict[str, dict]: + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} @@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection): res_fields[d["id"]] = m return res_fields - def getHighlight(self, res, keywords: list[str], fieldnm: str): + def get_highlight(self, res, keywords: list[str], fieldnm: str): ans = {} for d in res["hits"]["hits"]: hlts = d.get("highlight") @@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection): return ans - def getAggregation(self, res, fieldnm: str): + def get_aggregation(self, res, fieldnm: str): agg_field = "aggs_" + fieldnm if "aggregations" not in res or agg_field not in res["aggregations"]: return list() diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 03251e72c..ab575f9bc 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection): df_list.append(kb_res) self.connPool.release_conn(inf_conn) res = concat_dataframes(df_list, ["id"]) - res_fields = self.getFields(res, res.columns.tolist()) + res_fields = self.get_fields(res, res.columns.tolist()) return res_fields.get(chunkId, None) def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: @@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection): col_to_remove = list(removeValue.keys()) row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df() logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}") - row_to_opt = self.getFields(row_to_opt, col_to_remove) + row_to_opt = self.get_fields(row_to_opt, col_to_remove) for id, old_v in row_to_opt.items(): for k, remove_v in removeValue.items(): if remove_v in old_v[k]: @@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection): Helper functions for search result """ - def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: + def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: if isinstance(res, tuple): return res[1] return len(res) - def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: + def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: if isinstance(res, tuple): res = res[0] return list(res["id"]) - def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: + def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: if isinstance(res, tuple): res = res[0] if not fields: @@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection): return res2.set_index("id").to_dict(orient="index") - def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str): + def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str): if isinstance(res, tuple): res = res[0] ans = {} @@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection): ans[id] = txt return ans - def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str): + def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str): """ Manual aggregation for tag fields since Infinity doesn't provide native aggregation """ diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index 75cd2725b..e0913e98b 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -92,7 +92,7 @@ class RAGFlowMinio: logging.exception(f"Fail to get {bucket}/{filename}") self.__open__() time.sleep(1) - return + return None def obj_exist(self, bucket, filename, tenant_id=None): try: @@ -130,7 +130,7 @@ class RAGFlowMinio: logging.exception(f"Fail to get_presigned {bucket}/{fnm}:") self.__open__() time.sleep(1) - return + return None def remove_bucket(self, bucket): try: diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index 54650b54b..c6cebf9ca 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -62,8 +62,7 @@ class OpenDALStorage: def health(self): bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" - r = self._operator.write(f"{bucket}/{fnm}", binary) - return r + return self._operator.write(f"{bucket}/{fnm}", binary) def put(self, bucket, fnm, binary, tenant_id=None): self._operator.write(f"{bucket}/{fnm}", binary) diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py index c862b52e9..2df1d65ee 100644 --- a/rag/utils/opensearch_conn.py +++ b/rag/utils/opensearch_conn.py @@ -455,12 +455,12 @@ class OSConnection(DocStoreConnection): Helper functions for search result """ - def getTotal(self, res): + def get_total(self, res): if isinstance(res["hits"]["total"], type({})): return res["hits"]["total"]["value"] return res["hits"]["total"] - def getChunkIds(self, res): + def get_chunk_ids(self, res): return [d["_id"] for d in res["hits"]["hits"]] def __getSource(self, res): @@ -471,7 +471,7 @@ class OSConnection(DocStoreConnection): rr.append(d["_source"]) return rr - def getFields(self, res, fields: list[str]) -> dict[str, dict]: + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} @@ -490,7 +490,7 @@ class OSConnection(DocStoreConnection): res_fields[d["id"]] = m return res_fields - def getHighlight(self, res, keywords: list[str], fieldnm: str): + def get_highlight(self, res, keywords: list[str], fieldnm: str): ans = {} for d in res["hits"]["hits"]: hlts = d.get("highlight") @@ -515,7 +515,7 @@ class OSConnection(DocStoreConnection): return ans - def getAggregation(self, res, fieldnm: str): + def get_aggregation(self, res, fieldnm: str): agg_field = "aggs_" + fieldnm if "aggregations" not in res or agg_field not in res["aggregations"]: return list() diff --git a/rag/utils/oss_conn.py b/rag/utils/oss_conn.py index 20cea0b94..b0114f668 100644 --- a/rag/utils/oss_conn.py +++ b/rag/utils/oss_conn.py @@ -141,7 +141,7 @@ class RAGFlowOSS: logging.exception(f"fail get {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None @use_prefix_path @use_default_bucket @@ -170,5 +170,5 @@ class RAGFlowOSS: logging.exception(f"fail get url {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 3c6565230..58b0fe15b 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -104,6 +104,7 @@ class RedisDB: if self.REDIS.get(a) == b: return True + return False def info(self): info = self.REDIS.info() @@ -124,7 +125,7 @@ class RedisDB: def exist(self, k): if not self.REDIS: - return + return None try: return self.REDIS.exists(k) except Exception as e: @@ -133,7 +134,7 @@ class RedisDB: def get(self, k): if not self.REDIS: - return + return None try: return self.REDIS.get(k) except Exception as e: diff --git a/rag/utils/s3_conn.py b/rag/utils/s3_conn.py index 9006fa586..11ac65cee 100644 --- a/rag/utils/s3_conn.py +++ b/rag/utils/s3_conn.py @@ -164,7 +164,7 @@ class RAGFlowS3: logging.exception(f"fail get {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None @use_prefix_path @use_default_bucket @@ -193,7 +193,7 @@ class RAGFlowS3: logging.exception(f"fail get url {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None @use_default_bucket def rm_bucket(self, bucket, *args, **kwargs): diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py index 507570fba..559b41f3c 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -16,14 +16,15 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import create_dataset -from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN +from configs import DATASET_NAME_LIMIT, DEFAULT_PARSER_CONFIG, INVALID_API_TOKEN from hypothesis import example, given, settings from libs.auth import RAGFlowHttpApiAuth from utils import encode_avatar from utils.file_utils import create_image_file from utils.hypothesis_utils import valid_names -from configs import DEFAULT_PARSER_CONFIG + +from common import create_dataset + @pytest.mark.usefixtures("clear_datasets") class TestAuthorization: @@ -125,8 +126,8 @@ class TestDatasetCreate: assert res["code"] == 0, res res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 103, res - assert res["message"] == f"Dataset name '{name}' already exists", res + assert res["code"] == 0, res + assert res["data"]["name"] == name + "(1)", res @pytest.mark.p3 def test_name_case_insensitive(self, HttpApiAuth): @@ -137,8 +138,8 @@ class TestDatasetCreate: payload = {"name": name.lower()} res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 103, res - assert res["message"] == f"Dataset name '{name.lower()}' already exists", res + assert res["code"] == 0, res + assert res["data"]["name"] == name.lower() + "(1)", res @pytest.mark.p2 def test_avatar(self, HttpApiAuth, tmp_path): diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py index 049f288b3..a97cb66c8 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -17,13 +17,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from operator import attrgetter import pytest -from configs import DATASET_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN +from configs import DATASET_NAME_LIMIT, DEFAULT_PARSER_CONFIG, HOST_ADDRESS, INVALID_API_TOKEN from hypothesis import example, given, settings from ragflow_sdk import DataSet, RAGFlow from utils import encode_avatar from utils.file_utils import create_image_file from utils.hypothesis_utils import valid_names -from configs import DEFAULT_PARSER_CONFIG + @pytest.mark.usefixtures("clear_datasets") class TestAuthorization: @@ -95,9 +95,8 @@ class TestDatasetCreate: payload = {"name": name} client.create_dataset(**payload) - with pytest.raises(Exception) as excinfo: - client.create_dataset(**payload) - assert str(excinfo.value) == f"Dataset name '{name}' already exists", str(excinfo.value) + dataset = client.create_dataset(**payload) + assert dataset.name == name + "(1)", str(dataset) @pytest.mark.p3 def test_name_case_insensitive(self, client): @@ -106,9 +105,8 @@ class TestDatasetCreate: client.create_dataset(**payload) payload = {"name": name.lower()} - with pytest.raises(Exception) as excinfo: - client.create_dataset(**payload) - assert str(excinfo.value) == f"Dataset name '{name.lower()}' already exists", str(excinfo.value) + dataset = client.create_dataset(**payload) + assert dataset.name == name.lower() + "(1)", str(dataset) @pytest.mark.p2 def test_avatar(self, client, tmp_path):