fix: restful api

This commit is contained in:
Lynn 2025-12-09 11:58:14 +08:00
parent 75a9d1d3ce
commit 02ce5c12e0
3 changed files with 50 additions and 26 deletions

View file

@ -15,22 +15,23 @@
#
import logging
from quart import request
from api.apps import login_required, current_user
from api.db import TenantPermission
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import validate_request, request_json, get_error_argument_result, get_json_result, \
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
not_allowed_parameters
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from common.constants import MemoryType, RetCode, ForgettingPolicy
@manager.route("/create", methods=["POST"]) # noqa: F821
@manager.route("", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "memory_type", "embd_id", "llm_id")
async def create_memory():
req = await request_json()
req = await get_request_json()
# check name length
name = req["name"]
memory_name = name.strip()
@ -64,11 +65,11 @@ async def create_memory():
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/update/<memory_id>", methods=["PUT"]) # noqa: F821
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
@login_required
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
async def update_memory(memory_id):
req = await request_json()
req = await get_request_json()
update_dict = {}
# check name length
if "name" in req:
@ -132,7 +133,7 @@ async def update_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/rm/<memory_id>", methods=["DELETE"]) # noqa: F821
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
@ -146,19 +147,25 @@ async def delete_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/list", methods=["POST"]) # noqa: F821
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
req = await request_json()
args = request.args
try:
filter_dict = req.get("filter", {})
keywords = req.get("keywords", "")
page = req.get("page", 1)
page_size = req.get("page_size", 50)
if not filter_dict.get("tenant_id"):
tenant_ids = args.getlist("tenant_id")
memory_types = args.getlist("memory_type")
storage_type = args.get("storage_type")
keywords = args.get("keywords", "")
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
# make filter dict
filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
if not tenant_ids:
# restrict to current user's tenants
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
else:
filter_dict["tenant_id"] = tenant_ids
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
@ -169,7 +176,7 @@ async def list_memory():
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/config/<memory_id>", methods=["GET"]) # noqa: F821
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_config(memory_id):
memory = MemoryService.get_with_owner_name_by_id(memory_id)

View file

@ -28,7 +28,7 @@ CHUNK_API_URL = f"/{VERSION}/chunk"
DIALOG_APP_URL = f"/{VERSION}/dialog"
# SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
MEMORY_API_URL = f"/{VERSION}/memory"
MEMORY_API_URL = f"/{VERSION}/memories"
# KB APP
@ -262,30 +262,40 @@ def delete_dialogs(auth):
# MEMORY APP
def create_memory(auth, payload=None):
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/create"
url = f"{HOST_ADDRESS}{MEMORY_API_URL}"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def update_memory(auth, memory_id:str, payload=None):
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/update/{memory_id}"
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}"
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def delete_memory(auth, memory_id:str):
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/rm/{memory_id}"
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}"
res = requests.delete(url=url, headers=HEADERS, auth=auth)
return res.json()
def list_memory(auth, payload=None):
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/list"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
def list_memory(auth, params=None):
url = f"{HOST_ADDRESS}{MEMORY_API_URL}"
if params:
query_parts = []
for key, value in params.items():
if isinstance(value, list):
for item in value:
query_parts.append(f"{key}={item}")
else:
query_parts.append(f"{key}={value}")
query_string = "&".join(query_parts)
url = f"{url}?{query_string}"
res = requests.get(url=url, headers=HEADERS, auth=auth)
return res.json()
def get_memory_config(auth, memory_id:str):
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/config/{memory_id}"
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}/config"
res = requests.get(url=url, headers=HEADERS, auth=auth)
return res.json()

View file

@ -30,14 +30,14 @@ class TestAuthorization:
],
)
def test_auth_invalid(self, invalid_auth, expected_code, expected_message):
res = list_memory(invalid_auth, "some_memory_id")
res = list_memory(invalid_auth)
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
class TestCapability:
@pytest.mark.p3
def test_memory_id(self, WebApiAuth):
def test_capability(self, WebApiAuth):
count = 100
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_memory, WebApiAuth) for i in range(count)]
@ -78,14 +78,21 @@ class TestMemoryList:
@pytest.mark.p2
def test_filter_memory_type(self, WebApiAuth):
res = list_memory(WebApiAuth, {"filter": {"memory_type": ["semantic"]}})
res = list_memory(WebApiAuth, {"memory_type": ["semantic"]})
assert res["code"] == 0, res
for memory in res["data"]["memory_list"]:
assert "semantic" in memory["memory_type"], res
@pytest.mark.p2
def test_filter_multi_memory_type(self, WebApiAuth):
res = list_memory(WebApiAuth, {"memory_type": ["episodic", "procedural"]})
assert res["code"] == 0, res
for memory in res["data"]["memory_list"]:
assert "episodic" in memory["memory_type"] or "procedural" in memory["memory_type"], res
@pytest.mark.p2
def test_filter_storage_type(self, WebApiAuth):
res = list_memory(WebApiAuth, {"filter":{"storage_type": "table"}})
res = list_memory(WebApiAuth, {"storage_type": "table"})
assert res["code"] == 0, res
for memory in res["data"]["memory_list"]:
assert memory["storage_type"] == "table", res