fix: restful api
This commit is contained in:
parent
75a9d1d3ce
commit
02ce5c12e0
3 changed files with 50 additions and 26 deletions
|
|
@ -15,22 +15,23 @@
|
|||
#
|
||||
import logging
|
||||
|
||||
from quart import request
|
||||
from api.apps import login_required, current_user
|
||||
from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import validate_request, request_json, get_error_argument_result, get_json_result, \
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
||||
not_allowed_parameters
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
|
||||
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@manager.route("", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||
async def create_memory():
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
# check name length
|
||||
name = req["name"]
|
||||
memory_name = name.strip()
|
||||
|
|
@ -64,11 +65,11 @@ async def create_memory():
|
|||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/update/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
|
||||
async def update_memory(memory_id):
|
||||
req = await request_json()
|
||||
req = await get_request_json()
|
||||
update_dict = {}
|
||||
# check name length
|
||||
if "name" in req:
|
||||
|
|
@ -132,7 +133,7 @@ async def update_memory(memory_id):
|
|||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/rm/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_memory(memory_id):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
|
|
@ -146,19 +147,25 @@ async def delete_memory(memory_id):
|
|||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_memory():
|
||||
req = await request_json()
|
||||
args = request.args
|
||||
try:
|
||||
filter_dict = req.get("filter", {})
|
||||
keywords = req.get("keywords", "")
|
||||
page = req.get("page", 1)
|
||||
page_size = req.get("page_size", 50)
|
||||
if not filter_dict.get("tenant_id"):
|
||||
tenant_ids = args.getlist("tenant_id")
|
||||
memory_types = args.getlist("memory_type")
|
||||
storage_type = args.get("storage_type")
|
||||
keywords = args.get("keywords", "")
|
||||
page = int(args.get("page", 1))
|
||||
page_size = int(args.get("page_size", 50))
|
||||
# make filter dict
|
||||
filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
|
||||
if not tenant_ids:
|
||||
# restrict to current user's tenants
|
||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||
else:
|
||||
filter_dict["tenant_id"] = tenant_ids
|
||||
|
||||
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
||||
|
|
@ -169,7 +176,7 @@ async def list_memory():
|
|||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/config/<memory_id>", methods=["GET"]) # noqa: F821
|
||||
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_config(memory_id):
|
||||
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||
|
|
@ -28,7 +28,7 @@ CHUNK_API_URL = f"/{VERSION}/chunk"
|
|||
DIALOG_APP_URL = f"/{VERSION}/dialog"
|
||||
# SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
|
||||
# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
|
||||
MEMORY_API_URL = f"/{VERSION}/memory"
|
||||
MEMORY_API_URL = f"/{VERSION}/memories"
|
||||
|
||||
|
||||
# KB APP
|
||||
|
|
@ -262,30 +262,40 @@ def delete_dialogs(auth):
|
|||
|
||||
# MEMORY APP
|
||||
def create_memory(auth, payload=None):
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/create"
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}"
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def update_memory(auth, memory_id:str, payload=None):
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/update/{memory_id}"
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}"
|
||||
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_memory(auth, memory_id:str):
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/rm/{memory_id}"
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}"
|
||||
res = requests.delete(url=url, headers=HEADERS, auth=auth)
|
||||
return res.json()
|
||||
|
||||
|
||||
def list_memory(auth, payload=None):
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/list"
|
||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||
def list_memory(auth, params=None):
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}"
|
||||
if params:
|
||||
query_parts = []
|
||||
for key, value in params.items():
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
query_parts.append(f"{key}={item}")
|
||||
else:
|
||||
query_parts.append(f"{key}={value}")
|
||||
query_string = "&".join(query_parts)
|
||||
url = f"{url}?{query_string}"
|
||||
res = requests.get(url=url, headers=HEADERS, auth=auth)
|
||||
return res.json()
|
||||
|
||||
|
||||
def get_memory_config(auth, memory_id:str):
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/config/{memory_id}"
|
||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}/config"
|
||||
res = requests.get(url=url, headers=HEADERS, auth=auth)
|
||||
return res.json()
|
||||
|
|
|
|||
|
|
@ -30,14 +30,14 @@ class TestAuthorization:
|
|||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_code, expected_message):
|
||||
res = list_memory(invalid_auth, "some_memory_id")
|
||||
res = list_memory(invalid_auth)
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestCapability:
|
||||
@pytest.mark.p3
|
||||
def test_memory_id(self, WebApiAuth):
|
||||
def test_capability(self, WebApiAuth):
|
||||
count = 100
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(list_memory, WebApiAuth) for i in range(count)]
|
||||
|
|
@ -78,14 +78,21 @@ class TestMemoryList:
|
|||
|
||||
@pytest.mark.p2
|
||||
def test_filter_memory_type(self, WebApiAuth):
|
||||
res = list_memory(WebApiAuth, {"filter": {"memory_type": ["semantic"]}})
|
||||
res = list_memory(WebApiAuth, {"memory_type": ["semantic"]})
|
||||
assert res["code"] == 0, res
|
||||
for memory in res["data"]["memory_list"]:
|
||||
assert "semantic" in memory["memory_type"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_multi_memory_type(self, WebApiAuth):
|
||||
res = list_memory(WebApiAuth, {"memory_type": ["episodic", "procedural"]})
|
||||
assert res["code"] == 0, res
|
||||
for memory in res["data"]["memory_list"]:
|
||||
assert "episodic" in memory["memory_type"] or "procedural" in memory["memory_type"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_storage_type(self, WebApiAuth):
|
||||
res = list_memory(WebApiAuth, {"filter":{"storage_type": "table"}})
|
||||
res = list_memory(WebApiAuth, {"storage_type": "table"})
|
||||
assert res["code"] == 0, res
|
||||
for memory in res["data"]["memory_list"]:
|
||||
assert memory["storage_type"] == "table", res
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue