Merge branch 'main' of github.com:infiniflow/ragflow into feature/1111

This commit is contained in:
chanx 2025-11-14 11:55:19 +08:00
commit 2c2553cd1b
58 changed files with 915 additions and 1096 deletions

View file

@ -95,6 +95,38 @@ jobs:
version: ">=0.11.x"
args: "check"
- name: Check comments of changed Python files
if: ${{ !cancelled() && !failure() }}
run: |
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
| grep -E '\.(py)$' || true)
if [ -n "$CHANGED_FILES" ]; then
echo "Check comments of changed Python files with check_comment_ascii.py"
readarray -t files <<< "$CHANGED_FILES"
HAS_ERROR=0
for file in "${files[@]}"; do
if [ -f "$file" ]; then
if python3 check_comment_ascii.py $file"; then
echo "✅ $file"
else
echo "❌ $file"
HAS_ERROR=1
fi
fi
done
if [ $HAS_ERROR -ne 0 ]; then
exit 1
fi
else
echo "No Python files changed"
fi
fi
- name: Build ragflow:nightly
run: |
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}

View file

@ -192,9 +192,10 @@ releases! 🌟
```bash
$ cd ragflow/docker
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View file

@ -192,6 +192,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
$ cd ragflow/docker
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View file

@ -172,6 +172,7 @@
$ cd ragflow/docker
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View file

@ -174,6 +174,7 @@
$ cd ragflow/docker
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View file

@ -192,6 +192,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
$ cd ragflow/docker
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View file

@ -191,6 +191,7 @@
$ cd ragflow/docker
# 可選使用穩定版標籤查看發佈https://github.com/infiniflow/ragflow/releasesgit checkout v0.22.0
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View file

@ -192,6 +192,7 @@
$ cd ragflow/docker
# 可选使用稳定版本标签查看发布https://github.com/infiniflow/ragflow/releases例如git checkout v0.22.0
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
# Use CPU for DeepDoc tasks:
$ docker compose -f docker-compose.yml up -d

View file

@ -368,11 +368,19 @@ Respond immediately with your final comprehensive answer.
return "Error occurred."
def reset(self, temp=False):
def reset(self, only_output=False):
"""
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
"""
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
for k, cpn in self.tools.items():
if hasattr(cpn, "reset") and callable(cpn.reset):
cpn.reset()
if only_output:
return
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}

View file

@ -222,7 +222,7 @@ class LLM(ComponentBase):
output_structure = self._param.outputs['structured']
except Exception:
pass
if output_structure:
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema)
for _ in range(self._param.max_retries+1):

View file

@ -96,12 +96,12 @@ login_manager.init_app(app)
commands.register_commands(app)
def search_pages_path(pages_dir):
def search_pages_path(page_path):
app_path_list = [
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
]
api_path_list = [
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
]
app_path_list.extend(api_path_list)
return app_path_list
@ -138,7 +138,7 @@ pages_dir = [
]
client_urls_prefix = [
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
]
@ -177,5 +177,7 @@ def load_user(web_request):
@app.teardown_request
def _db_close(exc):
def _db_close(exception):
if exception:
logging.exception(f"Request failed: {exception}")
close_connection()

View file

@ -13,41 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import re
from datetime import datetime, timedelta
from flask import request, Response
from api.db.services.llm_service import LLMBundle
from flask import request
from flask_login import login_required, current_user
from api.db import VALID_FILE_TYPES, FileType
from api.db.db_models import APIToken, Task, File
from api.db.services import duplicate_name
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService, API4ConversationService
from api.db.services.dialog_service import DialogService, chat
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from common.misc_utils import get_uuid
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token
from api.utils.file_utils import filename_type, thumbnail
from rag.app.tag import label_question
from rag.prompts.generator import keyword_extraction
from common.time_utils import current_timestamp, datetime_format
from api.db.services.canvas_service import UserCanvasService
from agent.canvas import Canvas
from functools import partial
from pathlib import Path
from common import settings
@manager.route('/new_token', methods=['POST']) # noqa: F821
@login_required
@ -138,758 +113,3 @@ def stats():
except Exception as e:
return server_error_response(e)
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
def set_conversation():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
if objs[0].source == "agent":
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
if not e:
return server_error_response("canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
else:
e, dia = DialogService.get_by_id(objs[0].dialog_id)
if not e:
return get_data_error_result(message="Dialog not found")
conv = {
"id": get_uuid(),
"dialog_id": dia.id,
"user_id": request.args.get("user_id", ""),
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/completion', methods=['POST']) # noqa: F821
@validate_request("conversation_id", "messages")
def completion():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req:
req["quote"] = False
msg = []
for m in req["messages"]:
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["id"] = message_id
def rename_field(ans):
reference = ans['reference']
if not isinstance(reference, dict):
return
for chunk_i in reference.get('chunks', []):
if 'docnm_kwd' in chunk_i:
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
try:
if conv.source == "agent":
stream = req.get("stream", True)
conv.message.append(msg[-1])
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
if not e:
return server_error_response("canvas not found.")
del req["conversation_id"]
del req["messages"]
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
final_ans = {"reference": [], "content": ""}
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
canvas.messages.append(msg[-1])
canvas.add_user_input(msg[-1]["content"])
answer = canvas.run(stream=stream)
assert answer is not None, "Nothing. Is it over?"
if stream:
assert isinstance(answer, partial), "Nothing. Is it over?"
def sse():
nonlocal answer, cvs, conv
try:
for ans in answer():
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
canvas.history.append(("assistant", final_ans["content"]))
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
fillin_conv(result)
API4ConversationService.append_message(conv.id, conv.to_dict())
rename_field(result)
return get_json_result(data=result)
# ******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
del req["conversation_id"]
del req["messages"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
rename_field(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
answer = None
for ans in chat(dia, msg, **req):
answer = ans
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
break
rename_field(answer)
return get_json_result(data=answer)
except Exception as e:
return server_error_response(e)
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
# @login_required
def get_conversation(conversation_id):
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
e, conv = API4ConversationService.get_by_id(conversation_id)
if not e:
return get_data_error_result(message="Conversation not found!")
conv = conv.to_dict()
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
code=RetCode.AUTHENTICATION_ERROR)
for referenct_i in conv['reference']:
if referenct_i is None or len(referenct_i) == 0:
continue
for chunk_i in referenct_i['chunks']:
if 'docnm_kwd' in chunk_i.keys():
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/document/upload', methods=['POST']) # noqa: F821
@validate_request("kb_name")
def upload():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
kb_name = request.form.get("kb_name").strip()
tenant_id = objs[0].tenant_id
try:
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
kb_id = kb.id
except Exception as e:
return server_error_response(e)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
file = request.files['file']
if file.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, tenant_id)
kb_root_folder = FileService.get_kb_folder(tenant_id)
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
try:
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
return get_data_error_result(
message="Exceed the maximum file number of a free user!")
filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb_id)
filetype = filename_type(filename)
if not filetype:
return get_data_error_result(
message="This type of file has not been supported yet!")
location = filename
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
location += "_"
blob = request.files['file'].read()
settings.STORAGE_IMPL.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": kb.tenant_id,
"type": filetype,
"name": filename,
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob),
"suffix": Path(filename).suffix.lstrip("."),
}
form_data = request.form
if "parser_id" in form_data.keys():
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
doc["parser_id"] = request.form.get("parser_id").strip()
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if doc["type"] == FileType.AURAL:
doc["parser_id"] = ParserType.AUDIO.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
if re.search(r"\.(eml)$", filename):
doc["parser_id"] = ParserType.EMAIL.value
doc_result = DocumentService.insert(doc)
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
except Exception as e:
return server_error_response(e)
if "run" in form_data.keys():
if request.form.get("run").strip() == "1":
try:
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
DocumentService.update_by_id(doc["id"], info)
# if str(req["run"]) == TaskStatus.CANCEL.value:
tenant_id = DocumentService.get_tenant_id(doc["id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
# e, doc = DocumentService.get_by_id(doc["id"])
TaskService.filter_delete([Task.doc_id == doc["id"]])
e, doc = DocumentService.get_by_id(doc["id"])
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0)
except Exception as e:
return server_error_response(e)
return get_json_result(data=doc_result.to_json())
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
@validate_request("conversation_id")
def upload_parse():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids)
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
# @login_required
def list_chunks():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if "doc_name" in req.keys():
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
elif "doc_id" in req.keys():
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
doc_id = req['doc_id']
else:
return get_json_result(
data=False, message="Can't find doc_name or doc_id"
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
"doc_name": res_item["docnm_kwd"],
"image_id": res_item["img_id"]
} for res_item in res
]
except Exception as e:
return server_error_response(e)
return get_json_result(data=res)
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
# @login_required
def get_chunk(chunk_id):
from rag.nlp import search
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
try:
tenant_id = objs[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None:
return server_error_response(Exception("Chunk not found"))
k = []
for n in chunk.keys():
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
k.append(n)
for n in k:
del chunk[n]
return get_json_result(data=chunk)
except Exception as e:
return server_error_response(e)
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
# @login_required
def list_kb_docs():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
tenant_id = objs[0].tenant_id
kb_name = req.get("kb_name", "").strip()
try:
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
if not e:
return get_data_error_result(
message="Can't find this knowledgebase!")
kb_id = kb.id
except Exception as e:
return server_error_response(e)
page_number = int(req.get("page", 1))
items_per_page = int(req.get("page_size", 15))
orderby = req.get("orderby", "create_time")
desc = req.get("desc", True)
keywords = req.get("keywords", "")
status = req.get("status", [])
if status:
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
)
types = req.get("types", [])
if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types:
return get_data_error_result(
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
)
try:
docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
return get_json_result(data={"total": tol, "docs": docs})
except Exception as e:
return server_error_response(e)
@manager.route('/document/infos', methods=['POST']) # noqa: F821
@validate_request("doc_ids")
def docinfos():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
doc_ids = req["doc_ids"]
docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts()))
@manager.route('/document', methods=['DELETE']) # noqa: F821
# @login_required
def document_rm():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
req = request.json
try:
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
for doc_id in req.get("doc_ids", []):
if doc_id not in doc_ids:
doc_ids.append(doc_id)
if not doc_ids:
return get_json_result(
data=False, message="Can't find doc_names or doc_ids"
)
except Exception as e:
return server_error_response(e)
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, tenant_id)
errors = ""
docs = DocumentService.get_by_ids(doc_ids)
doc_dic = {}
for doc in docs:
doc_dic[doc.id] = doc
for doc_id in doc_ids:
try:
if doc_id not in doc_dic:
return get_data_error_result(message="Document not found!")
doc = doc_dic[doc_id]
tenant_id = DocumentService.get_tenant_id(doc_id)
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
if not DocumentService.remove_document(doc, tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
f2d = File2DocumentService.get_by_document_id(doc_id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc_id)
settings.STORAGE_IMPL.rm(b, n)
except Exception as e:
errors += str(e)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=True)
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
@validate_request("Authorization", "conversation_id", "word")
def completion_faq():
import base64
req = request.json
token = req["Authorization"]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req:
req["quote"] = True
msg = [{"role": "user", "content": req["word"]}]
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]
def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
ans["id"] = message_id
try:
if conv.source == "agent":
conv.message.append(msg[-1])
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
if not e:
return server_error_response("canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
final_ans = {"reference": [], "doc_aggs": []}
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
canvas.messages.append(msg[-1])
canvas.add_user_input(msg[-1]["content"])
answer = canvas.run(stream=False)
assert answer is not None, "Nothing. Is it over?"
data_type_picture = {
"type": 3,
"url": "base64 content"
}
data = [
{
"type": 1,
"content": ""
}
]
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = settings.STORAGE_IMPL.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
break
except Exception as e:
return server_error_response(e)
response = {"code": 200, "msg": "success", "data": data}
return response
# ******************For dialog******************
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
del req["conversation_id"]
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})
data_type_picture = {
"type": 3,
"url": "base64 content"
}
data = [
{
"type": 1,
"content": ""
}
]
ans = ""
for a in chat(dia, msg, stream=False, **req):
ans = a
break
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = settings.STORAGE_IMPL.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
break
except Exception as e:
return server_error_response(e)
response = {"code": 200, "msg": "success", "data": data}
return response
except Exception as e:
return server_error_response(e)
@manager.route('/retrieval', methods=['POST']) # noqa: F821
@validate_request("kb_id", "question")
def retrieval():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
req = request.json
kb_ids = req.get("kb_id", [])
doc_ids = req.get("doc_ids", [])
question = req.get("question")
page = int(req.get("page", 1))
size = int(req.get("page_size", 30))
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
highlight = bool(req.get("highlight", False))
try:
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_json_result(
data=False, message='Knowledge bases use different embedding models or does not exist."',
code=RetCode.AUTHENTICATION_ERROR)
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
if req.get("keyword", False):
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
rank_feature=label_question(question, kbs))
for c in ranks["chunks"]:
c.pop("vector", None)
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
return server_error_response(e)

View file

@ -426,7 +426,6 @@ def test_db_connect():
try:
import trino
import os
from trino.auth import BasicAuthentication
except Exception as e:
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
@ -438,7 +437,7 @@ def test_db_connect():
auth = None
if http_scheme == "https" and req.get("password"):
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
conn = trino.dbapi.connect(
host=req["host"],
@ -471,8 +470,8 @@ def test_db_connect():
@login_required
def getlistversion(canvas_id):
try:
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=list)
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
return get_json_result(data=versions)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")

View file

@ -55,7 +55,6 @@ def set_connector():
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
"status": TaskStatus.SCHEDULE,
}
conn["status"] = TaskStatus.SCHEDULE
ConnectorService.save(**conn)
time.sleep(1)

View file

@ -85,7 +85,6 @@ def get():
if not e:
return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
avatar = None
for tenant in tenants:
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
if dialog and len(dialog) > 0:

View file

@ -154,15 +154,15 @@ def get_kb_names(kb_ids):
@login_required
def list_dialogs():
try:
diags = DialogService.query(
conversations = DialogService.query(
tenant_id=current_user.id,
status=StatusEnum.VALID.value,
reverse=True,
order_by=DialogService.model.create_time)
diags = [d.to_dict() for d in diags]
for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
return get_json_result(data=diags)
conversations = [d.to_dict() for d in conversations]
for conversation in conversations:
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
return get_json_result(data=conversations)
except Exception as e:
return server_error_response(e)

View file

@ -308,7 +308,7 @@ def get_filter():
@manager.route("/infos", methods=["POST"]) # noqa: F821
@login_required
def docinfos():
def doc_infos():
req = request.json
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
@ -544,6 +544,7 @@ def change_parser():
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None
try:
if "pipeline_id" in req and req["pipeline_id"] != "":

View file

@ -246,8 +246,8 @@ def rm():
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:

View file

@ -16,6 +16,7 @@
import json
import logging
import random
import re
from flask import request
from flask_login import login_required, current_user
@ -731,6 +732,8 @@ def delete_kb_task():
def cancel_task(task_id):
REDIS_CONN.set(f"{task_id}-cancel", "x")
kb_task_id_field: str = ""
kb_task_finish_at: str = ""
match pipeline_task_type:
case PipelineTaskType.GRAPH_RAG:
kb_task_id_field = "graphrag_task_id"
@ -807,7 +810,7 @@ def check_embedding():
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
total = docStoreConn.getTotal(res0)
total = docStoreConn.get_total(res0)
if total <= 0:
return []
@ -824,7 +827,7 @@ def check_embedding():
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
)
ids = docStoreConn.getChunkIds(res1)
ids = docStoreConn.get_chunk_ids(res1)
if not ids:
continue
@ -845,8 +848,13 @@ def check_embedding():
"position_int": full_doc.get("position_int"),
"top_int": full_doc.get("top_int"),
"content_with_weight": full_doc.get("content_with_weight") or "",
"question_kwd": full_doc.get("question_kwd") or []
})
return out
def _clean(s: str) -> str:
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
return s if s else "None"
req = request.json
kb_id = req.get("kb_id", "")
embd_id = req.get("embd_id", "")
@ -859,8 +867,10 @@ def check_embedding():
results, eff_sims = [], []
for ck in samples:
txt = (ck.get("content_with_weight") or "").strip()
if not txt:
title = ck.get("doc_name") or "Title"
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
txt_in = _clean(txt_in)
if not txt_in:
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
continue
@ -869,8 +879,16 @@ def check_embedding():
continue
try:
qv, _ = emb_mdl.encode_queries(txt)
sim = _cos_sim(qv, ck["vector"])
v, _ = emb_mdl.encode([title, txt_in])
sim_content = _cos_sim(v[1], ck["vector"])
title_w = 0.1
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
sim_mix = _cos_sim(qv_mix, ck["vector"])
sim = sim_content
mode = "content_only"
if sim_mix > sim:
sim = sim_mix
mode = "title+content"
except Exception:
return get_error_data_result(message="embedding failure")
@ -892,8 +910,9 @@ def check_embedding():
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
"match_mode": mode,
}
if summary["avg_cos_sim"] > 0.99:
if summary["avg_cos_sim"] > 0.9:
return get_json_result(data={"summary": summary, "results": results})
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})

View file

@ -21,10 +21,11 @@ import json
from flask import request
from peewee import OperationalError
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
from api.db.services.user_service import TenantService
from common.constants import RetCode, FileSource, StatusEnum
from api.utils.api_utils import (
@ -118,7 +119,6 @@ def create(tenant_id):
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
if err is not None:
return get_error_argument_result(err)
req = KnowledgebaseService.create_with_name(
name = req.pop("name", None),
tenant_id = tenant_id,
@ -144,7 +144,6 @@ def create(tenant_id):
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)
except Exception as e:
@ -532,3 +531,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
@token_required
def run_graphrag(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.graphrag_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
return get_result(data={"graphrag_task_id": task_id})
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
@token_required
def trace_graphrag(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.graphrag_task_id
if not task_id:
return get_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_result(data={})
return get_result(data=task.to_dict())
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
@token_required
def run_raptor(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.raptor_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
return get_result(data={"raptor_task_id": task_id})
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
@token_required
def trace_raptor(tenant_id,dataset_id):
if not dataset_id:
return get_error_data_result(message='Lack of "Dataset ID"')
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
)
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
if not ok:
return get_error_data_result(message="Invalid Dataset ID")
task_id = kb.raptor_task_id
if not task_id:
return get_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
return get_result(data=task.to_dict())

View file

@ -93,6 +93,10 @@ def upload(dataset_id, tenant_id):
type: file
required: true
description: Document files to upload.
- in: formData
name: parent_path
type: string
description: Optional nested path under the parent folder. Uses '/' separators.
responses:
200:
description: Successfully uploaded documents.
@ -151,7 +155,7 @@ def upload(dataset_id, tenant_id):
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id)
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=request.form.get("parent_path"))
if err:
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
# rename key's name

View file

@ -242,7 +242,7 @@ class Connector2KbService(CommonService):
"id": get_uuid(),
"connector_id": conn_id,
"kb_id": kb_id,
"auto_parse": conn.get("auto_parse", "1")
"auto_parse": conn.get("auto_parse", "1")
})
SyncLogsService.schedule(conn_id, kb_id, reindex=True)

View file

@ -309,7 +309,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@ -322,7 +322,7 @@ class DocumentService(CommonService):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields(
graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:

View file

@ -31,7 +31,7 @@ from common.misc_utils import get_uuid
from common.constants import TaskStatus, FileSource, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
from rag.llm.cv_model import GptV4
from common import settings
@ -329,7 +329,7 @@ class FileService(CommonService):
current_id = start_id
while current_id:
e, file = cls.get_by_id(current_id)
if file.parent_id != file.id and e:
if e and file.parent_id != file.id:
parent_folders.append(file)
current_id = file.parent_id
else:
@ -423,13 +423,15 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def upload_document(self, kb, file_objs, user_id, src="local"):
def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None):
root_folder = self.get_root_folder(user_id)
pf_id = root_folder["id"]
self.init_knowledgebase_docs(pf_id, user_id)
kb_root_folder = self.get_kb_folder(user_id)
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
safe_parent_path = sanitize_path(parent_path)
err, files = [], []
for file in file_objs:
try:
@ -439,7 +441,7 @@ class FileService(CommonService):
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")
location = filename
location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}"
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
location += "_"

View file

@ -164,3 +164,23 @@ def read_potential_broken_pdf(blob):
return repaired
return blob
def sanitize_path(raw_path: str | None) -> str:
"""Normalize and sanitize a user-provided path segment.
- Converts backslashes to forward slashes
- Strips leading/trailing slashes
- Removes '.' and '..' segments
- Restricts characters to A-Za-z0-9, underscore, dash, and '/'
"""
if not raw_path:
return ""
backslash_re = re.compile(r"[\\]+")
unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]")
normalized = backslash_re.sub("/", raw_path)
normalized = normalized.strip("/")
parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")]
sanitized = "/".join(parts)
sanitized = unsafe_re.sub("", sanitized)
return sanitized

36
check_comment_ascii.py Normal file
View file

@ -0,0 +1,36 @@
#!/usr/bin/env python3
import sys
import tokenize
import ast
import pathlib
import re
ASCII = re.compile(r"^[ -~]*\Z") # Only printable ASCII
def check(src: str, name: str) -> int:
"""
I'm a docstring
"""
ok = 1
# A common comment begins with `#`
with tokenize.open(src) as fp:
for tk in tokenize.generate_tokens(fp.readline):
if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string):
print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}")
ok = 0
# A docstring begins and ends with `'''`
for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)):
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc):
print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}")
ok = 0
return ok
if __name__ == "__main__":
status = 0
for file in sys.argv[1:]:
if not check(file, file):
status = 1
sys.exit(status)

View file

@ -3,15 +3,9 @@ import os
import threading
from typing import Any, Callable
import requests
from common.data_source.config import DocumentSource
from common.data_source.google_util.constant import GOOGLE_SCOPES
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
DEFAULT_DEVICE_INTERVAL = 5
def _get_requested_scopes(source: DocumentSource) -> list[str]:
"""Return the scopes to request, honoring an optional override env var."""
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
return result.get("value")
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
if "client_id" in credentials:
return credentials["client_id"], credentials.get("client_secret")
for key in ("installed", "web"):
if key in credentials and isinstance(credentials[key], dict):
nested = credentials[key]
if "client_id" not in nested:
break
return nested["client_id"], nested.get("client_secret")
raise ValueError("Provided Google OAuth credentials are missing client_id.")
def start_device_authorization_flow(
credentials: dict[str, Any],
source: DocumentSource,
) -> tuple[dict[str, Any], dict[str, Any]]:
client_id, client_secret = _extract_client_info(credentials)
data = {
"client_id": client_id,
"scope": " ".join(_get_requested_scopes(source)),
}
if client_secret:
data["client_secret"] = client_secret
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
resp.raise_for_status()
payload = resp.json()
state = {
"client_id": client_id,
"client_secret": client_secret,
"device_code": payload.get("device_code"),
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
}
response_data = {
"user_code": payload.get("user_code"),
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
"verification_url_complete": payload.get("verification_url_complete")
or payload.get("verification_uri_complete"),
"expires_in": payload.get("expires_in"),
"interval": state["interval"],
}
return state, response_data
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
data = {
"client_id": state["client_id"],
"device_code": state["device_code"],
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
}
if state.get("client_secret"):
data["client_secret"] = state["client_secret"]
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
resp.raise_for_status()
return resp.json()
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
port = int(preferred_port) if preferred_port else 0
timeout_secs = _get_oauth_timeout_secs()
timeout_message = (
f"Google OAuth verification timed out after {timeout_secs} seconds. "
"Close any pending consent windows and rerun the connector configuration to try again."
)
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
print("Launching Google OAuth flow. A browser window should open shortly.")
print("If it does not, copy the URL shown in the console into your browser manually.")
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
instructions = [
"Google rejected one or more of the requested OAuth scopes.",
"Fix options:",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
" (be aware the connector may lose functionality).",
]
raise RuntimeError("\n".join(instructions)) from warning
raise
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
client_config = {"web": credentials["web"]}
if client_config is None:
raise ValueError(
"Provided Google OAuth credentials are missing both tokens and a client configuration."
)
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
return _run_local_server_flow(client_config, source)

View file

@ -186,9 +186,6 @@ class DoclingParser(RAGFlowPdfParser):
yield (DoclingContentType.EQUATION.value, text, bbox)
def _transfer_to_sections(self, doc) -> list[tuple[str, str]]:
"""
MinerUParser 保持一致返回 [(section_text, line_tag), ...]
"""
sections: list[tuple[str, str]] = []
for typ, payload, bbox in self._iter_doc_items(doc):
if typ == DoclingContentType.TEXT.value:

View file

@ -34,6 +34,7 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
if isinstance(figure_data[1], Image.Image)
]
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
@ -50,7 +51,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
try:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
@ -72,6 +74,7 @@ def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
shared_executor = ThreadPoolExecutor(max_workers=10)

View file

@ -117,7 +117,6 @@ def load_model(model_dir, nm, device_id: int | None = None):
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id))
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
else:
sess = ort.InferenceSession(

View file

@ -0,0 +1,8 @@
{
"label": "Add data source",
"position": 18,
"link": {
"type": "generated-index",
"description": "Add various data sources"
}
}

View file

@ -0,0 +1,137 @@
---
sidebar_position: 3
slug: /add_google_drive
---
# Add Google Drive
## 1. Create a Google Cloud Project
You can either create a dedicated project for RAGFlow or use an existing
Google Cloud external project.
**Steps:**
1. Open the project creation page\
`https://console.cloud.google.com/projectcreate`
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image1.jpeg?raw=true)
2. Select **External** as the Audience
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image2.png?raw=true)
3. Click **Create**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image3.jpeg?raw=true)
------------------------------------------------------------------------
## 2. Configure OAuth Consent Screen
1. Go to **APIs & Services → OAuth consent screen**
2. Ensure **User Type = External**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image4.jpeg?raw=true)
3. Add your test users under **Test Users** by entering email addresses
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image5.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image6.jpeg?raw=true)
------------------------------------------------------------------------
## 3. Create OAuth Client Credentials
1. Navigate to:\
`https://console.cloud.google.com/auth/clients`
2. Create a **Web Application**
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image7.png?raw=true)
3. Enter a name for the client
4. Add the following **Authorized Redirect URIs**:
```
http://localhost:9380/v1/connector/google-drive/oauth/web/callback
```
### If using Docker deployment:
**Authorized JavaScript origin:**
```
http://localhost:80
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image8.png?raw=true)
### If running from source:
**Authorized JavaScript origin:**
```
http://localhost:9222
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image9.png?raw=true)
5. After saving, click **Download JSON**. This file will later be
uploaded into RAGFlow.
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image10.png?raw=true)
------------------------------------------------------------------------
## 4. Add Scopes
1. Open **Data Access → Add or remove scopes**
2. Paste and add the following entries:
```
https://www.googleapis.com/auth/drive.readonly
https://www.googleapis.com/auth/drive.metadata.readonly
https://www.googleapis.com/auth/admin.directory.group.readonly
https://www.googleapis.com/auth/admin.directory.user.readonly
```
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image11.jpeg?raw=true)
3. Update and Save changes
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image12.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image13.jpeg?raw=true)
------------------------------------------------------------------------
## 5. Enable Required APIs
Navigate to the Google API Library:\
`https://console.cloud.google.com/apis/library`
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image14.png?raw=true)
Enable the following APIs:
- Google Drive API
- Admin SDK API
- Google Sheets API
- Google Docs API
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image15.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image16.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image17.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image18.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image19.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image21.png?raw=true)
------------------------------------------------------------------------
## 6. Add Google Drive As a Data Source in RAGFlow
1. Go to **Data Sources** inside RAGFlow
2. Select **Google Drive**
3. Upload the previously downloaded JSON credentials
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image22.jpeg?raw=true)
4. Enter the shared Google Drive folder link (https://drive.google.com/drive), such as:
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image23.png?raw=true)
5. Click **Authorize with Google**
A browser window will appear.
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image25.jpeg?raw=true)
Click: - **Continue** - **Select All → Continue** - Authorization should
succeed - Select **OK** to add the data source
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image26.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image27.jpeg?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image28.png?raw=true)
![placeholder-image](https://github.com/infiniflow/ragflow-docs/blob/040e4acd4c1eac6dc73dc44e934a6518de78d097/images/google_drive/image29.png?raw=true)

View file

@ -1,6 +1,6 @@
{
"label": "Best practices",
"position": 11,
"position": 19,
"link": {
"type": "generated-index",
"description": "Best practices on configuring a dataset."

View file

@ -64,7 +64,10 @@ The Admin CLI and Admin Service form a client-server architectural suite for RAG
- -p: RAGFlow admin server port
## Default administrative account
- Username: admin@ragflow.io
- Password: admin
## Supported Commands

View file

@ -974,6 +974,237 @@ Failure:
---
### Construct knowledge graph
**POST** `/api/v1/datasets/{dataset_id}/run_graphrag`
Constructs a knowledge graph from a specified dataset.
#### Request
- Method: POST
- URL: `/api/v1/datasets/{dataset_id}/run_graphrag`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request POST \
--url http://{address}/api/v1/datasets/{dataset_id}/run_graphrag \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"graphrag_task_id":"e498de54bfbb11f0ba028f704583b57b"
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Get knowledge graph construction status
**GET** `/api/v1/datasets/{dataset_id}/trace_graphrag`
Retrieves the knowledge graph construction status for a specified dataset.
#### Request
- Method: GET
- URL: `/api/v1/datasets/{dataset_id}/trace_graphrag`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request GET \
--url http://{address}/api/v1/datasets/{dataset_id}/trace_graphrag \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"begin_at":"Wed, 12 Nov 2025 19:36:56 GMT",
"chunk_ids":"",
"create_date":"Wed, 12 Nov 2025 19:36:56 GMT",
"create_time":1762947416350,
"digest":"39e43572e3dcd84f",
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
"from_page":100000000,
"id":"e498de54bfbb11f0ba028f704583b57b",
"priority":0,
"process_duration":2.45419,
"progress":1.0,
"progress_msg":"19:36:56 created task graphrag\n19:36:57 Task has been received.\n19:36:58 [GraphRAG] doc:083661febe2411f0bc79456921e5745f has no available chunks, skip generation.\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 start (chunks=1, timeout=10000000000s)\n19:36:58 Graph already contains 44661c10bde211f0bc93c164a47ffc40\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 empty\n19:36:58 [GraphRAG] kb:33137ed0bde211f0bc93c164a47ffc40 no subgraphs generated successfully, end.\n19:36:58 Knowledge Graph done (0.72s)","retry_count":1,
"task_type":"graphrag",
"to_page":100000000,
"update_date":"Wed, 12 Nov 2025 19:36:58 GMT",
"update_time":1762947418454
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Construct RAPTOR
**POST** `/api/v1/datasets/{dataset_id}/run_raptor`
Construct a RAPTOR from a specified dataset.
#### Request
- Method: POST
- URL: `/api/v1/datasets/{dataset_id}/run_raptor`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request POST \
--url http://{address}/api/v1/datasets/{dataset_id}/run_raptor \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"raptor_task_id":"50d3c31cbfbd11f0ba028f704583b57b"
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
### Get RAPTOR construction status
**GET** `/api/v1/datasets/{dataset_id}/trace_raptor`
Retrieves the RAPTOR construction status for a specified dataset.
#### Request
- Method: GET
- URL: `/api/v1/datasets/{dataset_id}/trace_raptor`
- Headers:
- `'Authorization: Bearer <YOUR_API_KEY>'`
##### Request example
```bash
curl --request GET \
--url http://{address}/api/v1/datasets/{dataset_id}/trace_raptor \
--header 'Authorization: Bearer <YOUR_API_KEY>'
```
##### Request parameters
- `dataset_id`: (*Path parameter*)
The ID of the target dataset.
#### Response
Success:
```json
{
"code":0,
"data":{
"begin_at":"Wed, 12 Nov 2025 19:47:07 GMT",
"chunk_ids":"",
"create_date":"Wed, 12 Nov 2025 19:47:07 GMT",
"create_time":1762948027427,
"digest":"8b279a6248cb8fc6",
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
"from_page":100000000,
"id":"50d3c31cbfbd11f0ba028f704583b57b",
"priority":0,
"process_duration":0.948244,
"progress":1.0,
"progress_msg":"19:47:07 created task raptor\n19:47:07 Task has been received.\n19:47:07 Processing...\n19:47:07 Processing...\n19:47:07 Indexing done (0.01s).\n19:47:07 Task done (0.29s)",
"retry_count":1,
"task_type":"raptor",
"to_page":100000000,
"update_date":"Wed, 12 Nov 2025 19:47:07 GMT",
"update_time":1762948027948
}
}
```
Failure:
```json
{
"code": 102,
"message": "Invalid Dataset ID"
}
```
---
## FILE MANAGEMENT WITHIN DATASET
---

View file

@ -114,7 +114,7 @@ class Extractor:
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
out_results = []
error_count = 0
max_errors = 3
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
limiter = trio.Semaphore(max_concurrency)

View file

@ -69,7 +69,7 @@ class KGSearch(Dealer):
def _ent_info_from_(self, es_res, sim_thr=0.3):
res = {}
flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
es_res = self.dataStore.getFields(es_res, flds)
es_res = self.dataStore.get_fields(es_res, flds)
for _, ent in es_res.items():
for f in flds:
if f in ent and ent[f] is None:
@ -88,7 +88,7 @@ class KGSearch(Dealer):
def _relation_info_from_(self, es_res, sim_thr=0.3):
res = {}
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
"weight_int"])
for _, ent in es_res.items():
if get_float(ent["_score"]) < sim_thr:
@ -300,7 +300,7 @@ class KGSearch(Dealer):
fltr["entities_kwd"] = entities
comm_res = self.dataStore.search(fields, [], fltr, [],
OrderByExpr(), 0, topn, idxnms, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, fields)
comm_res_fields = self.dataStore.get_fields(comm_res, fields)
txts = []
for ii, (_, row) in enumerate(comm_res_fields.items()):
obj = json.loads(row["content_with_weight"])

View file

@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
)
# tot = settings.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds)
# tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.get_fields(es_res, flds)
if len(es_res) == 0:
break

View file

@ -482,7 +482,7 @@ def tree_merge(bull, sections, depth):
root = Node(level=0, depth=target_level, texts=[])
root.build_tree(lines)
return [("\n").join(element) for element in root.get_tree() if element]
return [element for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):

View file

@ -38,11 +38,11 @@ class FulltextQueryer:
]
@staticmethod
def subSpecialChar(line):
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
@ -92,7 +92,7 @@ class FulltextQueryer:
otxt = txt
txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt):
if not self.is_chinese(txt):
txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
@ -163,7 +163,7 @@ class FulltextQueryer:
)
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
@ -171,7 +171,7 @@ class FulltextQueryer:
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
@ -180,7 +180,7 @@ class FulltextQueryer:
if len(keywords) >= 32:
break
tk = FulltextQueryer.subSpecialChar(tk)
tk = FulltextQueryer.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
@ -198,7 +198,7 @@ class FulltextQueryer:
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
for s in syns
]
)
@ -217,17 +217,17 @@ class FulltextQueryer:
return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
sims = cosine_similarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):
def to_dict(tks):
if isinstance(tks, str):
tks = tks.split()
d = defaultdict(int)
@ -236,8 +236,8 @@ class FulltextQueryer:
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
atks = to_dict(atks)
btkss = [to_dict(tks) for tks in btkss]
return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt):
@ -262,10 +262,10 @@ class FulltextQueryer:
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
tk = FulltextQueryer.sub_special_char(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:

View file

@ -35,7 +35,7 @@ class RagTokenizer:
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
def _load_dict(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}")
try:
of = open(fnm, "r", encoding='utf-8')
@ -85,18 +85,18 @@ class RagTokenizer:
self.trie_ = datrie.Trie(string.printable)
# load data from dict file and save to trie file
self.loadDict_(self.DIR_ + ".txt")
self._load_dict(self.DIR_ + ".txt")
def loadUserDict(self, fnm):
def load_user_dict(self, fnm):
try:
self.trie_ = datrie.Trie.load(fnm + ".trie")
return
except Exception:
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm)
self._load_dict(fnm)
def addUserDict(self, fnm):
self.loadDict_(fnm)
def add_user_dict(self, fnm):
self._load_dict(fnm)
def _strQ2B(self, ustring):
"""Convert full-width characters to half-width characters"""
@ -221,7 +221,7 @@ class RagTokenizer:
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
def _sort_tokens(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
@ -246,7 +246,7 @@ class RagTokenizer:
return " ".join(res)
def maxForward_(self, line):
def _max_forward(self, line):
res = []
s = 0
while s < len(line):
@ -270,7 +270,7 @@ class RagTokenizer:
return self.score_(res)
def maxBackward_(self, line):
def _max_backward(self, line):
res = []
s = len(line) - 1
while s >= 0:
@ -336,8 +336,8 @@ class RagTokenizer:
continue
# use maxforward for the first time
tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L)
tks, s = self._max_forward(L)
tks1, s1 = self._max_backward(L)
if self.DEBUG:
logging.debug("[FW] {} {}".format(tks, s))
logging.debug("[BW] {} {}".format(tks1, s1))
@ -369,7 +369,7 @@ class RagTokenizer:
# backward tokens from_i to i are different from forward tokens from _j to j.
tkslist = []
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
@ -385,7 +385,7 @@ class RagTokenizer:
assert "".join(tks1[_i:]) == "".join(tks[_j:])
tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
res = " ".join(res)
logging.debug("[TKS] {}".format(self.merge_(res)))
@ -413,7 +413,7 @@ class RagTokenizer:
if len(tkslist) < 2:
res.append(tk)
continue
stk = self.sortTks_(tkslist)[1][0]
stk = self._sort_tokens(tkslist)[1][0]
if len(stk) == len(tk):
stk = tk
else:
@ -447,14 +447,13 @@ def is_number(s):
def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or (
s >= u'\u0061' and s <= u'\u007a'):
if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'):
return True
else:
return False
def naiveQie(txt):
def naive_qie(txt):
tks = []
for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize
fine_grained_tokenize = tokenizer.fine_grained_tokenize
tag = tokenizer.tag
freq = tokenizer.freq
loadUserDict = tokenizer.loadUserDict
addUserDict = tokenizer.addUserDict
load_user_dict = tokenizer.load_user_dict
add_user_dict = tokenizer.add_user_dict
tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B
if __name__ == '__main__':
tknzr = RagTokenizer(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
# huqie.add_user_dict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
logging.info(tknzr.fine_grained_tokenize(tks))
@ -506,7 +505,7 @@ if __name__ == '__main__':
if len(sys.argv) < 2:
sys.exit()
tknzr.DEBUG = False
tknzr.loadUserDict(sys.argv[1])
tknzr.load_user_dict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()

View file

@ -102,7 +102,7 @@ class Dealer:
orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
highlightFields = ["content_ltks", "title_tks"]
@ -115,7 +115,7 @@ class Dealer:
matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
@ -127,20 +127,20 @@ class Dealer:
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match
if total == 0:
if filters.get("doc_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
else:
matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res)
total = self.dataStore.get_total(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords:
@ -153,17 +153,17 @@ class Dealer:
kwds.add(kk)
logging.debug(f"TOTAL: {total}")
ids = self.dataStore.getChunkIds(res)
ids = self.dataStore.get_chunk_ids(res)
keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
return self.SearchResult(
total=total,
ids=ids,
query_vector=q_vec,
aggregation=aggs,
highlight=highlight,
field=self.dataStore.getFields(res, src + ["_score"]),
field=self.dataStore.get_fields(res, src + ["_score"]),
keywords=keywords
)
@ -488,7 +488,7 @@ class Dealer:
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields)
dict_chunks = self.dataStore.get_fields(es_res, fields)
for id, doc in dict_chunks.items():
doc["id"] = id
if dict_chunks:
@ -501,11 +501,11 @@ class Dealer:
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd")
return self.dataStore.get_aggregation(res, "tag_kwd")
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd")
res = self.dataStore.get_aggregation(res, "tag_kwd")
total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res}
@ -513,7 +513,7 @@ class Dealer:
idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs:
return False
cnt = np.sum([c for _, c in aggs])
@ -529,7 +529,7 @@ class Dealer:
idx_nms = [index_name(tid) for tid in tenant_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd")
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs:
return {}
cnt = np.sum([c for _, c in aggs])
@ -552,7 +552,7 @@ class Dealer:
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
kb_ids)
toc = []
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
for _, doc in dict_chunks.items():
try:
toc.extend(json.loads(doc["content_with_weight"]))

View file

@ -113,20 +113,20 @@ class Dealer:
res.append(tk)
return res
def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
def token_merge(self, tks):
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
j = i
if i == 0 and oneTerm(tks[i]) and len(
if i == 0 and one_term(tks[i]) and len(
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2]))
i = 2
continue
while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]):
j += 1
if j - i > 1:
if j - i < 5:
@ -232,7 +232,7 @@ class Dealer:
tw = list(zip(tks, wts))
else:
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))
tt = self.token_merge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \

View file

@ -15,27 +15,35 @@
#
import logging
import re
import umap
import numpy as np
from sklearn.mixture import GaussianMixture
import trio
import umap
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
from common.exceptions import TaskCanceledException
from common.token_utils import truncate
from graphrag.utils import (
get_llm_cache,
chat_limiter,
get_embed_cache,
get_llm_cache,
set_embed_cache,
set_llm_cache,
chat_limiter,
)
from common.token_utils import truncate
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
self,
max_cluster,
llm_model,
embd_model,
prompt,
max_token=512,
threshold=0.1,
max_errors=3,
):
self._max_cluster = max_cluster
self._llm_model = llm_model
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._threshold = threshold
self._prompt = prompt
self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
@timeout(60*20)
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
response = await trio.to_thread.run_sync(
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
)
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
if cached:
return cached
if response:
return response
response = await trio.to_thread.run_sync(
lambda: self._llm_model.chat(system, history, gen_conf)
)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
)
return response
last_exc = None
for attempt in range(3):
try:
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
return response
except Exception as exc:
last_exc = exc
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
if attempt < 2:
await trio.sleep(1 + attempt)
raise last_exc if last_exc else Exception("LLM chat failed without exception")
@timeout(20)
async def _embedding_encode(self, txt):
response = await trio.to_thread.run_sync(
lambda: get_embed_cache(self._embd_model.llm_name, txt)
)
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
if response is not None:
return response
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_clusters = np.arange(1, max_clusters)
bics = []
for n in n_clusters:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
@timeout(60*20)
@timeout(60 * 20)
async def summarize(ck_idx: list[int]):
nonlocal chunks
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
raise TaskCanceledException(f"Task {task_id} was cancelled")
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int(
(self._llm_model.max_length - self._max_token) / len(texts)
)
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter:
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
labels = []
while end - start > 1:
if task_id:
if has_canceled(task_id):
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
if len(embeddings) == 2:
await summarize([start, start + 1])
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
labels.extend([0, 0])
layers.append((end, len(chunks)))
start = end
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
nursery.start_soon(summarize, ck_idx)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters
)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls)
layers.append((end, len(chunks)))
if callback:
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
start = end
end = len(chunks)

View file

@ -28,7 +28,7 @@ def collect():
logging.debug(doc_locations)
if len(doc_locations) == 0:
time.sleep(1)
return
return None
return doc_locations

View file

@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback):
task_canceled = has_canceled(task["id"])
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback):
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
def init_kb(row, vector_size: int):
@ -441,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
tk_count = 0
if len(tts) == len(cnts):
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@timeout(60)
@ -464,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
if not filename_embd_weight:
filename_embd_weight = 0.1
title_w = float(filename_embd_weight)
vects = (title_w * tts + (1 - title_w) *
cnts) if len(tts) == len(cnts) else cnts
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
vects = title_w * tts + (1 - title_w) * cnts
else:
vects = cnts
assert len(vects) == len(docs)
vector_size = 0
@ -648,6 +651,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
res = []
tk_count = 0
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
async def generate(chunks, did):
nonlocal tk_count, res
raptor = Raptor(
@ -657,6 +662,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["prompt"],
raptor_config["max_token"],
raptor_config["threshold"],
max_errors=max_errors,
)
original_length = len(chunks)
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
@ -719,7 +725,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
return False
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
@ -737,7 +743,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return
return False
return True

View file

@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"Fail put {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return None
return None
def rm(self, bucket, fnm):
try:
@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
def obj_exist(self, bucket, fnm):
try:
@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None

View file

@ -241,23 +241,23 @@ class DocStoreConnection(ABC):
"""
@abstractmethod
def getTotal(self, res):
def get_total(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getChunkIds(self, res):
def get_chunk_ids(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def getAggregation(self, res, fieldnm: str):
def get_aggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""

View file

@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res):
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getChunkIds(self, res):
def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection):
rr.append(d["_source"])
return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection):
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection):
return ans
def getAggregation(self, res, fieldnm: str):
def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()

View file

@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection):
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res_fields = self.getFields(res, res.columns.tolist())
res_fields = self.get_fields(res, res.columns.tolist())
return res_fields.get(chunkId, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection):
col_to_remove = list(removeValue.keys())
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
row_to_opt = self.getFields(row_to_opt, col_to_remove)
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
for id, old_v in row_to_opt.items():
for k, remove_v in removeValue.items():
if remove_v in old_v[k]:
@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection):
return res2.set_index("id").to_dict(orient="index")
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection):
ans[id] = txt
return ans
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""

View file

@ -92,7 +92,7 @@ class RAGFlowMinio:
logging.exception(f"Fail to get {bucket}/{filename}")
self.__open__()
time.sleep(1)
return
return None
def obj_exist(self, bucket, filename, tenant_id=None):
try:
@ -130,7 +130,7 @@ class RAGFlowMinio:
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
self.__open__()
time.sleep(1)
return
return None
def remove_bucket(self, bucket):
try:

View file

@ -62,8 +62,7 @@ class OpenDALStorage:
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
r = self._operator.write(f"{bucket}/{fnm}", binary)
return r
return self._operator.write(f"{bucket}/{fnm}", binary)
def put(self, bucket, fnm, binary, tenant_id=None):
self._operator.write(f"{bucket}/{fnm}", binary)

View file

@ -455,12 +455,12 @@ class OSConnection(DocStoreConnection):
Helper functions for search result
"""
def getTotal(self, res):
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getChunkIds(self, res):
def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
@ -471,7 +471,7 @@ class OSConnection(DocStoreConnection):
rr.append(d["_source"])
return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
@ -490,7 +490,7 @@ class OSConnection(DocStoreConnection):
res_fields[d["id"]] = m
return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str):
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
@ -515,7 +515,7 @@ class OSConnection(DocStoreConnection):
return ans
def getAggregation(self, res, fieldnm: str):
def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()

View file

@ -141,7 +141,7 @@ class RAGFlowOSS:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
@use_prefix_path
@use_default_bucket
@ -170,5 +170,5 @@ class RAGFlowOSS:
logging.exception(f"fail get url {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None

View file

@ -104,6 +104,7 @@ class RedisDB:
if self.REDIS.get(a) == b:
return True
return False
def info(self):
info = self.REDIS.info()
@ -124,7 +125,7 @@ class RedisDB:
def exist(self, k):
if not self.REDIS:
return
return None
try:
return self.REDIS.exists(k)
except Exception as e:
@ -133,7 +134,7 @@ class RedisDB:
def get(self, k):
if not self.REDIS:
return
return None
try:
return self.REDIS.get(k)
except Exception as e:

View file

@ -164,7 +164,7 @@ class RAGFlowS3:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
@use_prefix_path
@use_default_bucket
@ -193,7 +193,7 @@ class RAGFlowS3:
logging.exception(f"fail get url {bucket}/{fnm}")
self.__open__()
time.sleep(1)
return
return None
@use_default_bucket
def rm_bucket(self, bucket, *args, **kwargs):

View file

@ -16,14 +16,15 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from common import create_dataset
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
from configs import DATASET_NAME_LIMIT, DEFAULT_PARSER_CONFIG, INVALID_API_TOKEN
from hypothesis import example, given, settings
from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar
from utils.file_utils import create_image_file
from utils.hypothesis_utils import valid_names
from configs import DEFAULT_PARSER_CONFIG
from common import create_dataset
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@ -125,8 +126,8 @@ class TestDatasetCreate:
assert res["code"] == 0, res
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name}' already exists", res
assert res["code"] == 0, res
assert res["data"]["name"] == name + "(1)", res
@pytest.mark.p3
def test_name_case_insensitive(self, HttpApiAuth):
@ -137,8 +138,8 @@ class TestDatasetCreate:
payload = {"name": name.lower()}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
assert res["code"] == 0, res
assert res["data"]["name"] == name.lower() + "(1)", res
@pytest.mark.p2
def test_avatar(self, HttpApiAuth, tmp_path):

View file

@ -17,13 +17,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from operator import attrgetter
import pytest
from configs import DATASET_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN
from configs import DATASET_NAME_LIMIT, DEFAULT_PARSER_CONFIG, HOST_ADDRESS, INVALID_API_TOKEN
from hypothesis import example, given, settings
from ragflow_sdk import DataSet, RAGFlow
from utils import encode_avatar
from utils.file_utils import create_image_file
from utils.hypothesis_utils import valid_names
from configs import DEFAULT_PARSER_CONFIG
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@ -95,9 +95,8 @@ class TestDatasetCreate:
payload = {"name": name}
client.create_dataset(**payload)
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert str(excinfo.value) == f"Dataset name '{name}' already exists", str(excinfo.value)
dataset = client.create_dataset(**payload)
assert dataset.name == name + "(1)", str(dataset)
@pytest.mark.p3
def test_name_case_insensitive(self, client):
@ -106,9 +105,8 @@ class TestDatasetCreate:
client.create_dataset(**payload)
payload = {"name": name.lower()}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert str(excinfo.value) == f"Dataset name '{name.lower()}' already exists", str(excinfo.value)
dataset = client.create_dataset(**payload)
assert dataset.name == name.lower() + "(1)", str(dataset)
@pytest.mark.p2
def test_avatar(self, client, tmp_path):