fix: restful api
This commit is contained in:
parent
75a9d1d3ce
commit
02ce5c12e0
3 changed files with 50 additions and 26 deletions
|
|
@ -15,22 +15,23 @@
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from quart import request
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
from api.db import TenantPermission
|
from api.db import TenantPermission
|
||||||
from api.db.services.memory_service import MemoryService
|
from api.db.services.memory_service import MemoryService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import validate_request, request_json, get_error_argument_result, get_json_result, \
|
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
||||||
not_allowed_parameters
|
not_allowed_parameters
|
||||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
@manager.route("", methods=["POST"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
@validate_request("name", "memory_type", "embd_id", "llm_id")
|
||||||
async def create_memory():
|
async def create_memory():
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
# check name length
|
# check name length
|
||||||
name = req["name"]
|
name = req["name"]
|
||||||
memory_name = name.strip()
|
memory_name = name.strip()
|
||||||
|
|
@ -64,11 +65,11 @@ async def create_memory():
|
||||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/update/<memory_id>", methods=["PUT"]) # noqa: F821
|
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
|
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
|
||||||
async def update_memory(memory_id):
|
async def update_memory(memory_id):
|
||||||
req = await request_json()
|
req = await get_request_json()
|
||||||
update_dict = {}
|
update_dict = {}
|
||||||
# check name length
|
# check name length
|
||||||
if "name" in req:
|
if "name" in req:
|
||||||
|
|
@ -132,7 +133,7 @@ async def update_memory(memory_id):
|
||||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/rm/<memory_id>", methods=["DELETE"]) # noqa: F821
|
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def delete_memory(memory_id):
|
async def delete_memory(memory_id):
|
||||||
memory = MemoryService.get_by_memory_id(memory_id)
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
|
@ -146,19 +147,25 @@ async def delete_memory(memory_id):
|
||||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
@manager.route("", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def list_memory():
|
async def list_memory():
|
||||||
req = await request_json()
|
args = request.args
|
||||||
try:
|
try:
|
||||||
filter_dict = req.get("filter", {})
|
tenant_ids = args.getlist("tenant_id")
|
||||||
keywords = req.get("keywords", "")
|
memory_types = args.getlist("memory_type")
|
||||||
page = req.get("page", 1)
|
storage_type = args.get("storage_type")
|
||||||
page_size = req.get("page_size", 50)
|
keywords = args.get("keywords", "")
|
||||||
if not filter_dict.get("tenant_id"):
|
page = int(args.get("page", 1))
|
||||||
|
page_size = int(args.get("page_size", 50))
|
||||||
|
# make filter dict
|
||||||
|
filter_dict = {"memory_type": memory_types, "storage_type": storage_type}
|
||||||
|
if not tenant_ids:
|
||||||
# restrict to current user's tenants
|
# restrict to current user's tenants
|
||||||
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
|
||||||
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
|
||||||
|
else:
|
||||||
|
filter_dict["tenant_id"] = tenant_ids
|
||||||
|
|
||||||
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
|
||||||
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
[memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list]
|
||||||
|
|
@ -169,7 +176,7 @@ async def list_memory():
|
||||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/config/<memory_id>", methods=["GET"]) # noqa: F821
|
@manager.route("/<memory_id>/config", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
async def get_memory_config(memory_id):
|
async def get_memory_config(memory_id):
|
||||||
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
memory = MemoryService.get_with_owner_name_by_id(memory_id)
|
||||||
|
|
@ -28,7 +28,7 @@ CHUNK_API_URL = f"/{VERSION}/chunk"
|
||||||
DIALOG_APP_URL = f"/{VERSION}/dialog"
|
DIALOG_APP_URL = f"/{VERSION}/dialog"
|
||||||
# SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
|
# SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
|
||||||
# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
|
# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
|
||||||
MEMORY_API_URL = f"/{VERSION}/memory"
|
MEMORY_API_URL = f"/{VERSION}/memories"
|
||||||
|
|
||||||
|
|
||||||
# KB APP
|
# KB APP
|
||||||
|
|
@ -262,30 +262,40 @@ def delete_dialogs(auth):
|
||||||
|
|
||||||
# MEMORY APP
|
# MEMORY APP
|
||||||
def create_memory(auth, payload=None):
|
def create_memory(auth, payload=None):
|
||||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/create"
|
url = f"{HOST_ADDRESS}{MEMORY_API_URL}"
|
||||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
|
|
||||||
def update_memory(auth, memory_id:str, payload=None):
|
def update_memory(auth, memory_id:str, payload=None):
|
||||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/update/{memory_id}"
|
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}"
|
||||||
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
|
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
|
|
||||||
def delete_memory(auth, memory_id:str):
|
def delete_memory(auth, memory_id:str):
|
||||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/rm/{memory_id}"
|
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}"
|
||||||
res = requests.delete(url=url, headers=HEADERS, auth=auth)
|
res = requests.delete(url=url, headers=HEADERS, auth=auth)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
|
|
||||||
def list_memory(auth, payload=None):
|
def list_memory(auth, params=None):
|
||||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/list"
|
url = f"{HOST_ADDRESS}{MEMORY_API_URL}"
|
||||||
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
|
if params:
|
||||||
|
query_parts = []
|
||||||
|
for key, value in params.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
query_parts.append(f"{key}={item}")
|
||||||
|
else:
|
||||||
|
query_parts.append(f"{key}={value}")
|
||||||
|
query_string = "&".join(query_parts)
|
||||||
|
url = f"{url}?{query_string}"
|
||||||
|
res = requests.get(url=url, headers=HEADERS, auth=auth)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
|
|
||||||
def get_memory_config(auth, memory_id:str):
|
def get_memory_config(auth, memory_id:str):
|
||||||
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/config/{memory_id}"
|
url = f"{HOST_ADDRESS}{MEMORY_API_URL}/{memory_id}/config"
|
||||||
res = requests.get(url=url, headers=HEADERS, auth=auth)
|
res = requests.get(url=url, headers=HEADERS, auth=auth)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
|
||||||
|
|
@ -30,14 +30,14 @@ class TestAuthorization:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_auth_invalid(self, invalid_auth, expected_code, expected_message):
|
def test_auth_invalid(self, invalid_auth, expected_code, expected_message):
|
||||||
res = list_memory(invalid_auth, "some_memory_id")
|
res = list_memory(invalid_auth)
|
||||||
assert res["code"] == expected_code, res
|
assert res["code"] == expected_code, res
|
||||||
assert res["message"] == expected_message, res
|
assert res["message"] == expected_message, res
|
||||||
|
|
||||||
|
|
||||||
class TestCapability:
|
class TestCapability:
|
||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
def test_memory_id(self, WebApiAuth):
|
def test_capability(self, WebApiAuth):
|
||||||
count = 100
|
count = 100
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
futures = [executor.submit(list_memory, WebApiAuth) for i in range(count)]
|
futures = [executor.submit(list_memory, WebApiAuth) for i in range(count)]
|
||||||
|
|
@ -78,14 +78,21 @@ class TestMemoryList:
|
||||||
|
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
def test_filter_memory_type(self, WebApiAuth):
|
def test_filter_memory_type(self, WebApiAuth):
|
||||||
res = list_memory(WebApiAuth, {"filter": {"memory_type": ["semantic"]}})
|
res = list_memory(WebApiAuth, {"memory_type": ["semantic"]})
|
||||||
assert res["code"] == 0, res
|
assert res["code"] == 0, res
|
||||||
for memory in res["data"]["memory_list"]:
|
for memory in res["data"]["memory_list"]:
|
||||||
assert "semantic" in memory["memory_type"], res
|
assert "semantic" in memory["memory_type"], res
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_filter_multi_memory_type(self, WebApiAuth):
|
||||||
|
res = list_memory(WebApiAuth, {"memory_type": ["episodic", "procedural"]})
|
||||||
|
assert res["code"] == 0, res
|
||||||
|
for memory in res["data"]["memory_list"]:
|
||||||
|
assert "episodic" in memory["memory_type"] or "procedural" in memory["memory_type"], res
|
||||||
|
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
def test_filter_storage_type(self, WebApiAuth):
|
def test_filter_storage_type(self, WebApiAuth):
|
||||||
res = list_memory(WebApiAuth, {"filter":{"storage_type": "table"}})
|
res = list_memory(WebApiAuth, {"storage_type": "table"})
|
||||||
assert res["code"] == 0, res
|
assert res["code"] == 0, res
|
||||||
for memory in res["data"]["memory_list"]:
|
for memory in res["data"]["memory_list"]:
|
||||||
assert memory["storage_type"] == "table", res
|
assert memory["storage_type"] == "table", res
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue