diff --git a/api/apps/memory_app.py b/api/apps/memories_app.py similarity index 87% rename from api/apps/memory_app.py rename to api/apps/memories_app.py index fe6caa33f..425d44b96 100644 --- a/api/apps/memory_app.py +++ b/api/apps/memories_app.py @@ -15,22 +15,23 @@ # import logging +from quart import request from api.apps import login_required, current_user from api.db import TenantPermission from api.db.services.memory_service import MemoryService from api.db.services.user_service import UserTenantService -from api.utils.api_utils import validate_request, request_json, get_error_argument_result, get_json_result, \ +from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \ not_allowed_parameters from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from common.constants import MemoryType, RetCode, ForgettingPolicy -@manager.route("/create", methods=["POST"]) # noqa: F821 +@manager.route("", methods=["POST"]) # noqa: F821 @login_required @validate_request("name", "memory_type", "embd_id", "llm_id") async def create_memory(): - req = await request_json() + req = await get_request_json() # check name length name = req["name"] memory_name = name.strip() @@ -64,11 +65,11 @@ async def create_memory(): return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) -@manager.route("/update/", methods=["PUT"]) # noqa: F821 +@manager.route("/", methods=["PUT"]) # noqa: F821 @login_required @not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id") async def update_memory(memory_id): - req = await request_json() + req = await get_request_json() update_dict = {} # check name length if "name" in req: @@ -132,7 +133,7 @@ async def update_memory(memory_id): return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) -@manager.route("/rm/", methods=["DELETE"]) # noqa: F821 +@manager.route("/", methods=["DELETE"]) # noqa: F821 @login_required async def delete_memory(memory_id): memory = MemoryService.get_by_memory_id(memory_id) @@ -146,19 +147,25 @@ async def delete_memory(memory_id): return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) -@manager.route("/list", methods=["POST"]) # noqa: F821 +@manager.route("", methods=["GET"]) # noqa: F821 @login_required async def list_memory(): - req = await request_json() + args = request.args try: - filter_dict = req.get("filter", {}) - keywords = req.get("keywords", "") - page = req.get("page", 1) - page_size = req.get("page_size", 50) - if not filter_dict.get("tenant_id"): + tenant_ids = args.getlist("tenant_id") + memory_types = args.getlist("memory_type") + storage_type = args.get("storage_type") + keywords = args.get("keywords", "") + page = int(args.get("page", 1)) + page_size = int(args.get("page_size", 50)) + # make filter dict + filter_dict = {"memory_type": memory_types, "storage_type": storage_type} + if not tenant_ids: # restrict to current user's tenants user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] + else: + filter_dict["tenant_id"] = tenant_ids memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] @@ -169,7 +176,7 @@ async def list_memory(): return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) -@manager.route("/config/", methods=["GET"]) # noqa: F821 +@manager.route("//config", methods=["GET"]) # noqa: F821 @login_required async def get_memory_config(memory_id): memory = MemoryService.get_with_owner_name_by_id(memory_id) diff --git a/test/testcases/test_web_api/common.py b/test/testcases/test_web_api/common.py index e1595439a..4f4abf722 100644 --- a/test/testcases/test_web_api/common.py +++ b/test/testcases/test_web_api/common.py @@ -28,7 +28,7 @@ CHUNK_API_URL = f"/{VERSION}/chunk" DIALOG_APP_URL = f"/{VERSION}/dialog" # SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" # SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" -MEMORY_API_URL = f"/{VERSION}/memory" +MEMORY_API_URL = f"/{VERSION}/memories" # KB APP @@ -262,30 +262,40 @@ def delete_dialogs(auth): # MEMORY APP def create_memory(auth, payload=None): - url = f"{HOST_ADDRESS}{MEMORY_API_URL}/create" + url = f"{HOST_ADDRESS}{MEMORY_API_URL}" res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() def update_memory(auth, memory_id:str, payload=None): - url = f"{HOST_ADDRESS}{MEMORY_API_URL}/update/{memory_id}" + url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}" res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() def delete_memory(auth, memory_id:str): - url = f"{HOST_ADDRESS}{MEMORY_API_URL}/rm/{memory_id}" + url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}" res = requests.delete(url=url, headers=HEADERS, auth=auth) return res.json() -def list_memory(auth, payload=None): - url = f"{HOST_ADDRESS}{MEMORY_API_URL}/list" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) +def list_memory(auth, params=None): + url = f"{HOST_ADDRESS}{MEMORY_API_URL}" + if params: + query_parts = [] + for key, value in params.items(): + if isinstance(value, list): + for item in value: + query_parts.append(f"{key}={item}") + else: + query_parts.append(f"{key}={value}") + query_string = "&".join(query_parts) + url = f"{url}?{query_string}" + res = requests.get(url=url, headers=HEADERS, auth=auth) return res.json() def get_memory_config(auth, memory_id:str): - url = f"{HOST_ADDRESS}{MEMORY_API_URL}/config/{memory_id}" + url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}/config" res = requests.get(url=url, headers=HEADERS, auth=auth) return res.json() diff --git a/test/testcases/test_web_api/test_memory_app/test_list_memory.py b/test/testcases/test_web_api/test_memory_app/test_list_memory.py index 735599325..e1095358a 100644 --- a/test/testcases/test_web_api/test_memory_app/test_list_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_list_memory.py @@ -30,14 +30,14 @@ class TestAuthorization: ], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = list_memory(invalid_auth, "some_memory_id") + res = list_memory(invalid_auth) assert res["code"] == expected_code, res assert res["message"] == expected_message, res class TestCapability: @pytest.mark.p3 - def test_memory_id(self, WebApiAuth): + def test_capability(self, WebApiAuth): count = 100 with ThreadPoolExecutor(max_workers=5) as executor: futures = [executor.submit(list_memory, WebApiAuth) for i in range(count)] @@ -78,14 +78,21 @@ class TestMemoryList: @pytest.mark.p2 def test_filter_memory_type(self, WebApiAuth): - res = list_memory(WebApiAuth, {"filter": {"memory_type": ["semantic"]}}) + res = list_memory(WebApiAuth, {"memory_type": ["semantic"]}) assert res["code"] == 0, res for memory in res["data"]["memory_list"]: assert "semantic" in memory["memory_type"], res + @pytest.mark.p2 + def test_filter_multi_memory_type(self, WebApiAuth): + res = list_memory(WebApiAuth, {"memory_type": ["episodic", "procedural"]}) + assert res["code"] == 0, res + for memory in res["data"]["memory_list"]: + assert "episodic" in memory["memory_type"] or "procedural" in memory["memory_type"], res + @pytest.mark.p2 def test_filter_storage_type(self, WebApiAuth): - res = list_memory(WebApiAuth, {"filter":{"storage_type": "table"}}) + res = list_memory(WebApiAuth, {"storage_type": "table"}) assert res["code"] == 0, res for memory in res["data"]["memory_list"]: assert memory["storage_type"] == "table", res