Merge branch 'main' of github.com:infiniflow/ragflow into feature/1111
This commit is contained in:
commit
2c2553cd1b
58 changed files with 915 additions and 1096 deletions
32
.github/workflows/tests.yml
vendored
32
.github/workflows/tests.yml
vendored
|
|
@ -95,6 +95,38 @@ jobs:
|
|||
version: ">=0.11.x"
|
||||
args: "check"
|
||||
|
||||
- name: Check comments of changed Python files
|
||||
if: ${{ !cancelled() && !failure() }}
|
||||
run: |
|
||||
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
|
||||
| grep -E '\.(py)$' || true)
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "Check comments of changed Python files with check_comment_ascii.py"
|
||||
|
||||
readarray -t files <<< "$CHANGED_FILES"
|
||||
HAS_ERROR=0
|
||||
|
||||
for file in "${files[@]}"; do
|
||||
if [ -f "$file" ]; then
|
||||
if python3 check_comment_ascii.py $file"; then
|
||||
echo "✅ $file"
|
||||
else
|
||||
echo "❌ $file"
|
||||
HAS_ERROR=1
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $HAS_ERROR -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No Python files changed"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Build ragflow:nightly
|
||||
run: |
|
||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||
|
|
|
|||
|
|
@ -192,9 +192,10 @@ releases! 🌟
|
|||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||
$ cd ragflow/docker
|
||||
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
|
|
|||
|
|
@ -172,6 +172,7 @@
|
|||
$ cd ragflow/docker
|
||||
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
|
||||
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@
|
|||
$ cd ragflow/docker
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
|||
$ cd ragflow/docker
|
||||
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
|
||||
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
|
|
|||
|
|
@ -191,6 +191,7 @@
|
|||
$ cd ragflow/docker
|
||||
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases),例:git checkout v0.22.0
|
||||
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@
|
|||
$ cd ragflow/docker
|
||||
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases),例如:git checkout v0.22.0
|
||||
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
|
|
|||
|
|
@ -368,11 +368,19 @@ Respond immediately with your final comprehensive answer.
|
|||
|
||||
return "Error occurred."
|
||||
|
||||
def reset(self, temp=False):
|
||||
def reset(self, only_output=False):
|
||||
"""
|
||||
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
|
||||
"""
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
|
||||
for k, cpn in self.tools.items():
|
||||
if hasattr(cpn, "reset") and callable(cpn.reset):
|
||||
cpn.reset()
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class LLM(ComponentBase):
|
|||
output_structure = self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure:
|
||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
|
||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt += structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
|
|
|
|||
|
|
@ -96,12 +96,12 @@ login_manager.init_app(app)
|
|||
commands.register_commands(app)
|
||||
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
def search_pages_path(page_path):
|
||||
app_path_list = [
|
||||
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
|
||||
]
|
||||
api_path_list = [
|
||||
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(api_path_list)
|
||||
return app_path_list
|
||||
|
|
@ -138,7 +138,7 @@ pages_dir = [
|
|||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
|
||||
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -177,5 +177,7 @@ def load_user(web_request):
|
|||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exc):
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
logging.exception(f"Request failed: {exception}")
|
||||
close_connection()
|
||||
|
|
|
|||
|
|
@ -13,41 +13,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from flask import request, Response
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import APIToken, Task, File
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import queue_tasks, TaskService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
|
|
@ -138,758 +113,3 @@ def stats():
|
|||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
|
||||
def set_conversation():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
if objs[0].source == "agent":
|
||||
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent"
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
else:
|
||||
e, dia = DialogService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": dia.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = False
|
||||
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
def rename_field(ans):
|
||||
reference = ans['reference']
|
||||
if not isinstance(reference, dict):
|
||||
return
|
||||
for chunk_i in reference.get('chunks', []):
|
||||
if 'docnm_kwd' in chunk_i:
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
stream = req.get("stream", True)
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "content": ""}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=stream)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
if stream:
|
||||
assert isinstance(answer, partial), "Nothing. Is it over?"
|
||||
|
||||
def sse():
|
||||
nonlocal answer, cvs, conv
|
||||
try:
|
||||
for ans in answer():
|
||||
for k in ans.keys():
|
||||
final_ans[k] = ans[k]
|
||||
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
canvas.history.append(("assistant", final_ans["content"]))
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
fillin_conv(result)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
rename_field(result)
|
||||
return get_json_result(data=result)
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **req):
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
answer = ans
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
rename_field(answer)
|
||||
return get_json_result(data=answer)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_conversation(conversation_id):
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
||||
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for referenct_i in conv['reference']:
|
||||
if referenct_i is None or len(referenct_i) == 0:
|
||||
continue
|
||||
for chunk_i in referenct_i['chunks']:
|
||||
if 'docnm_kwd' in chunk_i.keys():
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/upload', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_name")
|
||||
def upload():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
kb_name = request.form.get("kb_name").strip()
|
||||
tenant_id = objs[0].tenant_id
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
kb_root_folder = FileService.get_kb_folder(tenant_id)
|
||||
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
try:
|
||||
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
|
||||
return get_data_error_result(
|
||||
message="Exceed the maximum file number of a free user!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
name=file.filename,
|
||||
kb_id=kb_id)
|
||||
filetype = filename_type(filename)
|
||||
if not filetype:
|
||||
return get_data_error_result(
|
||||
message="This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
settings.STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": kb.tenant_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
|
||||
form_data = request.form
|
||||
if "parser_id" in form_data.keys():
|
||||
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
||||
doc["parser_id"] = request.form.get("parser_id").strip()
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
doc["parser_id"] = ParserType.EMAIL.value
|
||||
|
||||
doc_result = DocumentService.insert(doc)
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if "run" in form_data.keys():
|
||||
if request.form.get("run").strip() == "1":
|
||||
try:
|
||||
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||
DocumentService.update_by_id(doc["id"], info)
|
||||
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
# e, doc = DocumentService.get_by_id(doc["id"])
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
e, doc = DocumentService.get_by_id(doc["id"])
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=doc_result.to_json())
|
||||
|
||||
|
||||
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id")
|
||||
def upload_parse():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_chunks():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
|
||||
try:
|
||||
if "doc_name" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
|
||||
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
|
||||
|
||||
elif "doc_id" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
|
||||
doc_id = req['doc_id']
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_name or doc_id"
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
"doc_name": res_item["docnm_kwd"],
|
||||
"image_id": res_item["img_id"]
|
||||
} for res_item in res
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=res)
|
||||
|
||||
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_chunk(chunk_id):
|
||||
from rag.nlp import search
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
k = []
|
||||
for n in chunk.keys():
|
||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||
k.append(n)
|
||||
for n in k:
|
||||
del chunk[n]
|
||||
|
||||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_kb_docs():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_name = req.get("kb_name", "").strip()
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
page_number = int(req.get("page", 1))
|
||||
items_per_page = int(req.get("page_size", 15))
|
||||
orderby = req.get("orderby", "create_time")
|
||||
desc = req.get("desc", True)
|
||||
keywords = req.get("keywords", "")
|
||||
status = req.get("status", [])
|
||||
if status:
|
||||
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
|
||||
)
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
|
||||
)
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
|
||||
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/infos', methods=['POST']) # noqa: F821
|
||||
@validate_request("doc_ids")
|
||||
def docinfos():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@manager.route('/document', methods=['DELETE']) # noqa: F821
|
||||
# @login_required
|
||||
def document_rm():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
try:
|
||||
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
|
||||
for doc_id in req.get("doc_ids", []):
|
||||
if doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
|
||||
if not doc_ids:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_names or doc_ids"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
|
||||
errors = ""
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
doc_dic = {}
|
||||
for doc in docs:
|
||||
doc_dic[doc.id] = doc
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
if doc_id not in doc_dic:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
doc = doc_dic[doc_id]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
|
||||
settings.STORAGE_IMPL.rm(b, n)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
|
||||
@validate_request("Authorization", "conversation_id", "word")
|
||||
def completion_faq():
|
||||
import base64
|
||||
req = request.json
|
||||
|
||||
token = req["Authorization"]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = True
|
||||
|
||||
msg = [{"role": "user", "content": req["word"]}]
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "doc_aggs": []}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=False)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
ans = ""
|
||||
for a in chat(dia, msg, stream=False, **req):
|
||||
ans = a
|
||||
break
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
kb_ids = req.get("kb_id", [])
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
question = req.get("question")
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("page_size", 30))
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
highlight = bool(req.get("highlight", False))
|
||||
|
||||
try:
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
return get_json_result(
|
||||
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
|
|
|||
|
|
@ -426,7 +426,6 @@ def test_db_connect():
|
|||
try:
|
||||
import trino
|
||||
import os
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception as e:
|
||||
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
|
||||
|
||||
|
|
@ -438,7 +437,7 @@ def test_db_connect():
|
|||
|
||||
auth = None
|
||||
if http_scheme == "https" and req.get("password"):
|
||||
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
|
|
@ -471,8 +470,8 @@ def test_db_connect():
|
|||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=list)
|
||||
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=versions)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ def set_connector():
|
|||
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
||||
"status": TaskStatus.SCHEDULE,
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
ConnectorService.save(**conn)
|
||||
|
||||
time.sleep(1)
|
||||
|
|
|
|||
|
|
@ -85,7 +85,6 @@ def get():
|
|||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
avatar = None
|
||||
for tenant in tenants:
|
||||
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
||||
if dialog and len(dialog) > 0:
|
||||
|
|
|
|||
|
|
@ -154,15 +154,15 @@ def get_kb_names(kb_ids):
|
|||
@login_required
|
||||
def list_dialogs():
|
||||
try:
|
||||
diags = DialogService.query(
|
||||
conversations = DialogService.query(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value,
|
||||
reverse=True,
|
||||
order_by=DialogService.model.create_time)
|
||||
diags = [d.to_dict() for d in diags]
|
||||
for d in diags:
|
||||
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
||||
return get_json_result(data=diags)
|
||||
conversations = [d.to_dict() for d in conversations]
|
||||
for conversation in conversations:
|
||||
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
|
||||
return get_json_result(data=conversations)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ def get_filter():
|
|||
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def docinfos():
|
||||
def doc_infos():
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
for doc_id in doc_ids:
|
||||
|
|
@ -544,6 +544,7 @@ def change_parser():
|
|||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||
|
|
|
|||
|
|
@ -246,8 +246,8 @@ def rm():
|
|||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
|
@ -731,6 +732,8 @@ def delete_kb_task():
|
|||
def cancel_task(task_id):
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
|
||||
kb_task_id_field: str = ""
|
||||
kb_task_finish_at: str = ""
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
kb_task_id_field = "graphrag_task_id"
|
||||
|
|
@ -807,7 +810,7 @@ def check_embedding():
|
|||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
total = docStoreConn.getTotal(res0)
|
||||
total = docStoreConn.get_total(res0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
|
|
@ -824,7 +827,7 @@ def check_embedding():
|
|||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.getChunkIds(res1)
|
||||
ids = docStoreConn.get_chunk_ids(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
|
|
@ -845,8 +848,13 @@ def check_embedding():
|
|||
"position_int": full_doc.get("position_int"),
|
||||
"top_int": full_doc.get("top_int"),
|
||||
"content_with_weight": full_doc.get("content_with_weight") or "",
|
||||
"question_kwd": full_doc.get("question_kwd") or []
|
||||
})
|
||||
return out
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||
return s if s else "None"
|
||||
req = request.json
|
||||
kb_id = req.get("kb_id", "")
|
||||
embd_id = req.get("embd_id", "")
|
||||
|
|
@ -859,8 +867,10 @@ def check_embedding():
|
|||
|
||||
results, eff_sims = [], []
|
||||
for ck in samples:
|
||||
txt = (ck.get("content_with_weight") or "").strip()
|
||||
if not txt:
|
||||
title = ck.get("doc_name") or "Title"
|
||||
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
|
||||
txt_in = _clean(txt_in)
|
||||
if not txt_in:
|
||||
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
||||
continue
|
||||
|
||||
|
|
@ -869,8 +879,16 @@ def check_embedding():
|
|||
continue
|
||||
|
||||
try:
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
sim = _cos_sim(qv, ck["vector"])
|
||||
v, _ = emb_mdl.encode([title, txt_in])
|
||||
sim_content = _cos_sim(v[1], ck["vector"])
|
||||
title_w = 0.1
|
||||
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
|
||||
sim_mix = _cos_sim(qv_mix, ck["vector"])
|
||||
sim = sim_content
|
||||
mode = "content_only"
|
||||
if sim_mix > sim:
|
||||
sim = sim_mix
|
||||
mode = "title+content"
|
||||
except Exception:
|
||||
return get_error_data_result(message="embedding failure")
|
||||
|
||||
|
|
@ -892,8 +910,9 @@ def check_embedding():
|
|||
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"match_mode": mode,
|
||||
}
|
||||
if summary["avg_cos_sim"] > 0.99:
|
||||
if summary["avg_cos_sim"] > 0.9:
|
||||
return get_json_result(data={"summary": summary, "results": results})
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||
|
||||
|
|
|
|||
|
|
@ -21,10 +21,11 @@ import json
|
|||
from flask import request
|
||||
from peewee import OperationalError
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, FileSource, StatusEnum
|
||||
from api.utils.api_utils import (
|
||||
|
|
@ -118,7 +119,6 @@ def create(tenant_id):
|
|||
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = tenant_id,
|
||||
|
|
@ -144,7 +144,6 @@ def create(tenant_id):
|
|||
ok, k = KnowledgebaseService.get_by_id(req["id"])
|
||||
if not ok:
|
||||
return get_error_data_result(message="Dataset created failed")
|
||||
|
||||
response_data = remap_dictionary_keys(k.to_dict())
|
||||
return get_result(data=response_data)
|
||||
except Exception as e:
|
||||
|
|
@ -532,3 +531,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
|
|||
search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_result(data={})
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"raptor_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
|
|
@ -93,6 +93,10 @@ def upload(dataset_id, tenant_id):
|
|||
type: file
|
||||
required: true
|
||||
description: Document files to upload.
|
||||
- in: formData
|
||||
name: parent_path
|
||||
type: string
|
||||
description: Optional nested path under the parent folder. Uses '/' separators.
|
||||
responses:
|
||||
200:
|
||||
description: Successfully uploaded documents.
|
||||
|
|
@ -151,7 +155,7 @@ def upload(dataset_id, tenant_id):
|
|||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=request.form.get("parent_path"))
|
||||
if err:
|
||||
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
# rename key's name
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ class Connector2KbService(CommonService):
|
|||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id,
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class DocumentService(CommonService):
|
|||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
|
|
@ -322,7 +322,7 @@ class DocumentService(CommonService):
|
|||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
graph_source = settings.docStoreConn.get_fields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from common.misc_utils import get_uuid
|
|||
from common.constants import TaskStatus, FileSource, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
|
||||
from rag.llm.cv_model import GptV4
|
||||
from common import settings
|
||||
|
||||
|
|
@ -329,7 +329,7 @@ class FileService(CommonService):
|
|||
current_id = start_id
|
||||
while current_id:
|
||||
e, file = cls.get_by_id(current_id)
|
||||
if file.parent_id != file.id and e:
|
||||
if e and file.parent_id != file.id:
|
||||
parent_folders.append(file)
|
||||
current_id = file.parent_id
|
||||
else:
|
||||
|
|
@ -423,13 +423,15 @@ class FileService(CommonService):
|
|||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def upload_document(self, kb, file_objs, user_id, src="local"):
|
||||
def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None):
|
||||
root_folder = self.get_root_folder(user_id)
|
||||
pf_id = root_folder["id"]
|
||||
self.init_knowledgebase_docs(pf_id, user_id)
|
||||
kb_root_folder = self.get_kb_folder(user_id)
|
||||
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
safe_parent_path = sanitize_path(parent_path)
|
||||
|
||||
err, files = [], []
|
||||
for file in file_objs:
|
||||
try:
|
||||
|
|
@ -439,7 +441,7 @@ class FileService(CommonService):
|
|||
if filetype == FileType.OTHER.value:
|
||||
raise RuntimeError("This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}"
|
||||
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
|
||||
location += "_"
|
||||
|
||||
|
|
|
|||
|
|
@ -164,3 +164,23 @@ def read_potential_broken_pdf(blob):
|
|||
return repaired
|
||||
|
||||
return blob
|
||||
|
||||
|
||||
def sanitize_path(raw_path: str | None) -> str:
|
||||
"""Normalize and sanitize a user-provided path segment.
|
||||
|
||||
- Converts backslashes to forward slashes
|
||||
- Strips leading/trailing slashes
|
||||
- Removes '.' and '..' segments
|
||||
- Restricts characters to A-Za-z0-9, underscore, dash, and '/'
|
||||
"""
|
||||
if not raw_path:
|
||||
return ""
|
||||
backslash_re = re.compile(r"[\\]+")
|
||||
unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]")
|
||||
normalized = backslash_re.sub("/", raw_path)
|
||||
normalized = normalized.strip("/")
|
||||
parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")]
|
||||
sanitized = "/".join(parts)
|
||||
sanitized = unsafe_re.sub("", sanitized)
|
||||
return sanitized
|
||||
|
|
|
|||
36
check_comment_ascii.py
Normal file
36
check_comment_ascii.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import tokenize
|
||||
import ast
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
ASCII = re.compile(r"^[ -~]*\Z") # Only printable ASCII
|
||||
|
||||
|
||||
def check(src: str, name: str) -> int:
|
||||
"""
|
||||
I'm a docstring
|
||||
"""
|
||||
ok = 1
|
||||
# A common comment begins with `#`
|
||||
with tokenize.open(src) as fp:
|
||||
for tk in tokenize.generate_tokens(fp.readline):
|
||||
if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string):
|
||||
print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}")
|
||||
ok = 0
|
||||
# A docstring begins and ends with `'''`
|
||||
for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)):
|
||||
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
|
||||
if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc):
|
||||
print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}")
|
||||
ok = 0
|
||||
return ok
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
status = 0
|
||||
for file in sys.argv[1:]:
|
||||
if not check(file, file):
|
||||
status = 1
|
||||
sys.exit(status)
|
||||
|
|
@ -3,15 +3,9 @@ import os
|
|||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
import requests
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_SCOPES
|
||||
|
||||
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
|
||||
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
DEFAULT_DEVICE_INTERVAL = 5
|
||||
|
||||
|
||||
def _get_requested_scopes(source: DocumentSource) -> list[str]:
|
||||
"""Return the scopes to request, honoring an optional override env var."""
|
||||
|
|
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
|
|||
return result.get("value")
|
||||
|
||||
|
||||
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
|
||||
if "client_id" in credentials:
|
||||
return credentials["client_id"], credentials.get("client_secret")
|
||||
for key in ("installed", "web"):
|
||||
if key in credentials and isinstance(credentials[key], dict):
|
||||
nested = credentials[key]
|
||||
if "client_id" not in nested:
|
||||
break
|
||||
return nested["client_id"], nested.get("client_secret")
|
||||
raise ValueError("Provided Google OAuth credentials are missing client_id.")
|
||||
|
||||
|
||||
def start_device_authorization_flow(
|
||||
credentials: dict[str, Any],
|
||||
source: DocumentSource,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
client_id, client_secret = _extract_client_info(credentials)
|
||||
data = {
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(_get_requested_scopes(source)),
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
state = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"device_code": payload.get("device_code"),
|
||||
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
|
||||
}
|
||||
response_data = {
|
||||
"user_code": payload.get("user_code"),
|
||||
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
|
||||
"verification_url_complete": payload.get("verification_url_complete")
|
||||
or payload.get("verification_uri_complete"),
|
||||
"expires_in": payload.get("expires_in"),
|
||||
"interval": state["interval"],
|
||||
}
|
||||
return state, response_data
|
||||
|
||||
|
||||
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
|
||||
data = {
|
||||
"client_id": state["client_id"],
|
||||
"device_code": state["device_code"],
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
if state.get("client_secret"):
|
||||
data["client_secret"] = state["client_secret"]
|
||||
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
|
|
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
|
|||
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
|
||||
port = int(preferred_port) if preferred_port else 0
|
||||
timeout_secs = _get_oauth_timeout_secs()
|
||||
timeout_message = (
|
||||
f"Google OAuth verification timed out after {timeout_secs} seconds. "
|
||||
"Close any pending consent windows and rerun the connector configuration to try again."
|
||||
)
|
||||
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
|
||||
|
||||
print("Launching Google OAuth flow. A browser window should open shortly.")
|
||||
print("If it does not, copy the URL shown in the console into your browser manually.")
|
||||
|
|
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
|
|||
instructions = [
|
||||
"Google rejected one or more of the requested OAuth scopes.",
|
||||
"Fix options:",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
|
||||
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
|
||||
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
|
||||
" (be aware the connector may lose functionality).",
|
||||
]
|
||||
raise RuntimeError("\n".join(instructions)) from warning
|
||||
raise
|
||||
|
|
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
|
|||
client_config = {"web": credentials["web"]}
|
||||
|
||||
if client_config is None:
|
||||
raise ValueError(
|
||||
"Provided Google OAuth credentials are missing both tokens and a client configuration."
|
||||
)
|
||||
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
|
||||
|
||||
return _run_local_server_flow(client_config, source)
|
||||
|
|
|
|||
|
|
@ -186,9 +186,6 @@ class DoclingParser(RAGFlowPdfParser):
|
|||
yield (DoclingContentType.EQUATION.value, text, bbox)
|
||||
|
||||
def _transfer_to_sections(self, doc) -> list[tuple[str, str]]:
|
||||
"""
|
||||
和 MinerUParser 保持一致:返回 [(section_text, line_tag), ...]
|
||||
"""
|
||||
sections: list[tuple[str, str]] = []
|
||||
for typ, payload, bbox in self._iter_doc_items(doc):
|
||||
if typ == DoclingContentType.TEXT.value:
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
|||
if isinstance(figure_data[1], Image.Image)
|
||||
]
|
||||
|
||||
|
||||
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
|
|
@ -50,7 +51,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
|
|||
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
return tbls
|
||||
|
||||
def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
|
||||
|
||||
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
|
|
@ -72,6 +74,7 @@ def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
|
|||
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
return tbls
|
||||
|
||||
|
||||
shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -117,7 +117,6 @@ def load_model(model_dir, nm, device_id: int | None = None):
|
|||
providers=['CUDAExecutionProvider'],
|
||||
provider_options=[cuda_provider_options]
|
||||
)
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id))
|
||||
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
|
||||
else:
|
||||
sess = ort.InferenceSession(
|
||||
|
|
|
|||
8
docs/guides/dataset/add_data_source/_category_.json
Normal file
8
docs/guides/dataset/add_data_source/_category_.json
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"label": "Add data source",
|
||||
"position": 18,
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Add various data sources"
|
||||
}
|
||||
}
|
||||
137
docs/guides/dataset/add_data_source/add_google_drive.md
Normal file
137
docs/guides/dataset/add_data_source/add_google_drive.md
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
---
|
||||
sidebar_position: 3
|
||||
slug: /add_google_drive
|
||||
---
|
||||
|
||||
# Add Google Drive
|
||||
|
||||
## 1. Create a Google Cloud Project
|
||||
|
||||
You can either create a dedicated project for RAGFlow or use an existing
|
||||
Google Cloud external project.
|
||||
|
||||
**Steps:**
|
||||
1. Open the project creation page\
|
||||
`https://console.cloud.google.com/projectcreate`
|
||||

|
||||
2. Select **External** as the Audience
|
||||

|
||||
3. Click **Create**
|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 2. Configure OAuth Consent Screen
|
||||
|
||||
1. Go to **APIs & Services → OAuth consent screen**
|
||||
2. Ensure **User Type = External**
|
||||

|
||||
3. Add your test users under **Test Users** by entering email addresses
|
||||

|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 3. Create OAuth Client Credentials
|
||||
|
||||
1. Navigate to:\
|
||||
`https://console.cloud.google.com/auth/clients`
|
||||
2. Create a **Web Application**
|
||||

|
||||
3. Enter a name for the client
|
||||
4. Add the following **Authorized Redirect URIs**:
|
||||
|
||||
```
|
||||
http://localhost:9380/v1/connector/google-drive/oauth/web/callback
|
||||
```
|
||||
|
||||
### If using Docker deployment:
|
||||
|
||||
**Authorized JavaScript origin:**
|
||||
```
|
||||
http://localhost:80
|
||||
```
|
||||
|
||||

|
||||
### If running from source:
|
||||
**Authorized JavaScript origin:**
|
||||
```
|
||||
http://localhost:9222
|
||||
```
|
||||
|
||||

|
||||
5. After saving, click **Download JSON**. This file will later be
|
||||
uploaded into RAGFlow.
|
||||
|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 4. Add Scopes
|
||||
|
||||
1. Open **Data Access → Add or remove scopes**
|
||||
|
||||
2. Paste and add the following entries:
|
||||
|
||||
```
|
||||
https://www.googleapis.com/auth/drive.readonly
|
||||
https://www.googleapis.com/auth/drive.metadata.readonly
|
||||
https://www.googleapis.com/auth/admin.directory.group.readonly
|
||||
https://www.googleapis.com/auth/admin.directory.user.readonly
|
||||
```
|
||||
|
||||

|
||||
3. Update and Save changes
|
||||
|
||||

|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 5. Enable Required APIs
|
||||
Navigate to the Google API Library:\
|
||||
`https://console.cloud.google.com/apis/library`
|
||||

|
||||
|
||||
Enable the following APIs:
|
||||
|
||||
- Google Drive API
|
||||
- Admin SDK API
|
||||
- Google Sheets API
|
||||
- Google Docs API
|
||||
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 6. Add Google Drive As a Data Source in RAGFlow
|
||||
|
||||
1. Go to **Data Sources** inside RAGFlow
|
||||
2. Select **Google Drive**
|
||||
3. Upload the previously downloaded JSON credentials
|
||||

|
||||
4. Enter the shared Google Drive folder link (https://drive.google.com/drive), such as:
|
||||

|
||||
|
||||
5. Click **Authorize with Google**
|
||||
A browser window will appear.
|
||||

|
||||
Click: - **Continue** - **Select All → Continue** - Authorization should
|
||||
succeed - Select **OK** to add the data source
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"label": "Best practices",
|
||||
"position": 11,
|
||||
"position": 19,
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Best practices on configuring a dataset."
|
||||
|
|
|
|||
|
|
@ -64,7 +64,10 @@ The Admin CLI and Admin Service form a client-server architectural suite for RAG
|
|||
|
||||
- -p: RAGFlow admin server port
|
||||
|
||||
## Default administrative account
|
||||
|
||||
- Username: admin@ragflow.io
|
||||
- Password: admin
|
||||
|
||||
## Supported Commands
|
||||
|
||||
|
|
|
|||
|
|
@ -974,6 +974,237 @@ Failure:
|
|||
|
||||
---
|
||||
|
||||
### Construct knowledge graph
|
||||
|
||||
**POST** `/api/v1/datasets/{dataset_id}/run_graphrag`
|
||||
|
||||
Constructs a knowledge graph from a specified dataset.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: POST
|
||||
- URL: `/api/v1/datasets/{dataset_id}/run_graphrag`
|
||||
- Headers:
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request POST \
|
||||
--url http://{address}/api/v1/datasets/{dataset_id}/run_graphrag \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `dataset_id`: (*Path parameter*)
|
||||
The ID of the target dataset.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code":0,
|
||||
"data":{
|
||||
"graphrag_task_id":"e498de54bfbb11f0ba028f704583b57b"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Invalid Dataset ID"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Get knowledge graph construction status
|
||||
|
||||
**GET** `/api/v1/datasets/{dataset_id}/trace_graphrag`
|
||||
|
||||
Retrieves the knowledge graph construction status for a specified dataset.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: GET
|
||||
- URL: `/api/v1/datasets/{dataset_id}/trace_graphrag`
|
||||
- Headers:
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request GET \
|
||||
--url http://{address}/api/v1/datasets/{dataset_id}/trace_graphrag \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `dataset_id`: (*Path parameter*)
|
||||
The ID of the target dataset.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code":0,
|
||||
"data":{
|
||||
"begin_at":"Wed, 12 Nov 2025 19:36:56 GMT",
|
||||
"chunk_ids":"",
|
||||
"create_date":"Wed, 12 Nov 2025 19:36:56 GMT",
|
||||
"create_time":1762947416350,
|
||||
"digest":"39e43572e3dcd84f",
|
||||
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
|
||||
"from_page":100000000,
|
||||
"id":"e498de54bfbb11f0ba028f704583b57b",
|
||||
"priority":0,
|
||||
"process_duration":2.45419,
|
||||
"progress":1.0,
|
||||
"progress_msg":"19:36:56 created task graphrag\n19:36:57 Task has been received.\n19:36:58 [GraphRAG] doc:083661febe2411f0bc79456921e5745f has no available chunks, skip generation.\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 start (chunks=1, timeout=10000000000s)\n19:36:58 Graph already contains 44661c10bde211f0bc93c164a47ffc40\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 empty\n19:36:58 [GraphRAG] kb:33137ed0bde211f0bc93c164a47ffc40 no subgraphs generated successfully, end.\n19:36:58 Knowledge Graph done (0.72s)","retry_count":1,
|
||||
"task_type":"graphrag",
|
||||
"to_page":100000000,
|
||||
"update_date":"Wed, 12 Nov 2025 19:36:58 GMT",
|
||||
"update_time":1762947418454
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Invalid Dataset ID"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Construct RAPTOR
|
||||
|
||||
**POST** `/api/v1/datasets/{dataset_id}/run_raptor`
|
||||
|
||||
Construct a RAPTOR from a specified dataset.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: POST
|
||||
- URL: `/api/v1/datasets/{dataset_id}/run_raptor`
|
||||
- Headers:
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request POST \
|
||||
--url http://{address}/api/v1/datasets/{dataset_id}/run_raptor \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `dataset_id`: (*Path parameter*)
|
||||
The ID of the target dataset.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code":0,
|
||||
"data":{
|
||||
"raptor_task_id":"50d3c31cbfbd11f0ba028f704583b57b"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Invalid Dataset ID"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Get RAPTOR construction status
|
||||
|
||||
**GET** `/api/v1/datasets/{dataset_id}/trace_raptor`
|
||||
|
||||
Retrieves the RAPTOR construction status for a specified dataset.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: GET
|
||||
- URL: `/api/v1/datasets/{dataset_id}/trace_raptor`
|
||||
- Headers:
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request GET \
|
||||
--url http://{address}/api/v1/datasets/{dataset_id}/trace_raptor \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `dataset_id`: (*Path parameter*)
|
||||
The ID of the target dataset.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code":0,
|
||||
"data":{
|
||||
"begin_at":"Wed, 12 Nov 2025 19:47:07 GMT",
|
||||
"chunk_ids":"",
|
||||
"create_date":"Wed, 12 Nov 2025 19:47:07 GMT",
|
||||
"create_time":1762948027427,
|
||||
"digest":"8b279a6248cb8fc6",
|
||||
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
|
||||
"from_page":100000000,
|
||||
"id":"50d3c31cbfbd11f0ba028f704583b57b",
|
||||
"priority":0,
|
||||
"process_duration":0.948244,
|
||||
"progress":1.0,
|
||||
"progress_msg":"19:47:07 created task raptor\n19:47:07 Task has been received.\n19:47:07 Processing...\n19:47:07 Processing...\n19:47:07 Indexing done (0.01s).\n19:47:07 Task done (0.29s)",
|
||||
"retry_count":1,
|
||||
"task_type":"raptor",
|
||||
"to_page":100000000,
|
||||
"update_date":"Wed, 12 Nov 2025 19:47:07 GMT",
|
||||
"update_time":1762948027948
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Invalid Dataset ID"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## FILE MANAGEMENT WITHIN DATASET
|
||||
|
||||
---
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class Extractor:
|
|||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
||||
out_results = []
|
||||
error_count = 0
|
||||
max_errors = 3
|
||||
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
|
||||
|
||||
limiter = trio.Semaphore(max_concurrency)
|
||||
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class KGSearch(Dealer):
|
|||
def _ent_info_from_(self, es_res, sim_thr=0.3):
|
||||
res = {}
|
||||
flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
|
||||
es_res = self.dataStore.getFields(es_res, flds)
|
||||
es_res = self.dataStore.get_fields(es_res, flds)
|
||||
for _, ent in es_res.items():
|
||||
for f in flds:
|
||||
if f in ent and ent[f] is None:
|
||||
|
|
@ -88,7 +88,7 @@ class KGSearch(Dealer):
|
|||
|
||||
def _relation_info_from_(self, es_res, sim_thr=0.3):
|
||||
res = {}
|
||||
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
|
||||
es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
|
||||
"weight_int"])
|
||||
for _, ent in es_res.items():
|
||||
if get_float(ent["_score"]) < sim_thr:
|
||||
|
|
@ -300,7 +300,7 @@ class KGSearch(Dealer):
|
|||
fltr["entities_kwd"] = entities
|
||||
comm_res = self.dataStore.search(fields, [], fltr, [],
|
||||
OrderByExpr(), 0, topn, idxnms, kb_ids)
|
||||
comm_res_fields = self.dataStore.getFields(comm_res, fields)
|
||||
comm_res_fields = self.dataStore.get_fields(comm_res, fields)
|
||||
txts = []
|
||||
for ii, (_, row) in enumerate(comm_res_fields.items()):
|
||||
obj = json.loads(row["content_with_weight"])
|
||||
|
|
|
|||
|
|
@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
|||
"removed_kwd": "N",
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
fields2 = settings.docStoreConn.getFields(res, fields)
|
||||
fields2 = settings.docStoreConn.get_fields(res, fields)
|
||||
graph_doc_ids = set()
|
||||
for chunk_id in fields2.keys():
|
||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||
|
|
@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
|||
es_res = await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||
)
|
||||
# tot = settings.docStoreConn.getTotal(es_res)
|
||||
es_res = settings.docStoreConn.getFields(es_res, flds)
|
||||
# tot = settings.docStoreConn.get_total(es_res)
|
||||
es_res = settings.docStoreConn.get_fields(es_res, flds)
|
||||
|
||||
if len(es_res) == 0:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -482,7 +482,7 @@ def tree_merge(bull, sections, depth):
|
|||
root = Node(level=0, depth=target_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [("\n").join(element) for element in root.get_tree() if element]
|
||||
return [element for element in root.get_tree() if element]
|
||||
|
||||
def hierarchical_merge(bull, sections, depth):
|
||||
|
||||
|
|
|
|||
|
|
@ -38,11 +38,11 @@ class FulltextQueryer:
|
|||
]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
def sub_special_char(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def isChinese(line):
|
||||
def is_chinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
|
|
@ -92,7 +92,7 @@ class FulltextQueryer:
|
|||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
if not self.is_chinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
|
|
@ -163,7 +163,7 @@ class FulltextQueryer:
|
|||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
|
|
@ -171,7 +171,7 @@ class FulltextQueryer:
|
|||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
|
|
@ -180,7 +180,7 @@ class FulltextQueryer:
|
|||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
|
|
@ -198,7 +198,7 @@ class FulltextQueryer:
|
|||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
|
|
@ -217,17 +217,17 @@ class FulltextQueryer:
|
|||
return None, keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
sims = cosine_similarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
if np.sum(sims[0]) == 0:
|
||||
return np.array(tksim), tksim, sims[0]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
def toDict(tks):
|
||||
def to_dict(tks):
|
||||
if isinstance(tks, str):
|
||||
tks = tks.split()
|
||||
d = defaultdict(int)
|
||||
|
|
@ -236,8 +236,8 @@ class FulltextQueryer:
|
|||
d[t] += c
|
||||
return d
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
atks = to_dict(atks)
|
||||
btkss = [to_dict(tks) for tks in btkss]
|
||||
return [self.similarity(atks, btks) for btks in btkss]
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
|
|
@ -262,10 +262,10 @@ class FulltextQueryer:
|
|||
keywords = [f'"{k.strip()}"' for k in keywords]
|
||||
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class RagTokenizer:
|
|||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
def _load_dict(self, fnm):
|
||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding='utf-8')
|
||||
|
|
@ -85,18 +85,18 @@ class RagTokenizer:
|
|||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
# load data from dict file and save to trie file
|
||||
self.loadDict_(self.DIR_ + ".txt")
|
||||
self._load_dict(self.DIR_ + ".txt")
|
||||
|
||||
def loadUserDict(self, fnm):
|
||||
def load_user_dict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.loadDict_(fnm)
|
||||
self._load_dict(fnm)
|
||||
|
||||
def addUserDict(self, fnm):
|
||||
self.loadDict_(fnm)
|
||||
def add_user_dict(self, fnm):
|
||||
self._load_dict(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""Convert full-width characters to half-width characters"""
|
||||
|
|
@ -221,7 +221,7 @@ class RagTokenizer:
|
|||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
def _sort_tokens(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
|
|
@ -246,7 +246,7 @@ class RagTokenizer:
|
|||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
def _max_forward(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
|
|
@ -270,7 +270,7 @@ class RagTokenizer:
|
|||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
def _max_backward(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
|
|
@ -336,8 +336,8 @@ class RagTokenizer:
|
|||
continue
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self.maxForward_(L)
|
||||
tks1, s1 = self.maxBackward_(L)
|
||||
tks, s = self._max_forward(L)
|
||||
tks1, s1 = self._max_backward(L)
|
||||
if self.DEBUG:
|
||||
logging.debug("[FW] {} {}".format(tks, s))
|
||||
logging.debug("[BW] {} {}".format(tks1, s1))
|
||||
|
|
@ -369,7 +369,7 @@ class RagTokenizer:
|
|||
# backward tokens from_i to i are different from forward tokens from _j to j.
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
|
||||
|
||||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
|
|
@ -385,7 +385,7 @@ class RagTokenizer:
|
|||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
|
||||
|
||||
res = " ".join(res)
|
||||
logging.debug("[TKS] {}".format(self.merge_(res)))
|
||||
|
|
@ -413,7 +413,7 @@ class RagTokenizer:
|
|||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
stk = self._sort_tokens(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
|
|
@ -447,14 +447,13 @@ def is_number(s):
|
|||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naiveQie(txt):
|
||||
def naive_qie(txt):
|
||||
tks = []
|
||||
for t in txt.split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
|
|
@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize
|
|||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||||
tag = tokenizer.tag
|
||||
freq = tokenizer.freq
|
||||
loadUserDict = tokenizer.loadUserDict
|
||||
addUserDict = tokenizer.addUserDict
|
||||
load_user_dict = tokenizer.load_user_dict
|
||||
add_user_dict = tokenizer.add_user_dict
|
||||
tradi2simp = tokenizer._tradi2simp
|
||||
strQ2B = tokenizer._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
tknzr = RagTokenizer(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
# huqie.add_user_dict("/tmp/tmp.new.tks.dict")
|
||||
tks = tknzr.tokenize(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
|
|
@ -506,7 +505,7 @@ if __name__ == '__main__':
|
|||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
tknzr.DEBUG = False
|
||||
tknzr.loadUserDict(sys.argv[1])
|
||||
tknzr.load_user_dict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class Dealer:
|
|||
orderBy.asc("top_int")
|
||||
orderBy.desc("create_timestamp_flt")
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
highlightFields = ["content_ltks", "title_tks"]
|
||||
|
|
@ -115,7 +115,7 @@ class Dealer:
|
|||
matchExprs = [matchText]
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
|
|
@ -127,20 +127,20 @@ class Dealer:
|
|||
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
if filters.get("doc_id"):
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
else:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||
|
||||
for k in keywords:
|
||||
|
|
@ -153,17 +153,17 @@ class Dealer:
|
|||
kwds.add(kk)
|
||||
|
||||
logging.debug(f"TOTAL: {total}")
|
||||
ids = self.dataStore.getChunkIds(res)
|
||||
ids = self.dataStore.get_chunk_ids(res)
|
||||
keywords = list(kwds)
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
||||
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
|
||||
return self.SearchResult(
|
||||
total=total,
|
||||
ids=ids,
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=highlight,
|
||||
field=self.dataStore.getFields(res, src + ["_score"]),
|
||||
field=self.dataStore.get_fields(res, src + ["_score"]),
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
|
|
@ -488,7 +488,7 @@ class Dealer:
|
|||
for p in range(offset, max_count, bs):
|
||||
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
|
||||
kb_ids)
|
||||
dict_chunks = self.dataStore.getFields(es_res, fields)
|
||||
dict_chunks = self.dataStore.get_fields(es_res, fields)
|
||||
for id, doc in dict_chunks.items():
|
||||
doc["id"] = id
|
||||
if dict_chunks:
|
||||
|
|
@ -501,11 +501,11 @@ class Dealer:
|
|||
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
|
||||
return []
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
return self.dataStore.getAggregation(res, "tag_kwd")
|
||||
return self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
|
||||
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
res = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
res = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
total = np.sum([c for _, c in res])
|
||||
return {t: (c + 1) / (total + S) for t, c in res}
|
||||
|
||||
|
|
@ -513,7 +513,7 @@ class Dealer:
|
|||
idx_nm = index_name(tenant_id)
|
||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return False
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
|
|
@ -529,7 +529,7 @@ class Dealer:
|
|||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
match_txt, _ = self.qryr.question(question, min_match=0.0)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return {}
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
|
|
@ -552,7 +552,7 @@ class Dealer:
|
|||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
|
||||
kb_ids)
|
||||
toc = []
|
||||
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
|
||||
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
|
||||
for _, doc in dict_chunks.items():
|
||||
try:
|
||||
toc.extend(json.loads(doc["content_with_weight"]))
|
||||
|
|
|
|||
|
|
@ -113,20 +113,20 @@ class Dealer:
|
|||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
def token_merge(self, tks):
|
||||
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
if i == 0 and one_term(tks[i]) and len(
|
||||
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
|
|
@ -232,7 +232,7 @@ class Dealer:
|
|||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
tt = self.token_merge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
|
|
|
|||
154
rag/raptor.py
154
rag/raptor.py
|
|
@ -15,27 +15,35 @@
|
|||
#
|
||||
import logging
|
||||
import re
|
||||
import umap
|
||||
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import trio
|
||||
import umap
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.connection_utils import timeout
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.token_utils import truncate
|
||||
from graphrag.utils import (
|
||||
get_llm_cache,
|
||||
chat_limiter,
|
||||
get_embed_cache,
|
||||
get_llm_cache,
|
||||
set_embed_cache,
|
||||
set_llm_cache,
|
||||
chat_limiter,
|
||||
)
|
||||
from common.token_utils import truncate
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
def __init__(
|
||||
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
|
||||
self,
|
||||
max_cluster,
|
||||
llm_model,
|
||||
embd_model,
|
||||
prompt,
|
||||
max_token=512,
|
||||
threshold=0.1,
|
||||
max_errors=3,
|
||||
):
|
||||
self._max_cluster = max_cluster
|
||||
self._llm_model = llm_model
|
||||
|
|
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||
self._threshold = threshold
|
||||
self._prompt = prompt
|
||||
self._max_token = max_token
|
||||
self._max_errors = max(1, max_errors)
|
||||
self._error_count = 0
|
||||
|
||||
@timeout(60*20)
|
||||
@timeout(60 * 20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
|
||||
)
|
||||
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
if response:
|
||||
return response
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: self._llm_model.chat(system, history, gen_conf)
|
||||
)
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
||||
)
|
||||
return response
|
||||
last_exc = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
|
||||
if attempt < 2:
|
||||
await trio.sleep(1 + attempt)
|
||||
|
||||
raise last_exc if last_exc else Exception("LLM chat failed without exception")
|
||||
|
||||
@timeout(20)
|
||||
async def _embedding_encode(self, txt):
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: get_embed_cache(self._embd_model.llm_name, txt)
|
||||
)
|
||||
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
|
||||
if response is not None:
|
||||
return response
|
||||
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
|
||||
|
|
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||
n_clusters = np.arange(1, max_clusters)
|
||||
bics = []
|
||||
for n in n_clusters:
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
|
||||
|
|
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
|
||||
@timeout(60*20)
|
||||
@timeout(60 * 20)
|
||||
async def summarize(ck_idx: list[int]):
|
||||
nonlocal chunks
|
||||
|
||||
|
|
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int(
|
||||
(self._llm_model.max_length - self._max_token) / len(texts)
|
||||
)
|
||||
cluster_content = "\n".join(
|
||||
[truncate(t, max(1, len_per_chunk)) for t in texts]
|
||||
)
|
||||
async with chat_limiter:
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
try:
|
||||
async with chat_limiter:
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(cluster_content=cluster_content),
|
||||
}
|
||||
],
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(
|
||||
cluster_content=cluster_content
|
||||
),
|
||||
}
|
||||
],
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._error_count += 1
|
||||
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
|
||||
logging.warning(warn_msg)
|
||||
if callback:
|
||||
callback(msg=warn_msg)
|
||||
if self._error_count >= self._max_errors:
|
||||
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
|
||||
|
||||
labels = []
|
||||
while end - start > 1:
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
|
||||
|
|
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||
if len(embeddings) == 2:
|
||||
await summarize([start, start + 1])
|
||||
if callback:
|
||||
callback(
|
||||
msg="Cluster one layer: {} -> {}".format(
|
||||
end - start, len(chunks) - end
|
||||
)
|
||||
)
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
labels.extend([0, 0])
|
||||
layers.append((end, len(chunks)))
|
||||
start = end
|
||||
|
|
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
|||
|
||||
nursery.start_soon(summarize, ck_idx)
|
||||
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
|
||||
len(chunks) - end, n_clusters
|
||||
)
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||
labels.extend(lbls)
|
||||
layers.append((end, len(chunks)))
|
||||
if callback:
|
||||
callback(
|
||||
msg="Cluster one layer: {} -> {}".format(
|
||||
end - start, len(chunks) - end
|
||||
)
|
||||
)
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
start = end
|
||||
end = len(chunks)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def collect():
|
|||
logging.debug(doc_locations)
|
||||
if len(doc_locations) == 0:
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
return doc_locations
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback):
|
|||
task_canceled = has_canceled(task["id"])
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
return None
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
|
|
@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback):
|
|||
d["page_num_int"] = [100000000]
|
||||
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
return d
|
||||
return None
|
||||
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
|
|
@ -441,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
|
||||
tts = np.tile(vts[0], (len(cnts), 1))
|
||||
tk_count += c
|
||||
|
||||
@timeout(60)
|
||||
|
|
@ -464,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||
if not filename_embd_weight:
|
||||
filename_embd_weight = 0.1
|
||||
title_w = float(filename_embd_weight)
|
||||
vects = (title_w * tts + (1 - title_w) *
|
||||
cnts) if len(tts) == len(cnts) else cnts
|
||||
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
|
||||
vects = title_w * tts + (1 - title_w) * cnts
|
||||
else:
|
||||
vects = cnts
|
||||
|
||||
assert len(vects) == len(docs)
|
||||
vector_size = 0
|
||||
|
|
@ -648,6 +651,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
|||
|
||||
res = []
|
||||
tk_count = 0
|
||||
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
|
||||
|
||||
async def generate(chunks, did):
|
||||
nonlocal tk_count, res
|
||||
raptor = Raptor(
|
||||
|
|
@ -657,6 +662,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
|||
raptor_config["prompt"],
|
||||
raptor_config["max_token"],
|
||||
raptor_config["threshold"],
|
||||
max_errors=max_errors,
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
|
||||
|
|
@ -719,7 +725,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
|||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
return False
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
|
|
@ -737,7 +743,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
|||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||
return
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob:
|
|||
logging.exception(f"Fail put {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return None
|
||||
return None
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
try:
|
||||
|
|
@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob:
|
|||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def obj_exist(self, bucket, fnm):
|
||||
try:
|
||||
|
|
@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob:
|
|||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
|
@ -241,23 +241,23 @@ class DocStoreConnection(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def getTotal(self, res):
|
||||
def get_total(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getChunkIds(self, res):
|
||||
def get_chunk_ids(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection):
|
|||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res):
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getChunkIds(self, res):
|
||||
def get_chunk_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
|
|
@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection):
|
|||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
|
|
@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection):
|
|||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
|
|
@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection):
|
|||
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
|
|
|
|||
|
|
@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection):
|
|||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, ["id"])
|
||||
res_fields = self.getFields(res, res.columns.tolist())
|
||||
res_fields = self.get_fields(res, res.columns.tolist())
|
||||
return res_fields.get(chunkId, None)
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
|
|
@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection):
|
|||
col_to_remove = list(removeValue.keys())
|
||||
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
|
||||
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
|
||||
row_to_opt = self.getFields(row_to_opt, col_to_remove)
|
||||
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
|
||||
for id, old_v in row_to_opt.items():
|
||||
for k, remove_v in removeValue.items():
|
||||
if remove_v in old_v[k]:
|
||||
|
|
@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection):
|
|||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if not fields:
|
||||
|
|
@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection):
|
|||
|
||||
return res2.set_index("id").to_dict(orient="index")
|
||||
|
||||
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
|
|
@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection):
|
|||
ans[id] = txt
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
"""
|
||||
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class RAGFlowMinio:
|
|||
logging.exception(f"Fail to get {bucket}/{filename}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def obj_exist(self, bucket, filename, tenant_id=None):
|
||||
try:
|
||||
|
|
@ -130,7 +130,7 @@ class RAGFlowMinio:
|
|||
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def remove_bucket(self, bucket):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -62,8 +62,7 @@ class OpenDALStorage:
|
|||
|
||||
def health(self):
|
||||
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
|
||||
r = self._operator.write(f"{bucket}/{fnm}", binary)
|
||||
return r
|
||||
return self._operator.write(f"{bucket}/{fnm}", binary)
|
||||
|
||||
def put(self, bucket, fnm, binary, tenant_id=None):
|
||||
self._operator.write(f"{bucket}/{fnm}", binary)
|
||||
|
|
|
|||
|
|
@ -455,12 +455,12 @@ class OSConnection(DocStoreConnection):
|
|||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res):
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getChunkIds(self, res):
|
||||
def get_chunk_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
|
|
@ -471,7 +471,7 @@ class OSConnection(DocStoreConnection):
|
|||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
|
|
@ -490,7 +490,7 @@ class OSConnection(DocStoreConnection):
|
|||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
|
|
@ -515,7 +515,7 @@ class OSConnection(DocStoreConnection):
|
|||
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ class RAGFlowOSS:
|
|||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
|
|
@ -170,5 +170,5 @@ class RAGFlowOSS:
|
|||
logging.exception(f"fail get url {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ class RedisDB:
|
|||
|
||||
if self.REDIS.get(a) == b:
|
||||
return True
|
||||
return False
|
||||
|
||||
def info(self):
|
||||
info = self.REDIS.info()
|
||||
|
|
@ -124,7 +125,7 @@ class RedisDB:
|
|||
|
||||
def exist(self, k):
|
||||
if not self.REDIS:
|
||||
return
|
||||
return None
|
||||
try:
|
||||
return self.REDIS.exists(k)
|
||||
except Exception as e:
|
||||
|
|
@ -133,7 +134,7 @@ class RedisDB:
|
|||
|
||||
def get(self, k):
|
||||
if not self.REDIS:
|
||||
return
|
||||
return None
|
||||
try:
|
||||
return self.REDIS.get(k)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ class RAGFlowS3:
|
|||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
|
|
@ -193,7 +193,7 @@ class RAGFlowS3:
|
|||
logging.exception(f"fail get url {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
@use_default_bucket
|
||||
def rm_bucket(self, bucket, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -16,14 +16,15 @@
|
|||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import create_dataset
|
||||
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
|
||||
from configs import DATASET_NAME_LIMIT, DEFAULT_PARSER_CONFIG, INVALID_API_TOKEN
|
||||
from hypothesis import example, given, settings
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from utils import encode_avatar
|
||||
from utils.file_utils import create_image_file
|
||||
from utils.hypothesis_utils import valid_names
|
||||
from configs import DEFAULT_PARSER_CONFIG
|
||||
|
||||
from common import create_dataset
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
|
|
@ -125,8 +126,8 @@ class TestDatasetCreate:
|
|||
assert res["code"] == 0, res
|
||||
|
||||
res = create_dataset(HttpApiAuth, payload)
|
||||
assert res["code"] == 103, res
|
||||
assert res["message"] == f"Dataset name '{name}' already exists", res
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["name"] == name + "(1)", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_case_insensitive(self, HttpApiAuth):
|
||||
|
|
@ -137,8 +138,8 @@ class TestDatasetCreate:
|
|||
|
||||
payload = {"name": name.lower()}
|
||||
res = create_dataset(HttpApiAuth, payload)
|
||||
assert res["code"] == 103, res
|
||||
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["name"] == name.lower() + "(1)", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_avatar(self, HttpApiAuth, tmp_path):
|
||||
|
|
|
|||
|
|
@ -17,13 +17,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
from operator import attrgetter
|
||||
|
||||
import pytest
|
||||
from configs import DATASET_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN
|
||||
from configs import DATASET_NAME_LIMIT, DEFAULT_PARSER_CONFIG, HOST_ADDRESS, INVALID_API_TOKEN
|
||||
from hypothesis import example, given, settings
|
||||
from ragflow_sdk import DataSet, RAGFlow
|
||||
from utils import encode_avatar
|
||||
from utils.file_utils import create_image_file
|
||||
from utils.hypothesis_utils import valid_names
|
||||
from configs import DEFAULT_PARSER_CONFIG
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
|
|
@ -95,9 +95,8 @@ class TestDatasetCreate:
|
|||
payload = {"name": name}
|
||||
client.create_dataset(**payload)
|
||||
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
client.create_dataset(**payload)
|
||||
assert str(excinfo.value) == f"Dataset name '{name}' already exists", str(excinfo.value)
|
||||
dataset = client.create_dataset(**payload)
|
||||
assert dataset.name == name + "(1)", str(dataset)
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_case_insensitive(self, client):
|
||||
|
|
@ -106,9 +105,8 @@ class TestDatasetCreate:
|
|||
client.create_dataset(**payload)
|
||||
|
||||
payload = {"name": name.lower()}
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
client.create_dataset(**payload)
|
||||
assert str(excinfo.value) == f"Dataset name '{name.lower()}' already exists", str(excinfo.value)
|
||||
dataset = client.create_dataset(**payload)
|
||||
assert dataset.name == name.lower() + "(1)", str(dataset)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_avatar(self, client, tmp_path):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue