From 0ebbb60102dd6b7ea571b0465b8ad3bf881d40cc Mon Sep 17 00:00:00 2001 From: writinwaters <93570324+writinwaters@users.noreply.github.com> Date: Mon, 1 Dec 2025 11:24:29 +0800 Subject: [PATCH 01/13] Docs: deploying a local model using Jina not supported (#11624) ### What problem does this PR solve? ### Type of change - [x] Documentation Update --- docs/guides/models/deploy_local_llm.mdx | 32 ------------------------- 1 file changed, 32 deletions(-) diff --git a/docs/guides/models/deploy_local_llm.mdx b/docs/guides/models/deploy_local_llm.mdx index 8eadfad94..dfee3fc78 100644 --- a/docs/guides/models/deploy_local_llm.mdx +++ b/docs/guides/models/deploy_local_llm.mdx @@ -314,35 +314,3 @@ To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the con 3. [Update System Model Settings](#6-update-system-model-settings) 4. [Update Chat Configuration](#7-update-chat-configuration) -## Deploy a local model using jina - -To deploy a local model, e.g., **gpt2**, using jina: - -### 1. Check firewall settings - -Ensure that your host machine's firewall allows inbound connections on port 12345. - -```bash -sudo ufw allow 12345/tcp -``` - -### 2. Install jina package - -```bash -pip install jina -``` - -### 3. Deploy a local model - -Step 1: Navigate to the **rag/svr** directory. - -```bash -cd rag/svr -``` - -Step 2: Run **jina_server.py**, specifying either the model's name or its local directory: - -```bash -python jina_server.py --model_name gpt2 -``` -> The script only supports models downloaded from Hugging Face. From 7499608a8bb09606cd9b4a10f454eaf590301810 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 1 Dec 2025 11:26:20 +0800 Subject: [PATCH 02/13] feat: add Redis username support (#11608) ### What problem does this PR solve? Support for Redis 6+ ACL authentication (username) close #11606 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Documentation Update --- conf/service_conf.yaml | 1 + docker/service_conf.yaml.template | 1 + docs/configurations.md | 9 +++++++++ rag/utils/redis_conn.py | 3 +++ 4 files changed, 14 insertions(+) diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 6b3cef80e..07a7b32a9 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -38,6 +38,7 @@ oceanbase: port: 2881 redis: db: 1 + username: '' password: 'infini_rag_flow' host: 'localhost:6379' task_executor: diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index fa85453ab..72e7a6d73 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -38,6 +38,7 @@ oceanbase: port: ${OCEANBASE_PORT:-2881} redis: db: 1 + username: '${REDIS_USERNAME:-}' password: '${REDIS_PASSWORD:-infini_rag_flow}' host: '${REDIS_HOST:-redis}:6379' user_default_llm: diff --git a/docs/configurations.md b/docs/configurations.md index 7574c6d12..f2602767c 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -89,6 +89,8 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit - `REDIS_PORT` The port used to expose the Redis service to the host machine, allowing **external** access to the Redis service running inside the Docker container. Defaults to `6379`. +- `REDIS_USERNAME` + Optional Redis ACL username when using Redis 6+ authentication. - `REDIS_PASSWORD` The password for Redis. @@ -160,6 +162,13 @@ If you cannot download the RAGFlow Docker image, try the following mirrors. - `password`: The password for MinIO. - `host`: The MinIO serving IP *and* port inside the Docker container. Defaults to `minio:9000`. +### `redis` + +- `host`: The Redis serving IP *and* port inside the Docker container. Defaults to `redis:6379`. +- `db`: The Redis database index to use. Defaults to `1`. +- `username`: Optional Redis ACL username (Redis 6+). +- `password`: The password for the specified Redis user. + ### `oauth` The OAuth configuration for signing up or signing in to RAGFlow using a third-party account. diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index a8bc43b57..b7cc15c63 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -86,6 +86,9 @@ class RedisDB: "db": int(self.config.get("db", 1)), "decode_responses": True, } + username = self.config.get("username") + if username: + conn_params["username"] = username password = self.config.get("password") if password: conn_params["password"] = password From 9a8ce9d3e2d0b7ca94176cd341694425222691db Mon Sep 17 00:00:00 2001 From: dzikus Date: Mon, 1 Dec 2025 04:26:34 +0100 Subject: [PATCH 03/13] fix: increase Quart RESPONSE_TIMEOUT and BODY_TIMEOUT for slow LLM responses (#11612) ### What problem does this PR solve? Quart framework has default RESPONSE_TIMEOUT and BODY_TIMEOUT of 60 seconds. This causes the frontend chat to hang exactly after 60 seconds when using slow LLM backends (e.g., Ollama on CPU, or remote APIs with high latency). This fix adds configurable timeout settings via environment variables with sensible defaults (600 seconds = 10 minutes) to match other timeout configurations in RAGFlow. Fixes issues with chat timeout when: - Using local Ollama on CPU (response time ~2 minutes) - Using remote LLM APIs with high latency - Processing complex RAG queries with many chunks ### Type of change - [X] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: Grzegorz Sterniczuk --- api/apps/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/api/apps/__init__.py b/api/apps/__init__.py index a6e33c13b..e034f460b 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -82,6 +82,11 @@ app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) +# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU) +# Default Quart timeouts are 60 seconds which is too short for many LLM backends +app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600)) +app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600)) + ## convince for dev and debug # app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False From 9d0309aedce5bbc508cf2e1542da844660f81ec8 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Mon, 1 Dec 2025 12:17:43 +0800 Subject: [PATCH 04/13] Fix: [MinerU] Missing output file (#11623) ### What problem does this PR solve? Add fallbacks for MinerU output path. #11613, #11620. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- deepdoc/parser/mineru_parser.py | 49 ++++++++++++++++++++++++--------- rag/llm/chat_model.py | 22 ++++++++------- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py index d4834de39..9670bdcf9 100644 --- a/deepdoc/parser/mineru_parser.py +++ b/deepdoc/parser/mineru_parser.py @@ -190,7 +190,7 @@ class MinerUParser(RAGFlowPdfParser): self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback) def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None): - OUTPUT_ZIP_PATH = os.path.join(str(output_dir), "output.zip") + output_zip_path = os.path.join(str(output_dir), "output.zip") pdf_file_path = str(input_path) @@ -230,16 +230,16 @@ class MinerUParser(RAGFlowPdfParser): response.raise_for_status() if response.headers.get("Content-Type") == "application/zip": - self.logger.info(f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...") + self.logger.info(f"[MinerU] zip file returned, saving to {output_zip_path}...") if callback: - callback(0.30, f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...") + callback(0.30, f"[MinerU] zip file returned, saving to {output_zip_path}...") - with open(OUTPUT_ZIP_PATH, "wb") as f: + with open(output_zip_path, "wb") as f: f.write(response.content) self.logger.info(f"[MinerU] Unzip to {output_path}...") - self._extract_zip_no_root(OUTPUT_ZIP_PATH, output_path, pdf_file_name + "/") + self._extract_zip_no_root(output_zip_path, output_path, pdf_file_name + "/") if callback: callback(0.40, f"[MinerU] Unzip to {output_path}...") @@ -459,13 +459,36 @@ class MinerUParser(RAGFlowPdfParser): return poss def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]: - subdir = output_dir / file_stem / method - if backend.startswith("vlm-"): - subdir = output_dir / file_stem / "vlm" - json_file = subdir / f"{file_stem}_content_list.json" + candidates = [] + seen = set() - if not json_file.exists(): - raise FileNotFoundError(f"[MinerU] Missing output file: {json_file}") + def add_candidate_path(p: Path): + if p not in seen: + seen.add(p) + candidates.append(p) + + if backend.startswith("vlm-"): + add_candidate_path(output_dir / file_stem / "vlm") + if method: + add_candidate_path(output_dir / file_stem / method) + add_candidate_path(output_dir / file_stem / "auto") + else: + if method: + add_candidate_path(output_dir / file_stem / method) + add_candidate_path(output_dir / file_stem / "vlm") + add_candidate_path(output_dir / file_stem / "auto") + + json_file = None + subdir = None + for sub in candidates: + jf = sub / f"{file_stem}_content_list.json" + if jf.exists(): + subdir = sub + json_file = jf + break + + if not json_file: + raise FileNotFoundError(f"[MinerU] Missing output file, tried: {', '.join(str(c / (file_stem + '_content_list.json')) for c in candidates)}") with open(json_file, "r", encoding="utf-8") as f: data = json.load(f) @@ -520,7 +543,7 @@ class MinerUParser(RAGFlowPdfParser): method: str = "auto", server_url: Optional[str] = None, delete_output: bool = True, - parse_method: str = "raw" + parse_method: str = "raw", ) -> tuple: import shutil @@ -570,7 +593,7 @@ class MinerUParser(RAGFlowPdfParser): self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.") if callback: callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.") - + return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs) finally: if temp_pdf and temp_pdf.exists(): diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 9fbc88348..726aecd8b 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -33,9 +33,9 @@ from openai.lib.azure import AzureOpenAI from strenum import StrEnum from zhipuai import ZhipuAI +from common.token_utils import num_tokens_from_string, total_token_count_from_response from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider from rag.nlp import is_chinese, is_english -from common.token_utils import num_tokens_from_string, total_token_count_from_response # Error message constants @@ -66,7 +66,7 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to class Base(ABC): def __init__(self, key, model_name, base_url, **kwargs): - timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) + timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name # Configure retry parameters @@ -127,7 +127,7 @@ class Base(ABC): "tool_choice", "logprobs", "top_logprobs", - "extra_headers" + "extra_headers", } gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf} @@ -1213,7 +1213,7 @@ class GoogleChat(Base): # Build GenerateContentConfig try: - from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part + from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig except ImportError as e: logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0") raise @@ -1242,14 +1242,14 @@ class GoogleChat(Base): role = "model" if item["role"] == "assistant" else item["role"] content = Content( role=role, - parts=[Part(text=item["content"])] + parts=[Part(text=item["content"])], ) contents.append(content) response = self.client.models.generate_content( model=self.model_name, contents=contents, - config=config + config=config, ) ans = response.text @@ -1299,7 +1299,7 @@ class GoogleChat(Base): # Build GenerateContentConfig try: - from google.genai.types import GenerateContentConfig, ThinkingConfig, Content, Part + from google.genai.types import Content, GenerateContentConfig, Part, ThinkingConfig except ImportError as e: logging.error(f"[GoogleChat] Failed to import google-genai: {e}. Please install: pip install google-genai>=1.41.0") raise @@ -1326,7 +1326,7 @@ class GoogleChat(Base): role = "model" if item["role"] == "assistant" else item["role"] content = Content( role=role, - parts=[Part(text=item["content"])] + parts=[Part(text=item["content"])], ) contents.append(content) @@ -1334,7 +1334,7 @@ class GoogleChat(Base): for chunk in self.client.models.generate_content_stream( model=self.model_name, contents=contents, - config=config + config=config, ): text = chunk.text ans = text @@ -1406,7 +1406,7 @@ class LiteLLMBase(ABC): ] def __init__(self, key, model_name, base_url=None, **kwargs): - self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) + self.timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600)) self.provider = kwargs.get("provider", "") self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "") self.model_name = f"{self.prefix}{model_name}" @@ -1625,6 +1625,7 @@ class LiteLLMBase(ABC): if self.provider == SupportedLiteLLMProvider.OpenRouter: if self.provider_order: + def _to_order_list(x): if x is None: return [] @@ -1633,6 +1634,7 @@ class LiteLLMBase(ABC): if isinstance(x, (list, tuple)): return [str(s).strip() for s in x if str(s).strip()] return [] + extra_body = {} provider_cfg = {} provider_order = _to_order_list(self.provider_order) From 88a28212b385317a8b9c00c600f2cbdd4af79278 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 1 Dec 2025 12:42:35 +0800 Subject: [PATCH 05/13] Fix: Table parse method issue. (#11627) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/document_app.py | 1 + api/db/services/dialog_service.py | 27 ++++++++++++++------------- rag/utils/es_conn.py | 6 +++--- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index bd2262919..4755453d4 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -706,6 +706,7 @@ async def set_meta(): except Exception as e: return server_error_response(e) + @manager.route("/upload_info", methods=["POST"]) # noqa: F821 async def upload_info(): files = await request.files diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index ae79b45a6..d5d0e1664 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -676,7 +676,11 @@ Please write the SQL, only SQL, without any other explanations or text. if kb_ids: kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" if "where" not in sql.lower(): - sql += f" WHERE {kb_filter}" + o = sql.lower().split("order by") + if len(o) > 1: + sql = o[0] + f" WHERE {kb_filter} order by " + o[1] + else: + sql += f" WHERE {kb_filter}" else: sql += f" AND {kb_filter}" @@ -684,10 +688,9 @@ Please write the SQL, only SQL, without any other explanations or text. tried_times += 1 return settings.retriever.sql_retrieval(sql, format="json"), sql - tbl, sql = get_table() - if tbl is None: - return None - if tbl.get("error") and tried_times <= 2: + try: + tbl, sql = get_table() + except Exception as e: user_prompt = """ Table name: {}; Table of database fields are as follows: @@ -701,16 +704,14 @@ Please write the SQL, only SQL, without any other explanations or text. The SQL error you provided last time is as follows: {} - Error issued by database as follows: - {} - Please correct the error and write SQL again, only SQL, without any other explanations or text. - """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"]) - tbl, sql = get_table() - logging.debug("TRY it again: {}".format(sql)) + """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e) + try: + tbl, sql = get_table() + except Exception: + return - logging.debug("GET table: {}".format(tbl)) - if tbl.get("error") or len(tbl["rows"]) == 0: + if len(tbl["rows"]) == 0: return None docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 5971950cf..cca3fc7c7 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -575,9 +575,9 @@ class ESConnection(DocStoreConnection): time.sleep(3) self._connect() continue - except Exception: - logger.exception("ESConnection.sql got exception") - break + except Exception as e: + logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}") + raise Exception(f"SQL error: {e}\n\nSQL: {sql}") logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!") return None From 6ea4248bdc6582486bca6be7072186a1fc29b2b9 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 1 Dec 2025 14:03:09 +0800 Subject: [PATCH 06/13] Feat: support parent-child in search procedure. (#11629) ### What problem does this PR solve? #7996 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/db/services/dialog_service.py | 1 + rag/nlp/search.py | 49 ++++++++++++++++++++++++++++++- rag/svr/task_executor.py | 2 +- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index d5d0e1664..24e46ad83 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -482,6 +482,7 @@ def chat(dialog, messages, stream=True, **kwargs): cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) if cks: kbinfos["chunks"] = cks + kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids) if prompt_config.get("tavily_api_key"): tav = Tavily(prompt_config["tavily_api_key"]) tav_res = tav.retrieve_chunks(" ".join(questions)) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 084f7b48f..4f64c1f8f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -17,7 +17,7 @@ import json import logging import re import math -from collections import OrderedDict +from collections import OrderedDict, defaultdict from dataclasses import dataclass from rag.prompts.generator import relevant_chunks_with_toc @@ -640,3 +640,50 @@ class Dealer: chunks.append(d) return sorted(chunks, key=lambda x:x["similarity"]*-1)[:topn] + + def retrieval_by_children(self, chunks:list[dict], tenant_ids:list[str]): + if not chunks: + return [] + idx_nms = [index_name(tid) for tid in tenant_ids] + mom_chunks = defaultdict([]) + i = 0 + while i < len(chunks): + ck = chunks[i] + if not ck.get("mom_id"): + i += 1 + continue + mom_chunks[ck["mom_id"]].append(chunks.pop(i)) + + if not mom_chunks: + return chunks + + if not chunks: + chunks = [] + + vector_size = 1024 + for id, cks in mom_chunks.items(): + chunk = self.dataStore.get(id, idx_nms, [ck["kb_id"] for ck in cks]) + d = { + "chunk_id": id, + "content_ltks": " ".join([ck["content_ltks"] for ck in cks]), + "content_with_weight": chunk["content_with_weight"], + "doc_id": chunk["doc_id"], + "docnm_kwd": chunk.get("docnm_kwd", ""), + "kb_id": chunk["kb_id"], + "important_kwd": [kwd for ck in cks for kwd in ck.get("important_kwd", [])], + "image_id": chunk.get("img_id", ""), + "similarity": np.mean([ck["similarity"] for ck in cks]), + "vector_similarity": np.mean([ck["similarity"] for ck in cks]), + "term_similarity": np.mean([ck["similarity"] for ck in cks]), + "vector": [0.0] * vector_size, + "positions": chunk.get("position_int", []), + "doc_type_kwd": chunk.get("doc_type_kwd", "") + } + for k in cks[0].keys(): + if k[-4:] == "_vec": + d["vector"] = cks[0][k] + vector_size = len(cks[0][k]) + break + chunks.append(d) + + return sorted(chunks, key=lambda x:x["similarity"]*-1) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d7cbced0c..714b886eb 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -734,7 +734,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c mom_ck["available_int"] = 0 flds = list(mom_ck.keys()) for fld in flds: - if fld not in ["id", "content_with_weight", "doc_id", "kb_id", "available_int"]: + if fld not in ["id", "content_with_weight", "doc_id", "kb_id", "available_int", "position_int"]: del mom_ck[fld] mothers.append(mom_ck) From b6c472268790b1648f56aae35af8a79510925300 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Mon, 1 Dec 2025 14:24:06 +0800 Subject: [PATCH 07/13] Refa: make RAGFlow more asynchronous (#11601) ### What problem does this PR solve? Try to make this more asynchronous. Verified in chat and agent scenarios, reducing blocking behavior. #11551, #11579. However, the impact of these changes still requires further investigation to ensure everything works as expected. ### Type of change - [x] Refactoring --- agent/canvas.py | 113 +++++--- agent/component/llm.py | 51 +++- agent/component/message.py | 50 +++- api/apps/__init__.py | 4 +- api/apps/api_app.py | 8 +- api/apps/auth/github.py | 49 +++- api/apps/auth/oauth.py | 53 +++- api/apps/auth/oidc.py | 13 +- api/apps/canvas_app.py | 25 +- api/apps/chunk_app.py | 14 +- api/apps/connector_app.py | 10 +- api/apps/conversation_app.py | 24 +- api/apps/dialog_app.py | 9 +- api/apps/document_app.py | 25 +- api/apps/file2document_app.py | 8 +- api/apps/file_app.py | 10 +- api/apps/kb_app.py | 30 +- api/apps/langfuse_app.py | 5 +- api/apps/llm_app.py | 13 +- api/apps/mcp_server_app.py | 23 +- api/apps/sdk/agents.py | 8 +- api/apps/sdk/chat.py | 8 +- api/apps/sdk/dify_retrieval.py | 6 +- api/apps/sdk/doc.py | 32 +-- api/apps/sdk/files.py | 25 +- api/apps/sdk/session.py | 46 ++-- api/apps/search_app.py | 10 +- api/apps/tenant_app.py | 5 +- api/apps/user_app.py | 81 +++--- api/db/services/file_service.py | 2 +- api/db/services/llm_service.py | 83 +++++- api/ragflow_server.py | 5 +- api/utils/api_utils.py | 42 ++- common/http_client.py | 157 +++++++++++ rag/llm/__init__.py | 3 + rag/llm/chat_model.py | 471 +++++++++++++++++++++++++++----- 36 files changed, 1162 insertions(+), 359 deletions(-) create mode 100644 common/http_client.py diff --git a/agent/canvas.py b/agent/canvas.py index 5344d70c3..c447b77b3 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +import base64 +import inspect import json import logging import re @@ -79,6 +82,7 @@ class Graph: self.dsl = json.loads(dsl) self._tenant_id = tenant_id self.task_id = task_id if task_id else get_uuid() + self._thread_pool = ThreadPoolExecutor(max_workers=5) self.load() def load(self): @@ -357,6 +361,7 @@ class Canvas(Graph): async def run(self, **kwargs): st = time.perf_counter() + self._loop = asyncio.get_running_loop() self.message_id = get_uuid() created_at = int(time.time()) self.add_user_input(kwargs.get("query")) @@ -372,7 +377,7 @@ class Canvas(Graph): for k in kwargs.keys(): if k in ["query", "user_id", "files"] and kwargs[k]: if k == "files": - self.globals[f"sys.{k}"] = FileService.get_files(kwargs[k]) + self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k]) else: self.globals[f"sys.{k}"] = kwargs[k] if not self.globals["sys.conversation_turns"] : @@ -402,31 +407,39 @@ class Canvas(Graph): yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) self.retrieval.append({"chunks": {}, "doc_aggs": {}}) - def _run_batch(f, t): + async def _run_batch(f, t): if self.is_canceled(): msg = f"Task {self.task_id} has been canceled during batch execution." logging.info(msg) raise TaskCanceledException(msg) - with ThreadPoolExecutor(max_workers=5) as executor: - thr = [] - i = f - while i < t: - cpn = self.get_component_obj(self.path[i]) - if cpn.component_name.lower() in ["begin", "userfillup"]: - thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) - i += 1 + loop = asyncio.get_running_loop() + tasks = [] + i = f + while i < t: + cpn = self.get_component_obj(self.path[i]) + task_fn = None + + if cpn.component_name.lower() in ["begin", "userfillup"]: + task_fn = partial(cpn.invoke, inputs=kwargs.get("inputs", {})) + i += 1 + else: + for _, ele in cpn.get_input_elements().items(): + if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0: + self.path.pop(i) + t -= 1 + break else: - for _, ele in cpn.get_input_elements().items(): - if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0: - self.path.pop(i) - t -= 1 - break - else: - thr.append(executor.submit(cpn.invoke, **cpn.get_input())) - i += 1 - for t in thr: - t.result() + task_fn = partial(cpn.invoke, **cpn.get_input()) + i += 1 + + if task_fn is None: + continue + + tasks.append(loop.run_in_executor(self._thread_pool, task_fn)) + + if tasks: + await asyncio.gather(*tasks) def _node_finished(cpn_obj): return decorate("node_finished",{ @@ -453,7 +466,7 @@ class Canvas(Graph): "component_type": self.get_component_type(self.path[i]), "thoughts": self.get_component_thoughts(self.path[i]) }) - _run_batch(idx, to) + await _run_batch(idx, to) to = len(self.path) # post processing of components invocation for i in range(idx, to): @@ -462,16 +475,29 @@ class Canvas(Graph): if cpn_obj.component_name.lower() == "message": if isinstance(cpn_obj.output("content"), partial): _m = "" - for m in cpn_obj.output("content")(): - if not m: - continue - if m == "": - yield decorate("message", {"content": "", "start_to_think": True}) - elif m == "": - yield decorate("message", {"content": "", "end_to_think": True}) - else: - yield decorate("message", {"content": m}) - _m += m + stream = cpn_obj.output("content")() + if inspect.isasyncgen(stream): + async for m in stream: + if not m: + continue + if m == "": + yield decorate("message", {"content": "", "start_to_think": True}) + elif m == "": + yield decorate("message", {"content": "", "end_to_think": True}) + else: + yield decorate("message", {"content": m}) + _m += m + else: + for m in stream: + if not m: + continue + if m == "": + yield decorate("message", {"content": "", "start_to_think": True}) + elif m == "": + yield decorate("message", {"content": "", "end_to_think": True}) + else: + yield decorate("message", {"content": m}) + _m += m cpn_obj.set_output("content", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m) else: @@ -621,6 +647,31 @@ class Canvas(Graph): def get_component_input_elements(self, cpnnm): return self.components[cpnnm]["obj"].get_input_elements() + async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]: + if not files: + return [] + def image_to_base64(file): + return "data:{};base64,{}".format(file["mime_type"], + base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) + loop = asyncio.get_running_loop() + tasks = [] + for file in files: + if file["mime_type"].find("image") >=0: + tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file)) + continue + tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) + return await asyncio.gather(*tasks) + + def get_files(self, files: Union[None, list[dict]]) -> list[str]: + """ + Synchronous wrapper for get_files_async, used by sync component invoke paths. + """ + loop = getattr(self, "_loop", None) + if loop and loop.is_running(): + return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result() + + return asyncio.run(self.get_files_async(files)) + def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): agent_ids = agent_id.split("-->") agent_name = self.get_component_name(agent_ids[0]) diff --git a/agent/component/llm.py b/agent/component/llm.py index 0f5317676..a29a36860 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -205,6 +205,55 @@ class LLM(ComponentBase): for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): yield delta(txt) + async def _stream_output_async(self, prompt, msg): + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + answer = "" + last_idx = 0 + endswith_think = False + + def delta(txt): + nonlocal answer, last_idx, endswith_think + delta_ans = txt[last_idx:] + answer = txt + + if delta_ans.find("") == 0: + last_idx += len("") + return "" + elif delta_ans.find("") > 0: + delta_ans = txt[last_idx:last_idx + delta_ans.find("")] + last_idx += delta_ans.find("") + return delta_ans + elif delta_ans.endswith(""): + endswith_think = True + elif endswith_think: + endswith_think = False + return "" + + last_idx = len(answer) + if answer.endswith(""): + last_idx -= len("") + return re.sub(r"(|)", "", delta_ans) + + stream_kwargs = {"images": self.imgs} if self.imgs else {} + async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs): + if self.check_if_canceled("LLM streaming"): + return + + if isinstance(ans, int): + continue + + if ans.find("**ERROR**") >= 0: + if self.get_exception_default_value(): + self.set_output("content", self.get_exception_default_value()) + yield self.get_exception_default_value() + else: + self.set_output("_ERROR", ans) + return + + yield delta(ans) + + self.set_output("content", answer) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): if self.check_if_canceled("LLM processing"): @@ -250,7 +299,7 @@ class LLM(ComponentBase): downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]): - self.set_output("content", partial(self._stream_output, prompt, msg)) + self.set_output("content", partial(self._stream_output_async, prompt, msg)) return for _ in range(self._param.max_retries+1): diff --git a/agent/component/message.py b/agent/component/message.py index ac1d2beae..28349a7c3 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +import inspect import json import os import random @@ -66,8 +68,12 @@ class Message(ComponentBase): v = "" ans = "" if isinstance(v, partial): - for t in v(): - ans += t + iter_obj = v() + if inspect.isasyncgen(iter_obj): + ans = asyncio.run(self._consume_async_gen(iter_obj)) + else: + for t in iter_obj: + ans += t elif isinstance(v, list) and delimiter: ans = delimiter.join([str(vv) for vv in v]) elif not isinstance(v, str): @@ -89,7 +95,13 @@ class Message(ComponentBase): _kwargs[_n] = v return script, _kwargs - def _stream(self, rand_cnt:str): + async def _consume_async_gen(self, agen): + buf = "" + async for t in agen: + buf += t + return buf + + async def _stream(self, rand_cnt:str): s = 0 all_content = "" cache = {} @@ -111,15 +123,27 @@ class Message(ComponentBase): v = "" if isinstance(v, partial): cnt = "" - for t in v(): - if self.check_if_canceled("Message streaming"): - return + iter_obj = v() + if inspect.isasyncgen(iter_obj): + async for t in iter_obj: + if self.check_if_canceled("Message streaming"): + return - all_content += t - cnt += t - yield t + all_content += t + cnt += t + yield t + else: + for t in iter_obj: + if self.check_if_canceled("Message streaming"): + return + + all_content += t + cnt += t + yield t self.set_input_value(exp, cnt) continue + elif inspect.isawaitable(v): + v = await v elif not isinstance(v, str): try: v = json.dumps(v, ensure_ascii=False) @@ -181,7 +205,7 @@ class Message(ComponentBase): import pypandoc doc_id = get_uuid() - + if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}: self._param.output_format = "markdown" @@ -231,11 +255,11 @@ class Message(ComponentBase): settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content) self.set_output("attachment", { - "doc_id":doc_id, - "format":self._param.output_format, + "doc_id":doc_id, + "format":self._param.output_format, "file_name":f"{doc_id[:8]}.{self._param.output_format}"}) logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})") except Exception as e: - logging.error(f"Error converting content to {self._param.output_format}: {e}") \ No newline at end of file + logging.error(f"Error converting content to {self._param.output_format}: {e}") diff --git a/api/apps/__init__.py b/api/apps/__init__.py index e034f460b..4d9c7c501 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import os import sys -import logging from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from quart import Blueprint, Quart, request, g, current_app, session -from werkzeug.wrappers.request import Request from flasgger import Swagger from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from quart_cors import cors @@ -40,7 +39,6 @@ settings.init_settings() __all__ = ["app"] -Request.json = property(lambda self: self.get_json(force=True, silent=True)) app = Quart(__name__) app = cors(app, allow_origin="*") diff --git a/api/apps/api_app.py b/api/apps/api_app.py index aa9c9fd6b..97d7dc943 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -18,8 +18,7 @@ from quart import request from api.db.db_models import APIToken from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.user_service import UserTenantService -from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ - generate_confirmation_token +from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.time_utils import current_timestamp, datetime_format from api.apps import login_required, current_user @@ -27,7 +26,7 @@ from api.apps import login_required, current_user @manager.route('/new_token', methods=['POST']) # noqa: F821 @login_required async def new_token(): - req = await request.json + req = await get_request_json() try: tenants = UserTenantService.query(user_id=current_user.id) if not tenants: @@ -73,7 +72,7 @@ def token_list(): @validate_request("tokens", "tenant_id") @login_required async def rm(): - req = await request.json + req = await get_request_json() try: for token in req["tokens"]: APITokenService.filter_delete( @@ -116,4 +115,3 @@ def stats(): return get_json_result(data=res) except Exception as e: return server_error_response(e) - diff --git a/api/apps/auth/github.py b/api/apps/auth/github.py index f48d4a5fc..918ff60db 100644 --- a/api/apps/auth/github.py +++ b/api/apps/auth/github.py @@ -14,7 +14,7 @@ # limitations under the License. # -import requests +from common.http_client import async_request, sync_request from .oauth import OAuthClient, UserInfo @@ -34,24 +34,49 @@ class GithubOAuthClient(OAuthClient): def fetch_user_info(self, access_token, **kwargs): """ - Fetch GitHub user info. + Fetch GitHub user info (synchronous). """ user_info = {} try: headers = {"Authorization": f"Bearer {access_token}"} - # user info - response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout) + response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout) response.raise_for_status() user_info.update(response.json()) - # email info - response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout) - response.raise_for_status() - email_info = response.json() - user_info["email"] = next( - (email for email in email_info if email["primary"]), None - )["email"] + email_response = sync_request( + "GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout + ) + email_response.raise_for_status() + email_info = email_response.json() + user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] return self.normalize_user_info(user_info) - except requests.exceptions.RequestException as e: + except Exception as e: + raise ValueError(f"Failed to fetch github user info: {e}") + + async def async_fetch_user_info(self, access_token, **kwargs): + """Async variant of fetch_user_info using httpx.""" + user_info = {} + headers = {"Authorization": f"Bearer {access_token}"} + try: + response = await async_request( + "GET", + self.userinfo_url, + headers=headers, + timeout=self.http_request_timeout, + ) + response.raise_for_status() + user_info.update(response.json()) + + email_response = await async_request( + "GET", + self.userinfo_url + "/emails", + headers=headers, + timeout=self.http_request_timeout, + ) + email_response.raise_for_status() + email_info = email_response.json() + user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] + return self.normalize_user_info(user_info) + except Exception as e: raise ValueError(f"Failed to fetch github user info: {e}") diff --git a/api/apps/auth/oauth.py b/api/apps/auth/oauth.py index 6f7e0e5b5..5b2afcea1 100644 --- a/api/apps/auth/oauth.py +++ b/api/apps/auth/oauth.py @@ -14,8 +14,8 @@ # limitations under the License. # -import requests import urllib.parse +from common.http_client import async_request, sync_request class UserInfo: @@ -74,15 +74,40 @@ class OAuthClient: "redirect_uri": self.redirect_uri, "grant_type": "authorization_code" } - response = requests.post( + response = sync_request( + "POST", self.token_url, data=payload, headers={"Accept": "application/json"}, - timeout=self.http_request_timeout + timeout=self.http_request_timeout, ) response.raise_for_status() return response.json() - except requests.exceptions.RequestException as e: + except Exception as e: + raise ValueError(f"Failed to exchange authorization code for token: {e}") + + async def async_exchange_code_for_token(self, code): + """ + Async variant of exchange_code_for_token using httpx. + """ + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, + "grant_type": "authorization_code", + } + try: + response = await async_request( + "POST", + self.token_url, + data=payload, + headers={"Accept": "application/json"}, + timeout=self.http_request_timeout, + ) + response.raise_for_status() + return response.json() + except Exception as e: raise ValueError(f"Failed to exchange authorization code for token: {e}") @@ -92,11 +117,27 @@ class OAuthClient: """ try: headers = {"Authorization": f"Bearer {access_token}"} - response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout) + response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout) response.raise_for_status() user_info = response.json() return self.normalize_user_info(user_info) - except requests.exceptions.RequestException as e: + except Exception as e: + raise ValueError(f"Failed to fetch user info: {e}") + + async def async_fetch_user_info(self, access_token, **kwargs): + """Async variant of fetch_user_info using httpx.""" + headers = {"Authorization": f"Bearer {access_token}"} + try: + response = await async_request( + "GET", + self.userinfo_url, + headers=headers, + timeout=self.http_request_timeout, + ) + response.raise_for_status() + user_info = response.json() + return self.normalize_user_info(user_info) + except Exception as e: raise ValueError(f"Failed to fetch user info: {e}") diff --git a/api/apps/auth/oidc.py b/api/apps/auth/oidc.py index cafcaadfd..80ac79399 100644 --- a/api/apps/auth/oidc.py +++ b/api/apps/auth/oidc.py @@ -15,7 +15,7 @@ # import jwt -import requests +from common.http_client import sync_request from .oauth import OAuthClient @@ -50,10 +50,10 @@ class OIDCClient(OAuthClient): """ try: metadata_url = f"{issuer}/.well-known/openid-configuration" - response = requests.get(metadata_url, timeout=7) + response = sync_request("GET", metadata_url, timeout=7) response.raise_for_status() return response.json() - except requests.exceptions.RequestException as e: + except Exception as e: raise ValueError(f"Failed to fetch OIDC metadata: {e}") @@ -95,6 +95,13 @@ class OIDCClient(OAuthClient): user_info.update(super().fetch_user_info(access_token).to_dict()) return self.normalize_user_info(user_info) + async def async_fetch_user_info(self, access_token, id_token=None, **kwargs): + user_info = {} + if id_token: + user_info = self.parse_id_token(id_token) + user_info.update((await super().async_fetch_user_info(access_token)).to_dict()) + return self.normalize_user_info(user_info) + def normalize_user_info(self, user_info): return super().normalize_user_info(user_info) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index afdb3269b..fe32dca0b 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging from functools import partial @@ -29,7 +30,7 @@ from api.db.services.user_canvas_version import UserCanvasVersionService from common.constants import RetCode from common.misc_utils import get_uuid from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \ - request_json + get_request_json from agent.canvas import Canvas from peewee import MySQLDatabase, PostgresqlDatabase from api.db.db_models import APIToken, Task @@ -52,7 +53,7 @@ def templates(): @validate_request("canvas_ids") @login_required async def rm(): - req = await request_json() + req = await get_request_json() for i in req["canvas_ids"]: if not UserCanvasService.accessible(i, current_user.id): return get_json_result( @@ -66,7 +67,7 @@ async def rm(): @validate_request("dsl", "title") @login_required async def save(): - req = await request_json() + req = await get_request_json() if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.loads(req["dsl"]) @@ -125,17 +126,17 @@ def getsse(canvas_id): @validate_request("id") @login_required async def run(): - req = await request_json() + req = await get_request_json() query = req.get("query", "") files = req.get("files", []) inputs = req.get("inputs", {}) user_id = req.get("user_id", current_user.id) - if not UserCanvasService.accessible(req["id"], current_user.id): + if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) - e, cvs = UserCanvasService.get_by_id(req["id"]) + e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"]) if not e: return get_data_error_result(message="canvas not found.") @@ -145,7 +146,7 @@ async def run(): if cvs.canvas_category == CanvasCategory.DataFlow: task_id = get_uuid() Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) - ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0) + ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, files[0], 0) if not ok: return get_data_error_result(message=error_message) return get_json_result(data={"message_id": task_id}) @@ -182,7 +183,7 @@ async def run(): @validate_request("id", "dsl", "component_id") @login_required async def rerun(): - req = await request_json() + req = await get_request_json() doc = PipelineOperationLogService.get_documents_info(req["id"]) if not doc: return get_data_error_result(message="Document not found.") @@ -220,7 +221,7 @@ def cancel(task_id): @validate_request("id") @login_required async def reset(): - req = await request_json() + req = await get_request_json() if not UserCanvasService.accessible(req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', @@ -278,7 +279,7 @@ def input_form(): @validate_request("id", "component_id", "params") @login_required async def debug(): - req = await request_json() + req = await get_request_json() if not UserCanvasService.accessible(req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', @@ -310,7 +311,7 @@ async def debug(): @validate_request("db_type", "database", "username", "host", "port", "password") @login_required async def test_db_connect(): - req = await request_json() + req = await get_request_json() try: if req["db_type"] in ["mysql", "mariadb"]: db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], @@ -455,7 +456,7 @@ def list_canvas(): @validate_request("id", "title", "permission") @login_required async def setting(): - req = await request_json() + req = await get_request_json() req["user_id"] = current_user.id if not UserCanvasService.accessible(req["id"], current_user.id): diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index b43fb9af1..d5d928342 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -27,7 +27,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ - request_json + get_request_json from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search @@ -42,7 +42,7 @@ from api.apps import login_required, current_user @login_required @validate_request("doc_id") async def list_chunk(): - req = await request_json() + req = await get_request_json() doc_id = req["doc_id"] page = int(req.get("page", 1)) size = int(req.get("size", 30)) @@ -123,7 +123,7 @@ def get(): @login_required @validate_request("doc_id", "chunk_id", "content_with_weight") async def set(): - req = await request_json() + req = await get_request_json() d = { "id": req["chunk_id"], "content_with_weight": req["content_with_weight"]} @@ -180,7 +180,7 @@ async def set(): @login_required @validate_request("chunk_ids", "available_int", "doc_id") async def switch(): - req = await request_json() + req = await get_request_json() try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -200,7 +200,7 @@ async def switch(): @login_required @validate_request("chunk_ids", "doc_id") async def rm(): - req = await request_json() + req = await get_request_json() try: e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -224,7 +224,7 @@ async def rm(): @login_required @validate_request("doc_id", "content_with_weight") async def create(): - req = await request_json() + req = await get_request_json() chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]} @@ -282,7 +282,7 @@ async def create(): @login_required @validate_request("kb_id", "question") async def retrieval_test(): - req = await request_json() + req = await get_request_json() page = int(req.get("page", 1)) size = int(req.get("size", 30)) question = req["question"] diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index 34da2293b..49d8005a6 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -26,7 +26,7 @@ from google_auth_oauthlib.flow import Flow from api.db import InputType from api.db.services.connector_service import ConnectorService, SyncLogsService -from api.utils.api_utils import get_data_error_result, get_json_result, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request from common.constants import RetCode, TaskStatus from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, DocumentSource from common.data_source.google_util.constant import GOOGLE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES @@ -38,7 +38,7 @@ from api.apps import login_required, current_user @manager.route("/set", methods=["POST"]) # noqa: F821 @login_required async def set_connector(): - req = await request.json + req = await get_request_json() if req.get("id"): conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} ConnectorService.update_by_id(req["id"], conn) @@ -90,7 +90,7 @@ def list_logs(connector_id): @manager.route("//resume", methods=["PUT"]) # noqa: F821 @login_required async def resume(connector_id): - req = await request.json + req = await get_request_json() if req.get("resume"): ConnectorService.resume(connector_id, TaskStatus.SCHEDULE) else: @@ -102,7 +102,7 @@ async def resume(connector_id): @login_required @validate_request("kb_id") async def rebuild(connector_id): - req = await request.json + req = await get_request_json() err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id) if err: return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR) @@ -211,7 +211,7 @@ async def start_google_web_oauth(): message="Google OAuth redirect URI is not configured on the server.", ) - req = await request.json or {} + req = await get_request_json() raw_credentials = req.get("credentials", "") try: diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 77b799016..a2ac131f3 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -26,7 +26,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from rag.prompts.template import load_prompt from rag.prompts.generator import chunks_format from common.constants import RetCode, LLMType @@ -35,7 +35,7 @@ from common.constants import RetCode, LLMType @manager.route("/set", methods=["POST"]) # noqa: F821 @login_required async def set_conversation(): - req = await request.json + req = await get_request_json() conv_id = req.get("conversation_id") is_new = req.get("is_new") name = req.get("name", "New conversation") @@ -78,7 +78,7 @@ async def set_conversation(): @manager.route("/get", methods=["GET"]) # noqa: F821 @login_required -def get(): +async def get(): conv_id = request.args["conversation_id"] try: e, conv = ConversationService.get_by_id(conv_id) @@ -129,7 +129,7 @@ def getsse(dialog_id): @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required async def rm(): - req = await request.json + req = await get_request_json() conv_ids = req["conversation_ids"] try: for cid in conv_ids: @@ -150,7 +150,7 @@ async def rm(): @manager.route("/list", methods=["GET"]) # noqa: F821 @login_required -def list_conversation(): +async def list_conversation(): dialog_id = request.args["dialog_id"] try: if not DialogService.query(tenant_id=current_user.id, id=dialog_id): @@ -167,7 +167,7 @@ def list_conversation(): @login_required @validate_request("conversation_id", "messages") async def completion(): - req = await request.json + req = await get_request_json() msg = [] for m in req["messages"]: if m["role"] == "system": @@ -252,7 +252,7 @@ async def completion(): @manager.route("/tts", methods=["POST"]) # noqa: F821 @login_required async def tts(): - req = await request.json + req = await get_request_json() text = req["text"] tenants = TenantService.get_info_by(current_user.id) @@ -285,7 +285,7 @@ async def tts(): @login_required @validate_request("conversation_id", "message_id") async def delete_msg(): - req = await request.json + req = await get_request_json() e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(message="Conversation not found!") @@ -308,7 +308,7 @@ async def delete_msg(): @login_required @validate_request("conversation_id", "message_id") async def thumbup(): - req = await request.json + req = await get_request_json() e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(message="Conversation not found!") @@ -335,7 +335,7 @@ async def thumbup(): @login_required @validate_request("question", "kb_ids") async def ask_about(): - req = await request.json + req = await get_request_json() uid = current_user.id search_id = req.get("search_id", "") @@ -367,7 +367,7 @@ async def ask_about(): @login_required @validate_request("question", "kb_ids") async def mindmap(): - req = await request.json + req = await get_request_json() search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} search_config = search_app.get("search_config", {}) if search_app else {} @@ -385,7 +385,7 @@ async def mindmap(): @login_required @validate_request("question") async def related_questions(): - req = await request.json + req = await get_request_json() search_id = req.get("search_id", "") search_config = {} diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index cbefc7752..0f5aebe0b 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -21,10 +21,9 @@ from common.constants import StatusEnum from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.misc_utils import get_uuid from common.constants import RetCode -from api.utils.api_utils import get_json_result from api.apps import login_required, current_user @@ -32,7 +31,7 @@ from api.apps import login_required, current_user @validate_request("prompt_config") @login_required async def set_dialog(): - req = await request.json + req = await get_request_json() dialog_id = req.get("dialog_id", "") is_create = not dialog_id name = req.get("name", "New Dialog") @@ -181,7 +180,7 @@ async def list_dialogs_next(): else: desc = True - req = await request.get_json() + req = await get_request_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -209,7 +208,7 @@ async def list_dialogs_next(): @login_required @validate_request("dialog_ids") async def rm(): - req = await request.json + req = await get_request_json() dialog_list=[] tenants = UserTenantService.query(user_id=current_user.id) try: diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 4755453d4..a56f11317 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -36,7 +36,7 @@ from api.utils.api_utils import ( get_data_error_result, get_json_result, server_error_response, - validate_request, request_json, + validate_request, get_request_json, ) from api.utils.file_utils import filename_type, thumbnail from common.file_utils import get_project_base_directory @@ -153,7 +153,7 @@ async def web_crawl(): @login_required @validate_request("name", "kb_id") async def create(): - req = await request_json() + req = await get_request_json() kb_id = req["kb_id"] if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -230,7 +230,7 @@ async def list_docs(): create_time_from = int(request.args.get("create_time_from", 0)) create_time_to = int(request.args.get("create_time_to", 0)) - req = await request.get_json() + req = await get_request_json() run_status = req.get("run_status", []) if run_status: @@ -271,7 +271,7 @@ async def list_docs(): @manager.route("/filter", methods=["POST"]) # noqa: F821 @login_required async def get_filter(): - req = await request.get_json() + req = await get_request_json() kb_id = req.get("kb_id") if not kb_id: @@ -309,7 +309,7 @@ async def get_filter(): @manager.route("/infos", methods=["POST"]) # noqa: F821 @login_required async def doc_infos(): - req = await request_json() + req = await get_request_json() doc_ids = req["doc_ids"] for doc_id in doc_ids: if not DocumentService.accessible(doc_id, current_user.id): @@ -341,7 +341,7 @@ def thumbnails(): @login_required @validate_request("doc_ids", "status") async def change_status(): - req = await request.get_json() + req = await get_request_json() doc_ids = req.get("doc_ids", []) status = str(req.get("status", "")) @@ -381,7 +381,7 @@ async def change_status(): @login_required @validate_request("doc_id") async def rm(): - req = await request_json() + req = await get_request_json() doc_ids = req["doc_id"] if isinstance(doc_ids, str): doc_ids = [doc_ids] @@ -402,7 +402,7 @@ async def rm(): @login_required @validate_request("doc_ids", "run") async def run(): - req = await request_json() + req = await get_request_json() for doc_id in req["doc_ids"]: if not DocumentService.accessible(doc_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) @@ -449,7 +449,7 @@ async def run(): @login_required @validate_request("doc_id", "name") async def rename(): - req = await request_json() + req = await get_request_json() if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: @@ -539,7 +539,7 @@ async def download_attachment(attachment_id): @validate_request("doc_id") async def change_parser(): - req = await request_json() + req = await get_request_json() if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) @@ -624,7 +624,8 @@ async def upload_and_parse(): @manager.route("/parse", methods=["POST"]) # noqa: F821 @login_required async def parse(): - url = await request.json.get("url") if await request.json else "" + req = await get_request_json() + url = req.get("url", "") if url: if not is_valid_url(url): return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) @@ -679,7 +680,7 @@ async def parse(): @login_required @validate_request("doc_id", "meta") async def set_meta(): - req = await request_json() + req = await get_request_json() if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: diff --git a/api/apps/file2document_app.py b/api/apps/file2document_app.py index 1f8921e92..54c314e74 100644 --- a/api/apps/file2document_app.py +++ b/api/apps/file2document_app.py @@ -19,22 +19,20 @@ from pathlib import Path from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService -from quart import request from api.apps import login_required, current_user from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.misc_utils import get_uuid from common.constants import RetCode from api.db import FileType from api.db.services.document_service import DocumentService -from api.utils.api_utils import get_json_result @manager.route('/convert', methods=['POST']) # noqa: F821 @login_required @validate_request("file_ids", "kb_ids") async def convert(): - req = await request.json + req = await get_request_json() kb_ids = req["kb_ids"] file_ids = req["file_ids"] file2documents = [] @@ -104,7 +102,7 @@ async def convert(): @login_required @validate_request("file_ids") async def rm(): - req = await request.json + req = await get_request_json() file_ids = req["file_ids"] if not file_ids: return get_json_result( diff --git a/api/apps/file_app.py b/api/apps/file_app.py index e262b3d7b..bbb5b3ddb 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -29,7 +29,7 @@ from common.constants import RetCode, FileSource from api.db import FileType from api.db.services import duplicate_name from api.db.services.file_service import FileService -from api.utils.api_utils import get_json_result +from api.utils.api_utils import get_json_result, get_request_json from api.utils.file_utils import filename_type from api.utils.web_utils import CONTENT_TYPE_MAP from common import settings @@ -124,7 +124,7 @@ async def upload(): @login_required @validate_request("name") async def create(): - req = await request.json + req = await get_request_json() pf_id = req.get("parent_id") input_file_type = req.get("type") if not pf_id: @@ -239,7 +239,7 @@ def get_all_parent_folders(): @login_required @validate_request("file_ids") async def rm(): - req = await request.json + req = await get_request_json() file_ids = req["file_ids"] def _delete_single_file(file): @@ -300,7 +300,7 @@ async def rm(): @login_required @validate_request("file_id", "name") async def rename(): - req = await request.json + req = await get_request_json() try: e, file = FileService.get_by_id(req["file_id"]) if not e: @@ -369,7 +369,7 @@ async def get(file_id): @login_required @validate_request("src_file_ids", "dest_file_id") async def move(): - req = await request.json + req = await get_request_json() try: file_ids = req["src_file_ids"] dest_parent_id = req["dest_file_id"] diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 4e8015d7f..7ff01cc19 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -30,7 +30,7 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogS from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \ - request_json + get_request_json from api.db import VALID_FILE_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.db_models import File @@ -48,7 +48,7 @@ from api.apps import login_required, current_user @login_required @validate_request("name") async def create(): - req = await request_json() + req = await get_request_json() e, res = KnowledgebaseService.create_with_name( name = req.pop("name", None), tenant_id = current_user.id, @@ -72,7 +72,7 @@ async def create(): @validate_request("kb_id", "name", "description", "parser_id") @not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") async def update(): - req = await request_json() + req = await get_request_json() if not isinstance(req["name"], str): return get_data_error_result(message="Dataset name must be string.") if req["name"].strip() == "": @@ -182,7 +182,7 @@ async def list_kbs(): else: desc = True - req = await request_json() + req = await get_request_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -209,7 +209,7 @@ async def list_kbs(): @login_required @validate_request("kb_id") async def rm(): - req = await request_json() + req = await get_request_json() if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): return get_json_result( data=False, @@ -286,7 +286,7 @@ def list_tags_from_kbs(): @manager.route('//rm_tags', methods=['POST']) # noqa: F821 @login_required async def rm_tags(kb_id): - req = await request_json() + req = await get_request_json() if not KnowledgebaseService.accessible(kb_id, current_user.id): return get_json_result( data=False, @@ -306,7 +306,7 @@ async def rm_tags(kb_id): @manager.route('//rename_tag', methods=['POST']) # noqa: F821 @login_required async def rename_tags(kb_id): - req = await request_json() + req = await get_request_json() if not KnowledgebaseService.accessible(kb_id, current_user.id): return get_json_result( data=False, @@ -428,7 +428,7 @@ async def list_pipeline_logs(): if create_date_to > create_date_from: return get_data_error_result(message="Create data filter is abnormal.") - req = await request_json() + req = await get_request_json() operation_status = req.get("operation_status", []) if operation_status: @@ -470,7 +470,7 @@ async def list_pipeline_dataset_logs(): if create_date_to > create_date_from: return get_data_error_result(message="Create data filter is abnormal.") - req = await request_json() + req = await get_request_json() operation_status = req.get("operation_status", []) if operation_status: @@ -492,7 +492,7 @@ async def delete_pipeline_logs(): if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - req = await request_json() + req = await get_request_json() log_ids = req.get("log_ids", []) PipelineOperationLogService.delete_by_ids(log_ids) @@ -517,7 +517,7 @@ def pipeline_log_detail(): @manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 @login_required async def run_graphrag(): - req = await request_json() + req = await get_request_json() kb_id = req.get("kb_id", "") if not kb_id: @@ -586,7 +586,7 @@ def trace_graphrag(): @manager.route("/run_raptor", methods=["POST"]) # noqa: F821 @login_required async def run_raptor(): - req = await request_json() + req = await get_request_json() kb_id = req.get("kb_id", "") if not kb_id: @@ -655,7 +655,7 @@ def trace_raptor(): @manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 @login_required async def run_mindmap(): - req = await request_json() + req = await get_request_json() kb_id = req.get("kb_id", "") if not kb_id: @@ -857,11 +857,11 @@ async def check_embedding(): "question_kwd": full_doc.get("question_kwd") or [] }) return out - + def _clean(s: str) -> str: s = re.sub(r"]{0,12})?>", " ", s or "") return s if s else "None" - req = await request_json() + req = await get_request_json() kb_id = req.get("kb_id", "") embd_id = req.get("embd_id", "") n = int(req.get("check_num", 5)) diff --git a/api/apps/langfuse_app.py b/api/apps/langfuse_app.py index ffdc6a5fd..8a05c0d4c 100644 --- a/api/apps/langfuse_app.py +++ b/api/apps/langfuse_app.py @@ -15,20 +15,19 @@ # -from quart import request from api.apps import current_user, login_required from langfuse import Langfuse from api.db.db_models import DB from api.db.services.langfuse_service import TenantLangfuseService -from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request +from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request @manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821 @login_required @validate_request("secret_key", "public_key", "host") async def set_api_key(): - req = await request.get_json() + req = await get_request_json() secret_key = req.get("secret_key", "") public_key = req.get("public_key", "") host = req.get("host", "") diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 29da88c4f..018fb4bca 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -21,10 +21,9 @@ from quart import request from api.apps import login_required, current_user from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.llm_service import LLMService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.constants import StatusEnum, LLMType from api.db.db_models import TenantLLM -from api.utils.api_utils import get_json_result, get_allowed_llm_factories from rag.utils.base64_image import test_image from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel @@ -54,7 +53,7 @@ def factories(): @login_required @validate_request("llm_factory", "api_key") async def set_api_key(): - req = await request.json + req = await get_request_json() # test if api key works chat_passed, embd_passed, rerank_passed = False, False, False factory = req["llm_factory"] @@ -124,7 +123,7 @@ async def set_api_key(): @login_required @validate_request("llm_factory") async def add_llm(): - req = await request.json + req = await get_request_json() factory = req["llm_factory"] api_key = req.get("api_key", "x") llm_name = req.get("llm_name") @@ -269,7 +268,7 @@ async def add_llm(): @login_required @validate_request("llm_factory", "llm_name") async def delete_llm(): - req = await request.json + req = await get_request_json() TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) return get_json_result(data=True) @@ -278,7 +277,7 @@ async def delete_llm(): @login_required @validate_request("llm_factory", "llm_name") async def enable_llm(): - req = await request.json + req = await get_request_json() TenantLLMService.filter_update( [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))} ) @@ -289,7 +288,7 @@ async def enable_llm(): @login_required @validate_request("llm_factory") async def delete_factory(): - req = await request.json + req = await get_request_json() TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) return get_json_result(data=True) diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 583f721c4..863aac963 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -22,8 +22,7 @@ from api.db.services.user_service import TenantService from common.constants import RetCode, VALID_MCP_SERVER_TYPES from common.misc_utils import get_uuid -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ - get_mcp_tools +from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request from api.utils.web_utils import get_float, safe_json_parse from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions @@ -40,7 +39,7 @@ async def list_mcp() -> Response: else: desc = True - req = await request.get_json() + req = await get_request_json() mcp_ids = req.get("mcp_ids", []) try: servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] @@ -73,7 +72,7 @@ def detail() -> Response: @login_required @validate_request("name", "url", "server_type") async def create() -> Response: - req = await request.get_json() + req = await get_request_json() server_type = req.get("server_type", "") if server_type not in VALID_MCP_SERVER_TYPES: @@ -128,7 +127,7 @@ async def create() -> Response: @login_required @validate_request("mcp_id") async def update() -> Response: - req = await request.get_json() + req = await get_request_json() mcp_id = req.get("mcp_id", "") e, mcp_server = MCPServerService.get_by_id(mcp_id) @@ -184,7 +183,7 @@ async def update() -> Response: @login_required @validate_request("mcp_ids") async def rm() -> Response: - req = await request.get_json() + req = await get_request_json() mcp_ids = req.get("mcp_ids", []) try: @@ -202,7 +201,7 @@ async def rm() -> Response: @login_required @validate_request("mcpServers") async def import_multiple() -> Response: - req = await request.get_json() + req = await get_request_json() servers = req.get("mcpServers", {}) if not servers: return get_data_error_result(message="No MCP servers provided.") @@ -269,7 +268,7 @@ async def import_multiple() -> Response: @login_required @validate_request("mcp_ids") async def export_multiple() -> Response: - req = await request.get_json() + req = await get_request_json() mcp_ids = req.get("mcp_ids", []) if not mcp_ids: @@ -301,7 +300,7 @@ async def export_multiple() -> Response: @login_required @validate_request("mcp_ids") async def list_tools() -> Response: - req = await request.get_json() + req = await get_request_json() mcp_ids = req.get("mcp_ids", []) if not mcp_ids: return get_data_error_result(message="No MCP server IDs provided.") @@ -348,7 +347,7 @@ async def list_tools() -> Response: @login_required @validate_request("mcp_id", "tool_name", "arguments") async def test_tool() -> Response: - req = await request.get_json() + req = await get_request_json() mcp_id = req.get("mcp_id", "") if not mcp_id: return get_data_error_result(message="No MCP server ID provided.") @@ -381,7 +380,7 @@ async def test_tool() -> Response: @login_required @validate_request("mcp_id", "tools") async def cache_tool() -> Response: - req = await request.get_json() + req = await get_request_json() mcp_id = req.get("mcp_id", "") if not mcp_id: return get_data_error_result(message="No MCP server ID provided.") @@ -404,7 +403,7 @@ async def cache_tool() -> Response: @manager.route("/test_mcp", methods=["POST"]) # noqa: F821 @validate_request("url", "server_type") async def test_mcp() -> Response: - req = await request.get_json() + req = await get_request_json() url = req.get("url", "") if not url: diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py index b20a22ad8..20e897388 100644 --- a/api/apps/sdk/agents.py +++ b/api/apps/sdk/agents.py @@ -25,7 +25,7 @@ from api.db.services.canvas_service import UserCanvasService from api.db.services.user_canvas_version import UserCanvasVersionService from common.constants import RetCode from common.misc_utils import get_uuid -from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required +from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required from api.utils.api_utils import get_result from quart import request, Response @@ -53,7 +53,7 @@ def list_agents(tenant_id): @manager.route("/agents", methods=["POST"]) # noqa: F821 @token_required async def create_agent(tenant_id: str): - req: dict[str, Any] = cast(dict[str, Any], await request.json) + req: dict[str, Any] = cast(dict[str, Any], await get_request_json()) req["user_id"] = tenant_id if req.get("dsl") is not None: @@ -90,7 +90,7 @@ async def create_agent(tenant_id: str): @manager.route("/agents/", methods=["PUT"]) # noqa: F821 @token_required async def update_agent(tenant_id: str, agent_id: str): - req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await request.json)).items() if v is not None} + req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None} req["user_id"] = tenant_id if req.get("dsl") is not None: @@ -136,7 +136,7 @@ def delete_agent(tenant_id: str, agent_id: str): @manager.route('/webhook/', methods=['POST']) # noqa: F821 @token_required async def webhook(tenant_id: str, agent_id: str): - req = await request.json + req = await get_request_json() if not UserCanvasService.accessible(req["id"], tenant_id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 0abf7374d..8c9619555 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -21,13 +21,13 @@ from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService from common.misc_utils import get_uuid from common.constants import RetCode, StatusEnum -from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, request_json +from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json @manager.route("/chats", methods=["POST"]) # noqa: F821 @token_required async def create(tenant_id): - req = await request_json() + req = await get_request_json() ids = [i for i in req.get("dataset_ids", []) if i] for kb_id in ids: kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) @@ -146,7 +146,7 @@ async def create(tenant_id): async def update(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(message="You do not own the chat") - req = await request_json() + req = await get_request_json() ids = req.get("dataset_ids", []) if "show_quotation" in req: req["do_refer"] = req.pop("show_quotation") @@ -229,7 +229,7 @@ async def update(tenant_id, chat_id): async def delete_chats(tenant_id): errors = [] success_count = 0 - req = await request_json() + req = await get_request_json() if not req: ids = None else: diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 55ea54faf..9665754eb 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -15,12 +15,12 @@ # import logging -from quart import request, jsonify +from quart import jsonify from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.utils.api_utils import validate_request, build_error_result, apikey_required +from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request from rag.app.tag import label_question from api.db.services.dialog_service import meta_filter, convert_conditions from common.constants import RetCode, LLMType @@ -113,7 +113,7 @@ async def retrieval(tenant_id): 404: description: Knowledge base or document not found """ - req = await request.json + req = await get_request_json() question = req["query"] kb_id = req["knowledge_id"] use_kg = req.get("use_kg", False) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index aebf925cc..0a007f148 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -36,7 +36,7 @@ from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.task_service import TaskService, queue_tasks from api.db.services.dialog_service import meta_filter, convert_conditions from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \ - request_json + get_request_json from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search @@ -231,7 +231,7 @@ async def update_doc(tenant_id, dataset_id, document_id): schema: type: object """ - req = await request_json() + req = await get_request_json() if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(message="You don't own the dataset.") e, kb = KnowledgebaseService.get_by_id(dataset_id) @@ -536,7 +536,7 @@ def list_docs(dataset_id, tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") q = request.args - document_id = q.get("id") + document_id = q.get("id") name = q.get("name") if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id): @@ -545,16 +545,16 @@ def list_docs(dataset_id, tenant_id): return get_error_data_result(message=f"You don't own the document {name}.") page = int(q.get("page", 1)) - page_size = int(q.get("page_size", 30)) + page_size = int(q.get("page_size", 30)) orderby = q.get("orderby", "create_time") desc = str(q.get("desc", "true")).strip().lower() != "false" keywords = q.get("keywords", "") # filters - align with OpenAPI parameter names - suffix = q.getlist("suffix") - run_status = q.getlist("run") - create_time_from = int(q.get("create_time_from", 0)) - create_time_to = int(q.get("create_time_to", 0)) + suffix = q.getlist("suffix") + run_status = q.getlist("run") + create_time_from = int(q.get("create_time_from", 0)) + create_time_to = int(q.get("create_time_to", 0)) # map run status (accept text or numeric) - align with API parameter run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"} @@ -575,7 +575,7 @@ def list_docs(dataset_id, tenant_id): # rename keys + map run status back to text for output key_mapping = { "chunk_num": "chunk_count", - "kb_id": "dataset_id", + "kb_id": "dataset_id", "token_num": "token_count", "parser_id": "chunk_method", } @@ -631,7 +631,7 @@ async def delete(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") - req = await request_json() + req = await get_request_json() if not req: doc_ids = None else: @@ -741,7 +741,7 @@ async def parse(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = await request_json() + req = await get_request_json() if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") doc_list = req.get("document_ids") @@ -824,7 +824,7 @@ async def stop_parsing(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = await request_json() + req = await get_request_json() if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") @@ -1096,7 +1096,7 @@ async def add_chunk(tenant_id, dataset_id, document_id): if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] - req = await request_json() + req = await get_request_json() if not str(req.get("content", "")).strip(): return get_error_data_result(message="`content` is required") if "important_keywords" in req: @@ -1202,7 +1202,7 @@ async def rm_chunk(tenant_id, dataset_id, document_id): docs = DocumentService.get_by_ids([document_id]) if not docs: raise LookupError(f"Can't find the document with ID {document_id}!") - req = await request_json() + req = await get_request_json() condition = {"doc_id": document_id} if "chunk_ids" in req: unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") @@ -1288,7 +1288,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] - req = await request_json() + req = await get_request_json() if "content" in req and req["content"] is not None: content = req["content"] else: @@ -1411,7 +1411,7 @@ async def retrieval_test(tenant_id): format: float description: Similarity score. """ - req = await request_json() + req = await get_request_json() if not req.get("dataset_ids"): return get_error_data_result("`dataset_ids` is required.") kb_ids = req["dataset_ids"] diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 6377ea7c8..fde3befa8 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -23,12 +23,11 @@ from pathlib import Path from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import server_error_response, token_required +from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required from common.misc_utils import get_uuid from api.db import FileType from api.db.services import duplicate_name from api.db.services.file_service import FileService -from api.utils.api_utils import get_json_result from api.utils.file_utils import filename_type from common import settings from common.constants import RetCode @@ -193,9 +192,9 @@ async def create(tenant_id): type: type: string """ - req = await request.json - pf_id = await request.json.get("parent_id") - input_file_type = await request.json.get("type") + req = await get_request_json() + pf_id = req.get("parent_id") + input_file_type = req.get("type") if not pf_id: root_folder = FileService.get_root_folder(tenant_id) pf_id = root_folder["id"] @@ -229,7 +228,7 @@ async def create(tenant_id): @manager.route('/file/list', methods=['GET']) # noqa: F821 @token_required -def list_files(tenant_id): +async def list_files(tenant_id): """ List files under a specific folder. --- @@ -321,7 +320,7 @@ def list_files(tenant_id): @manager.route('/file/root_folder', methods=['GET']) # noqa: F821 @token_required -def get_root_folder(tenant_id): +async def get_root_folder(tenant_id): """ Get user's root folder. --- @@ -357,7 +356,7 @@ def get_root_folder(tenant_id): @manager.route('/file/parent_folder', methods=['GET']) # noqa: F821 @token_required -def get_parent_folder(): +async def get_parent_folder(): """ Get parent folder info of a file. --- @@ -402,7 +401,7 @@ def get_parent_folder(): @manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821 @token_required -def get_all_parent_folders(tenant_id): +async def get_all_parent_folders(tenant_id): """ Get all parent folders of a file. --- @@ -481,7 +480,7 @@ async def rm(tenant_id): type: boolean example: true """ - req = await request.json + req = await get_request_json() file_ids = req["file_ids"] try: for file_id in file_ids: @@ -556,7 +555,7 @@ async def rename(tenant_id): type: boolean example: true """ - req = await request.json + req = await get_request_json() try: e, file = FileService.get_by_id(req["file_id"]) if not e: @@ -667,7 +666,7 @@ async def move(tenant_id): type: boolean example: true """ - req = await request.json + req = await get_request_json() try: file_ids = req["src_file_ids"] parent_id = req["dest_file_id"] @@ -694,7 +693,7 @@ async def move(tenant_id): @manager.route('/file/convert', methods=['POST']) # noqa: F821 @token_required async def convert(tenant_id): - req = await request.json + req = await get_request_json() kb_ids = req["kb_ids"] file_ids = req["file_ids"] file2documents = [] diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 074401ede..6276877a2 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -35,7 +35,7 @@ from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from common.misc_utils import get_uuid from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \ - get_result, server_error_response, token_required, validate_request + get_result, get_request_json, server_error_response, token_required, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format @@ -45,7 +45,7 @@ from common import settings @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required async def create(tenant_id, chat_id): - req = await request.json + req = await get_request_json() req["dialog_id"] = chat_id dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) if not dia: @@ -73,7 +73,7 @@ async def create(tenant_id, chat_id): @manager.route("/agents//sessions", methods=["POST"]) # noqa: F821 @token_required -def create_agent_session(tenant_id, agent_id): +async def create_agent_session(tenant_id, agent_id): user_id = request.args.get("user_id", tenant_id) e, cvs = UserCanvasService.get_by_id(agent_id) if not e: @@ -98,7 +98,7 @@ def create_agent_session(tenant_id, agent_id): @manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821 @token_required async def update(tenant_id, chat_id, session_id): - req = await request.json + req = await get_request_json() req["dialog_id"] = chat_id conv_id = session_id conv = ConversationService.query(id=conv_id, dialog_id=chat_id) @@ -120,7 +120,7 @@ async def update(tenant_id, chat_id, session_id): @manager.route("/chats//completions", methods=["POST"]) # noqa: F821 @token_required async def chat_completion(tenant_id, chat_id): - req = await request.json + req = await get_request_json() if not req: req = {"question": ""} if not req.get("session_id"): @@ -206,7 +206,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): if reference: print(completion.choices[0].message.reference) """ - req = await request.get_json() + req = await get_request_json() need_reference = bool(req.get("reference", False)) @@ -384,7 +384,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): @validate_request("model", "messages") # noqa: F821 @token_required async def agents_completion_openai_compatibility(tenant_id, agent_id): - req = await request.json + req = await get_request_json() tiktokenenc = tiktoken.get_encoding("cl100k_base") messages = req.get("messages", []) if not messages: @@ -442,7 +442,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): @manager.route("/agents//completions", methods=["POST"]) # noqa: F821 @token_required async def agent_completions(tenant_id, agent_id): - req = await request.json + req = await get_request_json() if req.get("stream", True): @@ -491,7 +491,7 @@ async def agent_completions(tenant_id, agent_id): @manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @token_required -def list_session(tenant_id, chat_id): +async def list_session(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(message=f"You don't own the assistant {chat_id}.") id = request.args.get("id") @@ -545,7 +545,7 @@ def list_session(tenant_id, chat_id): @manager.route("/agents//sessions", methods=["GET"]) # noqa: F821 @token_required -def list_agent_session(tenant_id, agent_id): +async def list_agent_session(tenant_id, agent_id): if not UserCanvasService.query(user_id=tenant_id, id=agent_id): return get_error_data_result(message=f"You don't own the agent {agent_id}.") id = request.args.get("id") @@ -614,7 +614,7 @@ async def delete(tenant_id, chat_id): errors = [] success_count = 0 - req = await request.json + req = await get_request_json() convs = ConversationService.query(dialog_id=chat_id) if not req: ids = None @@ -662,7 +662,7 @@ async def delete(tenant_id, chat_id): async def delete_agent_session(tenant_id, agent_id): errors = [] success_count = 0 - req = await request.json + req = await get_request_json() cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) if not cvs: return get_error_data_result(f"You don't own the agent {agent_id}") @@ -715,7 +715,7 @@ async def delete_agent_session(tenant_id, agent_id): @manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 @token_required async def ask_about(tenant_id): - req = await request.json + req = await get_request_json() if not req.get("question"): return get_error_data_result("`question` is required.") if not req.get("dataset_ids"): @@ -754,7 +754,7 @@ async def ask_about(tenant_id): @manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 @token_required async def related_questions(tenant_id): - req = await request.json + req = await get_request_json() if not req.get("question"): return get_error_data_result("`question` is required.") question = req["question"] @@ -805,7 +805,7 @@ Related search terms: @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 async def chatbot_completions(dialog_id): - req = await request.json + req = await get_request_json() token = request.headers.get("Authorization").split() if len(token) != 2: @@ -831,7 +831,7 @@ async def chatbot_completions(dialog_id): @manager.route("/chatbots//info", methods=["GET"]) # noqa: F821 -def chatbots_inputs(dialog_id): +async def chatbots_inputs(dialog_id): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -855,7 +855,7 @@ def chatbots_inputs(dialog_id): @manager.route("/agentbots//completions", methods=["POST"]) # noqa: F821 async def agent_bot_completions(agent_id): - req = await request.json + req = await get_request_json() token = request.headers.get("Authorization").split() if len(token) != 2: @@ -878,7 +878,7 @@ async def agent_bot_completions(agent_id): @manager.route("/agentbots//inputs", methods=["GET"]) # noqa: F821 -def begin_inputs(agent_id): +async def begin_inputs(agent_id): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -908,7 +908,7 @@ async def ask_about_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = await request.json + req = await get_request_json() uid = objs[0].tenant_id search_id = req.get("search_id", "") @@ -947,7 +947,7 @@ async def retrieval_test_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = await request.json + req = await get_request_json() page = int(req.get("page", 1)) size = int(req.get("size", 30)) question = req["question"] @@ -1046,7 +1046,7 @@ async def related_questions_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = await request.json + req = await get_request_json() tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") @@ -1081,7 +1081,7 @@ Related search terms: @manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821 -def detail_share_embedded(): +async def detail_share_embedded(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -1123,7 +1123,7 @@ async def mindmap(): return get_error_data_result(message='Authentication error: API key is invalid!"') tenant_id = objs[0].tenant_id - req = await request.json + req = await get_request_json() search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} diff --git a/api/apps/search_app.py b/api/apps/search_app.py index d350b93c3..d82c3b27d 100644 --- a/api/apps/search_app.py +++ b/api/apps/search_app.py @@ -24,14 +24,14 @@ from api.db.services.search_service import SearchService from api.db.services.user_service import TenantService, UserTenantService from common.misc_utils import get_uuid from common.constants import RetCode, StatusEnum -from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request @manager.route("/create", methods=["post"]) # noqa: F821 @login_required @validate_request("name") async def create(): - req = await request.get_json() + req = await get_request_json() search_name = req["name"] description = req.get("description", "") if not isinstance(search_name, str): @@ -66,7 +66,7 @@ async def create(): @validate_request("search_id", "name", "search_config", "tenant_id") @not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") async def update(): - req = await request.get_json() + req = await get_request_json() if not isinstance(req["name"], str): return get_data_error_result(message="Search name must be string.") if req["name"].strip() == "": @@ -150,7 +150,7 @@ async def list_search_app(): else: desc = True - req = await request.get_json() + req = await get_request_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -174,7 +174,7 @@ async def list_search_app(): @login_required @validate_request("search_id") async def rm(): - req = await request.get_json() + req = await get_request_json() search_id = req["search_id"] if not SearchService.accessible4deletion(search_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py index 380838bcd..fdb764e65 100644 --- a/api/apps/tenant_app.py +++ b/api/apps/tenant_app.py @@ -14,7 +14,6 @@ # limitations under the License. # -from quart import request from api.db import UserTenantRole from api.db.db_models import UserTenant from api.db.services.user_service import UserTenantService, UserService @@ -22,7 +21,7 @@ from api.db.services.user_service import UserTenantService, UserService from common.constants import RetCode, StatusEnum from common.misc_utils import get_uuid from common.time_utils import delta_seconds -from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from api.utils.web_utils import send_invite_email from common import settings from api.apps import smtp_mail_server, login_required, current_user @@ -56,7 +55,7 @@ async def create(tenant_id): message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) - req = await request.json + req = await get_request_json() invite_user_email = req["email"] invite_users = UserService.query(email=invite_user_email) if not invite_users: diff --git a/api/apps/user_app.py b/api/apps/user_app.py index ae1355da8..78407b242 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -39,6 +39,7 @@ from common.connection_utils import construct_response from api.utils.api_utils import ( get_data_error_result, get_json_result, + get_request_json, server_error_response, validate_request, ) @@ -57,6 +58,7 @@ from api.utils.web_utils import ( captcha_key, ) from common import settings +from common.http_client import async_request @manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @@ -90,7 +92,7 @@ async def login(): schema: type: object """ - json_body = await request.json + json_body = await get_request_json() if not json_body: return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!") @@ -136,7 +138,7 @@ async def login(): @manager.route("/login/channels", methods=["GET"]) # noqa: F821 -def get_login_channels(): +async def get_login_channels(): """ Get all supported authentication channels. """ @@ -157,7 +159,7 @@ def get_login_channels(): @manager.route("/login/", methods=["GET"]) # noqa: F821 -def oauth_login(channel): +async def oauth_login(channel): channel_config = settings.OAUTH_CONFIG.get(channel) if not channel_config: raise ValueError(f"Invalid channel name: {channel}") @@ -170,7 +172,7 @@ def oauth_login(channel): @manager.route("/oauth/callback/", methods=["GET"]) # noqa: F821 -def oauth_callback(channel): +async def oauth_callback(channel): """ Handle the OAuth/OIDC callback for various channels dynamically. """ @@ -192,7 +194,10 @@ def oauth_callback(channel): return redirect("/?error=missing_code") # Exchange authorization code for access token - token_info = auth_cli.exchange_code_for_token(code) + if hasattr(auth_cli, "async_exchange_code_for_token"): + token_info = await auth_cli.async_exchange_code_for_token(code) + else: + token_info = auth_cli.exchange_code_for_token(code) access_token = token_info.get("access_token") if not access_token: return redirect("/?error=token_failed") @@ -200,7 +205,10 @@ def oauth_callback(channel): id_token = token_info.get("id_token") # Fetch user info - user_info = auth_cli.fetch_user_info(access_token, id_token=id_token) + if hasattr(auth_cli, "async_fetch_user_info"): + user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token) + else: + user_info = auth_cli.fetch_user_info(access_token, id_token=id_token) if not user_info.email: return redirect("/?error=email_missing") @@ -259,7 +267,7 @@ def oauth_callback(channel): @manager.route("/github_callback", methods=["GET"]) # noqa: F821 -def github_callback(): +async def github_callback(): """ **Deprecated**, Use `/oauth/callback/` instead. @@ -279,9 +287,8 @@ def github_callback(): schema: type: object """ - import requests - - res = requests.post( + res = await async_request( + "POST", settings.GITHUB_OAUTH.get("url"), data={ "client_id": settings.GITHUB_OAUTH.get("client_id"), @@ -299,7 +306,7 @@ def github_callback(): session["access_token"] = res["access_token"] session["access_token_from"] = "github" - user_info = user_info_from_github(session["access_token"]) + user_info = await user_info_from_github(session["access_token"]) email_address = user_info["email"] users = UserService.query(email=email_address) user_id = get_uuid() @@ -348,7 +355,7 @@ def github_callback(): @manager.route("/feishu_callback", methods=["GET"]) # noqa: F821 -def feishu_callback(): +async def feishu_callback(): """ Feishu OAuth callback endpoint. --- @@ -366,9 +373,8 @@ def feishu_callback(): schema: type: object """ - import requests - - app_access_token_res = requests.post( + app_access_token_res = await async_request( + "POST", settings.FEISHU_OAUTH.get("app_access_token_url"), data=json.dumps( { @@ -382,7 +388,8 @@ def feishu_callback(): if app_access_token_res["code"] != 0: return redirect("/?error=%s" % app_access_token_res) - res = requests.post( + res = await async_request( + "POST", settings.FEISHU_OAUTH.get("user_access_token_url"), data=json.dumps( { @@ -403,7 +410,7 @@ def feishu_callback(): return redirect("/?error=contact:user.email:readonly not in scope") session["access_token"] = res["data"]["access_token"] session["access_token_from"] = "feishu" - user_info = user_info_from_feishu(session["access_token"]) + user_info = await user_info_from_feishu(session["access_token"]) email_address = user_info["email"] users = UserService.query(email=email_address) user_id = get_uuid() @@ -451,36 +458,34 @@ def feishu_callback(): return redirect("/?auth=%s" % user.get_id()) -def user_info_from_feishu(access_token): - import requests - +async def user_info_from_feishu(access_token): headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {access_token}", } - res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) + res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) user_info = res.json()["data"] user_info["email"] = None if user_info.get("email") == "" else user_info["email"] return user_info -def user_info_from_github(access_token): - import requests - +async def user_info_from_github(access_token): headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} - res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) + res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers) user_info = res.json() - email_info = requests.get( + email_info_response = await async_request( + "GET", f"https://api.github.com/user/emails?access_token={access_token}", headers=headers, - ).json() + ) + email_info = email_info_response.json() user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] return user_info @manager.route("/logout", methods=["GET"]) # noqa: F821 @login_required -def log_out(): +async def log_out(): """ User logout endpoint. --- @@ -531,7 +536,7 @@ async def setting_user(): type: object """ update_dict = {} - request_data = await request.json + request_data = await get_request_json() if request_data.get("password"): new_password = request_data.get("new_password") if not check_password_hash(current_user.password, decrypt(request_data["password"])): @@ -570,7 +575,7 @@ async def setting_user(): @manager.route("/info", methods=["GET"]) # noqa: F821 @login_required -def user_profile(): +async def user_profile(): """ Get user profile information. --- @@ -698,7 +703,7 @@ async def user_add(): code=RetCode.OPERATING_ERROR, ) - req = await request.json + req = await get_request_json() email_address = req["email"] # Validate the email address @@ -755,7 +760,7 @@ async def user_add(): @manager.route("/tenant_info", methods=["GET"]) # noqa: F821 @login_required -def tenant_info(): +async def tenant_info(): """ Get tenant information. --- @@ -831,14 +836,14 @@ async def set_tenant_info(): schema: type: object """ - req = await request.json + req = await get_request_json() try: tid = req.pop("tenant_id") TenantService.update_by_id(tid, req) return get_json_result(data=True) except Exception as e: return server_error_response(e) - + @manager.route("/forget/captcha", methods=["GET"]) # noqa: F821 async def forget_get_captcha(): @@ -875,7 +880,7 @@ async def forget_send_otp(): - Verify the image captcha stored at captcha:{email} (case-insensitive). - On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email. """ - req = await request.get_json() + req = await get_request_json() email = req.get("email") or "" captcha = (req.get("captcha") or "").strip() @@ -931,7 +936,7 @@ async def forget_send_otp(): ) except Exception: return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email") - + return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent") @@ -941,7 +946,7 @@ async def forget(): POST: Verify email + OTP and reset password, then log the user in. Request JSON: { email, otp, new_password, confirm_new_password } """ - req = await request.get_json() + req = await get_request_json() email = req.get("email") or "" otp = (req.get("otp") or "").strip() new_pwd = req.get("new_password") @@ -1006,4 +1011,4 @@ async def forget(): user.update_date = datetime_format(datetime.now()) user.save() msg = "Password reset successful. Logged in." - return construct_response(data=user.to_json(), auth=user.get_id(), message=msg) + return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg) diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 11ef5b454..d5a8535ef 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -655,7 +655,7 @@ class FileService(CommonService): return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) @staticmethod - def get_files(self, files: Union[None, list[dict]]) -> list[str]: + def get_files(files: Union[None, list[dict]]) -> list[str]: if not files: return [] def image_to_base64(file): diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 4d4ccaa57..a681341d4 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import inspect import logging import re +import threading from common.token_utils import num_tokens_from_string from functools import partial from typing import Generator @@ -242,7 +244,7 @@ class LLMBundle(LLM4Tenant): if not self.verbose_tool_use: txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): + if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) if self.langfuse: @@ -279,5 +281,80 @@ class LLMBundle(LLM4Tenant): yield ans if total_tokens > 0: - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): - logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): + logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) + + def _bridge_sync_stream(self, gen): + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for item in gen: + loop.call_soon_threadsafe(queue.put_nowait, item) + except Exception as e: # pragma: no cover + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) + + threading.Thread(target=worker, daemon=True).start() + return queue + + async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs): + chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) + if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"): + chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs) + + use_kwargs = self._clean_param(chat_partial, **kwargs) + + if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools: + txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs) + elif hasattr(self.mdl, "async_chat"): + txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs) + else: + txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs) + + txt = self._remove_reasoning_content(txt) + if not self.verbose_tool_use: + txt = re.sub(r".*?", "", txt, flags=re.DOTALL) + + if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): + logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) + + return txt + + async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): + total_tokens = 0 + if self.is_tools and self.mdl.is_tools: + stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None) + else: + stream_fn = getattr(self.mdl, "async_chat_streamly", None) + + if stream_fn: + chat_partial = partial(stream_fn, system, history, gen_conf) + use_kwargs = self._clean_param(chat_partial, **kwargs) + async for txt in chat_partial(**use_kwargs): + if isinstance(txt, int): + total_tokens = txt + break + yield txt + if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): + logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) + return + + chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf) + use_kwargs = self._clean_param(chat_partial, **kwargs) + queue = self._bridge_sync_stream(chat_partial(**use_kwargs)) + while True: + item = await queue.get() + if item is StopAsyncIteration: + break + if isinstance(item, Exception): + raise item + if isinstance(item, int): + total_tokens = item + break + yield item + + if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): + logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) diff --git a/api/ragflow_server.py b/api/ragflow_server.py index f6cb7bc2b..59622fe68 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -25,7 +25,6 @@ import logging import os import signal import sys -import time import traceback import threading import uuid @@ -69,7 +68,7 @@ def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") shutdown_all_mcp_sessions() stop_event.set() - time.sleep(1) + stop_event.wait(1) sys.exit(0) if __name__ == '__main__': @@ -163,5 +162,5 @@ if __name__ == '__main__': except Exception: traceback.print_exc() stop_event.set() - time.sleep(1) + stop_event.wait(1) os.kill(os.getpid(), signal.SIGKILL) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 314211694..8f17e1de0 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -22,6 +22,7 @@ import os import time from copy import deepcopy from functools import wraps +from typing import Any import requests import trio @@ -45,11 +46,40 @@ from common import settings requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) -async def request_json(): +async def _coerce_request_data() -> dict: + """Fetch JSON body with sane defaults; fallback to form data.""" + payload: Any = None + last_error: Exception | None = None + try: - return await request.json - except Exception: - return {} + payload = await request.get_json(force=True, silent=True) + except Exception as e: + last_error = e + payload = None + + if payload is None: + try: + form = await request.form + payload = form.to_dict() + except Exception as e: + last_error = e + payload = None + + if payload is None: + if last_error is not None: + raise last_error + raise ValueError("No JSON body or form data found in request.") + + if isinstance(payload, dict): + return payload or {} + + if isinstance(payload, str): + raise AttributeError("'str' object has no attribute 'get'") + + raise TypeError(f"Unsupported request payload type: {type(payload)!r}") + +async def get_request_json(): + return await _coerce_request_data() def serialize_for_json(obj): """ @@ -137,7 +167,7 @@ def validate_request(*args, **kwargs): def wrapper(func): @wraps(func) async def decorated_function(*_args, **_kwargs): - errs = process_args(await request.json or (await request.form).to_dict()) + errs = process_args(await _coerce_request_data()) if errs: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs) if inspect.iscoroutinefunction(func): @@ -152,7 +182,7 @@ def validate_request(*args, **kwargs): def not_allowed_parameters(*params): def decorator(func): async def wrapper(*args, **kwargs): - input_arguments = await request.json or (await request.form).to_dict() + input_arguments = await _coerce_request_data() for param in params: if param in input_arguments: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") diff --git a/common/http_client.py b/common/http_client.py new file mode 100644 index 000000000..2ffbb3bce --- /dev/null +++ b/common/http_client.py @@ -0,0 +1,157 @@ +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +import time +from typing import Any, Dict, Optional + +import httpx + +logger = logging.getLogger(__name__) + +# Default knobs; keep conservative to avoid unexpected behavioural changes. +DEFAULT_TIMEOUT = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15")) +# Align with requests default: follow redirects with a max of 30 unless overridden. +DEFAULT_FOLLOW_REDIRECTS = bool(int(os.environ.get("HTTP_CLIENT_FOLLOW_REDIRECTS", "1"))) +DEFAULT_MAX_REDIRECTS = int(os.environ.get("HTTP_CLIENT_MAX_REDIRECTS", "30")) +DEFAULT_MAX_RETRIES = int(os.environ.get("HTTP_CLIENT_MAX_RETRIES", "2")) +DEFAULT_BACKOFF_FACTOR = float(os.environ.get("HTTP_CLIENT_BACKOFF_FACTOR", "0.5")) +DEFAULT_PROXY = os.environ.get("HTTP_CLIENT_PROXY") +DEFAULT_USER_AGENT = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client") + + +def _clean_headers(headers: Optional[Dict[str, str]], auth_token: Optional[str] = None) -> Optional[Dict[str, str]]: + merged_headers: Dict[str, str] = {} + if DEFAULT_USER_AGENT: + merged_headers["User-Agent"] = DEFAULT_USER_AGENT + if auth_token: + merged_headers["Authorization"] = auth_token + if headers is None: + return merged_headers or None + merged_headers.update({str(k): str(v) for k, v in headers.items() if v is not None}) + return merged_headers or None + + +def _get_delay(backoff_factor: float, attempt: int) -> float: + return backoff_factor * (2**attempt) + + +async def async_request( + method: str, + url: str, + *, + timeout: float | httpx.Timeout | None = None, + follow_redirects: bool | None = None, + max_redirects: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + auth_token: Optional[str] = None, + retries: Optional[int] = None, + backoff_factor: Optional[float] = None, + proxies: Any = None, + **kwargs: Any, +) -> httpx.Response: + """Lightweight async HTTP wrapper using httpx.AsyncClient with safe defaults.""" + timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects + max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects + retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0) + backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor + headers = _clean_headers(headers, auth_token=auth_token) + proxies = DEFAULT_PROXY if proxies is None else proxies + + async with httpx.AsyncClient( + timeout=timeout, + follow_redirects=follow_redirects, + max_redirects=max_redirects, + proxies=proxies, + ) as client: + last_exc: Exception | None = None + for attempt in range(retries + 1): + try: + start = time.monotonic() + response = await client.request(method=method, url=url, headers=headers, **kwargs) + duration = time.monotonic() - start + logger.debug(f"async_request {method} {url} -> {response.status_code} in {duration:.3f}s") + return response + except httpx.RequestError as exc: + last_exc = exc + if attempt >= retries: + logger.warning(f"async_request exhausted retries for {method} {url}: {exc}") + raise + delay = _get_delay(backoff_factor, attempt) + logger.warning(f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s") + await asyncio.sleep(delay) + raise last_exc # pragma: no cover + + +def sync_request( + method: str, + url: str, + *, + timeout: float | httpx.Timeout | None = None, + follow_redirects: bool | None = None, + max_redirects: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + auth_token: Optional[str] = None, + retries: Optional[int] = None, + backoff_factor: Optional[float] = None, + proxies: Any = None, + **kwargs: Any, +) -> httpx.Response: + """Synchronous counterpart to async_request, for CLI/tests or sync contexts.""" + timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + follow_redirects = DEFAULT_FOLLOW_REDIRECTS if follow_redirects is None else follow_redirects + max_redirects = DEFAULT_MAX_REDIRECTS if max_redirects is None else max_redirects + retries = DEFAULT_MAX_RETRIES if retries is None else max(retries, 0) + backoff_factor = DEFAULT_BACKOFF_FACTOR if backoff_factor is None else backoff_factor + headers = _clean_headers(headers, auth_token=auth_token) + proxies = DEFAULT_PROXY if proxies is None else proxies + + with httpx.Client( + timeout=timeout, + follow_redirects=follow_redirects, + max_redirects=max_redirects, + proxies=proxies, + ) as client: + last_exc: Exception | None = None + for attempt in range(retries + 1): + try: + start = time.monotonic() + response = client.request(method=method, url=url, headers=headers, **kwargs) + duration = time.monotonic() - start + logger.debug(f"sync_request {method} {url} -> {response.status_code} in {duration:.3f}s") + return response + except httpx.RequestError as exc: + last_exc = exc + if attempt >= retries: + logger.warning(f"sync_request exhausted retries for {method} {url}: {exc}") + raise + delay = _get_delay(backoff_factor, attempt) + logger.warning(f"sync_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s") + time.sleep(delay) + raise last_exc # pragma: no cover + + +__all__ = [ + "async_request", + "sync_request", + "DEFAULT_TIMEOUT", + "DEFAULT_FOLLOW_REDIRECTS", + "DEFAULT_MAX_REDIRECTS", + "DEFAULT_MAX_RETRIES", + "DEFAULT_BACKOFF_FACTOR", + "DEFAULT_PROXY", + "DEFAULT_USER_AGENT", +] diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 897fec65f..1913646a2 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -50,6 +50,7 @@ class SupportedLiteLLMProvider(StrEnum): GiteeAI = "GiteeAI" AI_302 = "302.AI" JiekouAI = "Jiekou.AI" + ZHIPU_AI = "ZHIPU-AI" FACTORY_DEFAULT_BASE_URL = { @@ -71,6 +72,7 @@ FACTORY_DEFAULT_BASE_URL = { SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1", SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/", SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", + SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4", } @@ -102,6 +104,7 @@ LITELLM_PROVIDER_PREFIX = { SupportedLiteLLMProvider.GiteeAI: "openai/", SupportedLiteLLMProvider.AI_302: "openai/", SupportedLiteLLMProvider.JiekouAI: "openai/", + SupportedLiteLLMProvider.ZHIPU_AI: "openai/", } ChatModel = globals().get("ChatModel", {}) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 726aecd8b..1f38292ba 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -19,6 +19,7 @@ import logging import os import random import re +import threading import time from abc import ABC from copy import deepcopy @@ -28,10 +29,9 @@ import json_repair import litellm import openai import requests -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI from openai.lib.azure import AzureOpenAI from strenum import StrEnum -from zhipuai import ZhipuAI from common.token_utils import num_tokens_from_string, total_token_count_from_response from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider @@ -68,6 +68,7 @@ class Base(ABC): def __init__(self, key, model_name, base_url, **kwargs): timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) + self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name # Configure retry parameters self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) @@ -139,6 +140,23 @@ class Base(ABC): return gen_conf + def _bridge_sync_stream(self, gen): + """Run a sync generator in a thread and yield asynchronously.""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for item in gen: + loop.call_soon_threadsafe(queue.put_nowait, item) + except Exception as exc: # pragma: no cover - defensive + loop.call_soon_threadsafe(queue.put_nowait, exc) + finally: + loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) + + threading.Thread(target=worker, daemon=True).start() + return queue + def _chat(self, history, gen_conf, **kwargs): logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) if self.model_name.lower().find("qwq") >= 0: @@ -204,6 +222,60 @@ class Base(ABC): ans += LENGTH_NOTIFICATION_EN yield ans, tol + async def _async_chat_stream(self, history, gen_conf, **kwargs): + logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) + reasoning_start = False + + request_kwargs = {"model": self.model_name, "messages": history, "stream": True, **gen_conf} + stop = kwargs.get("stop") + if stop: + request_kwargs["stop"] = stop + + response = await self.async_client.chat.completions.create(**request_kwargs) + + async for resp in response: + if not resp.choices: + continue + if not resp.choices[0].delta.content: + resp.choices[0].delta.content = "" + if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" + ans += resp.choices[0].delta.reasoning_content + "" + else: + reasoning_start = False + ans = resp.choices[0].delta.content + + tol = total_token_count_from_response(resp) + if not tol: + tol = num_tokens_from_string(resp.choices[0].delta.content) + + finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" + if finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + yield ans, tol + + async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + gen_conf = self._clean_conf(gen_conf) + ans = "" + total_tokens = 0 + try: + async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs): + ans = delta_ans + total_tokens += tol + yield delta_ans + except openai.APIError as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens + def _length_stop(self, ans): if is_chinese([ans]): return ans + LENGTH_NOTIFICATION_CN @@ -232,7 +304,25 @@ class Base(ABC): time.sleep(delay) return None - return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}" + logging.error(f"sync base giving up: {msg}") + return msg + + async def _exceptions_async(self, e, attempt) -> str | None: + logging.exception("OpenAI async completion") + error_code = self._classify_error(e) + if attempt == self.max_retries: + error_code = LLMErrorCode.ERROR_MAX_RETRIES + + if self._should_retry(error_code): + delay = self._get_delay() + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + await asyncio.sleep(delay) + return None + + msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}" + logging.error(f"async base giving up: {msg}") + return msg def _verbose_tool_use(self, name, args, res): return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" @@ -323,6 +413,60 @@ class Base(ABC): assert False, "Shouldn't be here." + async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): + gen_conf = self._clean_conf(gen_conf) + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + + ans = "" + tk_count = 0 + hist = deepcopy(history) + for attempt in range(self.max_retries + 1): + history = deepcopy(hist) + try: + for _ in range(self.max_rounds + 1): + logging.info(f"{self.tools=}") + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf) + tk_count += total_token_count_from_response(response) + if any([not response.choices, not response.choices[0].message]): + raise Exception(f"500 response structure error. Response: {response}") + + if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: + if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: + ans += "" + response.choices[0].message.reasoning_content + "" + + ans += response.choices[0].message.content + if response.choices[0].finish_reason == "length": + ans = self._length_stop(ans) + + return ans, tk_count + + for tool_call in response.choices[0].message.tool_calls: + logging.info(f"Response {tool_call=}") + name = tool_call.function.name + try: + args = json_repair.loads(tool_call.function.arguments) + tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) + history = self._append_history(history, tool_call, tool_response) + ans += self._verbose_tool_use(name, args, tool_response) + except Exception as e: + logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) + ans += self._verbose_tool_use(name, {}, str(e)) + + logging.warning(f"Exceed max rounds: {self.max_rounds}") + history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) + response, token_count = await self._async_chat(history, gen_conf) + ans += response + tk_count += token_count + return ans, tk_count + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + return e, tk_count + + assert False, "Shouldn't be here." + def chat(self, system, history, gen_conf={}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -457,6 +601,160 @@ class Base(ABC): assert False, "Shouldn't be here." + async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): + gen_conf = self._clean_conf(gen_conf) + tools = self.tools + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + + total_tokens = 0 + hist = deepcopy(history) + + for attempt in range(self.max_retries + 1): + history = deepcopy(hist) + try: + for _ in range(self.max_rounds + 1): + reasoning_start = False + logging.info(f"{tools=}") + + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) + + final_tool_calls = {} + answer = "" + + async for resp in response: + if not hasattr(resp, "choices") or not resp.choices: + continue + + delta = resp.choices[0].delta + + if hasattr(delta, "tool_calls") and delta.tool_calls: + for tool_call in delta.tool_calls: + index = tool_call.index + if index not in final_tool_calls: + if not tool_call.function.arguments: + tool_call.function.arguments = "" + final_tool_calls[index] = tool_call + else: + final_tool_calls[index].function.arguments += tool_call.function.arguments or "" + continue + + if not hasattr(delta, "content") or delta.content is None: + delta.content = "" + + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" + ans += delta.reasoning_content + "" + yield ans + else: + reasoning_start = False + answer += delta.content + yield delta.content + + tol = total_token_count_from_response(resp) + if not tol: + total_tokens += num_tokens_from_string(delta.content) + else: + total_tokens = tol + + finish_reason = getattr(resp.choices[0], "finish_reason", "") + if finish_reason == "length": + yield self._length_stop("") + + if answer: + yield total_tokens + return + + for tool_call in final_tool_calls.values(): + name = tool_call.function.name + try: + args = json_repair.loads(tool_call.function.arguments) + yield self._verbose_tool_use(name, args, "Begin to call...") + tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) + history = self._append_history(history, tool_call, tool_response) + yield self._verbose_tool_use(name, args, tool_response) + except Exception as e: + logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) + yield self._verbose_tool_use(name, {}, str(e)) + + logging.warning(f"Exceed max rounds: {self.max_rounds}") + history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) + + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) + + async for resp in response: + if not hasattr(resp, "choices") or not resp.choices: + continue + delta = resp.choices[0].delta + if not hasattr(delta, "content") or delta.content is None: + continue + tol = total_token_count_from_response(resp) + if not tol: + total_tokens += num_tokens_from_string(delta.content) + else: + total_tokens = tol + yield delta.content + + yield total_tokens + return + + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + logging.error(f"async_chat_streamly failed: {e}") + yield e + yield total_tokens + return + + assert False, "Shouldn't be here." + + async def _async_chat(self, history, gen_conf, **kwargs): + logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) + if self.model_name.lower().find("qwq") >= 0: + logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly") + final_ans = "" + tol_token = 0 + async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs): + if delta.startswith("") or delta.endswith(""): + continue + final_ans += delta + tol_token = tol + + if len(final_ans.strip()) == 0: + final_ans = "**ERROR**: Empty response from reasoning model" + + return final_ans.strip(), tol_token + + if self.model_name.lower().find("qwen3") >= 0: + kwargs["extra_body"] = {"enable_thinking": False} + + response = await self.async_client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) + + if not response.choices or not response.choices[0].message or not response.choices[0].message.content: + return "", 0 + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans = self._length_stop(ans) + return ans, total_token_count_from_response(response) + + async def async_chat(self, system, history, gen_conf={}, **kwargs): + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + gen_conf = self._clean_conf(gen_conf) + + for attempt in range(self.max_retries + 1): + try: + return await self._async_chat(history, gen_conf, **kwargs) + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + return e, 0 + assert False, "Shouldn't be here." + def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -642,66 +940,6 @@ class BaiChuanChat(Base): yield total_tokens -class ZhipuChat(Base): - _FACTORY_NAME = "ZHIPU-AI" - - def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs): - super().__init__(key, model_name, base_url=base_url, **kwargs) - - self.client = ZhipuAI(api_key=key) - self.model_name = model_name - - def _clean_conf(self, gen_conf): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - gen_conf = self._clean_conf_plealty(gen_conf) - return gen_conf - - def _clean_conf_plealty(self, gen_conf): - if "presence_penalty" in gen_conf: - del gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - del gen_conf["frequency_penalty"] - return gen_conf - - def chat_with_tools(self, system: str, history: list, gen_conf: dict): - gen_conf = self._clean_conf_plealty(gen_conf) - - return super().chat_with_tools(system, history, gen_conf) - - def chat_streamly(self, system, history, gen_conf={}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - ans = "" - tk_count = 0 - try: - logging.info(json.dumps(history, ensure_ascii=False, indent=2)) - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) - for resp in response: - if not resp.choices[0].delta.content: - continue - delta = resp.choices[0].delta.content - ans = delta - if resp.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - tk_count = total_token_count_from_response(resp) - if resp.choices[0].finish_reason == "stop": - tk_count = total_token_count_from_response(resp) - yield ans - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - - yield tk_count - - def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): - gen_conf = self._clean_conf_plealty(gen_conf) - return super().chat_streamly_with_tools(system, history, gen_conf) - - class LocalAIChat(Base): _FACTORY_NAME = "LocalAI" @@ -1403,6 +1641,7 @@ class LiteLLMBase(ABC): "GiteeAI", "302.AI", "Jiekou.AI", + "ZHIPU-AI", ] def __init__(self, key, model_name, base_url=None, **kwargs): @@ -1482,6 +1721,7 @@ class LiteLLMBase(ABC): def _chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) + gen_conf = self._clean_conf(gen_conf) reasoning_start = False completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) @@ -1525,6 +1765,96 @@ class LiteLLMBase(ABC): yield ans, tol + async def async_chat(self, history, gen_conf, **kwargs): + logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) + if self.model_name.lower().find("qwen3") >= 0: + kwargs["extra_body"] = {"enable_thinking": False} + + completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) + + for attempt in range(self.max_retries + 1): + try: + response = await litellm.acompletion( + **completion_args, + drop_params=True, + timeout=self.timeout, + ) + + if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): + return "", 0 + ans = response.choices[0].message.content.strip() + if response.choices[0].finish_reason == "length": + ans = self._length_stop(ans) + + return ans, total_token_count_from_response(response) + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + return e, 0 + + assert False, "Shouldn't be here." + + async def async_chat_streamly(self, system, history, gen_conf, **kwargs): + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) + gen_conf = self._clean_conf(gen_conf) + reasoning_start = False + total_tokens = 0 + + completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) + stop = kwargs.get("stop") + if stop: + completion_args["stop"] = stop + + for attempt in range(self.max_retries + 1): + try: + stream = await litellm.acompletion( + **completion_args, + drop_params=True, + timeout=self.timeout, + ) + + async for resp in stream: + if not hasattr(resp, "choices") or not resp.choices: + continue + + delta = resp.choices[0].delta + if not hasattr(delta, "content") or delta.content is None: + delta.content = "" + + if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content: + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" + ans += delta.reasoning_content + "" + else: + reasoning_start = False + ans = delta.content + + tol = total_token_count_from_response(resp) + if not tol: + tol = num_tokens_from_string(delta.content) + total_tokens += tol + + finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" + if finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + + yield ans + yield total_tokens + return + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + yield e + yield total_tokens + return + def _length_stop(self, ans): if is_chinese([ans]): return ans + LENGTH_NOTIFICATION_CN @@ -1555,6 +1885,21 @@ class LiteLLMBase(ABC): return f"{ERROR_PREFIX}: {error_code} - {str(e)}" + async def _exceptions_async(self, e, attempt) -> str | None: + logging.exception("LiteLLMBase async completion") + error_code = self._classify_error(e) + if attempt == self.max_retries: + error_code = LLMErrorCode.ERROR_MAX_RETRIES + + if self._should_retry(error_code): + delay = self._get_delay() + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") + await asyncio.sleep(delay) + return None + msg = f"{ERROR_PREFIX}: {error_code} - {str(e)}" + logging.error(f"async_chat_streamly giving up: {msg}") + return msg + def _verbose_tool_use(self, name, args, res): return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" From 41cff3e09e03a2a141b96a1b1067de60222aa27f Mon Sep 17 00:00:00 2001 From: Billy Bao Date: Mon, 1 Dec 2025 14:24:35 +0800 Subject: [PATCH 08/13] Fix: jina embedding issue (#11628) ### What problem does this PR solve? Fix: jina embedding issue #11614 Feat: Add jina embedding v4 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- conf/llm_factories.json | 6 +++++ rag/llm/embedding_model.py | 50 +++++++++++++++----------------------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/conf/llm_factories.json b/conf/llm_factories.json index d3b2dcc1c..3c84bd03d 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -1194,6 +1194,12 @@ "tags": "TEXT EMBEDDING", "max_tokens": 8196, "model_type": "embedding" + }, + { + "llm_name": "jina-embeddings-v4", + "tags": "TEXT EMBEDDING", + "max_tokens": 32768, + "model_type": "embedding" } ] }, diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 7f2f9ee7d..445ecab5a 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -349,35 +349,6 @@ class YoudaoEmbed(Base): return np.array(embds[0]), num_tokens_from_string(text) -class JinaEmbed(Base): - _FACTORY_NAME = "Jina" - - def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"): - self.base_url = "https://api.jina.ai/v1/embeddings" - self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} - self.model_name = model_name - - def encode(self, texts: list): - texts = [truncate(t, 8196) for t in texts] - batch_size = 16 - ress = [] - token_count = 0 - for i in range(0, len(texts), batch_size): - data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"} - response = requests.post(self.base_url, headers=self.headers, json=data) - try: - res = response.json() - ress.extend([d["embedding"] for d in res["data"]]) - token_count += self.total_token_count(res) - except Exception as _e: - log_exception(_e, response) - return np.array(ress), token_count - - def encode_queries(self, text): - embds, cnt = self.encode([text]) - return np.array(embds[0]), cnt - - class JinaMultiVecEmbed(Base): _FACTORY_NAME = "Jina" @@ -403,11 +374,28 @@ class JinaMultiVecEmbed(Base): img_b64s = base64.b64encode(text).decode('utf8') input.append({"image": img_b64s}) # base64 encoded image for i in range(0, len(texts), batch_size): - data = {"model": self.model_name, "task": task, "truncate": True, "return_multivector": True, "input": input[i : i + batch_size]} + data = {"model": self.model_name, "input": input[i : i + batch_size]} + if "v4" in self.model_name: + data["return_multivector"] = True + + if "v3" in self.model_name or "v4" in self.model_name: + data['task'] = task + data['truncate'] = True + response = requests.post(self.base_url, headers=self.headers, json=data) try: res = response.json() - ress.extend([d["embeddings"] for d in res["data"]]) + for d in res['data']: + if data.get("return_multivector", False): # v4 + token_embs = np.asarray(d['embeddings'], dtype=np.float32) + chunk_emb = token_embs.mean(axis=0) + + else: + # v2/v3 + chunk_emb = np.asarray(d['embedding'], dtype=np.float32) + + ress.append(chunk_emb) + token_count += self.total_token_count(res) except Exception as _e: log_exception(_e, response) From 21d8ffca5651bb2ecd823b5bb3ec1820febf617d Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Mon, 1 Dec 2025 14:58:33 +0800 Subject: [PATCH 09/13] Fix workflows --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a5a44a29c..39c526104 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,7 +31,7 @@ jobs: name: ragflow_tests # https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution # https://github.com/orgs/community/discussions/26261 - if: ${{ github.event_name != 'pull_request_target' || (contains(github.event.pull_request.labels.*.name, 'ci') && github.event.pull_request.mergeable != false) }} + if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }} runs-on: [ "self-hosted", "ragflow-test" ] steps: # https://github.com/hmarr/debug-action From 221947acc410ff162599e9fd73006878db6934a3 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Mon, 1 Dec 2025 15:33:07 +0800 Subject: [PATCH 10/13] Fix workflows --- .github/workflows/tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 39c526104..a61492238 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ on: # The only difference between pull_request and pull_request_target is the context in which the workflow runs: # — pull_request_target workflows use the workflow files from the default branch, and secrets are available. # — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable. - pull_request_target: + pull_request: types: [ synchronize, ready_for_review ] paths-ignore: - 'docs/**' @@ -31,7 +31,7 @@ jobs: name: ragflow_tests # https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution # https://github.com/orgs/community/discussions/26261 - if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }} + if: ${{ github.event_name != 'pull_request' || (github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'ci')) }} runs-on: [ "self-hosted", "ragflow-test" ] steps: # https://github.com/hmarr/debug-action @@ -53,7 +53,7 @@ jobs: - name: Check workflow duplication if: ${{ !cancelled() && !failure() }} run: | - if [[ ${GITHUB_EVENT_NAME} != "pull_request_target" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then + if [[ ${GITHUB_EVENT_NAME} != "pull_request" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then HEAD=$(git rev-parse HEAD) # Find a PR that introduced a given commit gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}" @@ -78,7 +78,7 @@ jobs: fi fi fi - elif [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then + elif [[ ${GITHUB_EVENT_NAME} == "pull_request" ]]; then PR_NUMBER=${{ github.event.pull_request.number }} PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER} # Calculate the hash of the current workspace content @@ -98,7 +98,7 @@ jobs: - name: Check comments of changed Python files if: ${{ false }} run: | - if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then + if [[ ${{ github.event_name }} == 'pull_request' || ${{ github.event_name }} == 'pull_request_target' ]]; then CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \ | grep -E '\.(py)$' || true) From 1120575021e583546b8977e7c94303dcf44b9d32 Mon Sep 17 00:00:00 2001 From: balibabu Date: Mon, 1 Dec 2025 16:29:02 +0800 Subject: [PATCH 11/13] Feat: Files uploaded via the dialog box can be uploaded without binding to a dataset. #9590 (#11630) ### What problem does this PR solve? Feat: Files uploaded via the dialog box can be uploaded without binding to a dataset. #9590 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- web/src/components/message-input/index.tsx | 372 ------------------ web/src/components/message-item/index.tsx | 39 +- .../uploaded-message-files.tsx | 12 +- web/src/hooks/document-hooks.ts | 34 -- web/src/hooks/use-chat-request.ts | 2 + web/src/interfaces/database/chat.ts | 13 +- .../next-chats/hooks/use-send-chat-message.ts | 12 +- .../hooks/use-send-multiple-message.ts | 12 +- .../pages/next-chats/hooks/use-upload-file.ts | 23 +- web/src/services/knowledge-service.ts | 5 - web/src/utils/api.ts | 2 +- 11 files changed, 60 insertions(+), 466 deletions(-) delete mode 100644 web/src/components/message-input/index.tsx diff --git a/web/src/components/message-input/index.tsx b/web/src/components/message-input/index.tsx deleted file mode 100644 index 95a4ee195..000000000 --- a/web/src/components/message-input/index.tsx +++ /dev/null @@ -1,372 +0,0 @@ -import { useTranslate } from '@/hooks/common-hooks'; -import { - useDeleteDocument, - useFetchDocumentInfosByIds, - useRemoveNextDocument, - useUploadAndParseDocument, -} from '@/hooks/document-hooks'; -import { cn } from '@/lib/utils'; -import { getExtension } from '@/utils/document-util'; -import { formatBytes } from '@/utils/file-util'; -import { - CloseCircleOutlined, - InfoCircleOutlined, - LoadingOutlined, -} from '@ant-design/icons'; -import type { GetProp, UploadFile } from 'antd'; -import { - Button, - Card, - Divider, - Flex, - Input, - List, - Space, - Spin, - Typography, - Upload, - UploadProps, -} from 'antd'; -import get from 'lodash/get'; -import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react'; -import { - ChangeEventHandler, - memo, - useCallback, - useEffect, - useRef, - useState, -} from 'react'; -import FileIcon from '../file-icon'; -import styles from './index.less'; - -type FileType = Parameters>[0]; -const { Text } = Typography; - -const { TextArea } = Input; - -const getFileId = (file: UploadFile) => get(file, 'response.data.0'); - -const getFileIds = (fileList: UploadFile[]) => { - const ids = fileList.reduce((pre, cur) => { - return pre.concat(get(cur, 'response.data', [])); - }, []); - - return ids; -}; - -const isUploadSuccess = (file: UploadFile) => { - const code = get(file, 'response.code'); - return typeof code === 'number' && code === 0; -}; - -interface IProps { - disabled: boolean; - value: string; - sendDisabled: boolean; - sendLoading: boolean; - onPressEnter(documentIds: string[]): void; - onInputChange: ChangeEventHandler; - conversationId: string; - uploadMethod?: string; - isShared?: boolean; - showUploadIcon?: boolean; - createConversationBeforeUploadDocument?(message: string): Promise; - stopOutputMessage?(): void; -} - -const getBase64 = (file: FileType): Promise => - new Promise((resolve, reject) => { - const reader = new FileReader(); - reader.readAsDataURL(file as any); - reader.onload = () => resolve(reader.result as string); - reader.onerror = (error) => reject(error); - }); - -const MessageInput = ({ - isShared = false, - disabled, - value, - onPressEnter, - sendDisabled, - sendLoading, - onInputChange, - conversationId, - showUploadIcon = true, - createConversationBeforeUploadDocument, - uploadMethod = 'upload_and_parse', - stopOutputMessage, -}: IProps) => { - const { t } = useTranslate('chat'); - const { removeDocument } = useRemoveNextDocument(); - const { deleteDocument } = useDeleteDocument(); - const { data: documentInfos, setDocumentIds } = useFetchDocumentInfosByIds(); - const { uploadAndParseDocument } = useUploadAndParseDocument(uploadMethod); - const conversationIdRef = useRef(conversationId); - - const [fileList, setFileList] = useState([]); - - const handlePreview = async (file: UploadFile) => { - if (!file.url && !file.preview) { - file.preview = await getBase64(file.originFileObj as FileType); - } - }; - - const handleChange: UploadProps['onChange'] = async ({ - // fileList: newFileList, - file, - }) => { - let nextConversationId: string = conversationId; - if (createConversationBeforeUploadDocument) { - const creatingRet = await createConversationBeforeUploadDocument( - file.name, - ); - if (creatingRet?.code === 0) { - nextConversationId = creatingRet.data.id; - } - } - setFileList((list) => { - list.push({ - ...file, - status: 'uploading', - originFileObj: file as any, - }); - return [...list]; - }); - const ret = await uploadAndParseDocument({ - conversationId: nextConversationId, - fileList: [file], - }); - setFileList((list) => { - const nextList = list.filter((x) => x.uid !== file.uid); - nextList.push({ - ...file, - originFileObj: file as any, - response: ret, - percent: 100, - status: ret?.code === 0 ? 'done' : 'error', - }); - return nextList; - }); - }; - - const isUploadingFile = fileList.some((x) => x.status === 'uploading'); - - const handlePressEnter = useCallback(async () => { - if (isUploadingFile) return; - const ids = getFileIds(fileList.filter((x) => isUploadSuccess(x))); - - onPressEnter(ids); - setFileList([]); - }, [fileList, onPressEnter, isUploadingFile]); - - const handleKeyDown = useCallback( - async (event: React.KeyboardEvent) => { - // check if it was shift + enter - if (event.key === 'Enter' && event.shiftKey) return; - if (event.key !== 'Enter') return; - if (sendDisabled || isUploadingFile || sendLoading) return; - - event.preventDefault(); - handlePressEnter(); - }, - [sendDisabled, isUploadingFile, sendLoading, handlePressEnter], - ); - - const handleRemove = useCallback( - async (file: UploadFile) => { - const ids = get(file, 'response.data', []); - // Upload Successfully - if (Array.isArray(ids) && ids.length) { - if (isShared) { - await deleteDocument(ids); - } else { - await removeDocument(ids[0]); - } - setFileList((preList) => { - return preList.filter((x) => getFileId(x) !== ids[0]); - }); - } else { - // Upload failed - setFileList((preList) => { - return preList.filter((x) => x.uid !== file.uid); - }); - } - }, - [removeDocument, deleteDocument, isShared], - ); - - const handleStopOutputMessage = useCallback(() => { - stopOutputMessage?.(); - }, [stopOutputMessage]); - - const getDocumentInfoById = useCallback( - (id: string) => { - return documentInfos.find((x) => x.id === id); - }, - [documentInfos], - ); - - useEffect(() => { - const ids = getFileIds(fileList); - setDocumentIds(ids); - }, [fileList, setDocumentIds]); - - useEffect(() => { - if ( - conversationIdRef.current && - conversationId !== conversationIdRef.current - ) { - setFileList([]); - } - conversationIdRef.current = conversationId; - }, [conversationId, setFileList]); - - return ( - -