From 5b5f19cbc1be926ade1499a3776ea4fb3b50e13f Mon Sep 17 00:00:00 2001 From: balibabu Date: Fri, 5 Dec 2025 18:04:49 +0800 Subject: [PATCH 01/12] Fix: Newly added models to OpenAI-API-Compatible are not displayed in the LLM dropdown menu in a timely manner. #11774 (#11775) ### What problem does this PR solve? Fix: Newly added models to OpenAI-API-Compatible are not displayed in the LLM dropdown menu in a timely manner. #11774 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- web/src/hooks/use-llm-request.tsx | 70 ++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/web/src/hooks/use-llm-request.tsx b/web/src/hooks/use-llm-request.tsx index 260cd47f9..3436b7506 100644 --- a/web/src/hooks/use-llm-request.tsx +++ b/web/src/hooks/use-llm-request.tsx @@ -24,6 +24,15 @@ import { buildLlmUuid } from '@/utils/llm-util'; export const enum LLMApiAction { LlmList = 'llmList', + MyLlmList = 'myLlmList', + MyLlmListDetailed = 'myLlmListDetailed', + FactoryList = 'factoryList', + SaveApiKey = 'saveApiKey', + SaveTenantInfo = 'saveTenantInfo', + AddLlm = 'addLlm', + DeleteLlm = 'deleteLlm', + EnableLlm = 'enableLlm', + DeleteFactory = 'deleteFactory', } export const useFetchLlmList = (modelType?: LlmModelType) => { @@ -177,7 +186,7 @@ export const useComposeLlmOptionsByModelTypes = ( export const useFetchLlmFactoryList = (): ResponseGetType => { const { data, isFetching: loading } = useQuery({ - queryKey: ['factoryList'], + queryKey: [LLMApiAction.FactoryList], initialData: [], gcTime: 0, queryFn: async () => { @@ -196,7 +205,7 @@ export const useFetchMyLlmList = (): ResponseGetType< Record > => { const { data, isFetching: loading } = useQuery({ - queryKey: ['myLlmList'], + queryKey: [LLMApiAction.MyLlmList], initialData: {}, gcTime: 0, queryFn: async () => { @@ -213,7 +222,7 @@ export const useFetchMyLlmListDetailed = (): ResponseGetType< Record > => { const { data, isFetching: loading } = useQuery({ - queryKey: ['myLlmListDetailed'], + queryKey: [LLMApiAction.MyLlmListDetailed], initialData: {}, gcTime: 0, queryFn: async () => { @@ -271,14 +280,16 @@ export const useSaveApiKey = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: ['saveApiKey'], + mutationKey: [LLMApiAction.SaveApiKey], mutationFn: async (params: IApiKeySavingParams) => { const { data } = await userService.set_api_key(params); if (data.code === 0) { message.success(t('message.modified')); - queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); - queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] }); - queryClient.invalidateQueries({ queryKey: ['factoryList'] }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); + queryClient.invalidateQueries({ + queryKey: [LLMApiAction.MyLlmListDetailed], + }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); } return data.code; }, @@ -303,7 +314,7 @@ export const useSaveTenantInfo = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: ['saveTenantInfo'], + mutationKey: [LLMApiAction.SaveTenantInfo], mutationFn: async (params: ISystemModelSettingSavingParams) => { const { data } = await userService.set_tenant_info(params); if (data.code === 0) { @@ -324,13 +335,16 @@ export const useAddLlm = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: ['addLlm'], + mutationKey: [LLMApiAction.AddLlm], mutationFn: async (params: IAddLlmRequestBody) => { const { data } = await userService.add_llm(params); if (data.code === 0) { - queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); - queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] }); - queryClient.invalidateQueries({ queryKey: ['factoryList'] }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); + queryClient.invalidateQueries({ + queryKey: [LLMApiAction.MyLlmListDetailed], + }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.LlmList] }); message.success(t('message.modified')); } return data.code; @@ -348,13 +362,15 @@ export const useDeleteLlm = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: ['deleteLlm'], + mutationKey: [LLMApiAction.DeleteLlm], mutationFn: async (params: IDeleteLlmRequestBody) => { const { data } = await userService.delete_llm(params); if (data.code === 0) { - queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); - queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] }); - queryClient.invalidateQueries({ queryKey: ['factoryList'] }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); + queryClient.invalidateQueries({ + queryKey: [LLMApiAction.MyLlmListDetailed], + }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); message.success(t('message.deleted')); } return data.code; @@ -372,7 +388,7 @@ export const useEnableLlm = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: ['enableLlm'], + mutationKey: [LLMApiAction.EnableLlm], mutationFn: async (params: IDeleteLlmRequestBody & { enable: boolean }) => { const reqParam: IDeleteLlmRequestBody & { enable?: boolean; @@ -381,9 +397,11 @@ export const useEnableLlm = () => { delete reqParam.enable; const { data } = await userService.enable_llm(reqParam); if (data.code === 0) { - queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); - queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] }); - queryClient.invalidateQueries({ queryKey: ['factoryList'] }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); + queryClient.invalidateQueries({ + queryKey: [LLMApiAction.MyLlmListDetailed], + }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); message.success(t('message.modified')); } return data.code; @@ -401,14 +419,16 @@ export const useDeleteFactory = () => { isPending: loading, mutateAsync, } = useMutation({ - mutationKey: ['deleteFactory'], + mutationKey: [LLMApiAction.DeleteFactory], mutationFn: async (params: IDeleteLlmRequestBody) => { const { data } = await userService.deleteFactory(params); if (data.code === 0) { - queryClient.invalidateQueries({ queryKey: ['myLlmList'] }); - queryClient.invalidateQueries({ queryKey: ['myLlmListDetailed'] }); - queryClient.invalidateQueries({ queryKey: ['factoryList'] }); - queryClient.invalidateQueries({ queryKey: ['llmList'] }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.MyLlmList] }); + queryClient.invalidateQueries({ + queryKey: [LLMApiAction.MyLlmListDetailed], + }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.FactoryList] }); + queryClient.invalidateQueries({ queryKey: [LLMApiAction.LlmList] }); message.success(t('message.deleted')); } return data.code; From 15ef6dd72f99d26549500fb505bbc842effe4d06 Mon Sep 17 00:00:00 2001 From: Giles Lloyd <2910658+gileslloyd@users.noreply.github.com> Date: Fri, 5 Dec 2025 11:13:17 +0000 Subject: [PATCH 02/12] fix(mcp-server): Ensure all document meta-data is cached (#11767) ### What problem does this PR solve? The document metadata cache is built using the list documents endpoint with default pagination parameters of page=1, page_size=3. This means when using the MCP server to search a dataset, only chunks which come from the first 30 documents in the dataset will have metadata returned. Issue described in more detail here https://github.com/infiniflow/ragflow/issues/11533 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: Giles Lloyd --- mcp/server/server.py | 73 +++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/mcp/server/server.py b/mcp/server/server.py index 8d0d12c25..8350b184b 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -57,7 +57,6 @@ JSON_RESPONSE = True class RAGFlowConnector: _MAX_DATASET_CACHE = 32 - _MAX_DOCUMENT_CACHE = 128 _CACHE_TTL = 300 _dataset_metadata_cache: OrderedDict[str, tuple[dict, float | int]] = OrderedDict() # "dataset_id" -> (metadata, expiry_ts) @@ -116,8 +115,6 @@ class RAGFlowConnector: def _set_cached_document_metadata_by_dataset(self, dataset_id, doc_id_meta_list): self._document_metadata_cache[dataset_id] = (doc_id_meta_list, self._get_expiry_timestamp()) self._document_metadata_cache.move_to_end(dataset_id) - if len(self._document_metadata_cache) > self._MAX_DOCUMENT_CACHE: - self._document_metadata_cache.popitem(last=False) def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) @@ -240,46 +237,46 @@ class RAGFlowConnector: docs = None if force_refresh else self._get_cached_document_metadata_by_dataset(dataset_id) if docs is None: - docs_res = self._get(f"/datasets/{dataset_id}/documents") - docs_data = docs_res.json() - if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"): - doc_id_meta_list = [] - docs = {} - for doc in docs_data["data"]["docs"]: - doc_id = doc.get("id") - if not doc_id: - continue - doc_meta = { - "document_id": doc_id, - "name": doc.get("name", ""), - "location": doc.get("location", ""), - "type": doc.get("type", ""), - "size": doc.get("size"), - "chunk_count": doc.get("chunk_count"), - # "chunk_method": doc.get("chunk_method", ""), - "create_date": doc.get("create_date", ""), - "update_date": doc.get("update_date", ""), - # "process_begin_at": doc.get("process_begin_at", ""), - # "process_duration": doc.get("process_duration"), - # "progress": doc.get("progress"), - # "progress_msg": doc.get("progress_msg", ""), - # "status": doc.get("status", ""), - # "run": doc.get("run", ""), - "token_count": doc.get("token_count"), - # "source_type": doc.get("source_type", ""), - "thumbnail": doc.get("thumbnail", ""), - "dataset_id": doc.get("dataset_id", dataset_id), - "meta_fields": doc.get("meta_fields", {}), - # "parser_config": doc.get("parser_config", {}) - } - doc_id_meta_list.append((doc_id, doc_meta)) - docs[doc_id] = doc_meta + page = 1 + page_size = 30 + doc_id_meta_list = [] + docs = {} + while page: + docs_res = self._get(f"/datasets/{dataset_id}/documents?page={page}") + docs_data = docs_res.json() + if docs_data.get("code") == 0 and docs_data.get("data", {}).get("docs"): + for doc in docs_data["data"]["docs"]: + doc_id = doc.get("id") + if not doc_id: + continue + doc_meta = { + "document_id": doc_id, + "name": doc.get("name", ""), + "location": doc.get("location", ""), + "type": doc.get("type", ""), + "size": doc.get("size"), + "chunk_count": doc.get("chunk_count"), + "create_date": doc.get("create_date", ""), + "update_date": doc.get("update_date", ""), + "token_count": doc.get("token_count"), + "thumbnail": doc.get("thumbnail", ""), + "dataset_id": doc.get("dataset_id", dataset_id), + "meta_fields": doc.get("meta_fields", {}), + } + doc_id_meta_list.append((doc_id, doc_meta)) + docs[doc_id] = doc_meta + + page += 1 + if docs_data.get("data", {}).get("total", 0) - page * page_size <= 0: + page = None + self._set_cached_document_metadata_by_dataset(dataset_id, doc_id_meta_list) if docs: document_cache.update(docs) - except Exception: + except Exception as e: # Gracefully handle metadata cache failures + logging.error(f"Problem building the document metadata cache: {str(e)}") pass return document_cache, dataset_cache From 7719fd6350b66ff6cee81a287b2981a81ac45c9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=91=E5=8D=BF?= <121151546+shaoqing404@users.noreply.github.com> Date: Fri, 5 Dec 2025 19:25:45 +0800 Subject: [PATCH 03/12] Fix MinerU API sanitized-output lookup and manual chunk tuple handling (#11702) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? This PR addresses **two independent issues** encountered when using the MinerU engine in Ragflow: 1. **MinerU API output path mismatch for non-ASCII filenames** MinerU sanitizes the root directory name inside the returned ZIP when the original filename contains non-ASCII characters (e.g., Chinese). Ragflow's client-side unzip logic assumed the original filename stem and therefore failed to locate `_content_list.json`. This PR adds: * root-directory detection * fallback lookup using sanitized names * a broadened `_read_output` search with a glob fallback ensuring output files are consistently located regardless of filename encoding. 2. **Chunker crash due to tuple-structure mismatch in manual mode** Some parsers (e.g., MinerU / Docling) return **2-tuple sections**, but Ragflow’s chunker expects **3-tuple sections**, leading to: `ValueError: not enough values to unpack (expected 3, got 2)` This PR normalizes all sections to a uniform structure `(text, layout, positions)`: * parse position tags when present * default to empty positions when missing preserving backward compatibility and preventing crashes. ### Type of change * [x] Bug Fix (non-breaking change which fixes an issue) [#11136](https://github.com/infiniflow/ragflow/issues/11136) [#11700](https://github.com/infiniflow/ragflow/issues/11700) [#11620](https://github.com/infiniflow/ragflow/issues/11620) [#11701](https://github.com/infiniflow/ragflow/pull/11701) we need your help [yongtenglei](https://github.com/yongtenglei) --------- Co-authored-by: Kevin Hu --- rag/app/manual.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/rag/app/manual.py b/rag/app/manual.py index 1eb86a043..363c6e9e7 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -219,23 +219,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, ) def _normalize_section(section): - # pad section to length 3: (txt, sec_id, poss) - if len(section) == 1: + # Pad/normalize to (txt, layout, positions) + if not isinstance(section, (list, tuple)): + section = (section, "", []) + elif len(section) == 1: section = (section[0], "", []) elif len(section) == 2: section = (section[0], "", section[1]) - elif len(section) != 3: - raise ValueError(f"Unexpected section length: {len(section)} (value={section!r})") + else: + section = (section[0], section[1], section[2]) txt, layoutno, poss = section if isinstance(poss, str): poss = pdf_parser.extract_positions(poss) - first = poss[0] # tuple: ([pn], x1, x2, y1, y2) - pn = first[0] - - if isinstance(pn, list): - pn = pn[0] # [pn] -> pn + if poss: + first = poss[0] # tuple: ([pn], x1, x2, y1, y2) + pn = first[0] + if isinstance(pn, list) and pn: + pn = pn[0] # [pn] -> pn poss[0] = (pn, *first[1:]) + if not poss: + poss = [] return (txt, layoutno, poss) From e4e0a880535eee0792024cbf17c8a40aff9aa194 Mon Sep 17 00:00:00 2001 From: TeslaZY Date: Fri, 5 Dec 2025 19:27:36 +0800 Subject: [PATCH 04/12] Feat: Fillup component return value not object (#11780) ### What problem does this PR solve? Fillup component return value not object ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- agent/component/fillup.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/agent/component/fillup.py b/agent/component/fillup.py index 7428912d4..10163d10c 100644 --- a/agent/component/fillup.py +++ b/agent/component/fillup.py @@ -18,6 +18,7 @@ import re from functools import partial from agent.component.base import ComponentParamBase, ComponentBase +from api.db.services.file_service import FileService class UserFillUpParam(ComponentParamBase): @@ -63,6 +64,13 @@ class UserFillUp(ComponentBase): for k, v in kwargs.get("inputs", {}).items(): if self.check_if_canceled("UserFillUp processing"): return + if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0: + if v.get("optional") and v.get("value", None) is None: + v = None + else: + v = FileService.get_files([v["value"]]) + else: + v = v.get("value") self.set_output(k, v) def thoughts(self) -> str: From 8de6b9780681617ad83333ba01d78b46faf8d22f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E6=B5=B7=E8=92=BC=E7=81=86?= Date: Fri, 5 Dec 2025 19:42:35 +0800 Subject: [PATCH 05/12] Feature (canvas): Add Api for download "message" component output's file (#11772) ### What problem does this PR solve? -Add Api for download "message" component output's file -Change the attachment output type check from tuple to dictionary,because 'attachement' is not instance of tuple -Update the message type to message_end to avoid the problem that content does not send an error message when the message type is ans ["data"] ["content"] ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --- agent/canvas.py | 10 ++++++---- api/apps/sdk/files.py | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index b693ed434..cc40fd174 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -534,10 +534,12 @@ class Canvas(Graph): yield decorate("message", {"content": cpn_obj.output("content")}) cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content")) - if isinstance(cpn_obj.output("attachment"), tuple): - yield decorate("message", {"attachment": cpn_obj.output("attachment")}) - - yield decorate("message_end", {"reference": self.get_reference() if cite else None}) + message_end = {} + if isinstance(cpn_obj.output("attachment"), dict): + message_end["attachment"] = cpn_obj.output("attachment") + if cite: + message_end["reference"] = self.get_reference() + yield decorate("message_end", message_end) while partials: _cpn_obj = self.get_component_obj(partials[0]) diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 981d3975e..2e9fd6df3 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -14,7 +14,7 @@ # limitations under the License. # - +import asyncio import pathlib import re from quart import request, make_response @@ -29,6 +29,7 @@ from api.db import FileType from api.db.services import duplicate_name from api.db.services.file_service import FileService from api.utils.file_utils import filename_type +from api.utils.web_utils import CONTENT_TYPE_MAP from common import settings from common.constants import RetCode @@ -629,6 +630,19 @@ async def get(tenant_id, file_id): except Exception as e: return server_error_response(e) +@manager.route("/file/download/", methods=["GET"]) # noqa: F821 +@token_required +async def download_attachment(tenant_id,attachment_id): + try: + ext = request.args.get("ext", "markdown") + data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id) + response = await make_response(data) + response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) + + return response + + except Exception as e: + return server_error_response(e) @manager.route('/file/mv', methods=['POST']) # noqa: F821 @token_required From 6546f86b4e1e9c05f0bac9026b2aeb7de0e7a482 Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Mon, 8 Dec 2025 09:42:10 +0800 Subject: [PATCH 06/12] Fix errors (#11795) ### What problem does this PR solve? - typos - IDE warnings ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai --- Dockerfile | 4 ++-- common/data_source/discord_connector.py | 4 ++-- deepdoc/parser/html_parser.py | 8 ++++---- deepdoc/vision/ocr.py | 4 ++-- deepdoc/vision/postprocess.py | 18 +++++++++--------- docker/docker-compose.yml | 4 ++-- docs/faq.mdx | 2 +- .../accelerate_agent_question_answering.md | 4 ++-- .../accelerate_question_answering.mdx | 2 +- docs/guides/manage_users_and_services.md | 2 +- docs/guides/models/deploy_local_llm.mdx | 2 +- example/sdk/dataset_example.py | 4 ++-- sandbox/executor_manager/core/container.py | 12 ++++++------ 13 files changed, 35 insertions(+), 35 deletions(-) diff --git a/Dockerfile b/Dockerfile index d16834125..95a9d54b7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -78,12 +78,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ # A modern version of cargo is needed for the latest version of the Rust compiler. RUN apt update && apt install -y curl build-essential \ && if [ "$NEED_MIRROR" == "1" ]; then \ - # Use TUNA mirrors for rustup/rust dist files + # Use TUNA mirrors for rustup/rust dist files \ export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \ export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \ echo "Using TUNA mirrors for Rustup."; \ fi; \ - # Force curl to use HTTP/1.1 + # Force curl to use HTTP/1.1 \ curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \ && echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index 46b23443c..4c19a6d5e 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -33,7 +33,7 @@ def _convert_message_to_document( metadata: dict[str, str | list[str]] = {} semantic_substring = "" - # Only messages from TextChannels will make it here but we have to check for it anyways + # Only messages from TextChannels will make it here, but we have to check for it anyway if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name): metadata["Channel"] = channel_name semantic_substring += f" in Channel: #{channel_name}" @@ -176,7 +176,7 @@ def _manage_async_retrieval( # parse requested_start_date_string to datetime pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None - # Set start_time to the later of start and pull_date, or whichever is provided + # Set start_time to the most recent of start and pull_date, or whichever is provided start_time = max(filter(None, [start, pull_date])) if start or pull_date else None end_time: datetime | None = end diff --git a/deepdoc/parser/html_parser.py b/deepdoc/parser/html_parser.py index 7e4467c16..dcf33a8bb 100644 --- a/deepdoc/parser/html_parser.py +++ b/deepdoc/parser/html_parser.py @@ -151,7 +151,7 @@ class RAGFlowHtmlParser: block_content = [] current_content = "" table_info_list = [] - lask_block_id = None + last_block_id = None for item in parser_result: content = item.get("content") tag_name = item.get("tag_name") @@ -160,11 +160,11 @@ class RAGFlowHtmlParser: if block_id: if title_flag: content = f"{TITLE_TAGS[tag_name]} {content}" - if lask_block_id != block_id: - if lask_block_id is not None: + if last_block_id != block_id: + if last_block_id is not None: block_content.append(current_content) current_content = content - lask_block_id = block_id + last_block_id = block_id else: current_content += (" " if current_content else "") + content else: diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index 207fb0e84..afa692127 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -582,7 +582,7 @@ class OCR: self.crop_image_res_index = 0 def get_rotate_crop_image(self, img, points): - ''' + """ img_height, img_width = img.shape[0:2] left = int(np.min(points[:, 0])) right = int(np.max(points[:, 0])) @@ -591,7 +591,7 @@ class OCR: img_crop = img[top:bottom, left:right, :].copy() points[:, 0] = points[:, 0] - left points[:, 1] = points[:, 1] - top - ''' + """ assert len(points) == 4, "shape of points must be 4*2" img_crop_width = int( max( diff --git a/deepdoc/vision/postprocess.py b/deepdoc/vision/postprocess.py index a61464382..7704bc582 100644 --- a/deepdoc/vision/postprocess.py +++ b/deepdoc/vision/postprocess.py @@ -67,10 +67,10 @@ class DBPostProcess: [[1, 1], [1, 1]]) def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): - ''' + """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} - ''' + """ bitmap = _bitmap height, width = bitmap.shape @@ -114,10 +114,10 @@ class DBPostProcess: return boxes, scores def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): - ''' + """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} - ''' + """ bitmap = _bitmap height, width = bitmap.shape @@ -192,9 +192,9 @@ class DBPostProcess: return box, min(bounding_box[1]) def box_score_fast(self, bitmap, _box): - ''' + """ box_score_fast: use bbox mean score as the mean score - ''' + """ h, w = bitmap.shape[:2] box = _box.copy() xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) @@ -209,9 +209,9 @@ class DBPostProcess: return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] def box_score_slow(self, bitmap, contour): - ''' - box_score_slow: use polyon mean score as the mean score - ''' + """ + box_score_slow: use polygon mean score as the mean score + """ h, w = bitmap.shape[:2] contour = contour.copy() contour = np.reshape(contour, (-1, 2)) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index adb337511..b851687a5 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -25,7 +25,7 @@ services: # - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint) # - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP) - # Example configration to start Admin server: + # Example configuration to start Admin server: # command: # - --enable-adminserver ports: @@ -74,7 +74,7 @@ services: # - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint) # - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP) - # Example configration to start Admin server: + # Example configuration to start Admin server: # command: # - --enable-adminserver ports: diff --git a/docs/faq.mdx b/docs/faq.mdx index 55997e1c3..10c6bc57c 100644 --- a/docs/faq.mdx +++ b/docs/faq.mdx @@ -151,7 +151,7 @@ See [Build a RAGFlow Docker image](./develop/build_docker_image.mdx). ### Cannot access https://huggingface.co -A locally deployed RAGflow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails: +A locally deployed RAGFlow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails: ``` FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res' diff --git a/docs/guides/agent/best_practices/accelerate_agent_question_answering.md b/docs/guides/agent/best_practices/accelerate_agent_question_answering.md index 76de06068..1161588bd 100644 --- a/docs/guides/agent/best_practices/accelerate_agent_question_answering.md +++ b/docs/guides/agent/best_practices/accelerate_agent_question_answering.md @@ -45,13 +45,13 @@ Click the light bulb icon above the *current* dialogue and scroll down the popup | Item name | Description | -| ----------------- | --------------------------------------------------------------------------------------------- | +| ----------------- |-----------------------------------------------------------------------------------------------| | Total | Total time spent on this conversation round, including chunk retrieval and answer generation. | | Check LLM | Time to validate the specified LLM. | | Create retriever | Time to create a chunk retriever. | | Bind embedding | Time to initialize an embedding model instance. | | Bind LLM | Time to initialize an LLM instance. | -| Tune question | Time to optimize the user query using the context of the mult-turn conversation. | +| Tune question | Time to optimize the user query using the context of the multi-turn conversation. | | Bind reranker | Time to initialize an reranker model instance for chunk retrieval. | | Generate keywords | Time to extract keywords from the user query. | | Retrieval | Time to retrieve the chunks. | diff --git a/docs/guides/chat/best_practices/accelerate_question_answering.mdx b/docs/guides/chat/best_practices/accelerate_question_answering.mdx index e404c1c2a..af4d2521b 100644 --- a/docs/guides/chat/best_practices/accelerate_question_answering.mdx +++ b/docs/guides/chat/best_practices/accelerate_question_answering.mdx @@ -37,7 +37,7 @@ Please note that rerank models are essential in certain scenarios. There is alwa | Create retriever | Time to create a chunk retriever. | | Bind embedding | Time to initialize an embedding model instance. | | Bind LLM | Time to initialize an LLM instance. | -| Tune question | Time to optimize the user query using the context of the mult-turn conversation. | +| Tune question | Time to optimize the user query using the context of the multi-turn conversation. | | Bind reranker | Time to initialize an reranker model instance for chunk retrieval. | | Generate keywords | Time to extract keywords from the user query. | | Retrieval | Time to retrieve the chunks. | diff --git a/docs/guides/manage_users_and_services.md b/docs/guides/manage_users_and_services.md index 94b933ec2..6c06c40f8 100644 --- a/docs/guides/manage_users_and_services.md +++ b/docs/guides/manage_users_and_services.md @@ -8,7 +8,7 @@ slug: /manage_users_and_services -The Admin CLI and Admin Service form a client-server architectural suite for RAGflow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGflow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled. +The Admin CLI and Admin Service form a client-server architectural suite for RAGFlow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGFlow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled. diff --git a/docs/guides/models/deploy_local_llm.mdx b/docs/guides/models/deploy_local_llm.mdx index dfee3fc78..997e526f3 100644 --- a/docs/guides/models/deploy_local_llm.mdx +++ b/docs/guides/models/deploy_local_llm.mdx @@ -305,7 +305,7 @@ With the Ollama service running, open a new terminal and run `./ollama pull -### 4. Configure RAGflow +### 4. Configure RAGFlow To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section: diff --git a/example/sdk/dataset_example.py b/example/sdk/dataset_example.py index 3a0504d8d..a3931f143 100644 --- a/example/sdk/dataset_example.py +++ b/example/sdk/dataset_example.py @@ -14,9 +14,9 @@ # limitations under the License. # -''' +""" The example is about CRUD operations (Create, Read, Update, Delete) on a dataset. -''' +""" from ragflow_sdk import RAGFlow import sys diff --git a/sandbox/executor_manager/core/container.py b/sandbox/executor_manager/core/container.py index f953886c1..36cdded28 100644 --- a/sandbox/executor_manager/core/container.py +++ b/sandbox/executor_manager/core/container.py @@ -122,15 +122,15 @@ async def create_container(name: str, language: SupportLanguage) -> bool: logger.info(f"Sandbox config:\n\t {create_args}") try: - returncode, _, stderr = await async_run_command(*create_args, timeout=10) - if returncode != 0: + return_code, _, stderr = await async_run_command(*create_args, timeout=10) + if return_code != 0: logger.error(f"❌ Container creation failed {name}: {stderr}") return False if language == SupportLanguage.NODEJS: copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"] - returncode, _, stderr = await async_run_command(*copy_cmd, timeout=10) - if returncode != 0: + return_code, _, stderr = await async_run_command(*copy_cmd, timeout=10) + if return_code != 0: logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}") return False @@ -185,7 +185,7 @@ async def allocate_container_blocking(language: SupportLanguage, timeout=10) -> async def container_is_running(name: str) -> bool: """Asynchronously check the container status""" try: - returncode, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2) - return returncode == 0 and stdout.strip() == "true" + return_code, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2) + return return_code == 0 and stdout.strip() == "true" except Exception: return False From 9b8971a9de11361e488253da75eea707e4598b4c Mon Sep 17 00:00:00 2001 From: buua436 Date: Mon, 8 Dec 2025 09:42:20 +0800 Subject: [PATCH 07/12] Fix:toc in pipeline (#11785) ### What problem does this PR solve? change: Fix toc in pipeline ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/flow/extractor/extractor.py | 19 +++++++++++-------- rag/flow/splitter/splitter.py | 2 +- rag/svr/task_executor.py | 3 ++- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/rag/flow/extractor/extractor.py b/rag/flow/extractor/extractor.py index 45698b204..1b97fd1ee 100644 --- a/rag/flow/extractor/extractor.py +++ b/rag/flow/extractor/extractor.py @@ -15,9 +15,8 @@ import json import logging import random -from copy import deepcopy, copy +from copy import deepcopy -import trio import xxhash from agent.component.llm import LLMParam, LLM @@ -38,13 +37,13 @@ class ExtractorParam(ProcessParamBase, LLMParam): class Extractor(ProcessBase, LLM): component_name = "Extractor" - def _build_TOC(self, docs): - self.callback(message="Start to generate table of content ...") + async def _build_TOC(self, docs): + self.callback(0.2,message="Start to generate table of content ...") docs = sorted(docs, key=lambda d:( d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0), d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0) )) - toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl) + toc = await run_toc_from_text([d["text"] for d in docs], self.chat_mdl) logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' ')) ii = 0 while ii < len(toc): @@ -61,7 +60,8 @@ class Extractor(ProcessBase, LLM): ii += 1 if toc: - d = copy.deepcopy(docs[-1]) + d = deepcopy(docs[-1]) + d["doc_id"] = self._canvas._doc_id d["content_with_weight"] = json.dumps(toc, ensure_ascii=False) d["toc_kwd"] = "toc" d["available_int"] = 0 @@ -85,11 +85,14 @@ class Extractor(ProcessBase, LLM): if chunks: if self._param.field_name == "toc": - toc = self._build_TOC(chunks) + for ck in chunks: + ck["doc_id"] = self._canvas._doc_id + ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest() + toc =await self._build_TOC(chunks) chunks.append(toc) self.set_output("chunks", chunks) return - + prog = 0 for i, ck in enumerate(chunks): args[chunks_key] = ck["text"] diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index c790790cb..1ef06839d 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -125,7 +125,7 @@ class Splitter(ProcessBase): { "text": RAGFlowPdfParser.remove_tag(c), "image": img, - "positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)] + "positions": [[pos[0][-1], *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)] } for c, img in zip(chunks, images) if c.strip() ] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 8cf1bf290..b08aa7524 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -592,7 +592,8 @@ async def run_dataflow(task: dict): ck["docnm_kwd"] = task["name"] ck["create_time"] = str(datetime.now()).replace("T", " ")[:19] ck["create_timestamp_flt"] = datetime.now().timestamp() - ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest() + if not ck.get("id"): + ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest() if "questions" in ck: if "question_tks" not in ck: ck["question_kwd"] = ck["questions"].split("\n") From 51ec708c58d2a5edaf20bf8ed35e09cfcbe291c8 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Mon, 8 Dec 2025 09:43:03 +0800 Subject: [PATCH 08/12] Refa: cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats (#11779) ### What problem does this PR solve? Cleanup synchronous functions in chat_model and implement synchronization for conversation and dialog chats. ### Type of change - [x] Refactoring - [x] Performance Improvement --- api/apps/conversation_app.py | 12 +- api/apps/langfuse_app.py | 19 +- api/apps/llm_app.py | 4 +- api/apps/sdk/session.py | 25 +- api/db/services/conversation_service.py | 18 +- api/db/services/dialog_service.py | 33 +- api/db/services/evaluation_service.py | 233 ++++---- api/db/services/llm_service.py | 208 ++++--- rag/llm/__init__.py | 5 + rag/llm/chat_model.py | 707 ++++-------------------- 10 files changed, 421 insertions(+), 843 deletions(-) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 89630e4a4..337cb74df 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -23,7 +23,7 @@ from quart import Response, request from api.apps import current_user, login_required from api.db.db_models import APIToken from api.db.services.conversation_service import ConversationService, structure_answer -from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap +from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.tenant_llm_service import TenantLLMService @@ -218,10 +218,10 @@ async def completion(): dia.llm_setting = chat_model_config is_embedded = bool(chat_model_id) - def stream(): + async def stream(): nonlocal dia, msg, req, conv try: - for ans in chat(dia, msg, True, **req): + async for ans in async_chat(dia, msg, True, **req): ans = structure_answer(conv, ans, message_id, conv.id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" if not is_embedded: @@ -241,7 +241,7 @@ async def completion(): else: answer = None - for ans in chat(dia, msg, **req): + async for ans in async_chat(dia, msg, **req): answer = structure_answer(conv, ans, message_id, conv.id) if not is_embedded: ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -406,10 +406,10 @@ async def ask_about(): if search_app: search_config = search_app.get("search_config", {}) - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): + async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" diff --git a/api/apps/langfuse_app.py b/api/apps/langfuse_app.py index 8a05c0d4c..1d7993d36 100644 --- a/api/apps/langfuse_app.py +++ b/api/apps/langfuse_app.py @@ -34,8 +34,9 @@ async def set_api_key(): if not all([secret_key, public_key, host]): return get_error_data_result(message="Missing required fields") + current_user_id = current_user.id langfuse_keys = dict( - tenant_id=current_user.id, + tenant_id=current_user_id, secret_key=secret_key, public_key=public_key, host=host, @@ -45,23 +46,24 @@ async def set_api_key(): if not langfuse.auth_check(): return get_error_data_result(message="Invalid Langfuse keys") - langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) + langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id) with DB.atomic(): try: if not langfuse_entry: TenantLangfuseService.save(**langfuse_keys) else: - TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys) + TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys) return get_json_result(data=langfuse_keys) except Exception as e: - server_error_response(e) + return server_error_response(e) @manager.route("/api_key", methods=["GET"]) # noqa: F821 @login_required @validate_request() def get_api_key(): - langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id) + current_user_id = current_user.id + langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id) if not langfuse_entry: return get_json_result(message="Have not record any Langfuse keys.") @@ -72,7 +74,7 @@ def get_api_key(): except langfuse.api.core.api_error.ApiError as api_err: return get_json_result(message=f"Error from Langfuse: {api_err}") except Exception as e: - server_error_response(e) + return server_error_response(e) langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"] langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"] @@ -84,7 +86,8 @@ def get_api_key(): @login_required @validate_request() def delete_api_key(): - langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) + current_user_id = current_user.id + langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id) if not langfuse_entry: return get_json_result(message="Have not record any Langfuse keys.") @@ -93,4 +96,4 @@ def delete_api_key(): TenantLangfuseService.delete_model(langfuse_entry) return get_json_result(data=True) except Exception as e: - server_error_response(e) + return server_error_response(e) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 018fb4bca..d24a4bb44 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -74,7 +74,7 @@ async def set_api_key(): assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) + m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) if m.find("**ERROR**") >= 0: raise Exception(m) chat_passed = True @@ -217,7 +217,7 @@ async def add_llm(): **extra, ) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) + m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) if not tc and m.find("**ERROR**:") >= 0: raise Exception(m) except Exception as e: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index e94f14fcc..fe4723984 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -26,9 +26,10 @@ from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService, completion_openai from api.db.services.canvas_service import completion as agent_completion -from api.db.services.conversation_service import ConversationService, iframe_completion -from api.db.services.conversation_service import completion as rag_completion -from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter +from api.db.services.conversation_service import ConversationService +from api.db.services.conversation_service import async_iframe_completion as iframe_completion +from api.db.services.conversation_service import async_completion as rag_completion +from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle @@ -141,7 +142,7 @@ async def chat_completion(tenant_id, chat_id): return resp else: answer = None - for ans in rag_completion(tenant_id, chat_id, **req): + async for ans in rag_completion(tenant_id, chat_id, **req): answer = ans break return get_result(data=answer) @@ -245,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): # The value for the usage field on all chunks except for the last one will be null. # The usage field on the last chunk contains token usage statistics for the entire request. # The choices field on the last chunk will always be an empty array []. - def streamed_response_generator(chat_id, dia, msg): + async def streamed_response_generator(chat_id, dia, msg): token_used = 0 answer_cache = "" reasoning_cache = "" @@ -274,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): } try: - for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference): + async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference): last_ans = ans answer = ans["answer"] @@ -342,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id): return resp else: answer = None - for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference): + async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference): # focus answer content only answer = ans break @@ -733,10 +734,10 @@ async def ask_about(tenant_id): return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") uid = tenant_id - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid): + async for ans in async_ask(req["question"], req["kb_ids"], uid): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( @@ -827,7 +828,7 @@ async def chatbot_completions(dialog_id): resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - for answer in iframe_completion(dialog_id, **req): + async for answer in iframe_completion(dialog_id, **req): return get_result(data=answer) @@ -918,10 +919,10 @@ async def ask_about_embedded(): if search_app := SearchService.get_detail(search_id): search_config = search_app.get("search_config", {}) - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): + async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 60f8e55b1..aaec72bf5 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -19,7 +19,7 @@ from common.constants import StatusEnum from api.db.db_models import Conversation, DB from api.db.services.api_service import API4ConversationService from api.db.services.common_service import CommonService -from api.db.services.dialog_service import DialogService, chat +from api.db.services.dialog_service import DialogService, async_chat from common.misc_utils import get_uuid import json @@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id): conv.reference[-1] = reference return ans - -def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): +async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): assert name, "`name` can not be empty." dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) assert dia, "You do not own the chat." @@ -112,7 +111,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None "reference": {}, "audio_binary": None, "id": None, - "session_id": session_id + "session_id": session_id }}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" @@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None if stream: try: - for ans in chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None else: answer = None - for ans in chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) ConversationService.update_by_id(conv.id, conv.to_dict()) break yield answer - -def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): +async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): e, dia = DialogService.get_by_id(dialog_id) assert e, "Dialog not found" if not session_id: @@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg if stream: try: - for ans in chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" @@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg else: answer = None - for ans in chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) API4ConversationService.append_message(conv.id, conv.to_dict()) break diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 4afdd1f3c..43e345cd2 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -178,7 +178,8 @@ class DialogService(CommonService): offset += limit return res -def chat_solo(dialog, messages, stream=True): + +async def async_chat_solo(dialog, messages, stream=True): attachments = "" if "files" in messages[-1]: attachments = "\n\n".join(FileService.get_files(messages[-1]["files"])) @@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True): if stream: last_ans = "" delta_ans = "" - for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): + answer = "" + async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): answer = ans delta_ans = ans[len(last_ans):] if num_tokens_from_string(delta_ans) < 16: @@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True): if delta_ans: yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()} else: - answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting) + answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} @@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): return [] return list(doc_ids) - -def chat(dialog, messages, stream=True, **kwargs): +async def async_chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): - for ans in chat_solo(dialog, messages, stream): + async for ans in async_chat_solo(dialog, messages, stream): yield ans - return None + return chat_start_ts = timer() @@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs): ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) if ans: yield ans - return None + return for p in prompt_config["parameters"]: if p["key"] == "knowledge": @@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs): empty_res = prompt_config["empty_response"] yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)} - return {"answer": prompt_config["empty_response"], "reference": kbinfos} + yield {"answer": prompt_config["empty_response"], "reference": kbinfos} + return kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting @@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs): if stream: last_ans = "" answer = "" - for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): + async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): if thought: ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) answer = ans @@ -626,19 +628,19 @@ def chat(dialog, messages, stream=True, **kwargs): yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield decorate_answer(thought + answer) else: - answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf) + answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) res = decorate_answer(answer) res["audio_binary"] = tts(tts_mdl, answer) yield res - return None + return def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): sys_prompt = """ -You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. +You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. Ensure that: 1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it. 2. Write only the SQL, no explanations or additional text. @@ -805,8 +807,7 @@ def tts(tts_mdl, text): return None return binascii.hexlify(bin).decode("utf-8") - -def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): +async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None kb_ids = search_config.get("kb_ids", kb_ids) @@ -880,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): return {"answer": answer, "reference": refs} answer = "" - for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}): + async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}): answer = ans yield {"answer": answer, "reference": {}} yield decorate_answer(answer) diff --git a/api/db/services/evaluation_service.py b/api/db/services/evaluation_service.py index 81b4c44fe..c5a24176d 100644 --- a/api/db/services/evaluation_service.py +++ b/api/db/services/evaluation_service.py @@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including: - Configuration recommendations """ +import asyncio import logging +import queue +import threading from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from timeit import default_timer as timer from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult from api.db.services.common_service import CommonService -from api.db.services.dialog_service import DialogService, chat +from api.db.services.dialog_service import DialogService from common.misc_utils import get_uuid from common.time_utils import current_timestamp from common.constants import StatusEnum @@ -40,24 +43,24 @@ from common.constants import StatusEnum class EvaluationService(CommonService): """Service for managing RAG evaluations""" - + model = EvaluationDataset - + # ==================== Dataset Management ==================== - + @classmethod - def create_dataset(cls, name: str, description: str, kb_ids: List[str], + def create_dataset(cls, name: str, description: str, kb_ids: List[str], tenant_id: str, user_id: str) -> Tuple[bool, str]: """ Create a new evaluation dataset. - + Args: name: Dataset name description: Dataset description kb_ids: List of knowledge base IDs to evaluate against tenant_id: Tenant ID user_id: User ID who creates the dataset - + Returns: (success, dataset_id or error_message) """ @@ -74,15 +77,15 @@ class EvaluationService(CommonService): "update_time": current_timestamp(), "status": StatusEnum.VALID.value } - + if not EvaluationDataset.create(**dataset): return False, "Failed to create dataset" - + return True, dataset_id except Exception as e: logging.error(f"Error creating evaluation dataset: {e}") return False, str(e) - + @classmethod def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]: """Get dataset by ID""" @@ -94,9 +97,9 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error getting dataset {dataset_id}: {e}") return None - + @classmethod - def list_datasets(cls, tenant_id: str, user_id: str, + def list_datasets(cls, tenant_id: str, user_id: str, page: int = 1, page_size: int = 20) -> Dict[str, Any]: """List datasets for a tenant""" try: @@ -104,10 +107,10 @@ class EvaluationService(CommonService): (EvaluationDataset.tenant_id == tenant_id) & (EvaluationDataset.status == StatusEnum.VALID.value) ).order_by(EvaluationDataset.create_time.desc()) - + total = query.count() datasets = query.paginate(page, page_size) - + return { "total": total, "datasets": [d.to_dict() for d in datasets] @@ -115,7 +118,7 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error listing datasets: {e}") return {"total": 0, "datasets": []} - + @classmethod def update_dataset(cls, dataset_id: str, **kwargs) -> bool: """Update dataset""" @@ -127,7 +130,7 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error updating dataset {dataset_id}: {e}") return False - + @classmethod def delete_dataset(cls, dataset_id: str) -> bool: """Soft delete dataset""" @@ -139,18 +142,18 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error deleting dataset {dataset_id}: {e}") return False - + # ==================== Test Case Management ==================== - + @classmethod - def add_test_case(cls, dataset_id: str, question: str, + def add_test_case(cls, dataset_id: str, question: str, reference_answer: Optional[str] = None, relevant_doc_ids: Optional[List[str]] = None, relevant_chunk_ids: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]: """ Add a test case to a dataset. - + Args: dataset_id: Dataset ID question: Test question @@ -158,7 +161,7 @@ class EvaluationService(CommonService): relevant_doc_ids: Optional list of relevant document IDs relevant_chunk_ids: Optional list of relevant chunk IDs metadata: Optional additional metadata - + Returns: (success, case_id or error_message) """ @@ -174,15 +177,15 @@ class EvaluationService(CommonService): "metadata": metadata, "create_time": current_timestamp() } - + if not EvaluationCase.create(**case): return False, "Failed to create test case" - + return True, case_id except Exception as e: logging.error(f"Error adding test case: {e}") return False, str(e) - + @classmethod def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]: """Get all test cases for a dataset""" @@ -190,12 +193,12 @@ class EvaluationService(CommonService): cases = EvaluationCase.select().where( EvaluationCase.dataset_id == dataset_id ).order_by(EvaluationCase.create_time) - + return [c.to_dict() for c in cases] except Exception as e: logging.error(f"Error getting test cases for dataset {dataset_id}: {e}") return [] - + @classmethod def delete_test_case(cls, case_id: str) -> bool: """Delete a test case""" @@ -206,22 +209,22 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error deleting test case {case_id}: {e}") return False - + @classmethod def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]: """ Bulk import test cases from a list. - + Args: dataset_id: Dataset ID cases: List of test case dictionaries - + Returns: (success_count, failure_count) """ success_count = 0 failure_count = 0 - + for case_data in cases: success, _ = cls.add_test_case( dataset_id=dataset_id, @@ -231,28 +234,28 @@ class EvaluationService(CommonService): relevant_chunk_ids=case_data.get("relevant_chunk_ids"), metadata=case_data.get("metadata") ) - + if success: success_count += 1 else: failure_count += 1 - + return success_count, failure_count - + # ==================== Evaluation Execution ==================== - + @classmethod - def start_evaluation(cls, dataset_id: str, dialog_id: str, + def start_evaluation(cls, dataset_id: str, dialog_id: str, user_id: str, name: Optional[str] = None) -> Tuple[bool, str]: """ Start an evaluation run. - + Args: dataset_id: Dataset ID dialog_id: Dialog configuration to evaluate user_id: User ID who starts the run name: Optional run name - + Returns: (success, run_id or error_message) """ @@ -261,12 +264,12 @@ class EvaluationService(CommonService): success, dialog = DialogService.get_by_id(dialog_id) if not success: return False, "Dialog not found" - + # Create evaluation run run_id = get_uuid() if not name: name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - + run = { "id": run_id, "dataset_id": dataset_id, @@ -279,92 +282,128 @@ class EvaluationService(CommonService): "create_time": current_timestamp(), "complete_time": None } - + if not EvaluationRun.create(**run): return False, "Failed to create evaluation run" - + # Execute evaluation asynchronously (in production, use task queue) # For now, we'll execute synchronously cls._execute_evaluation(run_id, dataset_id, dialog) - + return True, run_id except Exception as e: logging.error(f"Error starting evaluation: {e}") return False, str(e) - + @classmethod def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any): """ Execute evaluation for all test cases. - + This method runs the RAG pipeline for each test case and computes metrics. """ try: # Get all test cases test_cases = cls.get_test_cases(dataset_id) - + if not test_cases: EvaluationRun.update( status="FAILED", complete_time=current_timestamp() ).where(EvaluationRun.id == run_id).execute() return - + # Execute each test case results = [] for case in test_cases: result = cls._evaluate_single_case(run_id, case, dialog) if result: results.append(result) - + # Compute summary metrics metrics_summary = cls._compute_summary_metrics(results) - + # Update run status EvaluationRun.update( status="COMPLETED", metrics_summary=metrics_summary, complete_time=current_timestamp() ).where(EvaluationRun.id == run_id).execute() - + except Exception as e: logging.error(f"Error executing evaluation {run_id}: {e}") EvaluationRun.update( status="FAILED", complete_time=current_timestamp() ).where(EvaluationRun.id == run_id).execute() - + @classmethod - def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any], + def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any], dialog: Any) -> Optional[Dict[str, Any]]: """ Evaluate a single test case. - + Args: run_id: Evaluation run ID case: Test case dictionary dialog: Dialog configuration - + Returns: Result dictionary or None if failed """ try: # Prepare messages messages = [{"role": "user", "content": case["question"]}] - + # Execute RAG pipeline start_time = timer() answer = "" retrieved_chunks = [] - + + + def _sync_from_async_gen(async_gen): + result_queue: queue.Queue = queue.Queue() + + def runner(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def consume(): + try: + async for item in async_gen: + result_queue.put(item) + except Exception as e: + result_queue.put(e) + finally: + result_queue.put(StopIteration) + + loop.run_until_complete(consume()) + loop.close() + + threading.Thread(target=runner, daemon=True).start() + + while True: + item = result_queue.get() + if item is StopIteration: + break + if isinstance(item, Exception): + raise item + yield item + + + def chat(dialog, messages, stream=True, **kwargs): + from api.db.services.dialog_service import async_chat + + return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs)) + for ans in chat(dialog, messages, stream=False): if isinstance(ans, dict): answer = ans.get("answer", "") retrieved_chunks = ans.get("reference", {}).get("chunks", []) break - + execution_time = timer() - start_time - + # Compute metrics metrics = cls._compute_metrics( question=case["question"], @@ -374,7 +413,7 @@ class EvaluationService(CommonService): relevant_chunk_ids=case.get("relevant_chunk_ids"), dialog=dialog ) - + # Save result result_id = get_uuid() result = { @@ -388,14 +427,14 @@ class EvaluationService(CommonService): "token_usage": None, # TODO: Track token usage "create_time": current_timestamp() } - + EvaluationResult.create(**result) - + return result except Exception as e: logging.error(f"Error evaluating case {case.get('id')}: {e}") return None - + @classmethod def _compute_metrics(cls, question: str, generated_answer: str, reference_answer: Optional[str], @@ -404,69 +443,69 @@ class EvaluationService(CommonService): dialog: Any) -> Dict[str, float]: """ Compute evaluation metrics for a single test case. - + Returns: Dictionary of metric names to values """ metrics = {} - + # Retrieval metrics (if ground truth chunks provided) if relevant_chunk_ids: retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks] metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids)) - + # Generation metrics if generated_answer: # Basic metrics metrics["answer_length"] = len(generated_answer) metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0 - + # TODO: Implement advanced metrics using LLM-as-judge # - Faithfulness (hallucination detection) # - Answer relevance # - Context relevance # - Semantic similarity (if reference answer provided) - + return metrics - + @classmethod - def _compute_retrieval_metrics(cls, retrieved_ids: List[str], + def _compute_retrieval_metrics(cls, retrieved_ids: List[str], relevant_ids: List[str]) -> Dict[str, float]: """ Compute retrieval metrics. - + Args: retrieved_ids: List of retrieved chunk IDs relevant_ids: List of relevant chunk IDs (ground truth) - + Returns: Dictionary of retrieval metrics """ if not relevant_ids: return {} - + retrieved_set = set(retrieved_ids) relevant_set = set(relevant_ids) - + # Precision: proportion of retrieved that are relevant precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0 - + # Recall: proportion of relevant that were retrieved recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0 - + # F1 score f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 - + # Hit rate: whether any relevant chunk was retrieved hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0 - + # MRR (Mean Reciprocal Rank): position of first relevant chunk mrr = 0.0 for i, chunk_id in enumerate(retrieved_ids, 1): if chunk_id in relevant_set: mrr = 1.0 / i break - + return { "precision": precision, "recall": recall, @@ -474,45 +513,45 @@ class EvaluationService(CommonService): "hit_rate": hit_rate, "mrr": mrr } - + @classmethod def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]: """ Compute summary metrics across all test cases. - + Args: results: List of result dictionaries - + Returns: Summary metrics dictionary """ if not results: return {} - + # Aggregate metrics metric_sums = {} metric_counts = {} - + for result in results: metrics = result.get("metrics", {}) for key, value in metrics.items(): if isinstance(value, (int, float)): metric_sums[key] = metric_sums.get(key, 0) + value metric_counts[key] = metric_counts.get(key, 0) + 1 - + # Compute averages summary = { "total_cases": len(results), "avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results) } - + for key in metric_sums: summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key] - + return summary - + # ==================== Results & Analysis ==================== - + @classmethod def get_run_results(cls, run_id: str) -> Dict[str, Any]: """Get results for an evaluation run""" @@ -520,11 +559,11 @@ class EvaluationService(CommonService): run = EvaluationRun.get_by_id(run_id) if not run: return {} - + results = EvaluationResult.select().where( EvaluationResult.run_id == run_id ).order_by(EvaluationResult.create_time) - + return { "run": run.to_dict(), "results": [r.to_dict() for r in results] @@ -532,15 +571,15 @@ class EvaluationService(CommonService): except Exception as e: logging.error(f"Error getting run results {run_id}: {e}") return {} - + @classmethod def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]: """ Analyze evaluation results and provide configuration recommendations. - + Args: run_id: Evaluation run ID - + Returns: List of recommendation dictionaries """ @@ -548,10 +587,10 @@ class EvaluationService(CommonService): run = EvaluationRun.get_by_id(run_id) if not run or not run.metrics_summary: return [] - + metrics = run.metrics_summary recommendations = [] - + # Low precision: retrieving irrelevant chunks if metrics.get("avg_precision", 1.0) < 0.7: recommendations.append({ @@ -564,7 +603,7 @@ class EvaluationService(CommonService): "Reduce top_k to return fewer chunks" ] }) - + # Low recall: missing relevant chunks if metrics.get("avg_recall", 1.0) < 0.7: recommendations.append({ @@ -578,7 +617,7 @@ class EvaluationService(CommonService): "Check chunk size - may be too large or too small" ] }) - + # Slow response time if metrics.get("avg_execution_time", 0) > 5.0: recommendations.append({ @@ -591,7 +630,7 @@ class EvaluationService(CommonService): "Consider caching frequently asked questions" ] }) - + return recommendations except Exception as e: logging.error(f"Error generating recommendations for run {run_id}: {e}") diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 6a63713ec..86356a7a7 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -16,15 +16,17 @@ import asyncio import inspect import logging +import queue import re import threading -from common.token_utils import num_tokens_from_string from functools import partial from typing import Generator -from common.constants import LLMType + from api.db.db_models import LLM from api.db.services.common_service import CommonService from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService +from common.constants import LLMType +from common.token_utils import num_tokens_from_string class LLMService(CommonService): @@ -33,6 +35,7 @@ class LLMService(CommonService): def get_init_tenant_llm(user_id): from common import settings + tenant_llm = [] model_configs = { @@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant): generation = self.langfuse.start_generation( trace_context=self.trace_context, name="stream_transcription", - metadata={"model": self.llm_name} + metadata={"model": self.llm_name}, ) final_text = "" used_tokens = 0 @@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant): if self.langfuse: generation.update( output={"output": final_text}, - usage_details={"total_tokens": used_tokens} + usage_details={"total_tokens": used_tokens}, ) generation.end() return if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name}) - full_text, used_tokens = mdl.transcription(audio) - if not TenantLLMService.increase_usage( - self.tenant_id, self.llm_type, used_tokens - ): - logging.error( - f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}" + generation = self.langfuse.start_generation( + trace_context=self.trace_context, + name="stream_transcription", + metadata={"model": self.llm_name}, ) + + full_text, used_tokens = mdl.transcription(audio) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}") + if self.langfuse: generation.update( output={"output": full_text}, - usage_details={"total_tokens": used_tokens} + usage_details={"total_tokens": used_tokens}, ) generation.end() yield { "event": "final", "text": full_text, - "streaming": False + "streaming": False, } def tts(self, text: str) -> Generator[bytes, None, None]: @@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant): return kwargs else: return {k: v for k, v in kwargs.items() if k in allowed_params} + + def _run_coroutine_sync(self, coro): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + result_queue: queue.Queue = queue.Queue() + + def runner(): + try: + result_queue.put((True, asyncio.run(coro))) + except Exception as e: + result_queue.put((False, e)) + + thread = threading.Thread(target=runner, daemon=True) + thread.start() + thread.join() + + success, value = result_queue.get_nowait() + if success: + return value + raise value + def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: - if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history}) + return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs)) - chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) - if self.is_tools and self.mdl.is_tools: - chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs) + def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs): + result_queue: queue.Queue = queue.Queue() - use_kwargs = self._clean_param(chat_partial, **kwargs) - txt, used_tokens = chat_partial(**use_kwargs) - txt = self._remove_reasoning_content(txt) + def runner(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - if not self.verbose_tool_use: - txt = re.sub(r".*?", "", txt, flags=re.DOTALL) + async def consume(): + try: + async for item in async_gen_fn(*args, **kwargs): + result_queue.put(item) + except Exception as e: + result_queue.put(e) + finally: + result_queue.put(StopIteration) - if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): - logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) + loop.run_until_complete(consume()) + loop.close() - if self.langfuse: - generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) - generation.end() + threading.Thread(target=runner, daemon=True).start() - return txt + while True: + item = result_queue.get() + if item is StopIteration: + break + if isinstance(item, Exception): + raise item + yield item def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): - if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) - ans = "" - chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf) - total_tokens = 0 - if self.is_tools and self.mdl.is_tools: - chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf) - use_kwargs = self._clean_param(chat_partial, **kwargs) - for txt in chat_partial(**use_kwargs): + for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs): if isinstance(txt, int): - total_tokens = txt - if self.langfuse: - generation.update(output={"output": ans}) - generation.end() break if txt.endswith(""): - ans = ans[: -len("")] + ans = txt[: -len("")] + continue if not self.verbose_tool_use: txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - ans += txt + # cancatination has beend done in async_chat_streamly + ans = txt yield ans - if total_tokens > 0: - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): - logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) - def _bridge_sync_stream(self, gen): loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() @@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant): try: for item in gen: loop.call_soon_threadsafe(queue.put_nowait, item) - except Exception as e: # pragma: no cover + except Exception as e: loop.call_soon_threadsafe(queue.put_nowait, e) finally: loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) @@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant): return queue async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs): - chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) - if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"): - chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs) + if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"): + base_fn = self.mdl.async_chat_with_tools + elif hasattr(self.mdl, "async_chat"): + base_fn = self.mdl.async_chat + else: + raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools") + generation = None + if self.langfuse: + generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history}) + + chat_partial = partial(base_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) - if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools: - txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs) - elif hasattr(self.mdl, "async_chat"): - txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs) - else: - txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs) + try: + txt, used_tokens = await chat_partial(**use_kwargs) + except Exception as e: + if generation: + generation.update(output={"error": str(e)}) + generation.end() + raise txt = self._remove_reasoning_content(txt) if not self.verbose_tool_use: @@ -381,49 +413,51 @@ class LLMBundle(LLM4Tenant): if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) + if generation: + generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) + generation.end() + return txt async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): total_tokens = 0 ans = "" - if self.is_tools and self.mdl.is_tools: + if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"): stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None) - else: + elif hasattr(self.mdl, "async_chat_streamly"): stream_fn = getattr(self.mdl, "async_chat_streamly", None) + else: + raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools") + + generation = None + if self.langfuse: + generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) if stream_fn: chat_partial = partial(stream_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) - async for txt in chat_partial(**use_kwargs): - if isinstance(txt, int): - total_tokens = txt - break + try: + async for txt in chat_partial(**use_kwargs): + if isinstance(txt, int): + total_tokens = txt + break - if txt.endswith(""): - ans = ans[: -len("")] + if txt.endswith(""): + ans = ans[: -len("")] - if not self.verbose_tool_use: - txt = re.sub(r".*?", "", txt, flags=re.DOTALL) + if not self.verbose_tool_use: + txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - ans += txt - yield ans + ans += txt + yield ans + except Exception as e: + if generation: + generation.update(output={"error": str(e)}) + generation.end() + raise if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) + if generation: + generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) + generation.end() return - - chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf) - use_kwargs = self._clean_param(chat_partial, **kwargs) - queue = self._bridge_sync_stream(chat_partial(**use_kwargs)) - while True: - item = await queue.get() - if item is StopAsyncIteration: - break - if isinstance(item, Exception): - raise item - if isinstance(item, int): - total_tokens = item - break - yield item - - if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): - logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 3ff5311fc..67bf0bb09 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum): JiekouAI = "Jiekou.AI" ZHIPU_AI = "ZHIPU-AI" MiniMax = "MiniMax" + DeerAPI = "DeerAPI" + GPUStack = "GPUStack" FACTORY_DEFAULT_BASE_URL = { @@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = { SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai", SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4", SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1", + SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1", } @@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = { SupportedLiteLLMProvider.JiekouAI: "openai/", SupportedLiteLLMProvider.ZHIPU_AI: "openai/", SupportedLiteLLMProvider.MiniMax: "openai/", + SupportedLiteLLMProvider.DeerAPI: "openai/", + SupportedLiteLLMProvider.GPUStack: "openai/", } ChatModel = globals().get("ChatModel", {}) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index e69ff1868..9f5457224 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -19,7 +19,6 @@ import logging import os import random import re -import threading import time from abc import ABC from copy import deepcopy @@ -78,11 +77,9 @@ class Base(ABC): self.toolcall_sessions = {} def _get_delay(self): - """Calculate retry delay time""" return self.base_delay * random.uniform(10, 150) def _classify_error(self, error): - """Classify error based on error message content""" error_str = str(error).lower() keywords_mapping = [ @@ -139,89 +136,7 @@ class Base(ABC): return gen_conf - def _bridge_sync_stream(self, gen): - """Run a sync generator in a thread and yield asynchronously.""" - loop = asyncio.get_running_loop() - queue: asyncio.Queue = asyncio.Queue() - - def worker(): - try: - for item in gen: - loop.call_soon_threadsafe(queue.put_nowait, item) - except Exception as exc: # pragma: no cover - defensive - loop.call_soon_threadsafe(queue.put_nowait, exc) - finally: - loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) - - threading.Thread(target=worker, daemon=True).start() - return queue - - def _chat(self, history, gen_conf, **kwargs): - logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) - if self.model_name.lower().find("qwq") >= 0: - logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly") - - final_ans = "" - tol_token = 0 - for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs): - if delta.startswith("") or delta.endswith(""): - continue - final_ans += delta - tol_token = tol - - if len(final_ans.strip()) == 0: - final_ans = "**ERROR**: Empty response from reasoning model" - - return final_ans.strip(), tol_token - - if self.model_name.lower().find("qwen3") >= 0: - kwargs["extra_body"] = {"enable_thinking": False} - - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) - - if not response.choices or not response.choices[0].message or not response.choices[0].message.content: - return "", 0 - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans = self._length_stop(ans) - return ans, total_token_count_from_response(response) - - def _chat_streamly(self, history, gen_conf, **kwargs): - logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) - reasoning_start = False - - if kwargs.get("stop") or "stop" in gen_conf: - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop")) - else: - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) - - for resp in response: - if not resp.choices: - continue - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: - ans = "" - if not reasoning_start: - reasoning_start = True - ans = "" - ans += resp.choices[0].delta.reasoning_content + "" - else: - reasoning_start = False - ans = resp.choices[0].delta.content - - tol = total_token_count_from_response(resp) - if not tol: - tol = num_tokens_from_string(resp.choices[0].delta.content) - - if resp.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - yield ans, tol - - async def _async_chat_stream(self, history, gen_conf, **kwargs): + async def _async_chat_streamly(self, history, gen_conf, **kwargs): logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) reasoning_start = False @@ -265,13 +180,19 @@ class Base(ABC): gen_conf = self._clean_conf(gen_conf) ans = "" total_tokens = 0 - try: - async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs): - ans = delta_ans - total_tokens += tol - yield delta_ans - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) + + for attempt in range(self.max_retries + 1): + try: + async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs): + ans = delta_ans + total_tokens += tol + yield ans + except Exception as e: + e = await self._exceptions_async(e, attempt) + if e: + yield e + yield total_tokens + return yield total_tokens @@ -307,7 +228,7 @@ class Base(ABC): logging.error(f"sync base giving up: {msg}") return msg - async def _exceptions_async(self, e, attempt) -> str | None: + async def _exceptions_async(self, e, attempt): logging.exception("OpenAI async completion") error_code = self._classify_error(e) if attempt == self.max_retries: @@ -357,61 +278,6 @@ class Base(ABC): self.toolcall_session = toolcall_session self.tools = tools - def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): - gen_conf = self._clean_conf(gen_conf) - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - - ans = "" - tk_count = 0 - hist = deepcopy(history) - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - history = hist - try: - for _ in range(self.max_rounds + 1): - logging.info(f"{self.tools=}") - response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf) - tk_count += total_token_count_from_response(response) - if any([not response.choices, not response.choices[0].message]): - raise Exception(f"500 response structure error. Response: {response}") - - if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: - if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: - ans += "" + response.choices[0].message.reasoning_content + "" - - ans += response.choices[0].message.content - if response.choices[0].finish_reason == "length": - ans = self._length_stop(ans) - - return ans, tk_count - - for tool_call in response.choices[0].message.tool_calls: - logging.info(f"Response {tool_call=}") - name = tool_call.function.name - try: - args = json_repair.loads(tool_call.function.arguments) - tool_response = self.toolcall_session.tool_call(name, args) - history = self._append_history(history, tool_call, tool_response) - ans += self._verbose_tool_use(name, args, tool_response) - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - ans += self._verbose_tool_use(name, {}, str(e)) - - logging.warning(f"Exceed max rounds: {self.max_rounds}") - history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response, token_count = self._chat(history, gen_conf) - ans += response - tk_count += token_count - return ans, tk_count - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, tk_count - - assert False, "Shouldn't be here." - async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": @@ -466,140 +332,6 @@ class Base(ABC): assert False, "Shouldn't be here." - def chat(self, system, history, gen_conf={}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - try: - return self._chat(history, gen_conf, **kwargs) - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, 0 - assert False, "Shouldn't be here." - - def _wrap_toolcall_message(self, stream): - final_tool_calls = {} - - for chunk in stream: - for tool_call in chunk.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - final_tool_calls[index] = tool_call - - final_tool_calls[index].function.arguments += tool_call.function.arguments - - return final_tool_calls - - def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): - gen_conf = self._clean_conf(gen_conf) - tools = self.tools - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - - total_tokens = 0 - hist = deepcopy(history) - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - history = hist - try: - for _ in range(self.max_rounds + 1): - reasoning_start = False - logging.info(f"{tools=}") - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) - final_tool_calls = {} - answer = "" - for resp in response: - if resp.choices[0].delta.tool_calls: - for tool_call in resp.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - if not tool_call.function.arguments: - tool_call.function.arguments = "" - final_tool_calls[index] = tool_call - else: - final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else "" - continue - - if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): - raise Exception("500 response structure error.") - - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - - if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: - ans = "" - if not reasoning_start: - reasoning_start = True - ans = "" - ans += resp.choices[0].delta.reasoning_content + "" - yield ans - else: - reasoning_start = False - answer += resp.choices[0].delta.content - yield resp.choices[0].delta.content - - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(resp.choices[0].delta.content) - else: - total_tokens = tol - - finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" - if finish_reason == "length": - yield self._length_stop("") - - if answer: - yield total_tokens - return - - for tool_call in final_tool_calls.values(): - name = tool_call.function.name - try: - args = json_repair.loads(tool_call.function.arguments) - yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = self.toolcall_session.tool_call(name, args) - history = self._append_history(history, tool_call, tool_response) - yield self._verbose_tool_use(name, args, tool_response) - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - yield self._verbose_tool_use(name, {}, str(e)) - - logging.warning(f"Exceed max rounds: {self.max_rounds}") - history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) - for resp in response: - if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): - raise Exception("500 response structure error.") - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - continue - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(resp.choices[0].delta.content) - else: - total_tokens = tol - answer += resp.choices[0].delta.content - yield resp.choices[0].delta.content - - yield total_tokens - return - - except Exception as e: - e = self._exceptions(e, attempt) - if e: - yield e - yield total_tokens - return - - assert False, "Shouldn't be here." - async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) tools = self.tools @@ -715,9 +447,10 @@ class Base(ABC): logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) if self.model_name.lower().find("qwq") >= 0: logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly") + final_ans = "" tol_token = 0 - async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs): + async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs): if delta.startswith("") or delta.endswith(""): continue final_ans += delta @@ -754,57 +487,6 @@ class Base(ABC): return e, 0 assert False, "Shouldn't be here." - def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - ans = "" - total_tokens = 0 - try: - for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): - yield delta_ans - total_tokens += tol - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) - - yield total_tokens - - def _calculate_dynamic_ctx(self, history): - """Calculate dynamic context window size""" - - def count_tokens(text): - """Calculate token count for text""" - # Simple calculation: 1 token per ASCII character - # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) - total = 0 - for char in text: - if ord(char) < 128: # ASCII characters - total += 1 - else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) - total += 2 - return total - - # Calculate total tokens for all messages - total_tokens = 0 - for message in history: - content = message.get("content", "") - # Calculate content tokens - content_tokens = count_tokens(content) - # Add role marker token overhead - role_tokens = 4 - total_tokens += content_tokens + role_tokens - - # Apply 1.2x buffer ratio - total_tokens_with_buffer = int(total_tokens * 1.2) - - if total_tokens_with_buffer <= 8192: - ctx_size = 8192 - else: - ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 - ctx_size = ctx_multiplier * 8192 - - return ctx_size - class GptTurbo(Base): _FACTORY_NAME = "OpenAI" @@ -1504,16 +1186,6 @@ class GoogleChat(Base): yield total_tokens -class GPUStackChat(Base): - _FACTORY_NAME = "GPUStack" - - def __init__(self, key=None, model_name="", base_url="", **kwargs): - if not base_url: - raise ValueError("Local llm url cannot be None") - base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url, **kwargs) - - class TokenPonyChat(Base): _FACTORY_NAME = "TokenPony" @@ -1523,15 +1195,6 @@ class TokenPonyChat(Base): super().__init__(key, model_name, base_url, **kwargs) -class DeerAPIChat(Base): - _FACTORY_NAME = "DeerAPI" - - def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs): - if not base_url: - base_url = "https://api.deerapi.com/v1" - super().__init__(key, model_name, base_url, **kwargs) - - class LiteLLMBase(ABC): _FACTORY_NAME = [ "Tongyi-Qianwen", @@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC): "Jiekou.AI", "ZHIPU-AI", "MiniMax", + "DeerAPI", + "GPUStack", ] def __init__(self, key, model_name, base_url=None, **kwargs): @@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC): self.provider_order = json.loads(key).get("provider_order", "") def _get_delay(self): - """Calculate retry delay time""" return self.base_delay * random.uniform(10, 150) def _classify_error(self, error): - """Classify error based on error message content""" error_str = str(error).lower() keywords_mapping = [ @@ -1619,72 +1282,6 @@ class LiteLLMBase(ABC): del gen_conf["max_tokens"] return gen_conf - def _chat(self, history, gen_conf, **kwargs): - logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) - if self.model_name.lower().find("qwen3") >= 0: - kwargs["extra_body"] = {"enable_thinking": False} - - completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf) - response = litellm.completion( - **completion_args, - drop_params=True, - timeout=self.timeout, - ) - # response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) - if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): - return "", 0 - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans = self._length_stop(ans) - - return ans, total_token_count_from_response(response) - - def _chat_streamly(self, history, gen_conf, **kwargs): - logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) - gen_conf = self._clean_conf(gen_conf) - reasoning_start = False - - completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf) - stop = kwargs.get("stop") - if stop: - completion_args["stop"] = stop - response = litellm.completion( - **completion_args, - drop_params=True, - timeout=self.timeout, - ) - - for resp in response: - if not hasattr(resp, "choices") or not resp.choices: - continue - - delta = resp.choices[0].delta - if not hasattr(delta, "content") or delta.content is None: - delta.content = "" - - if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content: - ans = "" - if not reasoning_start: - reasoning_start = True - ans = "" - ans += delta.reasoning_content + "" - else: - reasoning_start = False - ans = delta.content - - tol = total_token_count_from_response(resp) - if not tol: - tol = num_tokens_from_string(delta.content) - - finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" - if finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - - yield ans, tol - async def async_chat(self, system, history, gen_conf, **kwargs): hist = list(history) if history else [] if system: @@ -1795,22 +1392,7 @@ class LiteLLMBase(ABC): def _should_retry(self, error_code: str) -> bool: return error_code in self._retryable_errors - def _exceptions(self, e, attempt) -> str | None: - logging.exception("OpenAI chat_with_tools") - # Classify the error - error_code = self._classify_error(e) - if attempt == self.max_retries: - error_code = LLMErrorCode.ERROR_MAX_RETRIES - - if self._should_retry(error_code): - delay = self._get_delay() - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) - return None - - return f"{ERROR_PREFIX}: {error_code} - {str(e)}" - - async def _exceptions_async(self, e, attempt) -> str | None: + async def _exceptions_async(self, e, attempt): logging.exception("LiteLLMBase async completion") error_code = self._classify_error(e) if attempt == self.max_retries: @@ -1859,71 +1441,7 @@ class LiteLLMBase(ABC): self.toolcall_session = toolcall_session self.tools = tools - def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): - completion_args = { - "model": self.model_name, - "messages": history, - "api_key": self.api_key, - "num_retries": self.max_retries, - **kwargs, - } - if stream: - completion_args.update( - { - "stream": stream, - } - ) - if tools and self.tools: - completion_args.update( - { - "tools": self.tools, - "tool_choice": "auto", - } - ) - if self.provider in FACTORY_DEFAULT_BASE_URL: - completion_args.update({"api_base": self.base_url}) - elif self.provider == SupportedLiteLLMProvider.Bedrock: - completion_args.pop("api_key", None) - completion_args.pop("api_base", None) - completion_args.update( - { - "aws_access_key_id": self.bedrock_ak, - "aws_secret_access_key": self.bedrock_sk, - "aws_region_name": self.bedrock_region, - } - ) - - if self.provider == SupportedLiteLLMProvider.OpenRouter: - if self.provider_order: - - def _to_order_list(x): - if x is None: - return [] - if isinstance(x, str): - return [s.strip() for s in x.split(",") if s.strip()] - if isinstance(x, (list, tuple)): - return [str(s).strip() for s in x if str(s).strip()] - return [] - - extra_body = {} - provider_cfg = {} - provider_order = _to_order_list(self.provider_order) - provider_cfg["order"] = provider_order - provider_cfg["allow_fallbacks"] = False - extra_body["provider"] = provider_cfg - completion_args.update({"extra_body": extra_body}) - - # Ollama deployments commonly sit behind a reverse proxy that enforces - # Bearer auth. Ensure the Authorization header is set when an API key - # is provided, while respecting any user-supplied headers. #11350 - extra_headers = deepcopy(completion_args.get("extra_headers") or {}) - if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers: - extra_headers["Authorization"] = f"Bearer {self.api_key}" - if extra_headers: - completion_args["extra_headers"] = extra_headers - return completion_args - - def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -1931,16 +1449,14 @@ class LiteLLMBase(ABC): ans = "" tk_count = 0 hist = deepcopy(history) - - # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): - history = deepcopy(hist) # deepcopy is required here + history = deepcopy(hist) try: for _ in range(self.max_rounds + 1): logging.info(f"{self.tools=}") completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf) - response = litellm.completion( + response = await litellm.acompletion( **completion_args, drop_params=True, timeout=self.timeout, @@ -1966,7 +1482,7 @@ class LiteLLMBase(ABC): name = tool_call.function.name try: args = json_repair.loads(tool_call.function.arguments) - tool_response = self.toolcall_session.tool_call(name, args) + tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) history = self._append_history(history, tool_call, tool_response) ans += self._verbose_tool_use(name, args, tool_response) except Exception as e: @@ -1977,49 +1493,19 @@ class LiteLLMBase(ABC): logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response, token_count = self._chat(history, gen_conf) + response, token_count = await self.async_chat("", history, gen_conf) ans += response tk_count += token_count return ans, tk_count except Exception as e: - e = self._exceptions(e, attempt) + e = await self._exceptions_async(e, attempt) if e: return e, tk_count assert False, "Shouldn't be here." - def chat(self, system, history, gen_conf={}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - try: - response = self._chat(history, gen_conf, **kwargs) - return response - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, 0 - assert False, "Shouldn't be here." - - def _wrap_toolcall_message(self, stream): - final_tool_calls = {} - - for chunk in stream: - for tool_call in chunk.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - final_tool_calls[index] = tool_call - - final_tool_calls[index].function.arguments += tool_call.function.arguments - - return final_tool_calls - - def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): gen_conf = self._clean_conf(gen_conf) tools = self.tools if system and history and history[0].get("role") != "system": @@ -2028,16 +1514,15 @@ class LiteLLMBase(ABC): total_tokens = 0 hist = deepcopy(history) - # Implement exponential backoff retry strategy for attempt in range(self.max_retries + 1): - history = deepcopy(hist) # deepcopy is required here + history = deepcopy(hist) try: for _ in range(self.max_rounds + 1): reasoning_start = False logging.info(f"{tools=}") completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) - response = litellm.completion( + response = await litellm.acompletion( **completion_args, drop_params=True, timeout=self.timeout, @@ -2046,7 +1531,7 @@ class LiteLLMBase(ABC): final_tool_calls = {} answer = "" - for resp in response: + async for resp in response: if not hasattr(resp, "choices") or not resp.choices: continue @@ -2082,7 +1567,7 @@ class LiteLLMBase(ABC): if not tol: total_tokens += num_tokens_from_string(delta.content) else: - total_tokens += tol + total_tokens = tol finish_reason = getattr(resp.choices[0], "finish_reason", "") if finish_reason == "length": @@ -2097,31 +1582,25 @@ class LiteLLMBase(ABC): try: args = json_repair.loads(tool_call.function.arguments) yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = self.toolcall_session.tool_call(name, args) + tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args) history = self._append_history(history, tool_call, tool_response) yield self._verbose_tool_use(name, args, tool_response) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}", - } - ) + history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) yield self._verbose_tool_use(name, {}, str(e)) logging.warning(f"Exceed max rounds: {self.max_rounds}") history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf) - response = litellm.completion( + response = await litellm.acompletion( **completion_args, drop_params=True, timeout=self.timeout, ) - for resp in response: + async for resp in response: if not hasattr(resp, "choices") or not resp.choices: continue delta = resp.choices[0].delta @@ -2131,14 +1610,14 @@ class LiteLLMBase(ABC): if not tol: total_tokens += num_tokens_from_string(delta.content) else: - total_tokens += tol + total_tokens = tol yield delta.content yield total_tokens return except Exception as e: - e = self._exceptions(e, attempt) + e = await self._exceptions_async(e, attempt) if e: yield e yield total_tokens @@ -2146,53 +1625,71 @@ class LiteLLMBase(ABC): assert False, "Shouldn't be here." - def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - ans = "" - total_tokens = 0 - try: - for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): - yield delta_ans - total_tokens += tol - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) + def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs): + completion_args = { + "model": self.model_name, + "messages": history, + "api_key": self.api_key, + "num_retries": self.max_retries, + **kwargs, + } + if stream: + completion_args.update( + { + "stream": stream, + } + ) + if tools and self.tools: + completion_args.update( + { + "tools": self.tools, + "tool_choice": "auto", + } + ) + if self.provider in FACTORY_DEFAULT_BASE_URL: + completion_args.update({"api_base": self.base_url}) + elif self.provider == SupportedLiteLLMProvider.Bedrock: + completion_args.pop("api_key", None) + completion_args.pop("api_base", None) + completion_args.update( + { + "aws_access_key_id": self.bedrock_ak, + "aws_secret_access_key": self.bedrock_sk, + "aws_region_name": self.bedrock_region, + } + ) + elif self.provider == SupportedLiteLLMProvider.OpenRouter: + if self.provider_order: - yield total_tokens + def _to_order_list(x): + if x is None: + return [] + if isinstance(x, str): + return [s.strip() for s in x.split(",") if s.strip()] + if isinstance(x, (list, tuple)): + return [str(s).strip() for s in x if str(s).strip()] + return [] - def _calculate_dynamic_ctx(self, history): - """Calculate dynamic context window size""" + extra_body = {} + provider_cfg = {} + provider_order = _to_order_list(self.provider_order) + provider_cfg["order"] = provider_order + provider_cfg["allow_fallbacks"] = False + extra_body["provider"] = provider_cfg + completion_args.update({"extra_body": extra_body}) + elif self.provider == SupportedLiteLLMProvider.GPUStack: + completion_args.update( + { + "api_base": self.base_url, + } + ) - def count_tokens(text): - """Calculate token count for text""" - # Simple calculation: 1 token per ASCII character - # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) - total = 0 - for char in text: - if ord(char) < 128: # ASCII characters - total += 1 - else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) - total += 2 - return total - - # Calculate total tokens for all messages - total_tokens = 0 - for message in history: - content = message.get("content", "") - # Calculate content tokens - content_tokens = count_tokens(content) - # Add role marker token overhead - role_tokens = 4 - total_tokens += content_tokens + role_tokens - - # Apply 1.2x buffer ratio - total_tokens_with_buffer = int(total_tokens * 1.2) - - if total_tokens_with_buffer <= 8192: - ctx_size = 8192 - else: - ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 - ctx_size = ctx_multiplier * 8192 - - return ctx_size + # Ollama deployments commonly sit behind a reverse proxy that enforces + # Bearer auth. Ensure the Authorization header is set when an API key + # is provided, while respecting any user-supplied headers. #11350 + extra_headers = deepcopy(completion_args.get("extra_headers") or {}) + if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers: + extra_headers["Authorization"] = f"Bearer {self.api_key}" + if extra_headers: + completion_args["extra_headers"] = extra_headers + return completion_args From 3285f09c92146fd09e28d9536b187f145b6c1821 Mon Sep 17 00:00:00 2001 From: Mustafa Aldemir Date: Mon, 8 Dec 2025 02:50:03 +0100 Subject: [PATCH 09/12] Add huggingface-hub dependency (#11794) ### What problem does this PR solve? When a script has a block like this at the top, then uv run download_deps.py ignores the [project].dependencies in pyproject.toml and only uses that dependencies = [...] list. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- download_deps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/download_deps.py b/download_deps.py index 352e27f91..550a806a5 100644 --- a/download_deps.py +++ b/download_deps.py @@ -5,6 +5,7 @@ # requires-python = ">=3.10" # dependencies = [ # "nltk", +# "huggingface-hub" # ] # /// From 660fa8888b7d9f9a2dd0b8d0c4813fec1982a3f0 Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Mon, 8 Dec 2025 10:17:56 +0800 Subject: [PATCH 10/12] Features: Memory page rendering and other bug fixes (#11784) ### What problem does this PR solve? Features: Memory page rendering and other bug fixes - Rendering of the Memory list page - Rendering of the message list page in Memory - Fixed an issue where the empty state was incorrectly displayed when search criteria were applied - Added a web link for the API-Key - modifying the index_mode attribute of the Confluence data source. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --- web/src/assets/svg/home-icon/memory.svg | 16 + web/src/components/dynamic-form.tsx | 169 +++++++--- web/src/components/empty/constant.tsx | 10 + web/src/components/empty/empty.tsx | 27 +- .../llm-setting-items/llm-form-field.tsx | 12 +- web/src/components/ragflow-form.tsx | 18 +- web/src/constants/llm.ts | 50 +++ web/src/hooks/logic-hooks.ts | 21 +- web/src/hooks/logic-hooks/navigate-hooks.ts | 19 ++ web/src/interfaces/common.ts | 1 + web/src/locales/en.ts | 30 ++ web/src/locales/zh.ts | 6 + web/src/pages/agents/index.tsx | 17 +- .../configuration/common-item.tsx | 79 +++-- web/src/pages/datasets/index.tsx | 17 +- web/src/pages/memories/add-or-edit-modal.tsx | 75 +++++ web/src/pages/memories/constants/index.tsx | 41 +++ web/src/pages/memories/hooks.ts | 288 ++++++++++++++++++ web/src/pages/memories/index.tsx | 163 ++++++++++ web/src/pages/memories/interface.ts | 121 ++++++++ web/src/pages/memories/memory-card.tsx | 32 ++ web/src/pages/memories/memory-dropdown.tsx | 74 +++++ web/src/pages/memory/constant.tsx | 3 + .../pages/memory/hooks/use-memory-messages.ts | 59 ++++ .../pages/memory/hooks/use-memory-setting.ts | 59 ++++ web/src/pages/memory/index.tsx | 17 ++ web/src/pages/memory/memory-message/index.tsx | 51 ++++ .../pages/memory/memory-message/interface.ts | 19 ++ .../memory/memory-message/message-table.tsx | 225 ++++++++++++++ web/src/pages/memory/memory-setting/index.tsx | 13 + web/src/pages/memory/sidebar/hooks.tsx | 17 ++ web/src/pages/memory/sidebar/index.tsx | 88 ++++++ web/src/pages/next-chats/index.tsx | 17 +- web/src/pages/next-searches/hooks.ts | 46 +-- web/src/pages/next-searches/index.tsx | 71 +++-- .../component/confluence-token-field.tsx | 97 +++--- .../user-setting/data-source/contant.tsx | 6 +- .../data-source-detail-page/index.tsx | 4 +- .../setting-model/components/llm-header.tsx | 34 +++ .../setting-model/components/un-add-model.tsx | 21 +- .../modal/api-key-modal/index.tsx | 3 +- .../modal/azure-openai-modal/index.tsx | 3 +- .../modal/bedrock-modal/index.tsx | 3 +- .../modal/fish-audio-modal/index.tsx | 3 +- .../modal/google-modal/index.tsx | 3 +- .../modal/hunyuan-modal/index.tsx | 3 +- .../modal/next-tencent-modal/index.tsx | 3 +- .../modal/ollama-modal/index.tsx | 7 +- .../setting-model/modal/spark-modal/index.tsx | 3 +- .../modal/volcengine-modal/index.tsx | 3 +- .../setting-model/modal/yiyan-modal/index.tsx | 3 +- web/src/routes.ts | 40 +++ web/src/services/memory-service.ts | 43 +++ web/src/services/search-service.ts | 5 +- web/src/utils/api.ts | 7 + 55 files changed, 2047 insertions(+), 218 deletions(-) create mode 100644 web/src/assets/svg/home-icon/memory.svg create mode 100644 web/src/pages/memories/add-or-edit-modal.tsx create mode 100644 web/src/pages/memories/constants/index.tsx create mode 100644 web/src/pages/memories/hooks.ts create mode 100644 web/src/pages/memories/index.tsx create mode 100644 web/src/pages/memories/interface.ts create mode 100644 web/src/pages/memories/memory-card.tsx create mode 100644 web/src/pages/memories/memory-dropdown.tsx create mode 100644 web/src/pages/memory/constant.tsx create mode 100644 web/src/pages/memory/hooks/use-memory-messages.ts create mode 100644 web/src/pages/memory/hooks/use-memory-setting.ts create mode 100644 web/src/pages/memory/index.tsx create mode 100644 web/src/pages/memory/memory-message/index.tsx create mode 100644 web/src/pages/memory/memory-message/interface.ts create mode 100644 web/src/pages/memory/memory-message/message-table.tsx create mode 100644 web/src/pages/memory/memory-setting/index.tsx create mode 100644 web/src/pages/memory/sidebar/hooks.tsx create mode 100644 web/src/pages/memory/sidebar/index.tsx create mode 100644 web/src/pages/user-setting/setting-model/components/llm-header.tsx create mode 100644 web/src/services/memory-service.ts diff --git a/web/src/assets/svg/home-icon/memory.svg b/web/src/assets/svg/home-icon/memory.svg new file mode 100644 index 000000000..f50d755f4 --- /dev/null +++ b/web/src/assets/svg/home-icon/memory.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + diff --git a/web/src/components/dynamic-form.tsx b/web/src/components/dynamic-form.tsx index ca0b08763..2fdc55af0 100644 --- a/web/src/components/dynamic-form.tsx +++ b/web/src/components/dynamic-form.tsx @@ -1,6 +1,13 @@ import { zodResolver } from '@hookform/resolvers/zod'; -import { forwardRef, useEffect, useImperativeHandle, useMemo } from 'react'; import { + forwardRef, + useEffect, + useImperativeHandle, + useMemo, + useState, +} from 'react'; +import { + ControllerRenderProps, DefaultValues, FieldValues, SubmitHandler, @@ -26,6 +33,7 @@ import { Textarea } from '@/components/ui/textarea'; import { cn } from '@/lib/utils'; import { t } from 'i18next'; import { Loader } from 'lucide-react'; +import { MultiSelect, MultiSelectOptionType } from './ui/multi-select'; // Field type enumeration export enum FormFieldType { @@ -35,14 +43,17 @@ export enum FormFieldType { Number = 'number', Textarea = 'textarea', Select = 'select', + MultiSelect = 'multi-select', Checkbox = 'checkbox', Tag = 'tag', + Custom = 'custom', } // Field configuration interface export interface FormFieldConfig { name: string; label: string; + hideLabel?: boolean; type: FormFieldType; hidden?: boolean; required?: boolean; @@ -57,7 +68,7 @@ export interface FormFieldConfig { max?: number; message?: string; }; - render?: (fieldProps: any) => React.ReactNode; + render?: (fieldProps: ControllerRenderProps) => React.ReactNode; horizontal?: boolean; onChange?: (value: any) => void; tooltip?: React.ReactNode; @@ -78,10 +89,10 @@ interface DynamicFormProps { className?: string; children?: React.ReactNode; defaultValues?: DefaultValues; - onFieldUpdate?: ( - fieldName: string, - updatedField: Partial, - ) => void; + // onFieldUpdate?: ( + // fieldName: string, + // updatedField: Partial, + // ) => void; labelClassName?: string; } @@ -92,6 +103,10 @@ export interface DynamicFormRef { reset: (values?: any) => void; watch: (field: string, callback: (value: any) => void) => () => void; updateFieldType: (fieldName: string, newType: FormFieldType) => void; + onFieldUpdate: ( + fieldName: string, + newFieldProperties: Partial, + ) => void; } // Generate Zod validation schema based on field configurations @@ -110,6 +125,14 @@ const generateSchema = (fields: FormFieldConfig[]): ZodSchema => { case FormFieldType.Email: fieldSchema = z.string().email('Please enter a valid email address'); break; + case FormFieldType.MultiSelect: + fieldSchema = z.array(z.string()).optional(); + if (field.required) { + fieldSchema = z.array(z.string()).min(1, { + message: `${field.label} is required`, + }); + } + break; case FormFieldType.Number: fieldSchema = z.coerce.number(); if (field.validation?.min !== undefined) { @@ -275,7 +298,10 @@ const generateDefaultValues = ( defaultValues[field.name] = field.defaultValue; } else if (field.type === FormFieldType.Checkbox) { defaultValues[field.name] = false; - } else if (field.type === FormFieldType.Tag) { + } else if ( + field.type === FormFieldType.Tag || + field.type === FormFieldType.MultiSelect + ) { defaultValues[field.name] = []; } else { defaultValues[field.name] = ''; @@ -291,17 +317,21 @@ const DynamicForm = { Root: forwardRef( ( { - fields, + fields: originFields, onSubmit, className = '', children, defaultValues: formDefaultValues = {} as DefaultValues, - onFieldUpdate, + // onFieldUpdate, labelClassName, }: DynamicFormProps, ref: React.Ref, ) => { // Generate validation schema and default values + const [fields, setFields] = useState(originFields); + useMemo(() => { + setFields(originFields); + }, [originFields]); const schema = useMemo(() => generateSchema(fields), [fields]); const defaultValues = useMemo(() => { @@ -406,43 +436,54 @@ const DynamicForm = { }, [fields, form]); // Expose form methods via ref - useImperativeHandle(ref, () => ({ - submit: () => form.handleSubmit(onSubmit)(), - getValues: () => form.getValues(), - reset: (values?: T) => { - if (values) { - form.reset(values); - } else { - form.reset(); - } - }, - setError: form.setError, - clearErrors: form.clearErrors, - trigger: form.trigger, - watch: (field: string, callback: (value: any) => void) => { - const { unsubscribe } = form.watch((values: any) => { - if (values && values[field] !== undefined) { - callback(values[field]); - } - }); - return unsubscribe; - }, - - onFieldUpdate: ( - fieldName: string, - updatedField: Partial, - ) => { - setTimeout(() => { - if (onFieldUpdate) { - onFieldUpdate(fieldName, updatedField); + useImperativeHandle( + ref, + () => ({ + submit: () => form.handleSubmit(onSubmit)(), + getValues: () => form.getValues(), + reset: (values?: T) => { + if (values) { + form.reset(values); } else { - console.warn( - 'onFieldUpdate prop is not provided. Cannot update field type.', - ); + form.reset(); } - }, 0); - }, - })); + }, + setError: form.setError, + clearErrors: form.clearErrors, + trigger: form.trigger, + watch: (field: string, callback: (value: any) => void) => { + const { unsubscribe } = form.watch((values: any) => { + if (values && values[field] !== undefined) { + callback(values[field]); + } + }); + return unsubscribe; + }, + + onFieldUpdate: ( + fieldName: string, + updatedField: Partial, + ) => { + setFields((prevFields: any) => + prevFields.map((field: any) => + field.name === fieldName + ? { ...field, ...updatedField } + : field, + ), + ); + // setTimeout(() => { + // if (onFieldUpdate) { + // onFieldUpdate(fieldName, updatedField); + // } else { + // console.warn( + // 'onFieldUpdate prop is not provided. Cannot update field type.', + // ); + // } + // }, 0); + }, + }), + [form], + ); useEffect(() => { if (formDefaultValues && Object.keys(formDefaultValues).length > 0) { @@ -459,6 +500,9 @@ const DynamicForm = { // Render form fields const renderField = (field: FormFieldConfig) => { if (field.render) { + if (field.type === FormFieldType.Custom && field.hideLabel) { + return
{field.render({})}
; + } return ( ); + case FormFieldType.MultiSelect: + return ( + + {(fieldProps) => { + console.log('multi select value', fieldProps); + const finalFieldProps = { + ...fieldProps, + onValueChange: (value: string[]) => { + if (fieldProps.onChange) { + fieldProps.onChange(value); + } + field.onChange?.(value); + }, + }; + return ( + { + // console.log(data); + // field.onChange?.(data); + // }} + options={field.options as MultiSelectOptionType[]} + /> + ); + }} + + ); + case FormFieldType.Checkbox: return ( , title: t('empty.agentTitle'), + notFound: t('empty.notFoundAgent'), }, [EmptyCardType.Dataset]: { icon: , title: t('empty.datasetTitle'), + notFound: t('empty.notFoundDataset'), }, [EmptyCardType.Chat]: { icon: , title: t('empty.chatTitle'), + notFound: t('empty.notFoundChat'), }, [EmptyCardType.Search]: { icon: , title: t('empty.searchTitle'), + notFound: t('empty.notFoundSearch'), + }, + [EmptyCardType.Memory]: { + icon: , + title: t('empty.memoryTitle'), + notFound: t('empty.notFoundMemory'), }, }; diff --git a/web/src/components/empty/empty.tsx b/web/src/components/empty/empty.tsx index abf28dd3a..3623f43d0 100644 --- a/web/src/components/empty/empty.tsx +++ b/web/src/components/empty/empty.tsx @@ -76,9 +76,10 @@ export const EmptyAppCard = (props: { onClick?: () => void; showIcon?: boolean; className?: string; + isSearch?: boolean; size?: 'small' | 'large'; }) => { - const { type, showIcon, className } = props; + const { type, showIcon, className, isSearch } = props; let defaultClass = ''; let style = {}; switch (props.size) { @@ -95,19 +96,29 @@ export const EmptyAppCard = (props: { break; } return ( -
+
-
- -
+ {!isSearch && ( +
+ +
+ )}
); diff --git a/web/src/components/llm-setting-items/llm-form-field.tsx b/web/src/components/llm-setting-items/llm-form-field.tsx index 594c17df4..b4106ed0e 100644 --- a/web/src/components/llm-setting-items/llm-form-field.tsx +++ b/web/src/components/llm-setting-items/llm-form-field.tsx @@ -9,13 +9,19 @@ export type LLMFormFieldProps = { name?: string; }; -export function LLMFormField({ options, name }: LLMFormFieldProps) { - const { t } = useTranslation(); - +export const useModelOptions = () => { const modelOptions = useComposeLlmOptionsByModelTypes([ LlmModelType.Chat, LlmModelType.Image2text, ]); + return { + modelOptions, + }; +}; + +export function LLMFormField({ options, name }: LLMFormFieldProps) { + const { t } = useTranslation(); + const { modelOptions } = useModelOptions(); return ( diff --git a/web/src/components/ragflow-form.tsx b/web/src/components/ragflow-form.tsx index c59776824..5f21980b0 100644 --- a/web/src/components/ragflow-form.tsx +++ b/web/src/components/ragflow-form.tsx @@ -53,14 +53,16 @@ export function RAGFlowFormItem({ {label} )} - - {typeof children === 'function' - ? children(field) - : isValidElement(children) - ? cloneElement(children, { ...field }) - : children} - - +
+ + {typeof children === 'function' + ? children(field) + : isValidElement(children) + ? cloneElement(children, { ...field }) + : children} + + +
)} /> diff --git a/web/src/constants/llm.ts b/web/src/constants/llm.ts index c7757f805..a5f5e4b82 100644 --- a/web/src/constants/llm.ts +++ b/web/src/constants/llm.ts @@ -126,3 +126,53 @@ export const IconMap = { [LLMFactory.JiekouAI]: 'jiekouai', [LLMFactory.Builtin]: 'builtin', }; + +export const APIMapUrl = { + [LLMFactory.OpenAI]: 'https://platform.openai.com/api-keys', + [LLMFactory.Anthropic]: 'https://console.anthropic.com/settings/keys', + [LLMFactory.Gemini]: 'https://aistudio.google.com/app/apikey', + [LLMFactory.DeepSeek]: 'https://platform.deepseek.com/api_keys', + [LLMFactory.Moonshot]: 'https://platform.moonshot.cn/console/api-keys', + [LLMFactory.TongYiQianWen]: 'https://dashscope.console.aliyun.com/apiKey', + [LLMFactory.ZhipuAI]: 'https://open.bigmodel.cn/usercenter/apikeys', + [LLMFactory.XAI]: 'https://x.ai/api/', + [LLMFactory.HuggingFace]: 'https://huggingface.co/settings/tokens', + [LLMFactory.Mistral]: 'https://console.mistral.ai/api-keys/', + [LLMFactory.Cohere]: 'https://dashboard.cohere.com/api-keys', + [LLMFactory.BaiduYiYan]: 'https://wenxin.baidu.com/user/key', + [LLMFactory.Meituan]: 'https://longcat.chat/platform/api_keys', + [LLMFactory.Bedrock]: + 'https://us-east-2.console.aws.amazon.com/bedrock/home#/api-keys', + [LLMFactory.AzureOpenAI]: + 'https://portal.azure.com/#create/Microsoft.CognitiveServicesOpenAI', + [LLMFactory.OpenRouter]: 'https://openrouter.ai/keys', + [LLMFactory.XunFeiSpark]: 'https://console.xfyun.cn/services/cbm', + [LLMFactory.MiniMax]: + 'https://platform.minimaxi.com/user-center/basic-information', + [LLMFactory.Groq]: 'https://console.groq.com/keys', + [LLMFactory.NVIDIA]: 'https://build.nvidia.com/settings/api-keys', + [LLMFactory.SILICONFLOW]: 'https://cloud.siliconflow.cn/account/ak', + [LLMFactory.Replicate]: 'https://replicate.com/account/api-tokens', + [LLMFactory.VolcEngine]: 'https://console.volcengine.com/ark', + [LLMFactory.Jina]: 'https://jina.ai/embeddings/', + [LLMFactory.TencentHunYuan]: + 'https://console.cloud.tencent.com/hunyuan/api-key', + [LLMFactory.TencentCloud]: 'https://console.cloud.tencent.com/cam/capi', + [LLMFactory.ModelScope]: 'https://modelscope.cn/my/myaccesstoken', + [LLMFactory.GoogleCloud]: 'https://console.cloud.google.com/apis/credentials', + [LLMFactory.FishAudio]: 'https://fish.audio/app/api-keys/', + [LLMFactory.GiteeAI]: + 'https://ai.gitee.com/hhxzgrjn/dashboard/settings/tokens', + [LLMFactory.StepFun]: 'https://platform.stepfun.com/interface-key', + [LLMFactory.BaiChuan]: 'https://platform.baichuan-ai.com/console/apikey', + [LLMFactory.PPIO]: 'https://ppio.com/settings/key-management', + [LLMFactory.VoyageAI]: 'https://dash.voyageai.com/api-keys', + [LLMFactory.TogetherAI]: 'https://api.together.xyz/settings/api-keys', + [LLMFactory.NovitaAI]: 'https://novita.ai/dashboard/key', + [LLMFactory.Upstage]: 'https://console.upstage.ai/api-keys', + [LLMFactory.CometAPI]: 'https://api.cometapi.com/console/token', + [LLMFactory.Ai302]: 'https://302.ai/apis/list', + [LLMFactory.DeerAPI]: 'https://api.deerapi.com/token', + [LLMFactory.TokenPony]: 'https://www.tokenpony.cn/#/user/keys', + [LLMFactory.DeepInfra]: 'https://deepinfra.com/dash/api_keys', +}; diff --git a/web/src/hooks/logic-hooks.ts b/web/src/hooks/logic-hooks.ts index 73b389fd7..4fa4ef218 100644 --- a/web/src/hooks/logic-hooks.ts +++ b/web/src/hooks/logic-hooks.ts @@ -1,6 +1,7 @@ import { Authorization } from '@/constants/authorization'; import { MessageType } from '@/constants/chat'; import { LanguageTranslationMap } from '@/constants/common'; +import { Pagination } from '@/interfaces/common'; import { ResponseType } from '@/interfaces/database/base'; import { IAnswer, @@ -12,7 +13,7 @@ import { IKnowledgeFile } from '@/interfaces/database/knowledge'; import api from '@/utils/api'; import { getAuthorization } from '@/utils/authorization-util'; import { buildMessageUuid } from '@/utils/chat'; -import { PaginationProps, message } from 'antd'; +import { message } from 'antd'; import { FormInstance } from 'antd/lib'; import axios from 'axios'; import { EventSourceParserStream } from 'eventsource-parser/stream'; @@ -71,8 +72,8 @@ export const useGetPaginationWithRouter = () => { size: pageSize, } = useSetPaginationParams(); - const onPageChange: PaginationProps['onChange'] = useCallback( - (pageNumber: number, pageSize: number) => { + const onPageChange: Pagination['onChange'] = useCallback( + (pageNumber: number, pageSize?: number) => { setPaginationParams(pageNumber, pageSize); }, [setPaginationParams], @@ -88,7 +89,7 @@ export const useGetPaginationWithRouter = () => { [setPaginationParams, pageSize], ); - const pagination: PaginationProps = useMemo(() => { + const pagination: Pagination = useMemo(() => { return { showQuickJumper: true, total: 0, @@ -97,7 +98,7 @@ export const useGetPaginationWithRouter = () => { pageSize: pageSize, pageSizeOptions: [1, 2, 10, 20, 50, 100], onChange: onPageChange, - showTotal: (total) => `${t('total')} ${total}`, + showTotal: (total: number) => `${t('total')} ${total}`, }; }, [t, onPageChange, page, pageSize]); @@ -109,7 +110,7 @@ export const useGetPaginationWithRouter = () => { export const useHandleSearchChange = () => { const [searchString, setSearchString] = useState(''); - const { setPagination } = useGetPaginationWithRouter(); + const { pagination, setPagination } = useGetPaginationWithRouter(); const handleInputChange = useCallback( (e: React.ChangeEvent) => { const value = e.target.value; @@ -119,21 +120,21 @@ export const useHandleSearchChange = () => { [setPagination], ); - return { handleInputChange, searchString }; + return { handleInputChange, searchString, pagination, setPagination }; }; export const useGetPagination = () => { const [pagination, setPagination] = useState({ page: 1, pageSize: 10 }); const { t } = useTranslate('common'); - const onPageChange: PaginationProps['onChange'] = useCallback( + const onPageChange: Pagination['onChange'] = useCallback( (pageNumber: number, pageSize: number) => { setPagination({ page: pageNumber, pageSize }); }, [], ); - const currentPagination: PaginationProps = useMemo(() => { + const currentPagination: Pagination = useMemo(() => { return { showQuickJumper: true, total: 0, @@ -142,7 +143,7 @@ export const useGetPagination = () => { pageSize: pagination.pageSize, pageSizeOptions: [1, 2, 10, 20, 50, 100], onChange: onPageChange, - showTotal: (total) => `${t('total')} ${total}`, + showTotal: (total: number) => `${t('total')} ${total}`, }; }, [t, onPageChange, pagination]); diff --git a/web/src/hooks/logic-hooks/navigate-hooks.ts b/web/src/hooks/logic-hooks/navigate-hooks.ts index 2f4f770e0..3fd62b689 100644 --- a/web/src/hooks/logic-hooks/navigate-hooks.ts +++ b/web/src/hooks/logic-hooks/navigate-hooks.ts @@ -25,6 +25,17 @@ export const useNavigatePage = () => { [navigate], ); + const navigateToMemoryList = useCallback( + ({ isCreate = false }: { isCreate?: boolean }) => { + if (isCreate) { + navigate(Routes.Memories + '?isCreate=true'); + } else { + navigate(Routes.Memories); + } + }, + [navigate], + ); + const navigateToDataset = useCallback( (id: string) => () => { // navigate(`${Routes.DatasetBase}${Routes.DataSetOverview}/${id}`); @@ -105,6 +116,12 @@ export const useNavigatePage = () => { }, [navigate], ); + const navigateToMemory = useCallback( + (id: string) => () => { + navigate(`${Routes.Memory}${Routes.MemoryMessage}/${id}`); + }, + [navigate], + ); const navigateToChunkParsedResult = useCallback( (id: string, knowledgeId?: string) => () => { @@ -196,5 +213,7 @@ export const useNavigatePage = () => { navigateToDataflowResult, navigateToDataFile, navigateToDataSourceDetail, + navigateToMemory, + navigateToMemoryList, }; }; diff --git a/web/src/interfaces/common.ts b/web/src/interfaces/common.ts index 21553d653..771ff5aa6 100644 --- a/web/src/interfaces/common.ts +++ b/web/src/interfaces/common.ts @@ -2,6 +2,7 @@ export interface Pagination { current: number; pageSize: number; total: number; + onChange?: (page: number, pageSize: number) => void; } export interface BaseState { diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 479cafb64..215d8b9f4 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -99,6 +99,29 @@ export default { search: 'Search', welcome: 'Welcome to', dataset: 'Dataset', + Memories: 'Memory', + }, + memory: { + memory: 'Memory', + createMemory: 'Create Memory', + name: 'Name', + memoryNamePlaceholder: 'memory name', + memoryType: 'Memory type', + embeddingModel: 'Embedding model', + selectModel: 'Select model', + llm: 'LLM', + }, + memoryDetail: { + messages: { + sessionId: 'Session ID', + agent: 'Agent', + type: 'Type', + validDate: 'Valid date', + forgetAt: 'Forget at', + source: 'Source', + enable: 'Enable', + action: 'Action', + }, }, knowledgeList: { welcome: 'Welcome back', @@ -2044,14 +2067,21 @@ Important structured information may include: names, dates, locations, events, k delFilesContent: 'Selected {{count}} files', delChat: 'Delete chat', delMember: 'Delete member', + delMemory: 'Delete memory', }, empty: { noMCP: 'No MCP servers available', agentTitle: 'No agent app created yet', + notFoundAgent: 'Agent app not found', datasetTitle: 'No dataset created yet', + notFoundDataset: 'Dataset not found', chatTitle: 'No chat app created yet', + notFoundChat: 'Chat app not found', searchTitle: 'No search app created yet', + notFoundSearch: 'Search app not found', + memoryTitle: 'No memory created yet', + notFoundMemory: 'Memory not found', addNow: 'Add Now', }, diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 4179557c3..5d114594c 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -1900,9 +1900,15 @@ Tokenizer 会根据所选方式将内容存储为对应的数据结构。`, empty: { noMCP: '暂无 MCP 服务器可用', agentTitle: '尚未创建智能体', + notFoundAgent: '未查询到智能体', datasetTitle: '尚未创建数据集', + notFoundDataset: '未查询到数据集', chatTitle: '尚未创建聊天应用', + notFoundChat: '未查询到聊天应用', searchTitle: '尚未创建搜索应用', + notFoundSearch: '未查询到搜索应用', + memoryTitle: '尚未创建记忆', + notFoundMemory: '未查询到记忆', addNow: '立即添加', }, }, diff --git a/web/src/pages/agents/index.tsx b/web/src/pages/agents/index.tsx index 3600d3bd6..fd091c138 100644 --- a/web/src/pages/agents/index.tsx +++ b/web/src/pages/agents/index.tsx @@ -81,19 +81,20 @@ export default function Agents() { }, [isCreate, showCreatingModal, searchUrl, setSearchUrl]); return ( <> - {(!data?.length || data?.length <= 0) && ( + {(!data?.length || data?.length <= 0) && !searchString && (
showCreatingModal()} />
)}
- {!!data?.length && ( + {(!!data?.length || searchString) && ( <>
+ {(!data?.length || data?.length <= 0) && searchString && ( +
+ showCreatingModal()} + /> +
+ )}
{data.map((x) => { diff --git a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx index c6d18af13..a27dcfe6d 100644 --- a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx +++ b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx @@ -12,7 +12,7 @@ import { Switch } from '@/components/ui/switch'; import { useTranslate } from '@/hooks/common-hooks'; import { cn } from '@/lib/utils'; import { useMemo, useState } from 'react'; -import { useFormContext } from 'react-hook-form'; +import { FieldValues, useFormContext } from 'react-hook-form'; import { useHandleKbEmbedding, useHasParsedDocument, @@ -65,17 +65,59 @@ export function ChunkMethodItem(props: IProps) { /> ); } -export function EmbeddingModelItem({ line = 1, isEdit }: IProps) { + +export const EmbeddingSelect = ({ + isEdit, + field, + name, +}: { + isEdit: boolean; + field: FieldValues; + name?: string; +}) => { const { t } = useTranslate('knowledgeConfiguration'); const form = useFormContext(); const embeddingModelOptions = useSelectEmbeddingModelOptions(); const { handleChange } = useHandleKbEmbedding(); const disabled = useHasParsedDocument(isEdit); const oldValue = useMemo(() => { - const embdStr = form.getValues('embd_id'); + const embdStr = form.getValues(name || 'embd_id'); return embdStr || ''; }, [form]); const [loading, setLoading] = useState(false); + return ( + + { + field.onChange(value); + if (isEdit && disabled) { + setLoading(true); + const res = await handleChange({ + embed_id: value, + callback: field.onChange, + }); + if (res.code !== 0) { + field.onChange(oldValue); + } + setLoading(false); + } + }} + value={field.value} + options={embeddingModelOptions} + placeholder={t('embeddingModelPlaceholder')} + /> + + ); +}; + +export function EmbeddingModelItem({ line = 1, isEdit }: IProps) { + const { t } = useTranslate('knowledgeConfiguration'); + const form = useFormContext(); return ( <> - - { - field.onChange(value); - if (isEdit && disabled) { - setLoading(true); - const res = await handleChange({ - embed_id: value, - callback: field.onChange, - }); - if (res.code !== 0) { - field.onChange(oldValue); - } - setLoading(false); - } - }} - value={field.value} - options={embeddingModelOptions} - placeholder={t('embeddingModelPlaceholder')} - triggerClassName="!bg-bg-base" - /> - +
diff --git a/web/src/pages/datasets/index.tsx b/web/src/pages/datasets/index.tsx index c0515fc99..110a1e485 100644 --- a/web/src/pages/datasets/index.tsx +++ b/web/src/pages/datasets/index.tsx @@ -70,18 +70,19 @@ export default function Datasets() { return ( <>
- {(!kbs?.length || kbs?.length <= 0) && ( + {(!kbs?.length || kbs?.length <= 0) && !searchString && (
showModal()} />
)} - {!!kbs?.length && ( + {(!!kbs?.length || searchString) && ( <> + {(!kbs?.length || kbs?.length <= 0) && searchString && ( +
+ showModal()} + /> +
+ )}
{kbs.map((dataset) => { diff --git a/web/src/pages/memories/add-or-edit-modal.tsx b/web/src/pages/memories/add-or-edit-modal.tsx new file mode 100644 index 000000000..e5ec1082e --- /dev/null +++ b/web/src/pages/memories/add-or-edit-modal.tsx @@ -0,0 +1,75 @@ +import { DynamicForm, DynamicFormRef } from '@/components/dynamic-form'; +import { useModelOptions } from '@/components/llm-setting-items/llm-form-field'; +import { HomeIcon } from '@/components/svg-icon'; +import { Modal } from '@/components/ui/modal/modal'; +import { t } from 'i18next'; +import { useCallback, useEffect, useState } from 'react'; +import { createMemoryFields } from './constants'; +import { IMemory } from './interface'; + +type IProps = { + open: boolean; + onClose: () => void; + onSubmit?: (data: any) => void; + initialMemory: IMemory; + loading?: boolean; +}; +export const AddOrEditModal = (props: IProps) => { + const { open, onClose, onSubmit, initialMemory } = props; + // const [fields, setFields] = useState(createMemoryFields); + // const formRef = useRef(null); + const [formInstance, setFormInstance] = useState(null); + + const formCallbackRef = useCallback((node: DynamicFormRef | null) => { + if (node) { + // formRef.current = node; + setFormInstance(node); + } + }, []); + const { modelOptions } = useModelOptions(); + + useEffect(() => { + if (initialMemory && initialMemory.id) { + formInstance?.onFieldUpdate('memory_type', { hidden: true }); + formInstance?.onFieldUpdate('embedding', { hidden: true }); + formInstance?.onFieldUpdate('llm', { hidden: true }); + } else { + formInstance?.onFieldUpdate('llm', { options: modelOptions as any }); + } + }, [modelOptions, formInstance, initialMemory]); + + return ( + +
+ +
+ {t('memory.createMemory')} +
+ } + showfooter={false} + confirmLoading={props.loading} + > + {}} + defaultValues={initialMemory} + > +
+ + { + onSubmit?.(data); + }} + /> +
+
+ + ); +}; diff --git a/web/src/pages/memories/constants/index.tsx b/web/src/pages/memories/constants/index.tsx new file mode 100644 index 000000000..004298465 --- /dev/null +++ b/web/src/pages/memories/constants/index.tsx @@ -0,0 +1,41 @@ +import { FormFieldConfig, FormFieldType } from '@/components/dynamic-form'; +import { EmbeddingSelect } from '@/pages/dataset/dataset-setting/configuration/common-item'; +import { t } from 'i18next'; + +export const createMemoryFields = [ + { + name: 'memory_name', + label: t('memory.name'), + placeholder: t('memory.memoryNamePlaceholder'), + required: true, + }, + { + name: 'memory_type', + label: t('memory.memoryType'), + type: FormFieldType.MultiSelect, + placeholder: t('memory.descriptionPlaceholder'), + options: [ + { label: 'Raw', value: 'raw' }, + { label: 'Semantic', value: 'semantic' }, + { label: 'Episodic', value: 'episodic' }, + { label: 'Procedural', value: 'procedural' }, + ], + required: true, + }, + { + name: 'embedding', + label: t('memory.embeddingModel'), + placeholder: t('memory.selectModel'), + required: true, + // hideLabel: true, + // type: 'custom', + render: (field) => , + }, + { + name: 'llm', + label: t('memory.llm'), + placeholder: t('memory.selectModel'), + required: true, + type: FormFieldType.Select, + }, +] as FormFieldConfig[]; diff --git a/web/src/pages/memories/hooks.ts b/web/src/pages/memories/hooks.ts new file mode 100644 index 000000000..d1fa157e3 --- /dev/null +++ b/web/src/pages/memories/hooks.ts @@ -0,0 +1,288 @@ +// src/pages/next-memoryes/hooks.ts + +import message from '@/components/ui/message'; +import { useSetModalState } from '@/hooks/common-hooks'; +import { useHandleSearchChange } from '@/hooks/logic-hooks'; +import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks'; +import memoryService, { updateMemoryById } from '@/services/memory-service'; +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; +import { useDebounce } from 'ahooks'; +import { useCallback, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useParams, useSearchParams } from 'umi'; +import { + CreateMemoryResponse, + DeleteMemoryProps, + DeleteMemoryResponse, + ICreateMemoryProps, + IMemory, + IMemoryAppDetailProps, + MemoryDetailResponse, + MemoryListResponse, +} from './interface'; + +export const useCreateMemory = () => { + const { t } = useTranslation(); + + const { + data, + isError, + mutateAsync: createMemoryMutation, + } = useMutation({ + mutationKey: ['createMemory'], + mutationFn: async (props) => { + const { data: response } = await memoryService.createMemory(props); + if (response.code !== 0) { + throw new Error(response.message || 'Failed to create memory'); + } + return response.data; + }, + onSuccess: () => { + message.success(t('message.created')); + }, + onError: (error) => { + message.error(t('message.error', { error: error.message })); + }, + }); + + const createMemory = useCallback( + (props: ICreateMemoryProps) => { + return createMemoryMutation(props); + }, + [createMemoryMutation], + ); + + return { data, isError, createMemory }; +}; + +export const useFetchMemoryList = () => { + const { handleInputChange, searchString, pagination, setPagination } = + useHandleSearchChange(); + const debouncedSearchString = useDebounce(searchString, { wait: 500 }); + const { data, isLoading, isError, refetch } = useQuery< + MemoryListResponse, + Error + >({ + queryKey: [ + 'memoryList', + { + debouncedSearchString, + ...pagination, + }, + ], + queryFn: async () => { + const { data: response } = await memoryService.getMemoryList( + { + params: { + keywords: debouncedSearchString, + page_size: pagination.pageSize, + page: pagination.current, + }, + data: {}, + }, + true, + ); + if (response.code !== 0) { + throw new Error(response.message || 'Failed to fetch memory list'); + } + console.log(response); + return response; + }, + }); + + // const setMemoryListParams = (newParams: MemoryListParams) => { + // setMemoryParams((prevParams) => ({ + // ...prevParams, + // ...newParams, + // })); + // }; + + return { + data, + isLoading, + isError, + pagination, + searchString, + handleInputChange, + setPagination, + refetch, + }; +}; + +export const useFetchMemoryDetail = (tenantId?: string) => { + const { id } = useParams(); + + const [memoryParams] = useSearchParams(); + const shared_id = memoryParams.get('shared_id'); + const memoryId = id || shared_id; + let param: { id: string | null; tenant_id?: string } = { + id: memoryId, + }; + if (shared_id) { + param = { + id: memoryId, + tenant_id: tenantId, + }; + } + const fetchMemoryDetailFunc = shared_id + ? memoryService.getMemoryDetailShare + : memoryService.getMemoryDetail; + + const { data, isLoading, isError } = useQuery({ + queryKey: ['memoryDetail', memoryId], + enabled: !shared_id || !!tenantId, + queryFn: async () => { + const { data: response } = await fetchMemoryDetailFunc(param); + if (response.code !== 0) { + throw new Error(response.message || 'Failed to fetch memory detail'); + } + return response; + }, + }); + + return { data: data?.data, isLoading, isError }; +}; + +export const useDeleteMemory = () => { + const { t } = useTranslation(); + const queryClient = useQueryClient(); + const { + data, + isError, + mutateAsync: deleteMemoryMutation, + } = useMutation({ + mutationKey: ['deleteMemory'], + mutationFn: async (props) => { + const { data: response } = await memoryService.deleteMemory(props); + if (response.code !== 0) { + throw new Error(response.message || 'Failed to delete memory'); + } + + queryClient.invalidateQueries({ queryKey: ['memoryList'] }); + return response; + }, + onSuccess: () => { + message.success(t('message.deleted')); + }, + onError: (error) => { + message.error(t('message.error', { error: error.message })); + }, + }); + + const deleteMemory = useCallback( + (props: DeleteMemoryProps) => { + return deleteMemoryMutation(props); + }, + [deleteMemoryMutation], + ); + + return { data, isError, deleteMemory }; +}; + +export const useUpdateMemory = () => { + const { t } = useTranslation(); + const queryClient = useQueryClient(); + const { + data, + isError, + mutateAsync: updateMemoryMutation, + } = useMutation({ + mutationKey: ['updateMemory'], + mutationFn: async (formData) => { + const { data: response } = await updateMemoryById(formData.id, formData); + if (response.code !== 0) { + throw new Error(response.message || 'Failed to update memory'); + } + return response.data; + }, + onSuccess: (data, variables) => { + message.success(t('message.updated')); + queryClient.invalidateQueries({ + queryKey: ['memoryDetail', variables.id], + }); + }, + onError: (error) => { + message.error(t('message.error', { error: error.message })); + }, + }); + + const updateMemory = useCallback( + (formData: IMemoryAppDetailProps) => { + return updateMemoryMutation(formData); + }, + [updateMemoryMutation], + ); + + return { data, isError, updateMemory }; +}; + +export const useRenameMemory = () => { + const [memory, setMemory] = useState({} as IMemory); + const { navigateToMemory } = useNavigatePage(); + const { + visible: openCreateModal, + hideModal: hideChatRenameModal, + showModal: showChatRenameModal, + } = useSetModalState(); + const { updateMemory } = useUpdateMemory(); + const { createMemory } = useCreateMemory(); + const [loading, setLoading] = useState(false); + + const handleShowChatRenameModal = useCallback( + (record?: IMemory) => { + if (record) { + setMemory(record); + } + showChatRenameModal(); + }, + [showChatRenameModal], + ); + + const handleHideModal = useCallback(() => { + hideChatRenameModal(); + setMemory({} as IMemory); + }, [hideChatRenameModal]); + + const onMemoryRenameOk = useCallback( + async (data: ICreateMemoryProps, callBack?: () => void) => { + let res; + setLoading(true); + if (memory?.id) { + try { + // const reponse = await memoryService.getMemoryDetail({ + // id: memory?.id, + // }); + // const detail = reponse.data?.data; + // console.log('detail-->', detail); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + // const { id, created_by, update_time, ...memoryDataTemp } = detail; + res = await updateMemory({ + // ...memoryDataTemp, + name: data.memory_name, + id: memory?.id, + } as unknown as IMemoryAppDetailProps); + } catch (e) { + console.error('error', e); + } + } else { + res = await createMemory(data); + } + if (res && !memory?.id) { + navigateToMemory(res?.id)(); + } + callBack?.(); + setLoading(false); + handleHideModal(); + }, + [memory, createMemory, handleHideModal, navigateToMemory, updateMemory], + ); + return { + memoryRenameLoading: loading, + initialMemory: memory, + onMemoryRenameOk, + openCreateModal, + hideMemoryModal: handleHideModal, + showMemoryRenameModal: handleShowChatRenameModal, + }; +}; diff --git a/web/src/pages/memories/index.tsx b/web/src/pages/memories/index.tsx new file mode 100644 index 000000000..49407765d --- /dev/null +++ b/web/src/pages/memories/index.tsx @@ -0,0 +1,163 @@ +import { CardContainer } from '@/components/card-container'; +import { EmptyCardType } from '@/components/empty/constant'; +import { EmptyAppCard } from '@/components/empty/empty'; +import ListFilterBar from '@/components/list-filter-bar'; +import { Button } from '@/components/ui/button'; +import { RAGFlowPagination } from '@/components/ui/ragflow-pagination'; +import { useTranslate } from '@/hooks/common-hooks'; +import { pick } from 'lodash'; +import { Plus } from 'lucide-react'; +import { useCallback, useEffect } from 'react'; +import { useSearchParams } from 'umi'; +import { AddOrEditModal } from './add-or-edit-modal'; +import { useFetchMemoryList, useRenameMemory } from './hooks'; +import { ICreateMemoryProps } from './interface'; +import { MemoryCard } from './memory-card'; + +export default function MemoryList() { + // const { data } = useFetchFlowList(); + const { t } = useTranslate('memory'); + // const [isEdit, setIsEdit] = useState(false); + const { + data: list, + pagination, + searchString, + handleInputChange, + setPagination, + refetch: refetchList, + } = useFetchMemoryList(); + + const { + openCreateModal, + showMemoryRenameModal, + hideMemoryModal, + searchRenameLoading, + onMemoryRenameOk, + initialMemory, + } = useRenameMemory(); + + const onMemoryConfirm = (data: ICreateMemoryProps) => { + onMemoryRenameOk(data, () => { + refetchList(); + }); + }; + const openCreateModalFun = useCallback(() => { + // setIsEdit(false); + showMemoryRenameModal(); + }, [showMemoryRenameModal]); + const handlePageChange = useCallback( + (page: number, pageSize?: number) => { + setPagination({ page, pageSize }); + }, + [setPagination], + ); + + const [searchUrl, setMemoryUrl] = useSearchParams(); + const isCreate = searchUrl.get('isCreate') === 'true'; + useEffect(() => { + if (isCreate) { + openCreateModalFun(); + searchUrl.delete('isCreate'); + setMemoryUrl(searchUrl); + } + }, [isCreate, openCreateModalFun, searchUrl, setMemoryUrl]); + + return ( +
+ {(!list?.data?.memory_list?.length || + list?.data?.memory_list?.length <= 0) && + !searchString && ( +
+ openCreateModalFun()} + /> +
+ )} + {(!!list?.data?.memory_list?.length || searchString) && ( + <> +
+ + + +
+ {(!list?.data?.memory_list?.length || + list?.data?.memory_list?.length <= 0) && + searchString && ( +
+ openCreateModalFun()} + /> +
+ )} +
+ + {list?.data.memory_list.map((x) => { + return ( + { + showMemoryRenameModal(x); + }} + > + ); + })} + +
+ {list?.data.total && list?.data.total > 0 && ( +
+ +
+ )} + + )} + {/* {openCreateModal && ( + } + > + )} */} + {openCreateModal && ( + + )} +
+ ); +} diff --git a/web/src/pages/memories/interface.ts b/web/src/pages/memories/interface.ts new file mode 100644 index 000000000..46cba578c --- /dev/null +++ b/web/src/pages/memories/interface.ts @@ -0,0 +1,121 @@ +export interface ICreateMemoryProps { + memory_name: string; + memory_type: Array; + embedding: string; + llm: string; +} + +export interface CreateMemoryResponse { + id: string; + name: string; + description: string; +} + +export interface MemoryListParams { + keywords?: string; + parser_id?: string; + page?: number; + page_size?: number; + orderby?: string; + desc?: boolean; + owner_ids?: string; +} +export type MemoryType = 'raw' | 'semantic' | 'episodic' | 'procedural'; +export type StorageType = 'table' | 'graph'; +export type Permissions = 'me' | 'team'; +export type ForgettingPolicy = 'fifo' | 'lru'; + +export interface IMemory { + id: string; + name: string; + avatar: string; + tenant_id: string; + owner_name: string; + memory_type: MemoryType[]; + storage_type: StorageType; + embedding: string; + llm: string; + permissions: Permissions; + description: string; + memory_size: number; + forgetting_policy: ForgettingPolicy; + temperature: string; + system_prompt: string; + user_prompt: string; +} +export interface MemoryListResponse { + code: number; + data: { + memory_list: Array; + total: number; + }; + message: string; +} + +export interface DeleteMemoryProps { + memory_id: string; +} + +export interface DeleteMemoryResponse { + code: number; + data: boolean; + message: string; +} + +export interface IllmSettingProps { + llm_id: string; + parameter: string; + temperature?: number; + top_p?: number; + frequency_penalty?: number; + presence_penalty?: number; +} +interface IllmSettingEnableProps { + temperatureEnabled?: boolean; + topPEnabled?: boolean; + presencePenaltyEnabled?: boolean; + frequencyPenaltyEnabled?: boolean; +} +export interface IMemoryAppDetailProps { + avatar: any; + created_by: string; + description: string; + id: string; + name: string; + memory_config: { + cross_languages: string[]; + doc_ids: string[]; + chat_id: string; + highlight: boolean; + kb_ids: string[]; + keyword: boolean; + query_mindmap: boolean; + related_memory: boolean; + rerank_id: string; + use_rerank?: boolean; + similarity_threshold: number; + summary: boolean; + llm_setting: IllmSettingProps & IllmSettingEnableProps; + top_k: number; + use_kg: boolean; + vector_similarity_weight: number; + web_memory: boolean; + chat_settingcross_languages: string[]; + meta_data_filter?: { + method: string; + manual: { key: string; op: string; value: string }[]; + }; + }; + tenant_id: string; + update_time: number; +} + +export interface MemoryDetailResponse { + code: number; + data: IMemoryAppDetailProps; + message: string; +} + +// export type IUpdateMemoryProps = Omit & { +// id: string; +// }; diff --git a/web/src/pages/memories/memory-card.tsx b/web/src/pages/memories/memory-card.tsx new file mode 100644 index 000000000..716b19313 --- /dev/null +++ b/web/src/pages/memories/memory-card.tsx @@ -0,0 +1,32 @@ +import { HomeCard } from '@/components/home-card'; +import { MoreButton } from '@/components/more-button'; +import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks'; +import { IMemory } from './interface'; +import { MemoryDropdown } from './memory-dropdown'; + +interface IProps { + data: IMemory; + showMemoryRenameModal: (data: IMemory) => void; +} +export function MemoryCard({ data, showMemoryRenameModal }: IProps) { + const { navigateToMemory } = useNavigatePage(); + + return ( + + + + } + onClick={navigateToMemory(data?.id)} + /> + ); +} diff --git a/web/src/pages/memories/memory-dropdown.tsx b/web/src/pages/memories/memory-dropdown.tsx new file mode 100644 index 000000000..2bcdcac1a --- /dev/null +++ b/web/src/pages/memories/memory-dropdown.tsx @@ -0,0 +1,74 @@ +import { + ConfirmDeleteDialog, + ConfirmDeleteDialogNode, +} from '@/components/confirm-delete-dialog'; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu'; +import { PenLine, Trash2 } from 'lucide-react'; +import { MouseEventHandler, PropsWithChildren, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { IMemoryAppProps, useDeleteMemory } from './hooks'; + +export function MemoryDropdown({ + children, + dataset, + showMemoryRenameModal, +}: PropsWithChildren & { + dataset: IMemoryAppProps; + showMemoryRenameModal: (dataset: IMemoryAppProps) => void; +}) { + const { t } = useTranslation(); + const { deleteMemory } = useDeleteMemory(); + const handleShowChatRenameModal: MouseEventHandler = + useCallback( + (e) => { + e.stopPropagation(); + showMemoryRenameModal(dataset); + }, + [dataset, showMemoryRenameModal], + ); + const handleDelete: MouseEventHandler = useCallback(() => { + deleteMemory({ search_id: dataset.id }); + }, [dataset.id, deleteMemory]); + + return ( + + {children} + + + {t('common.rename')} + + + + ), + }} + > + { + e.preventDefault(); + }} + onClick={(e) => { + e.stopPropagation(); + }} + > + {t('common.delete')} + + + + + ); +} diff --git a/web/src/pages/memory/constant.tsx b/web/src/pages/memory/constant.tsx new file mode 100644 index 000000000..1377fccc5 --- /dev/null +++ b/web/src/pages/memory/constant.tsx @@ -0,0 +1,3 @@ +export enum MemoryApiAction { + FetchMemoryDetail = 'fetchMemoryDetail', +} diff --git a/web/src/pages/memory/hooks/use-memory-messages.ts b/web/src/pages/memory/hooks/use-memory-messages.ts new file mode 100644 index 000000000..e6345ae1a --- /dev/null +++ b/web/src/pages/memory/hooks/use-memory-messages.ts @@ -0,0 +1,59 @@ +import { useHandleSearchChange } from '@/hooks/logic-hooks'; +import { getMemoryDetailById } from '@/services/memory-service'; +import { useQuery } from '@tanstack/react-query'; +import { useParams, useSearchParams } from 'umi'; +import { MemoryApiAction } from '../constant'; +import { IMessageTableProps } from '../memory-message/interface'; + +export const useFetchMemoryMessageList = (props?: { + refreshCount?: number; +}) => { + const { refreshCount } = props || {}; + const { id } = useParams(); + const [searchParams] = useSearchParams(); + const memoryBaseId = searchParams.get('id') || id; + const { handleInputChange, searchString, pagination, setPagination } = + useHandleSearchChange(); + + let queryKey: (MemoryApiAction | number)[] = [ + MemoryApiAction.FetchMemoryDetail, + ]; + if (typeof refreshCount === 'number') { + queryKey = [MemoryApiAction.FetchMemoryDetail, refreshCount]; + } + + const { data, isFetching: loading } = useQuery({ + queryKey: [...queryKey, searchString, pagination], + initialData: {} as IMessageTableProps, + gcTime: 0, + queryFn: async () => { + if (memoryBaseId) { + const { data } = await getMemoryDetailById(memoryBaseId as string, { + // filter: { + // agent_id: '', + // }, + keyword: searchString, + page: pagination.current, + page_size: pagination.pageSize, + }); + // setPagination({ + // page: data?.page ?? 1, + // pageSize: data?.page_size ?? 10, + // total: data?.total ?? 0, + // }); + return data?.data ?? {}; + } else { + return {}; + } + }, + }); + + return { + data, + loading, + handleInputChange, + searchString, + pagination, + setPagination, + }; +}; diff --git a/web/src/pages/memory/hooks/use-memory-setting.ts b/web/src/pages/memory/hooks/use-memory-setting.ts new file mode 100644 index 000000000..bbca1c6ee --- /dev/null +++ b/web/src/pages/memory/hooks/use-memory-setting.ts @@ -0,0 +1,59 @@ +import { useHandleSearchChange } from '@/hooks/logic-hooks'; +import { IMemory } from '@/pages/memories/interface'; +import { getMemoryDetailById } from '@/services/memory-service'; +import { useQuery } from '@tanstack/react-query'; +import { useParams, useSearchParams } from 'umi'; +import { MemoryApiAction } from '../constant'; + +export const useFetchMemoryBaseConfiguration = (props?: { + refreshCount?: number; +}) => { + const { refreshCount } = props || {}; + const { id } = useParams(); + const [searchParams] = useSearchParams(); + const memoryBaseId = searchParams.get('id') || id; + const { handleInputChange, searchString, pagination, setPagination } = + useHandleSearchChange(); + + let queryKey: (MemoryApiAction | number)[] = [ + MemoryApiAction.FetchMemoryDetail, + ]; + if (typeof refreshCount === 'number') { + queryKey = [MemoryApiAction.FetchMemoryDetail, refreshCount]; + } + + const { data, isFetching: loading } = useQuery({ + queryKey: [...queryKey, searchString, pagination], + initialData: {} as IMemory, + gcTime: 0, + queryFn: async () => { + if (memoryBaseId) { + const { data } = await getMemoryDetailById(memoryBaseId as string, { + // filter: { + // agent_id: '', + // }, + keyword: searchString, + page: pagination.current, + page_size: pagination.size, + }); + // setPagination({ + // page: data?.page ?? 1, + // pageSize: data?.page_size ?? 10, + // total: data?.total ?? 0, + // }); + return data?.data ?? {}; + } else { + return {}; + } + }, + }); + + return { + data, + loading, + handleInputChange, + searchString, + pagination, + setPagination, + }; +}; diff --git a/web/src/pages/memory/index.tsx b/web/src/pages/memory/index.tsx new file mode 100644 index 000000000..3536a71b7 --- /dev/null +++ b/web/src/pages/memory/index.tsx @@ -0,0 +1,17 @@ +import Spotlight from '@/components/spotlight'; +import { Outlet } from 'umi'; +import { SideBar } from './sidebar'; + +export default function DatasetWrapper() { + return ( +
+
+ +
+ + +
+
+
+ ); +} diff --git a/web/src/pages/memory/memory-message/index.tsx b/web/src/pages/memory/memory-message/index.tsx new file mode 100644 index 000000000..c0ec80823 --- /dev/null +++ b/web/src/pages/memory/memory-message/index.tsx @@ -0,0 +1,51 @@ +import ListFilterBar from '@/components/list-filter-bar'; +import { t } from 'i18next'; +import { useFetchMemoryMessageList } from '../hooks/use-memory-messages'; +import { MemoryTable } from './message-table'; + +export default function MemoryMessage() { + const { + searchString, + // documents, + data, + pagination, + handleInputChange, + setPagination, + // filterValue, + // handleFilterSubmit, + loading, + } = useFetchMemoryMessageList(); + return ( +
+ +
{t('knowledgeDetails.subbarFiles')}
+
+ {t('knowledgeDetails.datasetDescription')} +
+
+ } + > + +
+
message
+
+ + ); +} diff --git a/web/src/pages/memory/memory-message/interface.ts b/web/src/pages/memory/memory-message/interface.ts new file mode 100644 index 000000000..234ca438a --- /dev/null +++ b/web/src/pages/memory/memory-message/interface.ts @@ -0,0 +1,19 @@ +export interface IMessageInfo { + message_id: number; + message_type: 'semantic' | 'raw' | 'procedural'; + source_id: string | '-'; + id: string; + user_id: string; + agent_id: string; + agent_name: string; + session_id: string; + valid_at: string; + invalid_at: string; + forget_at: string; + status: boolean; +} + +export interface IMessageTableProps { + messages: { message_list: Array; total: number }; + storage_type: string; +} diff --git a/web/src/pages/memory/memory-message/message-table.tsx b/web/src/pages/memory/memory-message/message-table.tsx new file mode 100644 index 000000000..2174c2f79 --- /dev/null +++ b/web/src/pages/memory/memory-message/message-table.tsx @@ -0,0 +1,225 @@ +import { + ColumnDef, + ColumnFiltersState, + SortingState, + VisibilityState, + flexRender, + getCoreRowModel, + getFilteredRowModel, + getPaginationRowModel, + getSortedRowModel, + useReactTable, +} from '@tanstack/react-table'; +import * as React from 'react'; + +import { EmptyType } from '@/components/empty/constant'; +import Empty from '@/components/empty/empty'; +import { Button } from '@/components/ui/button'; +import { RAGFlowPagination } from '@/components/ui/ragflow-pagination'; +import { Switch } from '@/components/ui/switch'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table'; +import { Pagination } from '@/interfaces/common'; +import { t } from 'i18next'; +import { pick } from 'lodash'; +import { Eraser, TextSelect } from 'lucide-react'; +import { useMemo } from 'react'; +import { IMessageInfo } from './interface'; + +export type MemoryTableProps = { + messages: Array; + total: number; + pagination: Pagination; + setPagination: (params: { page: number; pageSize: number }) => void; +}; + +export function MemoryTable({ + messages, + total, + pagination, + setPagination, +}: MemoryTableProps) { + const [sorting, setSorting] = React.useState([]); + const [columnFilters, setColumnFilters] = React.useState( + [], + ); + const [columnVisibility, setColumnVisibility] = + React.useState({}); + + // Define columns for the memory table + const columns: ColumnDef[] = useMemo( + () => [ + { + accessorKey: 'session_id', + header: () => {t('memoryDetail.messages.sessionId')}, + cell: ({ row }) => ( +
+ {row.getValue('session_id')} +
+ ), + }, + { + accessorKey: 'agent_name', + header: () => {t('memoryDetail.messages.agent')}, + cell: ({ row }) => ( +
+ {row.getValue('agent_name')} +
+ ), + }, + { + accessorKey: 'message_type', + header: () => {t('memoryDetail.messages.type')}, + cell: ({ row }) => ( +
+ {row.getValue('message_type')} +
+ ), + }, + { + accessorKey: 'valid_at', + header: () => {t('memoryDetail.messages.validDate')}, + cell: ({ row }) => ( +
{row.getValue('valid_at')}
+ ), + }, + { + accessorKey: 'forget_at', + header: () => {t('memoryDetail.messages.forgetAt')}, + cell: ({ row }) => ( +
{row.getValue('forget_at')}
+ ), + }, + { + accessorKey: 'source_id', + header: () => {t('memoryDetail.messages.source')}, + cell: ({ row }) => ( +
{row.getValue('source_id')}
+ ), + }, + { + accessorKey: 'status', + header: () => {t('memoryDetail.messages.enable')}, + cell: ({ row }) => { + const isEnabled = row.getValue('status') as boolean; + return ( +
+ {}} /> +
+ ); + }, + }, + { + accessorKey: 'action', + header: () => {t('memoryDetail.messages.action')}, + meta: { + cellClassName: 'w-12', + }, + cell: () => ( +
+ + +
+ ), + }, + ], + [], + ); + + const currentPagination = useMemo(() => { + return { + pageIndex: (pagination.current || 1) - 1, + pageSize: pagination.pageSize || 10, + }; + }, [pagination]); + + const table = useReactTable({ + data: messages, + columns, + onSortingChange: setSorting, + onColumnFiltersChange: setColumnFilters, + getCoreRowModel: getCoreRowModel(), + getPaginationRowModel: getPaginationRowModel(), + getSortedRowModel: getSortedRowModel(), + getFilteredRowModel: getFilteredRowModel(), + onColumnVisibilityChange: setColumnVisibility, + manualPagination: true, + state: { + sorting, + columnFilters, + columnVisibility, + pagination: currentPagination, + }, + rowCount: total, + }); + + return ( +
+ + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + + {header.isPlaceholder + ? null + : flexRender( + header.column.columnDef.header, + header.getContext(), + )} + + ))} + + ))} + + + {table.getRowModel().rows?.length ? ( + table.getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ))} + + )) + ) : ( + + + + + + )} + +
+ +
+ { + setPagination({ page, pageSize }); + }} + /> +
+
+ ); +} diff --git a/web/src/pages/memory/memory-setting/index.tsx b/web/src/pages/memory/memory-setting/index.tsx new file mode 100644 index 000000000..27be86cfd --- /dev/null +++ b/web/src/pages/memory/memory-setting/index.tsx @@ -0,0 +1,13 @@ +export default function MemoryMessage() { + return ( +
+
+
11
+
11
+
+
+
setting
+
+
+ ); +} diff --git a/web/src/pages/memory/sidebar/hooks.tsx b/web/src/pages/memory/sidebar/hooks.tsx new file mode 100644 index 000000000..1dd28785a --- /dev/null +++ b/web/src/pages/memory/sidebar/hooks.tsx @@ -0,0 +1,17 @@ +import { Routes } from '@/routes'; +import { useCallback } from 'react'; +import { useNavigate, useParams } from 'umi'; + +export const useHandleMenuClick = () => { + const navigate = useNavigate(); + const { id } = useParams(); + + const handleMenuClick = useCallback( + (key: Routes) => () => { + navigate(`${Routes.Memory}${key}/${id}`); + }, + [id, navigate], + ); + + return { handleMenuClick }; +}; diff --git a/web/src/pages/memory/sidebar/index.tsx b/web/src/pages/memory/sidebar/index.tsx new file mode 100644 index 000000000..98928fe43 --- /dev/null +++ b/web/src/pages/memory/sidebar/index.tsx @@ -0,0 +1,88 @@ +import { RAGFlowAvatar } from '@/components/ragflow-avatar'; +import { Button } from '@/components/ui/button'; +import { useSecondPathName } from '@/hooks/route-hook'; +import { cn, formatBytes } from '@/lib/utils'; +import { Routes } from '@/routes'; +import { formatPureDate } from '@/utils/date'; +import { Banknote, Logs } from 'lucide-react'; +import { useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useFetchMemoryBaseConfiguration } from '../hooks/use-memory-setting'; +import { useHandleMenuClick } from './hooks'; + +type PropType = { + refreshCount?: number; +}; + +export function SideBar({ refreshCount }: PropType) { + const pathName = useSecondPathName(); + const { handleMenuClick } = useHandleMenuClick(); + // refreshCount: be for avatar img sync update on top left + const { data } = useFetchMemoryBaseConfiguration({ refreshCount }); + const { t } = useTranslation(); + + const items = useMemo(() => { + const list = [ + { + icon: , + label: t(`knowledgeDetails.overview`), + key: Routes.MemoryMessage, + }, + { + icon: , + label: t(`knowledgeDetails.configuration`), + key: Routes.MemorySetting, + }, + ]; + return list; + }, [t]); + + return ( + + ); +} diff --git a/web/src/pages/next-chats/index.tsx b/web/src/pages/next-chats/index.tsx index e6667252f..15940e6d8 100644 --- a/web/src/pages/next-chats/index.tsx +++ b/web/src/pages/next-chats/index.tsx @@ -50,18 +50,19 @@ export default function ChatList() { return (
- {data.dialogs?.length <= 0 && ( + {data.dialogs?.length <= 0 && !searchString && (
handleShowCreateModal()} />
)} - {data.dialogs?.length > 0 && ( + {(data.dialogs?.length > 0 || searchString) && ( <>
+ {data.dialogs?.length <= 0 && searchString && ( +
+ handleShowCreateModal()} + /> +
+ )}
{data.dialogs.map((x) => { diff --git a/web/src/pages/next-searches/hooks.ts b/web/src/pages/next-searches/hooks.ts index 1787d7f23..64699e64a 100644 --- a/web/src/pages/next-searches/hooks.ts +++ b/web/src/pages/next-searches/hooks.ts @@ -2,9 +2,11 @@ import message from '@/components/ui/message'; import { useSetModalState } from '@/hooks/common-hooks'; +import { useHandleSearchChange } from '@/hooks/logic-hooks'; import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks'; import searchService from '@/services/search-service'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; +import { useDebounce } from 'ahooks'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useParams, useSearchParams } from 'umi'; @@ -84,21 +86,34 @@ interface SearchListResponse { message: string; } -export const useFetchSearchList = (params?: SearchListParams) => { - const [searchParams, setSearchParams] = useState({ - page: 1, - page_size: 50, - ...params, - }); +export const useFetchSearchList = () => { + const { handleInputChange, searchString, pagination, setPagination } = + useHandleSearchChange(); + const debouncedSearchString = useDebounce(searchString, { wait: 500 }); const { data, isLoading, isError, refetch } = useQuery< SearchListResponse, Error >({ - queryKey: ['searchList', searchParams], + queryKey: [ + 'searchList', + { + debouncedSearchString, + ...pagination, + }, + ], queryFn: async () => { - const { data: response } = - await searchService.getSearchList(searchParams); + const { data: response } = await searchService.getSearchList( + { + params: { + keywords: debouncedSearchString, + page_size: pagination.pageSize, + page: pagination.current, + }, + data: {}, + }, + true, + ); if (response.code !== 0) { throw new Error(response.message || 'Failed to fetch search list'); } @@ -106,19 +121,14 @@ export const useFetchSearchList = (params?: SearchListParams) => { }, }); - const setSearchListParams = (newParams: SearchListParams) => { - setSearchParams((prevParams) => ({ - ...prevParams, - ...newParams, - })); - }; - return { data, isLoading, isError, - searchParams, - setSearchListParams, + pagination, + searchString, + handleInputChange, + setPagination, refetch, }; }; diff --git a/web/src/pages/next-searches/index.tsx b/web/src/pages/next-searches/index.tsx index 4267dbbe7..89a44a9cc 100644 --- a/web/src/pages/next-searches/index.tsx +++ b/web/src/pages/next-searches/index.tsx @@ -7,6 +7,7 @@ import { RenameDialog } from '@/components/rename-dialog'; import { Button } from '@/components/ui/button'; import { RAGFlowPagination } from '@/components/ui/ragflow-pagination'; import { useTranslate } from '@/hooks/common-hooks'; +import { pick } from 'lodash'; import { Plus } from 'lucide-react'; import { useCallback, useEffect } from 'react'; import { useSearchParams } from 'umi'; @@ -19,10 +20,13 @@ export default function SearchList() { // const [isEdit, setIsEdit] = useState(false); const { data: list, - searchParams, - setSearchListParams, + pagination, + searchString, + handleInputChange, + setPagination, refetch: refetchList, } = useFetchSearchList(); + const { openCreateModal, showSearchRenameModal, @@ -32,9 +36,9 @@ export default function SearchList() { initialSearchName, } = useRenameSearch(); - const handleSearchChange = (value: string) => { - console.log(value); - }; + // const handleSearchChange = (value: string) => { + // console.log(value); + // }; const onSearchRenameConfirm = (name: string) => { onSearchRenameOk(name, () => { refetchList(); @@ -44,10 +48,12 @@ export default function SearchList() { // setIsEdit(false); showSearchRenameModal(); }, [showSearchRenameModal]); - const handlePageChange = (page: number, pageSize: number) => { - // setIsEdit(false); - setSearchListParams({ ...searchParams, page, page_size: pageSize }); - }; + const handlePageChange = useCallback( + (page: number, pageSize?: number) => { + setPagination({ page, pageSize }); + }, + [setPagination], + ); const [searchUrl, setSearchUrl] = useSearchParams(); const isCreate = searchUrl.get('isCreate') === 'true'; @@ -62,25 +68,28 @@ export default function SearchList() { return (
{(!list?.data?.search_apps?.length || - list?.data?.search_apps?.length <= 0) && ( -
- openCreateModalFun()} - /> -
- )} - {!!list?.data?.search_apps?.length && ( + list?.data?.search_apps?.length <= 0) && + !searchString && ( +
+ openCreateModalFun()} + /> +
+ )} + {(!!list?.data?.search_apps?.length || searchString) && ( <>
handleSearchChange(e.target.value)} + searchString={searchString} + onSearchChange={handleInputChange} >
+ {(!list?.data?.search_apps?.length || + list?.data?.search_apps?.length <= 0) && + searchString && ( +
+ openCreateModalFun()} + /> +
+ )}
{list?.data.search_apps.map((x) => { @@ -111,8 +134,8 @@ export default function SearchList() { {list?.data.total && list?.data.total > 0 && (
diff --git a/web/src/pages/user-setting/data-source/component/confluence-token-field.tsx b/web/src/pages/user-setting/data-source/component/confluence-token-field.tsx index 5fe50b931..6c7e201d4 100644 --- a/web/src/pages/user-setting/data-source/component/confluence-token-field.tsx +++ b/web/src/pages/user-setting/data-source/component/confluence-token-field.tsx @@ -1,9 +1,10 @@ -import { useEffect, useMemo } from 'react'; +import { useCallback, useEffect, useMemo, useState } from 'react'; import { ControllerRenderProps, useFormContext } from 'react-hook-form'; import { Checkbox } from '@/components/ui/checkbox'; import { Input } from '@/components/ui/input'; import { cn } from '@/lib/utils'; +import { debounce } from 'lodash'; /* ---------------- Token Field ---------------- */ @@ -48,15 +49,15 @@ type ConfluenceIndexingMode = 'everything' | 'space' | 'page'; export type ConfluenceIndexingModeFieldProps = ControllerRenderProps; export const ConfluenceIndexingModeField = ( - fieldProps: ConfluenceIndexingModeFieldProps, + fieldProps: ControllerRenderProps, ) => { const { value, onChange, disabled } = fieldProps; + const [mode, setMode] = useState( + value || 'everything', + ); const { watch, setValue } = useFormContext(); - const mode = useMemo( - () => (value as ConfluenceIndexingMode) || 'everything', - [value], - ); + useEffect(() => setMode(value), [value]); const spaceValue = watch('config.space'); const pageIdValue = watch('config.page_id'); @@ -66,27 +67,40 @@ export const ConfluenceIndexingModeField = ( if (!value) onChange('everything'); }, [value, onChange]); - const handleModeChange = (nextMode?: string) => { - const normalized = (nextMode || 'everything') as ConfluenceIndexingMode; - onChange(normalized); + const handleModeChange = useCallback( + (nextMode?: string) => { + let normalized: ConfluenceIndexingMode = 'everything'; + if (nextMode) { + normalized = nextMode as ConfluenceIndexingMode; + setMode(normalized); + onChange(normalized); + } else { + setMode(mode); + normalized = mode; + onChange(mode); + // onChange(mode); + } + if (normalized === 'everything') { + setValue('config.space', ''); + setValue('config.page_id', ''); + setValue('config.index_recursively', false); + } else if (normalized === 'space') { + setValue('config.page_id', ''); + setValue('config.index_recursively', false); + } else if (normalized === 'page') { + setValue('config.space', ''); + } + }, + [mode, onChange, setValue], + ); - if (normalized === 'everything') { - setValue('config.space', '', { shouldDirty: true, shouldTouch: true }); - setValue('config.page_id', '', { shouldDirty: true, shouldTouch: true }); - setValue('config.index_recursively', false, { - shouldDirty: true, - shouldTouch: true, - }); - } else if (normalized === 'space') { - setValue('config.page_id', '', { shouldDirty: true, shouldTouch: true }); - setValue('config.index_recursively', false, { - shouldDirty: true, - shouldTouch: true, - }); - } else if (normalized === 'page') { - setValue('config.space', '', { shouldDirty: true, shouldTouch: true }); - } - }; + const debouncedHandleChange = useMemo( + () => + debounce(() => { + handleModeChange(); + }, 300), + [handleModeChange], + ); return (
@@ -127,12 +141,11 @@ export const ConfluenceIndexingModeField = ( - setValue('config.space', e.target.value, { - shouldDirty: true, - shouldTouch: true, - }) - } + onChange={(e) => { + const value = e.target.value; + setValue('config.space', value); + debouncedHandleChange(); + }} placeholder="e.g. KB" disabled={disabled} /> @@ -148,12 +161,10 @@ export const ConfluenceIndexingModeField = ( - setValue('config.page_id', e.target.value, { - shouldDirty: true, - shouldTouch: true, - }) - } + onChange={(e) => { + setValue('config.page_id', e.target.value); + debouncedHandleChange(); + }} placeholder="e.g. 123456" disabled={disabled} /> @@ -164,12 +175,10 @@ export const ConfluenceIndexingModeField = (
- setValue('config.index_recursively', Boolean(checked), { - shouldDirty: true, - shouldTouch: true, - }) - } + onCheckedChange={(checked) => { + setValue('config.index_recursively', Boolean(checked)); + debouncedHandleChange(); + }} disabled={disabled} /> diff --git a/web/src/pages/user-setting/data-source/contant.tsx b/web/src/pages/user-setting/data-source/contant.tsx index 34ced0ae2..b3e86e118 100644 --- a/web/src/pages/user-setting/data-source/contant.tsx +++ b/web/src/pages/user-setting/data-source/contant.tsx @@ -1,6 +1,7 @@ import { FormFieldType } from '@/components/dynamic-form'; import SvgIcon from '@/components/svg-icon'; import { t } from 'i18next'; +import { ControllerRenderProps } from 'react-hook-form'; import { ConfluenceIndexingModeField } from './component/confluence-token-field'; import GmailTokenField from './component/gmail-token-field'; import GoogleDriveTokenField from './component/google-drive-token-field'; @@ -237,7 +238,9 @@ export const DataSourceFormFields = { required: false, horizontal: true, labelClassName: 'self-start pt-4', - render: (fieldProps) => , + render: (fieldProps: ControllerRenderProps) => ( + + ), }, { label: 'Space Key', @@ -598,6 +601,7 @@ export const DataSourceFormDefaultValues = { confluence_username: '', confluence_access_token: '', }, + index_mode: 'everything', }, }, [DataSourceKey.GOOGLE_DRIVE]: { diff --git a/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx b/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx index fe54dda64..f399fd21d 100644 --- a/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx +++ b/web/src/pages/user-setting/data-source/data-source-detail-page/index.tsx @@ -136,7 +136,7 @@ const SourceDetailPage = () => { ...customFields, ] as FormFieldConfig[]; - const neweFields = fields.map((field) => { + const newFields = fields.map((field) => { return { ...field, horizontal: true, @@ -145,7 +145,7 @@ const SourceDetailPage = () => { }, }; }); - setFields(neweFields); + setFields(newFields); const defultValueTemp = { ...(DataSourceFormDefaultValues[ diff --git a/web/src/pages/user-setting/setting-model/components/llm-header.tsx b/web/src/pages/user-setting/setting-model/components/llm-header.tsx new file mode 100644 index 000000000..0c90cf6b7 --- /dev/null +++ b/web/src/pages/user-setting/setting-model/components/llm-header.tsx @@ -0,0 +1,34 @@ +import { LlmIcon } from '@/components/svg-icon'; +import { Button } from '@/components/ui/button'; +import { APIMapUrl } from '@/constants/llm'; +import { t } from 'i18next'; +import { ArrowUpRight, Plus } from 'lucide-react'; + +export const LLMHeader = ({ name }: { name: string }) => { + return ( +
+ +
+
{name}
+ {!!APIMapUrl[name as keyof typeof APIMapUrl] && ( + + )} +
+ +
+ ); +}; diff --git a/web/src/pages/user-setting/setting-model/components/un-add-model.tsx b/web/src/pages/user-setting/setting-model/components/un-add-model.tsx index e73f32c95..f4592a796 100644 --- a/web/src/pages/user-setting/setting-model/components/un-add-model.tsx +++ b/web/src/pages/user-setting/setting-model/components/un-add-model.tsx @@ -2,9 +2,10 @@ import { LlmIcon } from '@/components/svg-icon'; import { Button } from '@/components/ui/button'; import { SearchInput } from '@/components/ui/input'; +import { APIMapUrl } from '@/constants/llm'; import { useTranslate } from '@/hooks/common-hooks'; import { useSelectLlmList } from '@/hooks/use-llm-request'; -import { Plus } from 'lucide-react'; +import { ArrowUpRight, Plus } from 'lucide-react'; import { FC, useMemo, useState } from 'react'; type TagType = @@ -128,10 +129,26 @@ export const AvailableModels: FC<{ >
-
+
{model.name}
+ {!!APIMapUrl[model.name as keyof typeof APIMapUrl] && ( + + )}