From fd7e55b23dcec277db4699489babe6865a9a8ce9 Mon Sep 17 00:00:00 2001 From: Russell Valentine Date: Tue, 9 Dec 2025 21:08:11 -0600 Subject: [PATCH 1/7] executor_manager updated docker version (#11806) ### What problem does this PR solve? The docker version(24.0.7) installed in the executor manager image is incompatible with the latest stable docker (29.1.3). The minmum api v29.1.3 can use is 1.4.4 api version, but 24.0.7 uses api version 1.4.3. ### Type of change - [X] Other (please describe): This could break things for people who still have an old docker installed on their system. A better approach could be a setting to share --- sandbox/executor_manager/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/executor_manager/Dockerfile b/sandbox/executor_manager/Dockerfile index 85f4f36c7..c26919f34 100644 --- a/sandbox/executor_manager/Dockerfile +++ b/sandbox/executor_manager/Dockerfile @@ -5,7 +5,7 @@ RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian. apt-get install -y curl gcc && \ rm -rf /var/lib/apt/lists/* -RUN curl -fsSL https://mirrors.aliyun.com/docker-ce/linux/static/stable/x86_64/docker-24.0.7.tgz -o docker.tgz && \ +RUN curl -fsSL https://mirrors.aliyun.com/docker-ce/linux/static/stable/x86_64/docker-29.1.0.tgz -o docker.tgz && \ tar -xzf docker.tgz && \ mv docker/docker /usr/bin/docker && \ rm -rf docker docker.tgz From a1164b9c895be7bfb1686fe33b2b8d8bf3cd6b31 Mon Sep 17 00:00:00 2001 From: Lynn Date: Wed, 10 Dec 2025 13:34:08 +0800 Subject: [PATCH 2/7] Feat/memory (#11812) ### What problem does this PR solve? Manage and display memory datasets. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- .gitignore | 3 + api/apps/memories_app.py | 185 ++++++++++++++++++ api/constants.py | 2 + api/db/db_models.py | 25 ++- api/db/services/memory_service.py | 150 ++++++++++++++ api/utils/memory_utils.py | 54 +++++ common/constants.py | 17 ++ test/testcases/test_web_api/common.py | 41 ++++ .../test_web_api/test_memory_app/conftest.py | 40 ++++ .../test_memory_app/test_create_memory.py | 106 ++++++++++ .../test_memory_app/test_list_memory.py | 118 +++++++++++ .../test_memory_app/test_rm_memory.py | 53 +++++ .../test_memory_app/test_update_memory.py | 161 +++++++++++++++ 13 files changed, 953 insertions(+), 2 deletions(-) create mode 100644 api/apps/memories_app.py create mode 100644 api/db/services/memory_service.py create mode 100644 api/utils/memory_utils.py create mode 100644 test/testcases/test_web_api/test_memory_app/conftest.py create mode 100644 test/testcases/test_web_api/test_memory_app/test_create_memory.py create mode 100644 test/testcases/test_web_api/test_memory_app/test_list_memory.py create mode 100644 test/testcases/test_web_api/test_memory_app/test_rm_memory.py create mode 100644 test/testcases/test_web_api/test_memory_app/test_update_memory.py diff --git a/.gitignore b/.gitignore index fbf80b3aa..11aa54493 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,6 @@ ragflow_cli.egg-info # Default backup dir backup + + +.hypothesis \ No newline at end of file diff --git a/api/apps/memories_app.py b/api/apps/memories_app.py new file mode 100644 index 000000000..9a5cae936 --- /dev/null +++ b/api/apps/memories_app.py @@ -0,0 +1,185 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +from quart import request +from api.apps import login_required, current_user +from api.db import TenantPermission +from api.db.services.memory_service import MemoryService +from api.db.services.user_service import UserTenantService +from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \ + not_allowed_parameters +from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human +from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT +from common.constants import MemoryType, RetCode, ForgettingPolicy + + +@manager.route("", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("name", "memory_type", "embd_id", "llm_id") +async def create_memory(): + req = await get_request_json() + # check name length + name = req["name"] + memory_name = name.strip() + if len(memory_name) == 0: + return get_error_argument_result("Memory name cannot be empty or whitespace.") + if len(memory_name) > MEMORY_NAME_LIMIT: + return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") + # check memory_type valid + memory_type = set(req["memory_type"]) + invalid_type = memory_type - {e.name.lower() for e in MemoryType} + if invalid_type: + return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") + memory_type = list(memory_type) + + try: + res, memory = MemoryService.create_memory( + tenant_id=current_user.id, + name=memory_name, + memory_type=memory_type, + embd_id=req["embd_id"], + llm_id=req["llm_id"] + ) + + if res: + return get_json_result(message=True, data=format_ret_data_from_memory(memory)) + + else: + return get_json_result(message=memory, code=RetCode.SERVER_ERROR) + + except Exception as e: + return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route("/", methods=["PUT"]) # noqa: F821 +@login_required +@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id") +async def update_memory(memory_id): + req = await get_request_json() + update_dict = {} + # check name length + if "name" in req: + name = req["name"] + memory_name = name.strip() + if len(memory_name) == 0: + return get_error_argument_result("Memory name cannot be empty or whitespace.") + if len(memory_name) > MEMORY_NAME_LIMIT: + return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") + update_dict["name"] = memory_name + # check permissions valid + if req.get("permissions"): + if req["permissions"] not in [e.value for e in TenantPermission]: + return get_error_argument_result(f"Unknown permission '{req['permissions']}'.") + update_dict["permissions"] = req["permissions"] + if req.get("llm_id"): + update_dict["llm_id"] = req["llm_id"] + # check memory_size valid + if req.get("memory_size"): + if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT: + return get_error_argument_result(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.") + update_dict["memory_size"] = req["memory_size"] + # check forgetting_policy valid + if req.get("forgetting_policy"): + if req["forgetting_policy"] not in [e.value for e in ForgettingPolicy]: + return get_error_argument_result(f"Forgetting policy '{req['forgetting_policy']}' is not supported.") + update_dict["forgetting_policy"] = req["forgetting_policy"] + # check temperature valid + if "temperature" in req: + temperature = float(req["temperature"]) + if not 0 <= temperature <= 1: + return get_error_argument_result("Temperature should be in range [0, 1].") + update_dict["temperature"] = temperature + # allow update to empty fields + for field in ["avatar", "description", "system_prompt", "user_prompt"]: + if field in req: + update_dict[field] = req[field] + current_memory = MemoryService.get_by_memory_id(memory_id) + if not current_memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + + memory_dict = current_memory.to_dict() + memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) + to_update = {} + for k, v in update_dict.items(): + if isinstance(v, list) and set(memory_dict[k]) != set(v): + to_update[k] = v + elif memory_dict[k] != v: + to_update[k] = v + + if not to_update: + return get_json_result(message=True, data=memory_dict) + + try: + MemoryService.update_memory(memory_id, to_update) + updated_memory = MemoryService.get_by_memory_id(memory_id) + return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory)) + + except Exception as e: + logging.error(e) + return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route("/", methods=["DELETE"]) # noqa: F821 +@login_required +async def delete_memory(memory_id): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(message=True, code=RetCode.NOT_FOUND) + try: + MemoryService.delete_memory(memory_id) + return get_json_result(message=True) + except Exception as e: + logging.error(e) + return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route("", methods=["GET"]) # noqa: F821 +@login_required +async def list_memory(): + args = request.args + try: + tenant_ids = args.getlist("tenant_id") + memory_types = args.getlist("memory_type") + storage_type = args.get("storage_type") + keywords = args.get("keywords", "") + page = int(args.get("page", 1)) + page_size = int(args.get("page_size", 50)) + # make filter dict + filter_dict = {"memory_type": memory_types, "storage_type": storage_type} + if not tenant_ids: + # restrict to current user's tenants + user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) + filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] + else: + filter_dict["tenant_id"] = tenant_ids + + memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) + [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] + return get_json_result(message=True, data={"memory_list": memory_list, "total_count": count}) + + except Exception as e: + logging.error(e) + return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route("//config", methods=["GET"]) # noqa: F821 +@login_required +async def get_memory_config(memory_id): + memory = MemoryService.get_with_owner_name_by_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + return get_json_result(message=True, data=format_ret_data_from_memory(memory)) diff --git a/api/constants.py b/api/constants.py index 464b7d8e6..9edaa844c 100644 --- a/api/constants.py +++ b/api/constants.py @@ -24,3 +24,5 @@ REQUEST_MAX_WAIT_SEC = 300 DATASET_NAME_LIMIT = 128 FILE_NAME_LEN_LIMIT = 255 +MEMORY_NAME_LIMIT = 128 +MEMORY_SIZE_LIMIT = 10*1024*1024 # Byte diff --git a/api/db/db_models.py b/api/db/db_models.py index 3d2192b2d..65fe1fd6e 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -1177,6 +1177,27 @@ class EvaluationResult(DataBaseModel): db_table = "evaluation_results" +class Memory(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + name = CharField(max_length=128, null=False, index=False, help_text="Memory name") + avatar = TextField(null=True, help_text="avatar base64 string") + tenant_id = CharField(max_length=32, null=False, index=True) + memory_type = IntegerField(null=False, default=1, index=True, help_text="Bit flags (LSB->MSB): 1=raw, 2=semantic, 4=episodic, 8=procedural. E.g., 5 enables raw + episodic.") + storage_type = CharField(max_length=32, default='table', null=False, index=True, help_text="table|graph") + embd_id = CharField(max_length=128, null=False, index=False, help_text="embedding model ID") + llm_id = CharField(max_length=128, null=False, index=False, help_text="chat model ID") + permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me") + description = TextField(null=True, help_text="description") + memory_size = IntegerField(default=5242880, null=False, index=False) + forgetting_policy = CharField(max_length=32, null=False, default="fifo", index=False, help_text="lru|fifo") + temperature = FloatField(default=0.5, index=False) + system_prompt = TextField(null=True, help_text="system prompt", index=False) + user_prompt = TextField(null=True, help_text="user prompt", index=False) + + class Meta: + db_table = "memory" + + def migrate_db(): logging.disable(logging.ERROR) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) @@ -1357,7 +1378,7 @@ def migrate_db(): migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False))) except Exception: pass - + # RAG Evaluation tables try: migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True))) @@ -1395,5 +1416,5 @@ def migrate_db(): migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1))) except Exception: pass - + logging.disable(logging.NOTSET) diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py new file mode 100644 index 000000000..bc071a66f --- /dev/null +++ b/api/db/services/memory_service.py @@ -0,0 +1,150 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List + +from api.apps import current_user +from api.db.db_models import DB, Memory, User +from api.db.services import duplicate_name +from api.db.services.common_service import CommonService +from api.utils.memory_utils import calculate_memory_type +from api.constants import MEMORY_NAME_LIMIT +from common.misc_utils import get_uuid +from common.time_utils import get_format_time, current_timestamp + + +class MemoryService(CommonService): + # Service class for manage memory operations + model = Memory + + @classmethod + @DB.connection_context() + def get_by_memory_id(cls, memory_id: str): + return cls.model.select().where(cls.model.id == memory_id).first() + + @classmethod + @DB.connection_context() + def get_with_owner_name_by_id(cls, memory_id: str): + fields = [ + cls.model.id, + cls.model.name, + cls.model.avatar, + cls.model.tenant_id, + User.nickname.alias("owner_name"), + cls.model.memory_type, + cls.model.storage_type, + cls.model.embd_id, + cls.model.llm_id, + cls.model.permissions, + cls.model.description, + cls.model.memory_size, + cls.model.forgetting_policy, + cls.model.temperature, + cls.model.system_prompt, + cls.model.user_prompt + ] + memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where( + cls.model.id == memory_id + ).first() + return memory + + @classmethod + @DB.connection_context() + def get_by_filter(cls, filter_dict: dict, keywords: str, page: int = 1, page_size: int = 50): + fields = [ + cls.model.id, + cls.model.name, + cls.model.avatar, + cls.model.tenant_id, + User.nickname.alias("owner_name"), + cls.model.memory_type, + cls.model.storage_type, + cls.model.permissions, + cls.model.description + ] + memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)) + if filter_dict.get("tenant_id"): + memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"])) + if filter_dict.get("memory_type"): + memory_type_int = calculate_memory_type(filter_dict["memory_type"]) + memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0) + if filter_dict.get("storage_type"): + memories = memories.where(cls.model.storage_type == filter_dict["storage_type"]) + if keywords: + memories = memories.where(cls.model.name.contains(keywords)) + count = memories.count() + memories = memories.order_by(cls.model.update_time.desc()) + memories = memories.paginate(page, page_size) + + return list(memories.dicts()), count + + @classmethod + @DB.connection_context() + def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str): + # Deduplicate name within tenant + memory_name = duplicate_name( + cls.query, + name=name, + tenant_id=tenant_id + ) + if len(memory_name) > MEMORY_NAME_LIMIT: + return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}." + + # build create dict + memory_info = { + "id": get_uuid(), + "name": memory_name, + "memory_type": calculate_memory_type(memory_type), + "tenant_id": tenant_id, + "embd_id": embd_id, + "llm_id": llm_id, + "create_time": current_timestamp(), + "create_date": get_format_time(), + "update_time": current_timestamp(), + "update_date": get_format_time(), + } + obj = cls.model(**memory_info).save(force_insert=True) + + if not obj: + return False, "Could not create new memory." + + db_row = cls.model.select().where(cls.model.id == memory_info["id"]).first() + + return obj, db_row + + @classmethod + @DB.connection_context() + def update_memory(cls, memory_id: str, update_dict: dict): + if not update_dict: + return 0 + if "temperature" in update_dict and isinstance(update_dict["temperature"], str): + update_dict["temperature"] = float(update_dict["temperature"]) + if "name" in update_dict: + update_dict["name"] = duplicate_name( + cls.query, + name=update_dict["name"], + tenant_id=current_user.id + ) + update_dict.update({ + "update_time": current_timestamp(), + "update_date": get_format_time() + }) + + return cls.model.update(update_dict).where(cls.model.id == memory_id).execute() + + @classmethod + @DB.connection_context() + def delete_memory(cls, memory_id: str): + return cls.model.delete().where(cls.model.id == memory_id).execute() diff --git a/api/utils/memory_utils.py b/api/utils/memory_utils.py new file mode 100644 index 000000000..bb7894951 --- /dev/null +++ b/api/utils/memory_utils.py @@ -0,0 +1,54 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List +from common.constants import MemoryType + +def format_ret_data_from_memory(memory): + return { + "id": memory.id, + "name": memory.name, + "avatar": memory.avatar, + "tenant_id": memory.tenant_id, + "owner_name": memory.owner_name if hasattr(memory, "owner_name") else None, + "memory_type": get_memory_type_human(memory.memory_type), + "storage_type": memory.storage_type, + "embd_id": memory.embd_id, + "llm_id": memory.llm_id, + "permissions": memory.permissions, + "description": memory.description, + "memory_size": memory.memory_size, + "forgetting_policy": memory.forgetting_policy, + "temperature": memory.temperature, + "system_prompt": memory.system_prompt, + "user_prompt": memory.user_prompt, + "create_time": memory.create_time, + "create_date": memory.create_date, + "update_time": memory.update_time, + "update_date": memory.update_date + } + + +def get_memory_type_human(memory_type: int) -> List[str]: + return [mem_type.name.lower() for mem_type in MemoryType if memory_type & mem_type.value] + + +def calculate_memory_type(memory_type_name_list: List[str]) -> int: + memory_type = 0 + type_value_map = {mem_type.name.lower(): mem_type.value for mem_type in MemoryType} + for mem_type in memory_type_name_list: + if mem_type in type_value_map: + memory_type |= type_value_map[mem_type] + return memory_type diff --git a/common/constants.py b/common/constants.py index 171319250..98e9faf36 100644 --- a/common/constants.py +++ b/common/constants.py @@ -151,6 +151,23 @@ class Storage(Enum): OPENDAL = 6 GCS = 7 + +class MemoryType(Enum): + RAW = 0b0001 # 1 << 0 = 1 (0b00000001) + SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010) + EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100) + PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000) + + +class MemoryStorageType(StrEnum): + TABLE = "table" + GRAPH = "graph" + + +class ForgettingPolicy(StrEnum): + FIFO = "fifo" + + # environment # ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT" # ENV_RAGFLOW_SECRET_KEY = "RAGFLOW_SECRET_KEY" diff --git a/test/testcases/test_web_api/common.py b/test/testcases/test_web_api/common.py index c7ec156d1..4f4abf722 100644 --- a/test/testcases/test_web_api/common.py +++ b/test/testcases/test_web_api/common.py @@ -28,6 +28,7 @@ CHUNK_API_URL = f"/{VERSION}/chunk" DIALOG_APP_URL = f"/{VERSION}/dialog" # SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" # SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" +MEMORY_API_URL = f"/{VERSION}/memories" # KB APP @@ -258,3 +259,43 @@ def delete_dialogs(auth): dialog_ids = [dialog["id"] for dialog in res["data"]] if dialog_ids: delete_dialog(auth, {"dialog_ids": dialog_ids}) + +# MEMORY APP +def create_memory(auth, payload=None): + url = f"{HOST_ADDRESS}{MEMORY_API_URL}" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def update_memory(auth, memory_id:str, payload=None): + url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}" + res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def delete_memory(auth, memory_id:str): + url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}" + res = requests.delete(url=url, headers=HEADERS, auth=auth) + return res.json() + + +def list_memory(auth, params=None): + url = f"{HOST_ADDRESS}{MEMORY_API_URL}" + if params: + query_parts = [] + for key, value in params.items(): + if isinstance(value, list): + for item in value: + query_parts.append(f"{key}={item}") + else: + query_parts.append(f"{key}={value}") + query_string = "&".join(query_parts) + url = f"{url}?{query_string}" + res = requests.get(url=url, headers=HEADERS, auth=auth) + return res.json() + + +def get_memory_config(auth, memory_id:str): + url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}/config" + res = requests.get(url=url, headers=HEADERS, auth=auth) + return res.json() diff --git a/test/testcases/test_web_api/test_memory_app/conftest.py b/test/testcases/test_web_api/test_memory_app/conftest.py new file mode 100644 index 000000000..11c7c2a10 --- /dev/null +++ b/test/testcases/test_web_api/test_memory_app/conftest.py @@ -0,0 +1,40 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +import random +from test_web_api.common import create_memory, list_memory, delete_memory + +@pytest.fixture(scope="function") +def add_memory_func(request, WebApiAuth): + def cleanup(): + memory_list_res = list_memory(WebApiAuth) + exist_memory_ids = [memory["id"] for memory in memory_list_res["data"]["memory_list"]] + for memory_id in exist_memory_ids: + delete_memory(WebApiAuth, memory_id) + + request.addfinalizer(cleanup) + + memory_ids = [] + for i in range(3): + payload = { + "name": f"test_memory_{i}", + "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), + "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", + "llm_id": "ZHIPU-AI@glm-4-flash" + } + res = create_memory(WebApiAuth, payload) + memory_ids.append(res["data"]["id"]) + return memory_ids diff --git a/test/testcases/test_web_api/test_memory_app/test_create_memory.py b/test/testcases/test_web_api/test_memory_app/test_create_memory.py new file mode 100644 index 000000000..d91500bc9 --- /dev/null +++ b/test/testcases/test_web_api/test_memory_app/test_create_memory.py @@ -0,0 +1,106 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import re + +import pytest +from test_web_api.common import create_memory +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth +from hypothesis import example, given, settings +from test.testcases.utils.hypothesis_utils import valid_names + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"] + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = create_memory(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestMemoryCreate: + @pytest.mark.p1 + @given(name=valid_names()) + @example("d" * 128) + @settings(max_examples=20) + def test_name(self, WebApiAuth, name): + payload = { + "name": name, + "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), + "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", + "llm_id": "ZHIPU-AI@glm-4-flash" + } + res = create_memory(WebApiAuth, payload) + assert res["code"] == 0, res + pattern = rf'^{name}|{name}(?:\((\d+)\))?$' + escaped_name = re.escape(res["data"]["name"]) + assert re.match(pattern, escaped_name), res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "Memory name cannot be empty or whitespace."), + (" ", "Memory name cannot be empty or whitespace."), + ("a" * 129, f"Memory name '{'a'*129}' exceeds limit of 128."), + ], + ids=["empty_name", "space_name", "too_long_name"], + ) + def test_name_invalid(self, WebApiAuth, name, expected_message): + payload = { + "name": name, + "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), + "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", + "llm_id": "ZHIPU-AI@glm-4-flash" + } + res = create_memory(WebApiAuth, payload) + assert res["message"] == expected_message, res + + @pytest.mark.p2 + @given(name=valid_names()) + def test_type_invalid(self, WebApiAuth, name): + payload = { + "name": name, + "memory_type": ["something"], + "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", + "llm_id": "ZHIPU-AI@glm-4-flash" + } + res = create_memory(WebApiAuth, payload) + assert res["message"] == f"Memory type '{ {'something'} }' is not supported.", res + + @pytest.mark.p3 + def test_name_duplicated(self, WebApiAuth): + name = "duplicated_name_test" + payload = { + "name": name, + "memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)), + "embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5", + "llm_id": "ZHIPU-AI@glm-4-flash" + } + res1 = create_memory(WebApiAuth, payload) + assert res1["code"] == 0, res1 + + res2 = create_memory(WebApiAuth, payload) + assert res2["code"] == 0, res2 diff --git a/test/testcases/test_web_api/test_memory_app/test_list_memory.py b/test/testcases/test_web_api/test_memory_app/test_list_memory.py new file mode 100644 index 000000000..e1095358a --- /dev/null +++ b/test/testcases/test_web_api/test_memory_app/test_list_memory.py @@ -0,0 +1,118 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from test_web_api.common import list_memory, get_memory_config +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = list_memory(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestCapability: + @pytest.mark.p3 + def test_capability(self, WebApiAuth): + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_memory, WebApiAuth) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + +@pytest.mark.usefixtures("add_memory_func") +class TestMemoryList: + @pytest.mark.p1 + def test_params_unset(self, WebApiAuth): + res = list_memory(WebApiAuth, None) + assert res["code"] == 0, res + + @pytest.mark.p1 + def test_params_empty(self, WebApiAuth): + res = list_memory(WebApiAuth, {}) + assert res["code"] == 0, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"page": 1, "page_size": 10}, 3), + ({"page": 2, "page_size": 10}, 0), + ({"page": 1, "page_size": 2}, 2), + ({"page": 2, "page_size": 2}, 1), + ({"page": 5, "page_size": 10}, 0), + ], + ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page" , "normal_middle_page", + "full_data_single_page"], + ) + def test_page(self, WebApiAuth, params, expected_page_size): + # have added 3 memories in fixture + res = list_memory(WebApiAuth, params) + assert res["code"] == 0, res + assert len(res["data"]["memory_list"]) == expected_page_size, res + + @pytest.mark.p2 + def test_filter_memory_type(self, WebApiAuth): + res = list_memory(WebApiAuth, {"memory_type": ["semantic"]}) + assert res["code"] == 0, res + for memory in res["data"]["memory_list"]: + assert "semantic" in memory["memory_type"], res + + @pytest.mark.p2 + def test_filter_multi_memory_type(self, WebApiAuth): + res = list_memory(WebApiAuth, {"memory_type": ["episodic", "procedural"]}) + assert res["code"] == 0, res + for memory in res["data"]["memory_list"]: + assert "episodic" in memory["memory_type"] or "procedural" in memory["memory_type"], res + + @pytest.mark.p2 + def test_filter_storage_type(self, WebApiAuth): + res = list_memory(WebApiAuth, {"storage_type": "table"}) + assert res["code"] == 0, res + for memory in res["data"]["memory_list"]: + assert memory["storage_type"] == "table", res + + @pytest.mark.p2 + def test_match_keyword(self, WebApiAuth): + res = list_memory(WebApiAuth, {"keywords": "s"}) + assert res["code"] == 0, res + for memory in res["data"]["memory_list"]: + assert "s" in memory["name"], res + + @pytest.mark.p1 + def test_get_config(self, WebApiAuth): + memory_list = list_memory(WebApiAuth, {}) + assert memory_list["code"] == 0, memory_list + + memory_config = get_memory_config(WebApiAuth, memory_list["data"]["memory_list"][0]["id"]) + assert memory_config["code"] == 0, memory_config + assert memory_config["data"]["id"] == memory_list["data"]["memory_list"][0]["id"], memory_config + for field in ["name", "avatar", "tenant_id", "owner_name", "memory_type", "storage_type", + "embd_id", "llm_id", "permissions", "description", "memory_size", "forgetting_policy", + "temperature", "system_prompt", "user_prompt"]: + assert field in memory_config["data"], memory_config diff --git a/test/testcases/test_web_api/test_memory_app/test_rm_memory.py b/test/testcases/test_web_api/test_memory_app/test_rm_memory.py new file mode 100644 index 000000000..e6faf5d3f --- /dev/null +++ b/test/testcases/test_web_api/test_memory_app/test_rm_memory.py @@ -0,0 +1,53 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from test_web_api.common import (list_memory, delete_memory) +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = delete_memory(invalid_auth, "some_memory_id") + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestMemoryDelete: + @pytest.mark.p1 + def test_memory_id(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + res = delete_memory(WebApiAuth, memory_ids[0]) + assert res["code"] == 0, res + + res = list_memory(WebApiAuth) + assert res["data"]["total_count"] == 2, res + + @pytest.mark.p2 + @pytest.mark.usefixtures("add_memory_func") + def test_id_wrong_uuid(self, WebApiAuth): + res = delete_memory(WebApiAuth, "d94a8dc02c9711f0930f7fbc369eab6d") + assert res["code"] == 404, res + + res = list_memory(WebApiAuth) + assert len(res["data"]["memory_list"]) == 3, res diff --git a/test/testcases/test_web_api/test_memory_app/test_update_memory.py b/test/testcases/test_web_api/test_memory_app/test_update_memory.py new file mode 100644 index 000000000..4def9d8b1 --- /dev/null +++ b/test/testcases/test_web_api/test_memory_app/test_update_memory.py @@ -0,0 +1,161 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from test_web_api.common import update_memory +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth +from hypothesis import HealthCheck, example, given, settings +from utils import encode_avatar +from utils.file_utils import create_image_file +from utils.hypothesis_utils import valid_names + + +class TestAuthorization: + @pytest.mark.p1 + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ids=["empty_auth", "invalid_api_token"] + ) + def test_auth_invalid(self, invalid_auth, expected_code, expected_message): + res = update_memory(invalid_auth, "memory_id") + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestMemoryUpdate: + + @pytest.mark.p1 + @given(name=valid_names()) + @example("f" * 128) + @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture]) + def test_name(self, WebApiAuth, add_memory_func, name): + memory_ids = add_memory_func + payload = {"name": name} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["name"] == name, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "name, expected_message", + [ + ("", "Memory name cannot be empty or whitespace."), + (" ", "Memory name cannot be empty or whitespace."), + ("a" * 129, f"Memory name '{'a' * 129}' exceeds limit of 128."), + ] + ) + def test_name_invalid(self, WebApiAuth, add_memory_func, name, expected_message): + memory_ids = add_memory_func + payload = {"name": name} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 101, res + assert res["message"] == expected_message, res + + @pytest.mark.p2 + def test_duplicate_name(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + payload = {"name": "Test_Memory"} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + + payload = {"name": "Test_Memory"} + res = update_memory(WebApiAuth, memory_ids[1], payload) + assert res["code"] == 0, res + assert res["data"]["name"] == "Test_Memory(1)", res + + @pytest.mark.p1 + def test_avatar(self, WebApiAuth, add_memory_func, tmp_path): + memory_ids = add_memory_func + fn = create_image_file(tmp_path / "ragflow_test.png") + payload = {"avatar": f"data:image/png;base64,{encode_avatar(fn)}"} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res + + @pytest.mark.p1 + def test_description(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + description = "This is a test description." + payload = {"description": description} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["description"] == description, res + + @pytest.mark.p1 + def test_llm(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + llm_id = "ZHIPU-AI@glm-4" + payload = {"llm_id": llm_id} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["llm_id"] == llm_id, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "permission", + [ + "me", + "team" + ], + ids=["me", "team"] + ) + def test_permission(self, WebApiAuth, add_memory_func, permission): + memory_ids = add_memory_func + payload = {"permissions": permission} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["permissions"] == permission.lower().strip(), res + + + @pytest.mark.p1 + def test_memory_size(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + memory_size = 1048576 # 1 MB + payload = {"memory_size": memory_size} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["memory_size"] == memory_size, res + + @pytest.mark.p1 + def test_temperature(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + temperature = 0.7 + payload = {"temperature": temperature} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["temperature"] == temperature, res + + @pytest.mark.p1 + def test_system_prompt(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + system_prompt = "This is a system prompt." + payload = {"system_prompt": system_prompt} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["system_prompt"] == system_prompt, res + + @pytest.mark.p1 + def test_user_prompt(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + user_prompt = "This is a user prompt." + payload = {"user_prompt": user_prompt} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 0, res + assert res["data"]["user_prompt"] == user_prompt, res From 80f3ccf1acc55c5d01be80a6061d569a80e09d8f Mon Sep 17 00:00:00 2001 From: chanx <1243304602@qq.com> Date: Wed, 10 Dec 2025 13:38:24 +0800 Subject: [PATCH 3/7] Fix:Modify the name of the Overlapped percent field (#11866) ### What problem does this PR solve? Fix:Modify the name of the Overlapped percent field ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- web/src/locales/en.ts | 1 + .../pages/dataset/dataset-setting/configuration/common-item.tsx | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 980e08750..c03408beb 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -329,6 +329,7 @@ export default { reRankModelWaring: 'Re-rank model is very time consuming.', }, knowledgeConfiguration: { + overlappedPercent: 'Overlapped percent', generationScopeTip: 'Determines whether RAPTOR is generated for the entire dataset or for a single file.', scopeDataset: 'Dataset', diff --git a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx index 7ca33ffb5..2e6b7400a 100644 --- a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx +++ b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx @@ -292,7 +292,7 @@ export function OverlappedPercent() { return ( From ab4b62031f4e7b5261e7dd4d41af863ea5e46fd9 Mon Sep 17 00:00:00 2001 From: buua436 Date: Wed, 10 Dec 2025 16:44:06 +0800 Subject: [PATCH 4/7] Fix:csv parse in Table (#11870) ### What problem does this PR solve? change: csv parse in Table ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/app/table.py | 30 ++++++++++++++++++- web/package-lock.json | 17 +++++++++++ web/package.json | 2 ++ .../document-preview/csv-preview.tsx | 18 ++++++----- 4 files changed, 59 insertions(+), 8 deletions(-) diff --git a/rag/app/table.py b/rag/app/table.py index 7a21a738a..a87a858bf 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -15,6 +15,8 @@ # import copy +import csv +import io import logging import re from io import BytesIO @@ -323,7 +325,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese callback(0.1, "Start to parse.") excel_parser = Excel() dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback) - elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): + elif re.search(r"\.txt$", filename, re.IGNORECASE): callback(0.1, "Start to parse.") txt = get_text(filename, binary) lines = txt.split("\n") @@ -344,7 +346,33 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese callback(0.3, ("Extract records: {}~{}".format(from_page, min(len(lines), to_page)) + (f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) dfs = [pd.DataFrame(np.array(rows), columns=headers)] + elif re.search(r"\.csv$", filename, re.IGNORECASE): + callback(0.1, "Start to parse.") + txt = get_text(filename, binary) + delimiter = kwargs.get("delimiter", ",") + reader = csv.reader(io.StringIO(txt), delimiter=delimiter) + all_rows = list(reader) + if not all_rows: + raise ValueError("Empty CSV file") + + headers = all_rows[0] + fails = [] + rows = [] + + for i, row in enumerate(all_rows[1 + from_page : 1 + to_page]): + if len(row) != len(headers): + fails.append(str(i + from_page)) + continue + rows.append(row) + + callback( + 0.3, + (f"Extract records: {from_page}~{from_page + len(rows)}" + + (f"{len(fails)} failure, line: {','.join(fails[:3])}..." if fails else "")) + ) + + dfs = [pd.DataFrame(rows, columns=headers)] else: raise NotImplementedError("file type not supported yet(excel, text, csv supported)") diff --git a/web/package-lock.json b/web/package-lock.json index 880a1c9b4..ed94049b2 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -45,6 +45,7 @@ "@tanstack/react-query": "^5.40.0", "@tanstack/react-query-devtools": "^5.51.5", "@tanstack/react-table": "^8.20.5", + "@types/papaparse": "^5.5.1", "@uiw/react-markdown-preview": "^5.1.3", "@xyflow/react": "^12.3.6", "ahooks": "^3.7.10", @@ -73,6 +74,7 @@ "mammoth": "^1.7.2", "next-themes": "^0.4.6", "openai-speech-stream-player": "^1.0.8", + "papaparse": "^5.5.3", "pptx-preview": "^1.0.5", "rc-tween-one": "^3.0.6", "react": "^18.2.0", @@ -10632,6 +10634,15 @@ "integrity": "sha512-37i+OaWTh9qeK4LSHPsyRC7NahnGotNuZvjLSgcPzblpHB3rrCJxAOgI5gCdKm7coonsaX1Of0ILiTcnZjbfxA==", "peer": true }, + "node_modules/@types/papaparse": { + "version": "5.5.1", + "resolved": "https://registry.npmmirror.com/@types/papaparse/-/papaparse-5.5.1.tgz", + "integrity": "sha512-esEO+VISsLIyE+JZBmb89NzsYYbpwV8lmv2rPo6oX5y9KhBaIP7hhHgjuTut54qjdKVMufTEcrh5fUl9+58huw==", + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@types/parse-json": { "version": "4.0.2", "resolved": "https://registry.npmmirror.com/@types/parse-json/-/parse-json-4.0.2.tgz", @@ -27413,6 +27424,12 @@ "resolved": "https://registry.npmmirror.com/pako/-/pako-1.0.11.tgz", "integrity": "sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==" }, + "node_modules/papaparse": { + "version": "5.5.3", + "resolved": "https://registry.npmmirror.com/papaparse/-/papaparse-5.5.3.tgz", + "integrity": "sha512-5QvjGxYVjxO59MGU2lHVYpRWBBtKHnlIAcSe1uNFCkkptUh63NFRj0FJQm7nR67puEruUci/ZkjmEFrjCAyP4A==", + "license": "MIT" + }, "node_modules/param-case": { "version": "3.0.4", "resolved": "https://registry.npmmirror.com/param-case/-/param-case-3.0.4.tgz", diff --git a/web/package.json b/web/package.json index f183c8008..051c4b9d7 100644 --- a/web/package.json +++ b/web/package.json @@ -58,6 +58,7 @@ "@tanstack/react-query": "^5.40.0", "@tanstack/react-query-devtools": "^5.51.5", "@tanstack/react-table": "^8.20.5", + "@types/papaparse": "^5.5.1", "@uiw/react-markdown-preview": "^5.1.3", "@xyflow/react": "^12.3.6", "ahooks": "^3.7.10", @@ -86,6 +87,7 @@ "mammoth": "^1.7.2", "next-themes": "^0.4.6", "openai-speech-stream-player": "^1.0.8", + "papaparse": "^5.5.3", "pptx-preview": "^1.0.5", "rc-tween-one": "^3.0.6", "react": "^18.2.0", diff --git a/web/src/components/document-preview/csv-preview.tsx b/web/src/components/document-preview/csv-preview.tsx index 45b05454e..fa1cf1ed8 100644 --- a/web/src/components/document-preview/csv-preview.tsx +++ b/web/src/components/document-preview/csv-preview.tsx @@ -2,6 +2,7 @@ import message from '@/components/ui/message'; import { Spin } from '@/components/ui/spin'; import request from '@/utils/request'; import classNames from 'classnames'; +import Papa from 'papaparse'; import React, { useEffect, useRef, useState } from 'react'; interface CSVData { @@ -20,14 +21,17 @@ const CSVFileViewer: React.FC = ({ url }) => { const containerRef = useRef(null); // const url = useGetDocumentUrl(); const parseCSV = (csvText: string): CSVData => { - console.log('Parsing CSV data:', csvText); - const lines = csvText.split('\n'); - const headers = lines[0].split(',').map((header) => header.trim()); - const rows = lines - .slice(1) - .map((line) => line.split(',').map((cell) => cell.trim())); + const result = Papa.parse(csvText, { + header: false, + skipEmptyLines: false, + }); - return { headers, rows }; + const rows = result.data as string[][]; + + const headers = rows[0]; + const dataRows = rows.slice(1); + + return { headers, rows: dataRows }; }; useEffect(() => { From 3cb72377d7e8e98d287df347ff0b15abe72ff4b8 Mon Sep 17 00:00:00 2001 From: buua436 Date: Wed, 10 Dec 2025 19:08:45 +0800 Subject: [PATCH 5/7] Refa:remove sensitive information (#11873) ### What problem does this PR solve? change: remove sensitive information ### Type of change - [x] Refactoring --- .github/workflows/tests.yml | 2 ++ admin/client/admin_client.py | 12 ++++---- admin/server/auth.py | 6 ++-- api/apps/canvas_app.py | 10 ++++++- api/db/init_data.py | 2 +- api/db/joint_services/user_account_service.py | 2 +- api/db/services/llm_service.py | 4 +-- common/data_source/confluence_connector.py | 5 +++- common/data_source/jira/connector.py | 7 ++--- common/http_client.py | 28 +++++++++++++++++-- rag/utils/opendal_conn.py | 8 +++++- 11 files changed, 62 insertions(+), 24 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5341d83ae..a5bdc1735 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,6 @@ name: tests +permissions: + contents: read on: push: diff --git a/admin/client/admin_client.py b/admin/client/admin_client.py index 4b210d2b5..8cad14bab 100644 --- a/admin/client/admin_client.py +++ b/admin/client/admin_client.py @@ -351,7 +351,7 @@ class AdminCLI(Cmd): def verify_admin(self, arguments: dict, single_command: bool): self.host = arguments['host'] self.port = arguments['port'] - print(f"Attempt to access ip: {self.host}, port: {self.port}") + print("Attempt to access server for admin login") url = f"http://{self.host}:{self.port}/api/v1/admin/login" attempt_count = 3 @@ -390,7 +390,7 @@ class AdminCLI(Cmd): print(f"Bad response,status: {response.status_code}, password is wrong") except Exception as e: print(str(e)) - print(f"Can't access {self.host}, port: {self.port}") + print("Can't access server for admin login (connection failed)") def _format_service_detail_table(self, data): if isinstance(data, list): @@ -674,7 +674,7 @@ class AdminCLI(Cmd): user_name: str = user_name_tree.children[0].strip("'\"") password_tree: Tree = command['password'] password: str = password_tree.children[0].strip("'\"") - print(f"Alter user: {user_name}, password: {password}") + print(f"Alter user: {user_name}, password: ******") url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password' response = self.session.put(url, json={'new_password': encrypt(password)}) res_json = response.json() @@ -689,7 +689,7 @@ class AdminCLI(Cmd): password_tree: Tree = command['password'] password: str = password_tree.children[0].strip("'\"") role: str = command['role'] - print(f"Create user: {user_name}, password: {password}, role: {role}") + print(f"Create user: {user_name}, password: ******, role: {role}") url = f'http://{self.host}:{self.port}/api/v1/admin/users' response = self.session.post( url, @@ -951,7 +951,7 @@ def main(): args = cli.parse_connection_args(sys.argv) if 'error' in args: - print(f"Error: {args['error']}") + print("Error: Invalid connection arguments") return if 'command' in args: @@ -960,7 +960,7 @@ def main(): return if cli.verify_admin(args, single_command=True): command: str = args['command'] - print(f"Run single command: {command}") + # print(f"Run single command: {command}") cli.run_single_command(command) else: if cli.verify_admin(args, single_command=False): diff --git a/admin/server/auth.py b/admin/server/auth.py index 6c8bc2cb8..486b9a4fb 100644 --- a/admin/server/auth.py +++ b/admin/server/auth.py @@ -176,11 +176,11 @@ def login_verify(f): "message": "Access denied", "data": None }), 200 - except Exception as e: - error_msg = str(e) + except Exception: + logging.exception("An error occurred during admin login verification.") return jsonify({ "code": 500, - "message": error_msg + "message": "An internal server error occurred." }), 200 return f(*args, **kwargs) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index fe32dca0b..ed8c8c7a0 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -342,7 +342,15 @@ async def test_db_connect(): f"UID={req['username']};" f"PWD={req['password']};" ) - logging.info(conn_str) + redacted_conn_str = ( + f"DATABASE={req['database']};" + f"HOSTNAME={req['host']};" + f"PORT={req['port']};" + f"PROTOCOL=TCPIP;" + f"UID={req['username']};" + f"PWD=****;" + ) + logging.info(redacted_conn_str) conn = ibm_db.connect(conn_str, "", "") stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1") ibm_db.fetch_assoc(stmt) diff --git a/api/db/init_data.py b/api/db/init_data.py index d4873d332..7454965eb 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -73,7 +73,7 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_ UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) logging.info( - f"Super user initialized. email: {email}, password: {password}. Changing the password after login is strongly recommended.") + f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.") chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) msg = chat_mdl.chat(system="", history=[ diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py index 34ceee648..48937653e 100644 --- a/api/db/joint_services/user_account_service.py +++ b/api/db/joint_services/user_account_service.py @@ -273,7 +273,7 @@ def delete_user_data(user_id: str) -> dict: except Exception as e: logging.exception(e) - return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"} + return {"success": False, "message": "An internal error occurred during user deletion. Some operations may have completed.","details": done_msg} def delete_user_agents(user_id: str) -> dict: diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 86356a7a7..e4bf64aac 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -109,7 +109,7 @@ class LLMBundle(LLM4Tenant): llm_name = getattr(self, "llm_name", None) if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name): - logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) + logging.error("LLMBundle.encode can't update token usage for /EMBEDDING used_tokens: {}".format(used_tokens)) if self.langfuse: generation.update(usage_details={"total_tokens": used_tokens}) @@ -124,7 +124,7 @@ class LLMBundle(LLM4Tenant): emd, used_tokens = self.mdl.encode_queries(query) llm_name = getattr(self, "llm_name", None) if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name): - logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) + logging.error("LLMBundle.encode_queries can't update token usage for /EMBEDDING used_tokens: {}".format(used_tokens)) if self.langfuse: generation.update(usage_details={"total_tokens": used_tokens}) diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py index a057d0694..aff225703 100644 --- a/common/data_source/confluence_connector.py +++ b/common/data_source/confluence_connector.py @@ -1110,7 +1110,10 @@ def _make_attachment_link( ) -> str | None: download_link = "" - if "api.atlassian.com" in confluence_client.url: + from urllib.parse import urlparse + netloc =urlparse(confluence_client.url).hostname + if netloc == "api.atlassian.com" or (netloc and netloc.endswith(".api.atlassian.com")): + # if "api.atlassian.com" in confluence_client.url: # https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get if not parent_content_id: logging.warning( diff --git a/common/data_source/jira/connector.py b/common/data_source/jira/connector.py index 06a0a9069..2a93aaf51 100644 --- a/common/data_source/jira/connector.py +++ b/common/data_source/jira/connector.py @@ -135,7 +135,7 @@ class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync except ValueError as exc: raise ConnectorValidationError(str(exc)) from exc else: - logger.warning(f"[Jira] Scoped token requested but Jira base URL {self.jira_base_url} does not appear to be an Atlassian Cloud domain; scoped token ignored.") + logger.warning("[Jira] Scoped token requested but Jira base URL does not appear to be an Atlassian Cloud domain; scoped token ignored.") user_email = credentials.get("jira_user_email") or credentials.get("username") api_token = credentials.get("jira_api_token") or credentials.get("token") or credentials.get("api_token") @@ -245,7 +245,7 @@ class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync while True: attempt += 1 jql = self._build_jql(attempt_start, end) - logger.info(f"[Jira] Executing Jira JQL attempt {attempt} (start={attempt_start}, end={end}, buffered_retry={retried_with_buffer}): {jql}") + logger.info(f"[Jira] Executing Jira JQL attempt {attempt} (start={attempt_start}, end={end}, buffered_retry={retried_with_buffer})") try: return (yield from self._load_from_checkpoint_internal(jql, checkpoint, start_filter=start)) except Exception as exc: @@ -927,9 +927,6 @@ def main(config: dict[str, Any] | None = None) -> None: base_url = config.get("base_url") credentials = config.get("credentials", {}) - print(f"[Jira] {config=}", flush=True) - print(f"[Jira] {credentials=}", flush=True) - if not base_url: raise RuntimeError("Jira base URL must be provided via config or CLI arguments.") if not (credentials.get("jira_api_token") or (credentials.get("jira_user_email") and credentials.get("jira_password"))): diff --git a/common/http_client.py b/common/http_client.py index 91ac0cadc..5c57f8638 100644 --- a/common/http_client.py +++ b/common/http_client.py @@ -16,6 +16,7 @@ import logging import os import time from typing import Any, Dict, Optional +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse import httpx @@ -52,6 +53,27 @@ def _get_delay(backoff_factor: float, attempt: int) -> float: return backoff_factor * (2**attempt) +# List of sensitive parameters to redact from URLs before logging +_SENSITIVE_QUERY_KEYS = {"client_secret", "secret", "code", "access_token", "refresh_token", "password", "token", "app_secret"} + +def _redact_sensitive_url_params(url: str) -> str: + try: + parsed = urlparse(url) + if not parsed.query: + return url + clean_query = [] + for k, v in parse_qsl(parsed.query, keep_blank_values=True): + if k.lower() in _SENSITIVE_QUERY_KEYS: + clean_query.append((k, "***REDACTED***")) + else: + clean_query.append((k, v)) + new_query = urlencode(clean_query, doseq=True) + redacted_url = urlunparse(parsed._replace(query=new_query)) + return redacted_url + except Exception: + return url + + async def async_request( method: str, url: str, @@ -94,19 +116,19 @@ async def async_request( ) duration = time.monotonic() - start logger.debug( - f"async_request {method} {url} -> {response.status_code} in {duration:.3f}s" + f"async_request {method} {_redact_sensitive_url_params(url)} -> {response.status_code} in {duration:.3f}s" ) return response except httpx.RequestError as exc: last_exc = exc if attempt >= retries: logger.warning( - f"async_request exhausted retries for {method} {url}: {exc}" + f"async_request exhausted retries for {method} {_redact_sensitive_url_params(url)}: {exc}" ) raise delay = _get_delay(backoff_factor, attempt) logger.warning( - f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {url}: {exc}; retrying in {delay:.2f}s" + f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {_redact_sensitive_url_params(url)}: {exc}; retrying in {delay:.2f}s" ) await asyncio.sleep(delay) raise last_exc # pragma: no cover diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index c6cebf9ca..a260daebc 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -41,7 +41,13 @@ def get_opendal_config(): scheme = opendal_config.get("scheme") config_data = opendal_config.get("config", {}) kwargs = {"scheme": scheme, **config_data} - logging.info("Loaded OpenDAL configuration from yaml: %s", kwargs) + redacted_kwargs = kwargs.copy() + if 'password' in redacted_kwargs: + redacted_kwargs['password'] = '***REDACTED***' + if 'connection_string' in redacted_kwargs and 'password' in redacted_kwargs: + import re + redacted_kwargs['connection_string'] = re.sub(r':[^@]+@', ':***REDACTED***@', redacted_kwargs['connection_string']) + logging.info("Loaded OpenDAL configuration from yaml: %s", redacted_kwargs) return kwargs except Exception as e: logging.error("Failed to load OpenDAL configuration from yaml: %s", str(e)) From badf33e3b9939bce42011cece8338e496ab81f66 Mon Sep 17 00:00:00 2001 From: He Wang Date: Wed, 10 Dec 2025 19:13:37 +0800 Subject: [PATCH 6/7] feat: enhance OBConnection.search (#11876) ### What problem does this PR solve? Enhance OBConnection.search for better performance. Main changes: 1. Use string type of vector array in distance func for better parsing performance. 2. Manually set max_connections as pool size instead of using default value. 3. Set 'fulltext_search_columns' when starting. 4. Cache the results of the table existence check (we will never drop the table). 5. Remove unused 'group_results' logic. 6. Add the `USE_FULLTEXT_FIRST_FUSION_SEARCH` flag, and the corresponding fusion search SQL when it's false. ### Type of change - [x] Performance Improvement --- rag/utils/ob_conn.py | 266 ++++++++++++++++++++++++++----------------- 1 file changed, 164 insertions(+), 102 deletions(-) diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index 6218a8c4e..3c00be421 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -17,13 +17,16 @@ import json import logging import os import re +import threading import time from typing import Any, Optional +import numpy as np from elasticsearch_dsl import Q, Search from pydantic import BaseModel from pymysql.converters import escape_string from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR +from pyobvector.client import ClusterVersionException from pyobvector.client.hybrid_search import HybridSearch from pyobvector.util import ObVersion from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table @@ -106,17 +109,6 @@ index_columns: list[str] = [ "removed_kwd", ] -fulltext_search_columns: list[str] = [ - "docnm_kwd", - "content_with_weight", - "title_tks", - "title_sm_tks", - "important_tks", - "question_tks", - "content_ltks", - "content_sm_ltks" -] - fts_columns_origin: list[str] = [ "docnm_kwd^10", "content_with_weight", @@ -138,7 +130,7 @@ fulltext_index_name_template = "fts_idx_%s" # MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607 fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)" # cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938 -vector_search_template = "cosine_distance(%s, %s)" +vector_search_template = "cosine_distance(%s, '%s')" class SearchResult(BaseModel): @@ -362,18 +354,28 @@ class OBConnection(DocStoreConnection): port = mysql_config.get("port", 2881) self.username = mysql_config.get("user", "root@test") self.password = mysql_config.get("password", "infini_rag_flow") + max_connections = mysql_config.get("max_connections", 300) else: logger.info("Use customized config to create OceanBase connection.") host = ob_config.get("host", "localhost") port = ob_config.get("port", 2881) self.username = ob_config.get("user", "root@test") self.password = ob_config.get("password", "infini_rag_flow") + max_connections = ob_config.get("max_connections", 300) self.db_name = ob_config.get("db_name", "test") self.uri = f"{host}:{port}" logger.info(f"Use OceanBase '{self.uri}' as the doc engine.") + # Set the maximum number of connections that can be created above the pool_size. + # By default, this is half of max_connections, but at least 10. + # This allows the pool to handle temporary spikes in demand without exhausting resources. + max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10))) + # Set the number of seconds to wait before giving up when trying to get a connection from the pool. + # Default is 30 seconds, but can be overridden with the OB_POOL_TIMEOUT environment variable. + pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30")) + for _ in range(ATTEMPT_TIME): try: self.client = ObVecClient( @@ -383,6 +385,9 @@ class OBConnection(DocStoreConnection): db_name=self.db_name, pool_pre_ping=True, pool_recycle=3600, + pool_size=max_connections, + max_overflow=max_overflow, + pool_timeout=pool_timeout, ) break except Exception as e: @@ -398,6 +403,37 @@ class OBConnection(DocStoreConnection): self._check_ob_version() self._try_to_update_ob_query_timeout() + self.es = None + if self.enable_hybrid_search: + try: + self.es = HybridSearch( + uri=self.uri, + user=self.username, + password=self.password, + db_name=self.db_name, + pool_pre_ping=True, + pool_recycle=3600, + pool_size=max_connections, + max_overflow=max_overflow, + pool_timeout=pool_timeout, + ) + logger.info("OceanBase Hybrid Search feature is enabled") + except ClusterVersionException as e: + logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e) + self.es = None + + if self.es is not None and self.search_original_content: + logger.info("HybridSearch is enabled, forcing search_original_content to False") + self.search_original_content = False + # Determine which columns to use for full-text search dynamically: + # If HybridSearch is enabled (self.es is not None), we must use tokenized columns (fts_columns_tks) + # for compatibility and performance with HybridSearch. Otherwise, we use the original content columns + # (fts_columns_origin), which may be controlled by an environment variable. + self.fulltext_search_columns = fts_columns_origin if self.search_original_content else fts_columns_tks + + self._table_exists_cache: set[str] = set() + self._table_exists_cache_lock = threading.RLock() + logger.info(f"OceanBase {self.uri} is healthy.") def _check_ob_version(self): @@ -417,18 +453,6 @@ class OBConnection(DocStoreConnection): f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}" ) - self.es = None - if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search: - self.es = HybridSearch( - uri=self.uri, - user=self.username, - password=self.password, - db_name=self.db_name, - pool_pre_ping=True, - pool_recycle=3600, - ) - logger.info("OceanBase Hybrid Search feature is enabled") - def _try_to_update_ob_query_timeout(self): try: val = self._get_variable_value("ob_query_timeout") @@ -455,9 +479,19 @@ class OBConnection(DocStoreConnection): return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y'] self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true') + logger.info(f"ENABLE_FULLTEXT_SEARCH={self.enable_fulltext_search}") + self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true') + logger.info(f"USE_FULLTEXT_HINT={self.use_fulltext_hint}") + self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true') + logger.info(f"SEARCH_ORIGINAL_CONTENT={self.search_original_content}") + self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false') + logger.info(f"ENABLE_HYBRID_SEARCH={self.enable_hybrid_search}") + + self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true') + logger.info(f"USE_FULLTEXT_FIRST_FUSION_SEARCH={self.use_fulltext_first_fusion_search}") """ Database operations @@ -478,6 +512,43 @@ class OBConnection(DocStoreConnection): return row[1] raise Exception(f"Variable '{var_name}' not found.") + def _check_table_exists_cached(self, table_name: str) -> bool: + """ + Check table existence with cache to reduce INFORMATION_SCHEMA queries under high concurrency. + Only caches when table exists. Does not cache when table does not exist. + Thread-safe implementation: read operations are lock-free (GIL-protected), + write operations are protected by RLock to ensure cache consistency. + + Args: + table_name: Table name + + Returns: + Whether the table exists with all required indexes and columns + """ + if table_name in self._table_exists_cache: + return True + + try: + if not self.client.check_table_exists(table_name): + return False + for column_name in index_columns: + if not self._index_exists(table_name, index_name_template % (table_name, column_name)): + return False + for fts_column in self.fulltext_search_columns: + column_name = fts_column.split("^")[0] + if not self._index_exists(table_name, fulltext_index_name_template % column_name): + return False + for column in [column_order_id, column_group_id]: + if not self._column_exist(table_name, column.name): + return False + except Exception as e: + raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}") + + with self._table_exists_cache_lock: + if table_name not in self._table_exists_cache: + self._table_exists_cache.add(table_name) + return True + """ Table operations """ @@ -500,8 +571,7 @@ class OBConnection(DocStoreConnection): process_func=lambda: self._add_index(indexName, column_name), ) - fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks - for fts_column in fts_columns: + for fts_column in self.fulltext_search_columns: column_name = fts_column.split("^")[0] _try_with_lock( lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}", @@ -546,24 +616,7 @@ class OBConnection(DocStoreConnection): raise Exception(f"OBConnection.deleteIndex error: {str(e)}") def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool: - try: - if not self.client.check_table_exists(indexName): - return False - for column_name in index_columns: - if not self._index_exists(indexName, index_name_template % (indexName, column_name)): - return False - fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks - for fts_column in fts_columns: - column_name = fts_column.split("^")[0] - if not self._index_exists(indexName, fulltext_index_name_template % column_name): - return False - for column in [column_order_id, column_group_id]: - if not self._column_exist(indexName, column.name): - return False - except Exception as e: - raise Exception(f"OBConnection.indexExist error: {str(e)}") - - return True + return self._check_table_exists_cached(indexName) def _get_count(self, table_name: str, filter_list: list[str] = None) -> int: where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else "" @@ -853,10 +906,8 @@ class OBConnection(DocStoreConnection): fulltext_query = escape_string(fulltext_query.strip()) fulltext_topn = m.topn - fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks - # get fulltext match expression and weight values - for field in fts_columns: + for field in self.fulltext_search_columns: parts = field.split("^") column_name: str = parts[0] column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0 @@ -885,7 +936,8 @@ class OBConnection(DocStoreConnection): fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})" if vector_data: - vector_search_expr = vector_search_template % (vector_column_name, vector_data) + vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]" + vector_search_expr = vector_search_template % (vector_column_name, vector_data_str) # use (1 - cosine_distance) as score, which should be [-1, 1] # https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323 vector_search_score_expr = f"(1 - {vector_search_expr})" @@ -910,11 +962,15 @@ class OBConnection(DocStoreConnection): if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields: output_fields.append("_score") - group_results = kwargs.get("group_results", False) + if limit: + if vector_topn is not None: + limit = min(vector_topn, limit) + if fulltext_topn is not None: + limit = min(fulltext_topn, limit) for index_name in indexNames: - if not self.client.check_table_exists(index_name): + if not self._check_table_exists_cached(index_name): continue fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else "" @@ -922,29 +978,7 @@ class OBConnection(DocStoreConnection): if search_type == "fusion": # fusion search, usually for chat num_candidates = vector_topn + fulltext_topn - if group_results: - count_sql = ( - f"WITH fulltext_results AS (" - f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" - f" FROM {index_name}" - f" WHERE {filters_expr} AND {fulltext_search_filter}" - f" ORDER BY relevance DESC" - f" LIMIT {num_candidates}" - f")," - f" scored_results AS (" - f" SELECT *" - f" FROM fulltext_results" - f" WHERE {vector_search_filter}" - f")," - f" group_results AS (" - f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn" - f" FROM scored_results" - f")" - f" SELECT COUNT(*)" - f" FROM group_results" - f" WHERE rn = 1" - ) - else: + if self.use_fulltext_first_fusion_search: count_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" @@ -955,6 +989,22 @@ class OBConnection(DocStoreConnection): f")" f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}" ) + else: + count_sql = ( + f"WITH fulltext_results AS (" + f" SELECT {fulltext_search_hint} id FROM {index_name}" + f" WHERE {filters_expr} AND {fulltext_search_filter}" + f" ORDER BY {fulltext_search_score_expr}" + f" LIMIT {fulltext_topn}" + f")," + f"vector_results AS (" + f" SELECT id FROM {index_name}" + f" WHERE {filters_expr} AND {vector_search_filter}" + f" ORDER BY {vector_search_expr}" + f" APPROXIMATE LIMIT {vector_topn}" + f")" + f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id" + ) logger.debug("OBConnection.search with count sql: %s", count_sql) start_time = time.time() @@ -976,32 +1026,8 @@ class OBConnection(DocStoreConnection): if total_count == 0: continue - score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})" - if group_results: - fusion_sql = ( - f"WITH fulltext_results AS (" - f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" - f" FROM {index_name}" - f" WHERE {filters_expr} AND {fulltext_search_filter}" - f" ORDER BY relevance DESC" - f" LIMIT {num_candidates}" - f")," - f" scored_results AS (" - f" SELECT *, {score_expr} AS _score" - f" FROM fulltext_results" - f" WHERE {vector_search_filter}" - f")," - f" group_results AS (" - f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn" - f" FROM scored_results" - f")" - f" SELECT {fields_expr}, _score" - f" FROM group_results" - f" WHERE rn = 1" - f" ORDER BY _score DESC" - f" LIMIT {offset}, {limit}" - ) - else: + if self.use_fulltext_first_fusion_search: + score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})" fusion_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" @@ -1016,6 +1042,38 @@ class OBConnection(DocStoreConnection): f" ORDER BY _score DESC" f" LIMIT {offset}, {limit}" ) + else: + pagerank_score_expr = f"(CAST(IFNULL(f.{PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)" + score_expr = f"(f.relevance * {1 - vector_similarity_weight} + v.similarity * {vector_similarity_weight} + {pagerank_score_expr})" + fields_expr = ", ".join([f"t.{f} as {f}" for f in output_fields if f != "_score"]) + fusion_sql = ( + f"WITH fulltext_results AS (" + f" SELECT {fulltext_search_hint} id, pagerank_fea, {fulltext_search_score_expr} AS relevance" + f" FROM {index_name}" + f" WHERE {filters_expr} AND {fulltext_search_filter}" + f" ORDER BY relevance DESC" + f" LIMIT {fulltext_topn}" + f")," + f"vector_results AS (" + f" SELECT id, pagerank_fea, {vector_search_score_expr} AS similarity" + f" FROM {index_name}" + f" WHERE {filters_expr} AND {vector_search_filter}" + f" ORDER BY {vector_search_expr}" + f" APPROXIMATE LIMIT {vector_topn}" + f")," + f"combined_results AS (" + f" SELECT COALESCE(f.id, v.id) AS id, {score_expr} AS score" + f" FROM fulltext_results f" + f" FULL OUTER JOIN vector_results v" + f" ON f.id = v.id" + f")" + f" SELECT {fields_expr}, c.score as _score" + f" FROM combined_results c" + f" JOIN {index_name} t" + f" ON c.id = t.id" + f" ORDER BY score DESC" + f" LIMIT {offset}, {limit}" + ) logger.debug("OBConnection.search with fusion sql: %s", fusion_sql) start_time = time.time() @@ -1234,10 +1292,14 @@ class OBConnection(DocStoreConnection): for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) + + if result.total == 0: + result.total = len(result.chunks) + return result def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: - if not self.client.check_table_exists(indexName): + if not self._check_table_exists_cached(indexName): return None try: @@ -1336,7 +1398,7 @@ class OBConnection(DocStoreConnection): return res def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: - if not self.client.check_table_exists(indexName): + if not self._check_table_exists_cached(indexName): return True condition["kb_id"] = knowledgebaseId @@ -1387,7 +1449,7 @@ class OBConnection(DocStoreConnection): return False def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: - if not self.client.check_table_exists(indexName): + if not self._check_table_exists_cached(indexName): return 0 condition["kb_id"] = knowledgebaseId From 34d29d7e8b1b920ea13275ff66332bd27bc05f9c Mon Sep 17 00:00:00 2001 From: balibabu Date: Wed, 10 Dec 2025 19:13:57 +0800 Subject: [PATCH 7/7] Feat: Add configuration for webhook to the begin node. #10427 (#11875) ### What problem does this PR solve? Feat: Add configuration for webhook to the begin node. #10427 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- .../schema-editor/add-field-button.tsx | 5 +- .../components/schema-editor/context.ts | 9 + .../components/schema-editor/interface.ts | 1 + .../schema-editor/schema-property-editor.tsx | 4 + .../schema-editor/schema-visual-editor.tsx | 21 +- web/src/constants/agent.tsx | 16 ++ web/src/locales/en.ts | 31 +++ web/src/locales/zh.ts | 31 +++ .../agent/canvas/node/variable-display.tsx | 26 +-- web/src/pages/agent/constant/index.tsx | 35 +++ web/src/pages/agent/form/agent-form/index.tsx | 12 +- web/src/pages/agent/form/begin-form/index.tsx | 156 ++++++++----- .../form/begin-form/use-handle-mode-change.ts | 76 +++++++ .../begin-form/use-show-schema-dialog.tsx | 28 +++ .../agent/form/begin-form/use-watch-change.ts | 13 ++ .../agent/form/begin-form/webhook/auth.tsx | 139 ++++++++++++ .../begin-form/webhook/dynamic-response.tsx | 213 ++++++++++++++++++ .../agent/form/begin-form/webhook/index.tsx | 134 +++++++++++ .../form/begin-form/webhook/response.tsx | 30 +++ .../form/components/dynamic-string-form.tsx | 46 ++++ .../schema-dialog.tsx} | 12 +- .../schema-panel.tsx} | 2 +- .../hooks/use-build-structured-output.ts | 130 +++++++++-- .../pages/agent/hooks/use-get-begin-query.tsx | 4 +- web/src/pages/agent/utils.ts | 40 ++++ 25 files changed, 1097 insertions(+), 117 deletions(-) create mode 100644 web/src/components/jsonjoy-builder/components/schema-editor/context.ts create mode 100644 web/src/components/jsonjoy-builder/components/schema-editor/interface.ts create mode 100644 web/src/pages/agent/form/begin-form/use-handle-mode-change.ts create mode 100644 web/src/pages/agent/form/begin-form/use-show-schema-dialog.tsx create mode 100644 web/src/pages/agent/form/begin-form/webhook/auth.tsx create mode 100644 web/src/pages/agent/form/begin-form/webhook/dynamic-response.tsx create mode 100644 web/src/pages/agent/form/begin-form/webhook/index.tsx create mode 100644 web/src/pages/agent/form/begin-form/webhook/response.tsx create mode 100644 web/src/pages/agent/form/components/dynamic-string-form.tsx rename web/src/pages/agent/form/{agent-form/structured-output-dialog.tsx => components/schema-dialog.tsx} (81%) rename web/src/pages/agent/form/{agent-form/structured-output-panel.tsx => components/schema-panel.tsx} (78%) diff --git a/web/src/components/jsonjoy-builder/components/schema-editor/add-field-button.tsx b/web/src/components/jsonjoy-builder/components/schema-editor/add-field-button.tsx index 7a25705f9..fe06c1952 100644 --- a/web/src/components/jsonjoy-builder/components/schema-editor/add-field-button.tsx +++ b/web/src/components/jsonjoy-builder/components/schema-editor/add-field-button.tsx @@ -20,6 +20,7 @@ import { CirclePlus, HelpCircle, Info } from 'lucide-react'; import { useId, useState, type FC, type FormEvent } from 'react'; import { useTranslation } from '../../hooks/use-translation'; import type { NewField, SchemaType } from '../../types/json-schema'; +import { KeyInputProps } from './interface'; import SchemaTypeSelector from './schema-type-selector'; interface AddFieldButtonProps { @@ -27,9 +28,10 @@ interface AddFieldButtonProps { variant?: 'primary' | 'secondary'; } -const AddFieldButton: FC = ({ +const AddFieldButton: FC = ({ onAddField, variant = 'primary', + pattern, }) => { const [dialogOpen, setDialogOpen] = useState(false); const [fieldName, setFieldName] = useState(''); @@ -120,6 +122,7 @@ const AddFieldButton: FC = ({ placeholder={t.fieldNamePlaceholder} className="font-mono text-sm w-full" required + searchValue={pattern} /> diff --git a/web/src/components/jsonjoy-builder/components/schema-editor/context.ts b/web/src/components/jsonjoy-builder/components/schema-editor/context.ts new file mode 100644 index 000000000..3fbb14a26 --- /dev/null +++ b/web/src/components/jsonjoy-builder/components/schema-editor/context.ts @@ -0,0 +1,9 @@ +import React, { useContext } from 'react'; +import { KeyInputProps } from './interface'; + +export const KeyInputContext = React.createContext({}); + +export function useInputPattern() { + const x = useContext(KeyInputContext); + return x.pattern; +} diff --git a/web/src/components/jsonjoy-builder/components/schema-editor/interface.ts b/web/src/components/jsonjoy-builder/components/schema-editor/interface.ts new file mode 100644 index 000000000..39e74a641 --- /dev/null +++ b/web/src/components/jsonjoy-builder/components/schema-editor/interface.ts @@ -0,0 +1 @@ +export type KeyInputProps = { pattern?: RegExp | string }; diff --git a/web/src/components/jsonjoy-builder/components/schema-editor/schema-property-editor.tsx b/web/src/components/jsonjoy-builder/components/schema-editor/schema-property-editor.tsx index f95031e9c..347d69d26 100644 --- a/web/src/components/jsonjoy-builder/components/schema-editor/schema-property-editor.tsx +++ b/web/src/components/jsonjoy-builder/components/schema-editor/schema-property-editor.tsx @@ -16,6 +16,7 @@ import { withObjectSchema, } from '../../types/json-schema'; import type { ValidationTreeNode } from '../../types/validation'; +import { useInputPattern } from './context'; import TypeDropdown from './type-dropdown'; import TypeEditor from './type-editor'; @@ -54,6 +55,8 @@ export const SchemaPropertyEditor: React.FC = ({ 'object' as SchemaType, ); + const pattern = useInputPattern(); + // Update temp values when props change useEffect(() => { setTempName(name); @@ -123,6 +126,7 @@ export const SchemaPropertyEditor: React.FC = ({ className="h-8 text-sm font-medium min-w-[120px] max-w-full z-10" autoFocus onFocus={(e) => e.target.select()} + searchValue={pattern} /> ) : ( - + )} {structuredOutputDialogVisible && ( - + > )} ); diff --git a/web/src/pages/agent/form/begin-form/index.tsx b/web/src/pages/agent/form/begin-form/index.tsx index ad4eb9d3e..c86f24cac 100644 --- a/web/src/pages/agent/form/begin-form/index.tsx +++ b/web/src/pages/agent/form/begin-form/index.tsx @@ -12,6 +12,7 @@ import { RAGFlowSelect } from '@/components/ui/select'; import { Switch } from '@/components/ui/switch'; import { Textarea } from '@/components/ui/textarea'; import { FormTooltip } from '@/components/ui/tooltip'; +import { WebhookAlgorithmList } from '@/constants/agent'; import { zodResolver } from '@hookform/resolvers/zod'; import { t } from 'i18next'; import { Plus } from 'lucide-react'; @@ -24,37 +25,71 @@ import { INextOperatorForm } from '../../interface'; import { ParameterDialog } from './parameter-dialog'; import { QueryTable } from './query-table'; import { useEditQueryRecord } from './use-edit-query'; +import { useHandleModeChange } from './use-handle-mode-change'; import { useValues } from './use-values'; import { useWatchFormChange } from './use-watch-change'; +import { WebHook } from './webhook'; const ModeOptions = [ { value: AgentDialogueMode.Conversational, label: t('flow.conversational') }, { value: AgentDialogueMode.Task, label: t('flow.task') }, + { value: AgentDialogueMode.Webhook, label: t('flow.webhook.name') }, ]; +const FormSchema = z.object({ + enablePrologue: z.boolean().optional(), + prologue: z.string().trim().optional(), + mode: z.string(), + inputs: z + .array( + z.object({ + key: z.string(), + type: z.string(), + value: z.string(), + optional: z.boolean(), + name: z.string(), + options: z.array(z.union([z.number(), z.string(), z.boolean()])), + }), + ) + .optional(), + methods: z.string().optional(), + content_types: z.string().optional(), + security: z + .object({ + auth_type: z.string(), + ip_whitelist: z.array(z.object({ value: z.string() })), + rate_limit: z.object({ + limit: z.number(), + per: z.string().optional(), + }), + max_body_size: z.string(), + jwt: z + .object({ + algorithm: z.string().default(WebhookAlgorithmList[0]).optional(), + }) + .optional(), + }) + .optional(), + schema: z.record(z.any()).optional(), + response: z + .object({ + status: z.number(), + headers_template: z.array( + z.object({ key: z.string(), value: z.string() }), + ), + body_template: z.array(z.object({ key: z.string(), value: z.string() })), + }) + .optional(), + execution_mode: z.string().optional(), +}); + +export type BeginFormSchemaType = z.infer; + function BeginForm({ node }: INextOperatorForm) { const { t } = useTranslation(); const values = useValues(node); - const FormSchema = z.object({ - enablePrologue: z.boolean().optional(), - prologue: z.string().trim().optional(), - mode: z.string(), - inputs: z - .array( - z.object({ - key: z.string(), - type: z.string(), - value: z.string(), - optional: z.boolean(), - name: z.string(), - options: z.array(z.union([z.number(), z.string(), z.boolean()])), - }), - ) - .optional(), - }); - const form = useForm({ defaultValues: values, resolver: zodResolver(FormSchema), @@ -72,6 +107,8 @@ function BeginForm({ node }: INextOperatorForm) { const previousModeRef = useRef(mode); + const { handleModeChange } = useHandleModeChange(form); + useEffect(() => { if ( previousModeRef.current === AgentDialogueMode.Task && @@ -111,6 +148,10 @@ function BeginForm({ node }: INextOperatorForm) { placeholder={t('common.pleaseSelect')} options={ModeOptions} {...field} + onChange={(val) => { + handleModeChange(val); + field.onChange(val); + }} > @@ -158,44 +199,49 @@ function BeginForm({ node }: INextOperatorForm) { )} /> )} - {/* Create a hidden field to make Form instance record this */} -
} - /> - - {t('flow.input')} - - - } - rightContent={ - + } > - - - } - > - - - {visible && ( - + + + {visible && ( + + )} + )} diff --git a/web/src/pages/agent/form/begin-form/use-handle-mode-change.ts b/web/src/pages/agent/form/begin-form/use-handle-mode-change.ts new file mode 100644 index 000000000..e85ed5a6e --- /dev/null +++ b/web/src/pages/agent/form/begin-form/use-handle-mode-change.ts @@ -0,0 +1,76 @@ +import { useCallback } from 'react'; +import { UseFormReturn } from 'react-hook-form'; +import { + AgentDialogueMode, + RateLimitPerList, + WebhookExecutionMode, + WebhookMaxBodySize, + WebhookSecurityAuthType, +} from '../../constant'; + +// const WebhookSchema = { +// query: { +// type: 'object', +// required: [], +// properties: { +// // debug: { type: 'boolean' }, +// // event: { type: 'string' }, +// }, +// }, + +// headers: { +// type: 'object', +// required: [], +// properties: { +// // 'X-Trace-ID': { type: 'string' }, +// }, +// }, + +// body: { +// type: 'object', +// required: [], +// properties: { +// id: { type: 'string' }, +// payload: { type: 'object' }, +// }, +// }, +// }; + +const schema = { + properties: { + query: { + type: 'object', + description: '', + }, + headers: { + type: 'object', + description: '', + }, + body: { + type: 'object', + description: '', + }, + }, +}; + +const initialFormValuesMap = { + schema: schema, + 'security.auth_type': WebhookSecurityAuthType.Basic, + 'security.rate_limit.per': RateLimitPerList[0], + 'security.max_body_size': WebhookMaxBodySize[0], + execution_mode: WebhookExecutionMode.Immediately, +}; + +export function useHandleModeChange(form: UseFormReturn) { + const handleModeChange = useCallback( + (mode: AgentDialogueMode) => { + if (mode === AgentDialogueMode.Webhook) { + Object.entries(initialFormValuesMap).forEach(([key, value]) => { + form.setValue(key, value, { shouldDirty: true }); + }); + } + }, + [form], + ); + return { handleModeChange }; +} diff --git a/web/src/pages/agent/form/begin-form/use-show-schema-dialog.tsx b/web/src/pages/agent/form/begin-form/use-show-schema-dialog.tsx new file mode 100644 index 000000000..0bc6261e5 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/use-show-schema-dialog.tsx @@ -0,0 +1,28 @@ +import { JSONSchema } from '@/components/jsonjoy-builder'; +import { useSetModalState } from '@/hooks/common-hooks'; +import { useCallback } from 'react'; +import { UseFormReturn } from 'react-hook-form'; + +export function useShowSchemaDialog(form: UseFormReturn) { + const { + visible: schemaDialogVisible, + showModal: showSchemaDialog, + hideModal: hideSchemaDialog, + } = useSetModalState(); + + const handleSchemaDialogOk = useCallback( + (values: JSONSchema) => { + // Sync data to canvas + form.setValue('schema', values); + hideSchemaDialog(); + }, + [form, hideSchemaDialog], + ); + + return { + schemaDialogVisible, + showSchemaDialog, + hideSchemaDialog, + handleSchemaDialogOk, + }; +} diff --git a/web/src/pages/agent/form/begin-form/use-watch-change.ts b/web/src/pages/agent/form/begin-form/use-watch-change.ts index f0da58068..02158e969 100644 --- a/web/src/pages/agent/form/begin-form/use-watch-change.ts +++ b/web/src/pages/agent/form/begin-form/use-watch-change.ts @@ -1,6 +1,7 @@ import { omit } from 'lodash'; import { useEffect } from 'react'; import { UseFormReturn, useWatch } from 'react-hook-form'; +import { AgentDialogueMode } from '../../constant'; import { BeginQuery } from '../../interface'; import useGraphStore from '../../store'; @@ -20,9 +21,21 @@ export function useWatchFormChange(id?: string, form?: UseFormReturn) { if (id) { values = form?.getValues() || {}; + let outputs: Record = {}; + + // For webhook mode, use schema properties as direct outputs + // Each property (body, headers, query) should be able to show secondary menu + if ( + values.mode === AgentDialogueMode.Webhook && + values.schema?.properties + ) { + outputs = { ...values.schema.properties }; + } + const nextValues = { ...values, inputs: transferInputsArrayToObject(values.inputs), + outputs, }; updateNodeForm(id, nextValues); diff --git a/web/src/pages/agent/form/begin-form/webhook/auth.tsx b/web/src/pages/agent/form/begin-form/webhook/auth.tsx new file mode 100644 index 000000000..4a739b491 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/webhook/auth.tsx @@ -0,0 +1,139 @@ +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Input } from '@/components/ui/input'; +import { WebhookAlgorithmList } from '@/constants/agent'; +import { WebhookSecurityAuthType } from '@/pages/agent/constant'; +import { buildOptions } from '@/utils/form'; +import { useCallback } from 'react'; +import { useFormContext, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +const AlgorithmOptions = buildOptions(WebhookAlgorithmList); + +const RequiredClaimsOptions = buildOptions(['exp', 'sub']); + +export function Auth() { + const { t } = useTranslation(); + const form = useFormContext(); + + const authType = useWatch({ + name: 'security.auth_type', + control: form.control, + }); + + const renderTokenAuth = useCallback( + () => ( + <> + + + + + + + + ), + [t], + ); + + const renderBasicAuth = useCallback( + () => ( + <> + + + + + + + + ), + [t], + ); + + const renderJwtAuth = useCallback( + () => ( + <> + + + + + + + + + + + + + + + + + ), + [t], + ); + + const renderHmacAuth = useCallback( + () => ( + <> + + + + + + + + + + + ), + [t], + ); + + const AuthMap = { + [WebhookSecurityAuthType.Token]: renderTokenAuth, + [WebhookSecurityAuthType.Basic]: renderBasicAuth, + [WebhookSecurityAuthType.Jwt]: renderJwtAuth, + [WebhookSecurityAuthType.Hmac]: renderHmacAuth, + [WebhookSecurityAuthType.None]: () => null, + }; + + return AuthMap[ + (authType ?? WebhookSecurityAuthType.None) as WebhookSecurityAuthType + ](); +} diff --git a/web/src/pages/agent/form/begin-form/webhook/dynamic-response.tsx b/web/src/pages/agent/form/begin-form/webhook/dynamic-response.tsx new file mode 100644 index 000000000..18030feff --- /dev/null +++ b/web/src/pages/agent/form/begin-form/webhook/dynamic-response.tsx @@ -0,0 +1,213 @@ +import { BoolSegmented } from '@/components/bool-segmented'; +import { KeyInput } from '@/components/key-input'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { useIsDarkTheme } from '@/components/theme-provider'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Separator } from '@/components/ui/separator'; +import { Textarea } from '@/components/ui/textarea'; +import { Editor, loader } from '@monaco-editor/react'; +import { X } from 'lucide-react'; +import { ReactNode, useCallback } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { InputMode, TypesWithArray } from '../../../constant'; +import { + InputModeOptions, + buildConversationVariableSelectOptions, +} from '../../../utils'; +import { DynamicFormHeader } from '../../components/dynamic-fom-header'; +import { QueryVariable } from '../../components/query-variable'; + +loader.config({ paths: { vs: '/vs' } }); + +type SelectKeysProps = { + name: string; + label: ReactNode; + tooltip?: string; + keyField?: string; + valueField?: string; + operatorField?: string; + nodeId?: string; +}; + +const VariableTypeOptions = buildConversationVariableSelectOptions(); + +const modeField = 'input_mode'; + +const ConstantValueMap = { + [TypesWithArray.Boolean]: true, + [TypesWithArray.Number]: 0, + [TypesWithArray.String]: '', + [TypesWithArray.ArrayBoolean]: '[]', + [TypesWithArray.ArrayNumber]: '[]', + [TypesWithArray.ArrayString]: '[]', + [TypesWithArray.ArrayObject]: '[]', + [TypesWithArray.Object]: '{}', +}; + +export function DynamicResponse({ + name, + label, + tooltip, + keyField = 'key', + valueField = 'value', + operatorField = 'type', +}: SelectKeysProps) { + const form = useFormContext(); + const isDarkTheme = useIsDarkTheme(); + + const { fields, remove, append } = useFieldArray({ + name: name, + control: form.control, + }); + + const initializeValue = useCallback( + (mode: string, variableType: string, valueFieldAlias: string) => { + if (mode === InputMode.Variable) { + form.setValue(valueFieldAlias, '', { shouldDirty: true }); + } else { + const val = ConstantValueMap[variableType as TypesWithArray]; + form.setValue(valueFieldAlias, val, { shouldDirty: true }); + } + }, + [form], + ); + + const handleModeChange = useCallback( + (mode: string, valueFieldAlias: string, operatorFieldAlias: string) => { + const variableType = form.getValues(operatorFieldAlias); + initializeValue(mode, variableType, valueFieldAlias); + }, + [form, initializeValue], + ); + + const handleVariableTypeChange = useCallback( + (variableType: string, valueFieldAlias: string, modeFieldAlias: string) => { + const mode = form.getValues(modeFieldAlias); + + initializeValue(mode, variableType, valueFieldAlias); + }, + [form, initializeValue], + ); + + const renderParameter = useCallback( + (operatorFieldName: string, modeFieldName: string) => { + const mode = form.getValues(modeFieldName); + const logicalOperator = form.getValues(operatorFieldName); + + if (mode === InputMode.Constant) { + if (logicalOperator === TypesWithArray.Boolean) { + return ; + } + + if (logicalOperator === TypesWithArray.Number) { + return ; + } + + if (logicalOperator === TypesWithArray.String) { + return ; + } + + return ( + + ); + } + + return ( + + ); + }, + [form, isDarkTheme], + ); + + return ( +
+ + append({ + [keyField]: '', + [valueField]: '', + [modeField]: InputMode.Constant, + [operatorField]: TypesWithArray.String, + }) + } + > +
+ {fields.map((field, index) => { + const keyFieldAlias = `${name}.${index}.${keyField}`; + const valueFieldAlias = `${name}.${index}.${valueField}`; + const operatorFieldAlias = `${name}.${index}.${operatorField}`; + const modeFieldAlias = `${name}.${index}.${modeField}`; + + return ( +
+
+
+ + + + + + {(field) => ( + { + handleVariableTypeChange( + val, + valueFieldAlias, + modeFieldAlias, + ); + field.onChange(val); + }} + options={VariableTypeOptions} + > + )} + + + + {(field) => ( + { + handleModeChange( + val, + valueFieldAlias, + operatorFieldAlias, + ); + field.onChange(val); + }} + options={InputModeOptions} + > + )} + +
+ + {renderParameter(operatorFieldAlias, modeFieldAlias)} + +
+ + +
+ ); + })} +
+
+ ); +} diff --git a/web/src/pages/agent/form/begin-form/webhook/index.tsx b/web/src/pages/agent/form/begin-form/webhook/index.tsx new file mode 100644 index 000000000..86e844b07 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/webhook/index.tsx @@ -0,0 +1,134 @@ +import { Collapse } from '@/components/collapse'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Separator } from '@/components/ui/separator'; +import { Textarea } from '@/components/ui/textarea'; +import { buildOptions } from '@/utils/form'; +import { useFormContext, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { + RateLimitPerList, + WebhookContentType, + WebhookExecutionMode, + WebhookMaxBodySize, + WebhookMethod, + WebhookSecurityAuthType, +} from '../../../constant'; +import { DynamicStringForm } from '../../components/dynamic-string-form'; +import { SchemaDialog } from '../../components/schema-dialog'; +import { SchemaPanel } from '../../components/schema-panel'; +import { useShowSchemaDialog } from '../use-show-schema-dialog'; +import { Auth } from './auth'; +import { WebhookResponse } from './response'; + +const RateLimitPerOptions = buildOptions(RateLimitPerList); + +export function WebHook() { + const { t } = useTranslation(); + const form = useFormContext(); + + const executionMode = useWatch({ + control: form.control, + name: 'execution_mode', + }); + + const { + showSchemaDialog, + schemaDialogVisible, + hideSchemaDialog, + handleSchemaDialogOk, + } = useShowSchemaDialog(form); + + const schema = form.getValues('schema'); + + return ( + <> + + + + + + + Security}> +
+ + + + + + + + + + + + + + +
+
+ + + + + + + {executionMode === WebhookExecutionMode.Immediately && ( + + )} + +
+ Schema + +
+ + {schemaDialogVisible && ( + + )} + + ); +} diff --git a/web/src/pages/agent/form/begin-form/webhook/response.tsx b/web/src/pages/agent/form/begin-form/webhook/response.tsx new file mode 100644 index 000000000..a50d212e0 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/webhook/response.tsx @@ -0,0 +1,30 @@ +import { Collapse } from '@/components/collapse'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Input } from '@/components/ui/input'; +import { useTranslation } from 'react-i18next'; +import { DynamicResponse } from './dynamic-response'; + +export function WebhookResponse() { + const { t } = useTranslation(); + + return ( + Response}> +
+ + + + + +
+
+ ); +} diff --git a/web/src/pages/agent/form/components/dynamic-string-form.tsx b/web/src/pages/agent/form/components/dynamic-string-form.tsx new file mode 100644 index 000000000..224e92310 --- /dev/null +++ b/web/src/pages/agent/form/components/dynamic-string-form.tsx @@ -0,0 +1,46 @@ +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Trash2 } from 'lucide-react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { DynamicFormHeader, FormListHeaderProps } from './dynamic-fom-header'; + +type DynamicStringFormProps = { name: string } & FormListHeaderProps; +export function DynamicStringForm({ name, label }: DynamicStringFormProps) { + const form = useFormContext(); + + const { fields, append, remove } = useFieldArray({ + name: name, + control: form.control, + }); + + return ( +
+ append({ value: '' })} + > +
+ {fields.map((field, index) => ( +
+ + + + +
+ ))} +
+
+ ); +} diff --git a/web/src/pages/agent/form/agent-form/structured-output-dialog.tsx b/web/src/pages/agent/form/components/schema-dialog.tsx similarity index 81% rename from web/src/pages/agent/form/agent-form/structured-output-dialog.tsx rename to web/src/pages/agent/form/components/schema-dialog.tsx index 6ce305bff..4d67e00c0 100644 --- a/web/src/pages/agent/form/agent-form/structured-output-dialog.tsx +++ b/web/src/pages/agent/form/components/schema-dialog.tsx @@ -3,6 +3,7 @@ import { JsonSchemaVisualizer, SchemaVisualEditor, } from '@/components/jsonjoy-builder'; +import { KeyInputProps } from '@/components/jsonjoy-builder/components/schema-editor/interface'; import { Button } from '@/components/ui/button'; import { Dialog, @@ -16,11 +17,12 @@ import { IModalProps } from '@/interfaces/common'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; -export function StructuredOutputDialog({ +export function SchemaDialog({ hideModal, onOk, initialValues, -}: IModalProps) { + pattern, +}: IModalProps & KeyInputProps) { const { t } = useTranslation(); const [schema, setSchema] = useState(initialValues); @@ -36,7 +38,11 @@ export function StructuredOutputDialog({
- +
diff --git a/web/src/pages/agent/form/agent-form/structured-output-panel.tsx b/web/src/pages/agent/form/components/schema-panel.tsx similarity index 78% rename from web/src/pages/agent/form/agent-form/structured-output-panel.tsx rename to web/src/pages/agent/form/components/schema-panel.tsx index 64e13c6eb..e76ff726e 100644 --- a/web/src/pages/agent/form/agent-form/structured-output-panel.tsx +++ b/web/src/pages/agent/form/components/schema-panel.tsx @@ -1,6 +1,6 @@ import { JSONSchema, JsonSchemaVisualizer } from '@/components/jsonjoy-builder'; -export function StructuredOutputPanel({ value }: { value: JSONSchema }) { +export function SchemaPanel({ value }: { value: JSONSchema }) { return (
state); + const { getOperatorTypeFromId, getNode } = useGraphStore((state) => state); const showSecondaryMenu = useCallback( (value: string, outputLabel: string) => { const nodeId = getNodeId(value); - return ( - getOperatorTypeFromId(nodeId) === Operator.Agent && + const operatorType = getOperatorTypeFromId(nodeId); + + // For Agent nodes, show secondary menu for 'structured' field + if ( + operatorType === Operator.Agent && outputLabel === AgentStructuredOutputField - ); + ) { + return true; + } + + // For Begin nodes in webhook mode, show secondary menu for schema properties (body, headers, query, etc.) + if (operatorType === Operator.Begin) { + const node = getNode(nodeId); + const mode = get(node, 'data.form.mode'); + if (mode === AgentDialogueMode.Webhook) { + // Check if this output field is from the schema + const outputs = get(node, 'data.form.outputs', {}); + const outputField = outputs[outputLabel]; + // Show secondary menu if the field is an object or has properties + return ( + outputField && + (outputField.type === 'object' || + (outputField.properties && + Object.keys(outputField.properties).length > 0)) + ); + } + } + + return false; }, - [getOperatorTypeFromId], + [getOperatorTypeFromId, getNode], ); return showSecondaryMenu; } +function useGetBeginOutputsOrSchema() { + const { getNode } = useGraphStore((state) => state); + + const getBeginOutputs = useCallback(() => { + const node = getNode(BeginId); + const outputs = get(node, 'data.form.outputs', {}); + return outputs; + }, [getNode]); + + const getBeginSchema = useCallback(() => { + const node = getNode(BeginId); + const outputs = get(node, 'data.form.schema', {}); + return outputs; + }, [getNode]); + + return { getBeginOutputs, getBeginSchema }; +} export function useGetStructuredOutputByValue() { - const { getNode } = useGraphStore((state) => state); + const { getNode, getOperatorTypeFromId } = useGraphStore((state) => state); + + const { getBeginOutputs } = useGetBeginOutputsOrSchema(); const getStructuredOutput = useCallback( (value: string) => { - const node = getNode(getNodeId(value)); - const structuredOutput = get( - node, - `data.form.outputs.${AgentStructuredOutputField}`, - ); + const nodeId = getNodeId(value); + const node = getNode(nodeId); + const operatorType = getOperatorTypeFromId(nodeId); + const fields = splitValue(value); + const outputLabel = fields.at(1); + + let structuredOutput; + if (operatorType === Operator.Agent) { + structuredOutput = get( + node, + `data.form.outputs.${AgentStructuredOutputField}`, + ); + } else if (operatorType === Operator.Begin) { + // For Begin nodes in webhook mode, get the specific schema property + const outputs = getBeginOutputs(); + if (outputLabel) { + structuredOutput = outputs[outputLabel]; + } + } return structuredOutput; }, - [getNode], + [getBeginOutputs, getNode, getOperatorTypeFromId], ); return getStructuredOutput; @@ -66,13 +126,14 @@ export function useFindAgentStructuredOutputLabel() { icon?: ReactNode; }>, ) => { - // agent structured output const fields = splitValue(value); + const operatorType = getOperatorTypeFromId(fields.at(0)); + + // Handle Agent structured fields if ( - getOperatorTypeFromId(fields.at(0)) === Operator.Agent && + operatorType === Operator.Agent && fields.at(1)?.startsWith(AgentStructuredOutputField) ) { - // is agent structured output const agentOption = options.find((x) => value.includes(x.value)); const jsonSchemaFields = fields .at(1) @@ -84,6 +145,19 @@ export function useFindAgentStructuredOutputLabel() { value: value, }; } + + // Handle Begin webhook fields + if (operatorType === Operator.Begin && fields.at(1)) { + const fieldOption = options + .filter((x) => x.parentLabel === BeginId) + .find((x) => value.startsWith(x.value)); + + return { + ...fieldOption, + label: fields.at(1), + value: value, + }; + } }, [getOperatorTypeFromId], ); @@ -94,6 +168,7 @@ export function useFindAgentStructuredOutputLabel() { export function useFindAgentStructuredOutputTypeByValue() { const { getOperatorTypeFromId } = useGraphStore((state) => state); const filterStructuredOutput = useGetStructuredOutputByValue(); + const { getBeginSchema } = useGetBeginOutputsOrSchema(); const findTypeByValue = useCallback( ( @@ -136,10 +211,12 @@ export function useFindAgentStructuredOutputTypeByValue() { } const fields = splitValue(value); const nodeId = fields.at(0); + const operatorType = getOperatorTypeFromId(nodeId); const jsonSchema = filterStructuredOutput(value); + // Handle Agent structured fields if ( - getOperatorTypeFromId(nodeId) === Operator.Agent && + operatorType === Operator.Agent && fields.at(1)?.startsWith(AgentStructuredOutputField) ) { const jsonSchemaFields = fields @@ -151,13 +228,32 @@ export function useFindAgentStructuredOutputTypeByValue() { return type; } } + + // Handle Begin webhook fields (body, headers, query, etc.) + if (operatorType === Operator.Begin) { + const outputLabel = fields.at(1); + const schema = getBeginSchema(); + if (outputLabel && schema) { + const jsonSchemaFields = fields.at(1); + if (jsonSchemaFields) { + const type = findTypeByValue(schema, jsonSchemaFields); + return type; + } + } + } }, - [filterStructuredOutput, findTypeByValue, getOperatorTypeFromId], + [ + filterStructuredOutput, + findTypeByValue, + getBeginSchema, + getOperatorTypeFromId, + ], ); return findAgentStructuredOutputTypeByValue; } +// TODO: Consider merging with useFindAgentStructuredOutputLabel export function useFindAgentStructuredOutputLabelByValue() { const { getNode } = useGraphStore((state) => state); diff --git a/web/src/pages/agent/hooks/use-get-begin-query.tsx b/web/src/pages/agent/hooks/use-get-begin-query.tsx index 387f59821..46825e5a4 100644 --- a/web/src/pages/agent/hooks/use-get-begin-query.tsx +++ b/web/src/pages/agent/hooks/use-get-begin-query.tsx @@ -314,10 +314,12 @@ export function useFilterQueryVariableOptionsByTypes({ ? toLower(y.type).includes(toLower(x)) : toLower(y.type) === toLower(x), ) || + // agent structured output isAgentStructured( y.value, y.value.slice(-AgentStructuredOutputField.length), - ), // agent structured output + ) || + y.value.startsWith(BeginId), // begin node outputs ), }; }) diff --git a/web/src/pages/agent/utils.ts b/web/src/pages/agent/utils.ts index 6825dd9f5..592d92e45 100644 --- a/web/src/pages/agent/utils.ts +++ b/web/src/pages/agent/utils.ts @@ -24,6 +24,7 @@ import { import pipe from 'lodash/fp/pipe'; import isObject from 'lodash/isObject'; import { + AgentDialogueMode, CategorizeAnchorPointPositions, FileType, FileTypeSuffixMap, @@ -34,6 +35,7 @@ import { Operator, TypesWithArray, } from './constant'; +import { BeginFormSchemaType } from './form/begin-form'; import { DataOperationsFormSchemaType } from './form/data-operations-form'; import { ExtractorFormSchemaType } from './form/extractor-form'; import { HierarchicalMergerFormSchemaType } from './form/hierarchical-merger-form'; @@ -312,6 +314,41 @@ function transformDataOperationsParams(params: DataOperationsFormSchemaType) { }; } +export function transformArrayToObject( + list?: Array<{ key: string; value: string }>, +) { + if (!Array.isArray(list)) return {}; + return list?.reduce>((pre, cur) => { + if (cur.key) { + pre[cur.key] = cur.value; + } + return pre; + }, {}); +} + +function transformBeginParams(params: BeginFormSchemaType) { + if (params.mode === AgentDialogueMode.Webhook) { + return { + ...params, + security: { + ...params.security, + ip_whitelist: params.security?.ip_whitelist.map((x) => x.value), + }, + response: { + ...params.response, + headers_template: transformArrayToObject( + params.response?.headers_template, + ), + body_template: transformArrayToObject(params.response?.body_template), + }, + }; + } + + return { + ...params, + }; +} + // construct a dsl based on the node information of the graph export const buildDslComponentsByGraph = ( nodes: RAGFlowNodeType[], @@ -361,6 +398,9 @@ export const buildDslComponentsByGraph = ( case Operator.DataOperations: params = transformDataOperationsParams(params); break; + case Operator.Begin: + params = transformBeginParams(params); + break; default: break; }