[OND211-2329]: Updated create_user and update_user APIs and corresponding test cases to be compliant with PEP 8, Ruff, and MyPy standards.

This commit is contained in:
Hetavi Shah 2025-11-11 12:21:25 +05:30
parent d442bf0504
commit 61b84b0302
3 changed files with 456 additions and 260 deletions

View file

@ -759,6 +759,7 @@ def user_add():
def create_user() -> Response: def create_user() -> Response:
""" """
Create a new user. Create a new user.
--- ---
tags: tags:
- User - User
@ -808,12 +809,21 @@ def create_user() -> Response:
description: Server error during user creation. description: Server error during user creation.
schema: schema:
type: object type: object
""" """
if request.json is None:
return get_json_result(
data=False,
message="Request body is required!",
code=RetCode.ARGUMENT_ERROR,
)
req: Dict[str, Any] = request.json req: Dict[str, Any] = request.json
email_address: str = req["email"] email_address: str = req["email"]
# Validate the email address # Validate the email address
email_match: Optional[Match[str]] = re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address) email_match: Optional[Match[str]] = re.match(
r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email_address
)
if not email_match: if not email_match:
return get_json_result( return get_json_result(
data=False, data=False,
@ -822,8 +832,9 @@ def create_user() -> Response:
) )
# Check if the email address is already used # Check if the email address is already used
existing_users: Any = UserService.query(email=email_address) existing_users_query = UserService.query(email=email_address)
if existing_users: existing_users_list: List[User] = list(existing_users_query)
if existing_users_list:
return get_json_result( return get_json_result(
data=False, data=False,
message=f"Email: {email_address} has already registered!", message=f"Email: {email_address} has already registered!",
@ -833,7 +844,7 @@ def create_user() -> Response:
# Construct user info data # Construct user info data
nickname: str = req["nickname"] nickname: str = req["nickname"]
is_superuser: bool = req.get("is_superuser", False) is_superuser: bool = req.get("is_superuser", False)
try: try:
password: str = decrypt(req["password"]) password: str = decrypt(req["password"])
except BaseException: except BaseException:
@ -855,13 +866,13 @@ def create_user() -> Response:
user_id: str = get_uuid() user_id: str = get_uuid()
try: try:
users: Any = user_register(user_id, user_dict) users_query = user_register(user_id, user_dict)
if not users: if not users_query:
raise Exception(f"Fail to create user {email_address}.") raise Exception(f"Fail to create user {email_address}.")
users_list: List[User] = list(users) users_list: List[User] = list(users_query)
if len(users_list) > 1: if len(users_list) > 1:
raise Exception(f"Same email: {email_address} exists!") raise Exception(f"Same email: {email_address} exists!")
user: User = users_list[0] user: User = users_list[0]
return get_json_result( return get_json_result(
data=user.to_dict(), data=user.to_dict(),
@ -901,10 +912,13 @@ def update_user() -> Response:
description: User ID to update (optional if email is provided). description: User ID to update (optional if email is provided).
email: email:
type: string type: string
description: User email to identify the user (optional if user_id is provided). If user_id is provided, this can be used as new_email. description: User email to identify the user (optional if user_id
is provided). If user_id is provided, this can be used as
new_email.
new_email: new_email:
type: string type: string
description: New email address (optional). Use this to update email when identifying user by user_id. description: New email address (optional). Use this to update email
when identifying user by user_id.
nickname: nickname:
type: string type: string
description: New nickname (optional). description: New nickname (optional).
@ -938,11 +952,18 @@ def update_user() -> Response:
schema: schema:
type: object type: object
""" """
if request.json is None:
return get_json_result(
data=False,
message="Request body is required!",
code=RetCode.ARGUMENT_ERROR,
)
req: Dict[str, Any] = request.json req: Dict[str, Any] = request.json
user_id: Optional[str] = req.get("user_id") user_id: Optional[str] = req.get("user_id")
email: Optional[str] = req.get("email") email: Optional[str] = req.get("email")
identified_by_user_id: bool = bool(user_id) identified_by_user_id: bool = bool(user_id)
# Validate that either user_id or email is provided # Validate that either user_id or email is provided
if not user_id and not email: if not user_id and not email:
return get_json_result( return get_json_result(
@ -950,24 +971,26 @@ def update_user() -> Response:
message="Either user_id or email must be provided!", message="Either user_id or email must be provided!",
code=RetCode.ARGUMENT_ERROR, code=RetCode.ARGUMENT_ERROR,
) )
# Find the user by user_id or email # Find the user by user_id or email
user: Optional[User] = None user: Optional[User] = None
if user_id: if user_id:
user = UserService.filter_by_id(user_id) user = UserService.filter_by_id(user_id)
elif email: elif email:
# Validate the email address format # Validate the email address format
email_match: Optional[Match[str]] = re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email) email_match: Optional[Match[str]] = re.match(
r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", email
)
if not email_match: if not email_match:
return get_json_result( return get_json_result(
data=False, data=False,
message=f"Invalid email address: {email}!", message=f"Invalid email address: {email}!",
code=RetCode.OPERATING_ERROR, code=RetCode.OPERATING_ERROR,
) )
users: Any = UserService.query(email=email) users_query = UserService.query(email=email)
users_list: List[User] = list(users) users_list: List[User] = list(users_query)
if not users_list: if not users_list:
return get_json_result( return get_json_result(
data=False, data=False,
@ -982,17 +1005,17 @@ def update_user() -> Response:
) )
user = users_list[0] user = users_list[0]
user_id = user.id user_id = user.id
if not user: if not user:
return get_json_result( return get_json_result(
data=False, data=False,
message="User not found!", message="User not found!",
code=RetCode.DATA_ERROR, code=RetCode.DATA_ERROR,
) )
# Build update dictionary # Build update dictionary
update_dict: Dict[str, Any] = {} update_dict: Dict[str, Any] = {}
# Handle nickname update # Handle nickname update
# Allow empty nickname (empty string is a valid value) # Allow empty nickname (empty string is a valid value)
if "nickname" in req: if "nickname" in req:
@ -1000,7 +1023,7 @@ def update_user() -> Response:
# Only skip if explicitly None, allow empty strings # Only skip if explicitly None, allow empty strings
if nickname is not None: if nickname is not None:
update_dict["nickname"] = nickname update_dict["nickname"] = nickname
# Handle password update # Handle password update
if "password" in req and req["password"]: if "password" in req and req["password"]:
try: try:
@ -1012,7 +1035,7 @@ def update_user() -> Response:
code=RetCode.SERVER_ERROR, code=RetCode.SERVER_ERROR,
message="Fail to decrypt password", message="Fail to decrypt password",
) )
# Handle email update # Handle email update
# If user_id was used to identify, "email" in req can be the new email # If user_id was used to identify, "email" in req can be the new email
# Otherwise, use "new_email" field # Otherwise, use "new_email" field
@ -1021,33 +1044,37 @@ def update_user() -> Response:
new_email = req["email"] new_email = req["email"]
elif "new_email" in req and req["new_email"]: elif "new_email" in req and req["new_email"]:
new_email = req["new_email"] new_email = req["new_email"]
if new_email: if new_email:
# Validate the new email address format # Validate the new email address format
email_match: Optional[Match[str]] = re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", new_email) email_match: Optional[Match[str]] = re.match(
r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,}$", new_email
)
if not email_match: if not email_match:
return get_json_result( return get_json_result(
data=False, data=False,
message=f"Invalid email address: {new_email}!", message=f"Invalid email address: {new_email}!",
code=RetCode.OPERATING_ERROR, code=RetCode.OPERATING_ERROR,
) )
# Check if the new email is already used by another user # Check if the new email is already used by another user
existing_users: Any = UserService.query(email=new_email) existing_users_query = UserService.query(email=new_email)
existing_users_list: List[User] = list(existing_users) existing_users_list: List[User] = list(existing_users_query)
if existing_users_list and existing_users_list[0].id != user_id: if existing_users_list and existing_users_list[0].id != user_id:
return get_json_result( return get_json_result(
data=False, data=False,
message=f"Email: {new_email} is already in use by another user!", message=(
f"Email: {new_email} is already in use by another user!"
),
code=RetCode.OPERATING_ERROR, code=RetCode.OPERATING_ERROR,
) )
update_dict["email"] = new_email update_dict["email"] = new_email
# Handle is_superuser update # Handle is_superuser update
if "is_superuser" in req: if "is_superuser" in req:
is_superuser: bool = req.get("is_superuser", False) is_superuser: bool = req.get("is_superuser", False)
update_dict["is_superuser"] = is_superuser update_dict["is_superuser"] = is_superuser
# If no fields to update, return error # If no fields to update, return error
if not update_dict: if not update_dict:
return get_json_result( return get_json_result(
@ -1055,7 +1082,7 @@ def update_user() -> Response:
message="No valid fields to update!", message="No valid fields to update!",
code=RetCode.ARGUMENT_ERROR, code=RetCode.ARGUMENT_ERROR,
) )
# Update the user # Update the user
try: try:
UserService.update_user(user_id, update_dict) UserService.update_user(user_id, update_dict)

View file

@ -13,57 +13,87 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from __future__ import annotations
import base64 import base64
import os import os
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from typing import Any
import pytest import pytest
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from Cryptodome.PublicKey import RSA
from common import create_user from common import create_user
from configs import INVALID_API_TOKEN from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowHttpApiAuth from libs.auth import RAGFlowHttpApiAuth
# ---------------------------------------------------------------------------
# Utility Functions
# ---------------------------------------------------------------------------
def encrypt_password(password: str) -> str: def encrypt_password(password: str) -> str:
""" """
Encrypt password for API calls without importing from api.utils.crypt Encrypt password for API calls without importing from api.utils.crypt.
Avoids ModuleNotFoundError caused by test helper module named `common`. Avoids ModuleNotFoundError caused by test helper module named `common`.
""" """
# test/testcases/test_http_api/test_user_management/test_create_user.py -> project root
current_dir: str = os.path.dirname(os.path.abspath(__file__)) current_dir: str = os.path.dirname(os.path.abspath(__file__))
project_base: str = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) project_base: str = os.path.abspath(
os.path.join(current_dir, "..", "..", "..", "..")
)
file_path: str = os.path.join(project_base, "conf", "public.pem") file_path: str = os.path.join(project_base, "conf", "public.pem")
rsa_key: RSA.RSAKey = RSA.importKey(open(file_path).read(), "Welcome") with open(file_path, encoding="utf-8") as pem_file:
cipher: Cipher_pkcs1_v1_5.Cipher_pkcs1_v1_5 = Cipher_pkcs1_v1_5.new(rsa_key) rsa_key: RSA.RsaKey = RSA.import_key(
password_base64: str = base64.b64encode(password.encode("utf-8")).decode("utf-8") pem_file.read(), passphrase="Welcome"
encrypted_password: str = cipher.encrypt(password_base64.encode()) )
return base64.b64encode(encrypted_password).decode("utf-8")
cipher: Cipher_pkcs1_v1_5.PKCS115_Cipher = Cipher_pkcs1_v1_5.new(rsa_key)
password_base64: str = base64.b64encode(password.encode()).decode()
encrypted_password: bytes = cipher.encrypt(password_base64.encode())
return base64.b64encode(encrypted_password).decode()
# ---------------------------------------------------------------------------
# Test Classes
# ---------------------------------------------------------------------------
@pytest.mark.p1 @pytest.mark.p1
class TestAuthorization: class TestAuthorization:
"""Tests for authentication behavior during user creation."""
@pytest.mark.parametrize( @pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message", ("invalid_auth", "expected_code", "expected_message"),
[ [
# Note: @login_required is commented out, so endpoint works without auth # Note: @login_required is commented out, so endpoint works
# Testing with None auth should succeed (code 0) if endpoint doesn't require auth # without auth
# Testing with None auth should succeed (code 0) if endpoint
# doesn't require auth
(None, 0, ""), (None, 0, ""),
# Invalid token should also work if auth is not required # Invalid token should also work if auth is not required
(RAGFlowHttpApiAuth(INVALID_API_TOKEN), 0, ""), (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 0, ""),
], ],
) )
def test_invalid_auth(self, invalid_auth, expected_code, expected_message): def test_invalid_auth(
self,
invalid_auth: RAGFlowHttpApiAuth | None,
expected_code: int,
expected_message: str,
) -> None:
"""Test user creation with invalid or missing authentication."""
# Use unique email to avoid conflicts # Use unique email to avoid conflicts
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload = { payload: dict[str, str] = {
"nickname": "test_user", "nickname": "test_user",
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
res = create_user(invalid_auth, payload) res: dict[str, Any] = create_user(invalid_auth, payload)
assert res["code"] == expected_code, res assert res["code"] == expected_code, res
if expected_message: if expected_message:
assert expected_message in res["message"] assert expected_message in res["message"]
@ -71,25 +101,84 @@ class TestAuthorization:
@pytest.mark.usefixtures("clear_users") @pytest.mark.usefixtures("clear_users")
class TestUserCreate: class TestUserCreate:
"""Comprehensive tests for user creation API."""
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload, expected_code, expected_message", ("payload", "expected_code", "expected_message"),
[ [
({"nickname": "valid_user", "email": "valid@example.com", "password": encrypt_password("test123")}, 0, ""), (
({"nickname": "", "email": "test@example.com", "password": encrypt_password("test123")}, 0, ""), # Empty nickname is accepted {
({"nickname": "test_user", "email": "", "password": encrypt_password("test123")}, 103, "Invalid email address"), "nickname": "valid_user",
({"nickname": "test_user", "email": "test@example.com", "password": ""}, 500, "Fail to decrypt password"), "email": "valid@example.com",
({"nickname": "test_user", "email": "test@example.com"}, 101, "required argument are missing"), "password": encrypt_password("test123"),
({"nickname": "test_user", "password": encrypt_password("test123")}, 101, "required argument are missing"), },
({"email": "test@example.com", "password": encrypt_password("test123")}, 101, "required argument are missing"), 0,
"",
),
(
{
"nickname": "",
"email": "test@example.com",
"password": encrypt_password("test123"),
},
0,
"",
), # Empty nickname is accepted
(
{
"nickname": "test_user",
"email": "",
"password": encrypt_password("test123"),
},
103,
"Invalid email address",
),
(
{
"nickname": "test_user",
"email": "test@example.com",
"password": "",
},
500,
"Fail to decrypt password",
),
(
{"nickname": "test_user", "email": "test@example.com"},
101,
"required argument are missing",
),
(
{
"nickname": "test_user",
"password": encrypt_password("test123"),
},
101,
"required argument are missing",
),
(
{
"email": "test@example.com",
"password": encrypt_password("test123"),
},
101,
"required argument are missing",
),
], ],
) )
def test_required_fields(self, HttpApiAuth: RAGFlowHttpApiAuth, payload: dict, expected_code: int, expected_message: str) -> None: def test_required_fields(
self,
HttpApiAuth: RAGFlowHttpApiAuth,
payload: dict[str, Any],
expected_code: int,
expected_message: str,
) -> None:
"""Test user creation with various required field combinations."""
if payload.get("email") and "@" in payload.get("email", ""): if payload.get("email") and "@" in payload.get("email", ""):
# Use unique email to avoid conflicts # Use unique email to avoid conflicts
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload["email"] = unique_email payload["email"] = unique_email
res = create_user(HttpApiAuth, payload) res: dict[str, Any] = create_user(HttpApiAuth, payload)
assert res["code"] == expected_code, res assert res["code"] == expected_code, res
if expected_code == 0: if expected_code == 0:
assert res["data"]["nickname"] == payload["nickname"] assert res["data"]["nickname"] == payload["nickname"]
@ -99,7 +188,7 @@ class TestUserCreate:
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"email, expected_code, expected_message", ("email", "expected_code", "expected_message"),
[ [
("valid@example.com", 0, ""), ("valid@example.com", 0, ""),
("user.name@example.com", 0, ""), ("user.name@example.com", 0, ""),
@ -112,16 +201,23 @@ class TestUserCreate:
("", 103, "Invalid email address"), ("", 103, "Invalid email address"),
], ],
) )
def test_email_validation(self, HttpApiAuth, email, expected_code, expected_message): def test_email_validation(
self,
HttpApiAuth: RAGFlowHttpApiAuth,
email: str,
expected_code: int,
expected_message: str,
) -> None:
"""Test email validation with various email formats."""
if email and "@" in email and expected_code == 0: if email and "@" in email and expected_code == 0:
# Use unique email to avoid conflicts # Use unique email to avoid conflicts
email = f"test_{uuid.uuid4().hex[:8]}@example.com" email = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload = { payload: dict[str, str] = {
"nickname": "test_user", "nickname": "test_user",
"email": email, "email": email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
res = create_user(HttpApiAuth, payload) res: dict[str, Any] = create_user(HttpApiAuth, payload)
assert res["code"] == expected_code, res assert res["code"] == expected_code, res
if expected_code == 0: if expected_code == 0:
assert res["data"]["email"] == email assert res["data"]["email"] == email
@ -130,7 +226,7 @@ class TestUserCreate:
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"nickname, expected_code, expected_message", ("nickname", "expected_code", "expected_message"),
[ [
("valid_nickname", 0, ""), ("valid_nickname", 0, ""),
("user123", 0, ""), ("user123", 0, ""),
@ -139,14 +235,21 @@ class TestUserCreate:
("", 0, ""), # Empty nickname is accepted by the API ("", 0, ""), # Empty nickname is accepted by the API
], ],
) )
def test_nickname(self, HttpApiAuth, nickname, expected_code, expected_message): def test_nickname(
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" self,
payload = { HttpApiAuth: RAGFlowHttpApiAuth,
nickname: str,
expected_code: int,
expected_message: str,
) -> None:
"""Test nickname validation with various nickname formats."""
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict[str, str] = {
"nickname": nickname, "nickname": nickname,
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
res = create_user(HttpApiAuth, payload) res: dict[str, Any] = create_user(HttpApiAuth, payload)
assert res["code"] == expected_code, res assert res["code"] == expected_code, res
if expected_code == 0: if expected_code == 0:
assert res["data"]["nickname"] == nickname assert res["data"]["nickname"] == nickname
@ -154,38 +257,47 @@ class TestUserCreate:
assert expected_message in res["message"] assert expected_message in res["message"]
@pytest.mark.p1 @pytest.mark.p1
def test_duplicate_email(self, HttpApiAuth): def test_duplicate_email(
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" self, HttpApiAuth: RAGFlowHttpApiAuth
payload = { ) -> None:
"""Test that creating a user with duplicate email fails."""
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict[str, str] = {
"nickname": "test_user_1", "nickname": "test_user_1",
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
res = create_user(HttpApiAuth, payload) res: dict[str, Any] = create_user(HttpApiAuth, payload)
assert res["code"] == 0 assert res["code"] == 0
# Try to create another user with the same email # Try to create another user with the same email
payload2 = { payload2: dict[str, str] = {
"nickname": "test_user_2", "nickname": "test_user_2",
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
res2 = create_user(HttpApiAuth, payload2) res2: dict[str, Any] = create_user(HttpApiAuth, payload2)
assert res2["code"] == 103 assert res2["code"] == 103
assert "has already registered" in res2["message"] assert "has already registered" in res2["message"]
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"is_superuser, expected_value", ("is_superuser", "expected_value"),
[ [
(True, True), (True, True),
(False, False), (False, False),
(None, False), # Default should be False (None, False), # Default should be False
], ],
) )
def test_is_superuser(self, HttpApiAuth, is_superuser, expected_value): def test_is_superuser(
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" self,
payload = { HttpApiAuth: RAGFlowHttpApiAuth,
is_superuser: bool | None,
expected_value: bool,
) -> None:
"""Test is_superuser flag handling during user creation."""
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict[str, Any] = {
"nickname": "test_user", "nickname": "test_user",
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
@ -193,67 +305,89 @@ class TestUserCreate:
if is_superuser is not None: if is_superuser is not None:
payload["is_superuser"] = is_superuser payload["is_superuser"] = is_superuser
res = create_user(HttpApiAuth, payload) res: dict[str, Any] = create_user(HttpApiAuth, payload)
assert res["code"] == 0 assert res["code"] == 0
assert res["data"]["is_superuser"] == expected_value assert res["data"]["is_superuser"] == expected_value
@pytest.mark.p2 @pytest.mark.p2
def test_password_encryption(self, HttpApiAuth): def test_password_encryption(
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" self, HttpApiAuth: RAGFlowHttpApiAuth
password = "test_password_123" ) -> None:
payload = { """Test that password is properly encrypted and hashed."""
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
password: str = "test_password_123"
payload: dict[str, str] = {
"nickname": "test_user", "nickname": "test_user",
"email": unique_email, "email": unique_email,
"password": encrypt_password(password), "password": encrypt_password(password),
} }
res = create_user(HttpApiAuth, payload) res: dict[str, Any] = create_user(HttpApiAuth, payload)
assert res["code"] == 0 assert res["code"] == 0
# Password should be hashed in the response (not plain text) # Password should be hashed in the response (not plain text)
assert "password" in res["data"], f"Password field not found in response: {res['data'].keys()}" assert "password" in res["data"], (
assert res["data"]["password"].startswith("scrypt:"), f"Password is not hashed: {res['data']['password']}" f"Password field not found in response: {res['data'].keys()}"
)
assert res["data"]["password"].startswith("scrypt:"), (
f"Password is not hashed: {res['data']['password']}"
)
# Verify it's not the plain password # Verify it's not the plain password
assert res["data"]["password"] != password assert res["data"]["password"] != password
assert res["data"]["password"] != encrypt_password(password) assert res["data"]["password"] != encrypt_password(password)
@pytest.mark.p2 @pytest.mark.p2
def test_invalid_password_encryption(self, HttpApiAuth): def test_invalid_password_encryption(
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" self, HttpApiAuth: RAGFlowHttpApiAuth
payload = { ) -> None:
"""Test that plain text password without encryption fails."""
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict[str, str] = {
"nickname": "test_user", "nickname": "test_user",
"email": unique_email, "email": unique_email,
"password": "plain_text_password", # Not encrypted "password": "plain_text_password", # Not encrypted
} }
res = create_user(HttpApiAuth, payload) res: dict[str, Any] = create_user(HttpApiAuth, payload)
# Should fail to decrypt password # Should fail to decrypt password
assert res["code"] == 500 assert res["code"] == 500
assert "Fail to decrypt password" in res["message"] assert "Fail to decrypt password" in res["message"]
@pytest.mark.p3 @pytest.mark.p3
def test_concurrent_create(self, HttpApiAuth): def test_concurrent_create(
count = 10 self, HttpApiAuth: RAGFlowHttpApiAuth
) -> None:
"""Test concurrent user creation with multiple threads."""
count: int = 10
with ThreadPoolExecutor(max_workers=5) as executor: with ThreadPoolExecutor(max_workers=5) as executor:
futures = [] futures: list[Future[dict[str, Any]]] = []
for i in range(count): for i in range(count):
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload = { payload: dict[str, str] = {
"nickname": f"test_user_{i}", "nickname": f"test_user_{i}",
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
futures.append(executor.submit(create_user, HttpApiAuth, payload)) futures.append(
responses = list(as_completed(futures)) executor.submit(create_user, HttpApiAuth, payload)
assert len(responses) == count, responses )
assert all(future.result()["code"] == 0 for future in futures) responses: list[Future[dict[str, Any]]] = list(
as_completed(futures)
)
assert len(responses) == count, responses
assert all(
future.result()["code"] == 0 for future in futures
)
@pytest.mark.p2 @pytest.mark.p2
def test_user_creation_response_structure(self, HttpApiAuth): def test_user_creation_response_structure(
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" self, HttpApiAuth: RAGFlowHttpApiAuth
payload = { ) -> None:
"""Test that user creation returns the expected response structure."""
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict[str, str] = {
"nickname": "test_user", "nickname": "test_user",
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
res = create_user(HttpApiAuth, payload) res: dict[str, Any] = create_user(HttpApiAuth, payload)
assert res["code"] == 0 assert res["code"] == 0
assert "data" in res assert "data" in res
assert "id" in res["data"] assert "id" in res["data"]

View file

@ -13,47 +13,65 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from __future__ import annotations
import base64 import base64
import os import os
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from typing import Any
import pytest import pytest
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from Cryptodome.PublicKey import RSA
from common import create_user, update_user from common import create_user, update_user
from configs import INVALID_API_TOKEN from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowHttpApiAuth from libs.auth import RAGFlowHttpApiAuth
# ---------------------------------------------------------------------------
# Utility Functions
# ---------------------------------------------------------------------------
def encrypt_password(password: str) -> str: def encrypt_password(password: str) -> str:
""" """
Encrypt password for API calls without importing from api.utils.crypt Encrypt password for API calls without importing from api.utils.crypt.
Avoids ModuleNotFoundError caused by test helper module named `common`. Avoids ModuleNotFoundError caused by test helper module named `common`.
""" """
# test/testcases/test_http_api/test_user_management/test_update_user.py -> project root
current_dir: str = os.path.dirname(os.path.abspath(__file__)) current_dir: str = os.path.dirname(os.path.abspath(__file__))
project_base: str = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) project_base: str = os.path.abspath(
os.path.join(current_dir, "..", "..", "..", "..")
)
file_path: str = os.path.join(project_base, "conf", "public.pem") file_path: str = os.path.join(project_base, "conf", "public.pem")
rsa_key: RSA.RSAKey = RSA.importKey(open(file_path).read(), "Welcome") with open(file_path, encoding="utf-8") as pem_file:
cipher: Cipher_pkcs1_v1_5.Cipher_pkcs1_v1_5 = Cipher_pkcs1_v1_5.new(rsa_key) rsa_key: RSA.RsaKey = RSA.import_key(pem_file.read(), passphrase="Welcome")
password_base64: str = base64.b64encode(password.encode("utf-8")).decode("utf-8")
encrypted_password: str = cipher.encrypt(password_base64.encode()) cipher: Cipher_pkcs1_v1_5.PKCS115_Cipher = Cipher_pkcs1_v1_5.new(rsa_key)
return base64.b64encode(encrypted_password).decode("utf-8") password_base64: str = base64.b64encode(password.encode()).decode()
encrypted_password: bytes = cipher.encrypt(password_base64.encode())
return base64.b64encode(encrypted_password).decode()
@pytest.fixture # ---------------------------------------------------------------------------
def test_user(HttpApiAuth: RAGFlowHttpApiAuth) -> dict: # Fixtures
"""Create a test user for update tests""" # ---------------------------------------------------------------------------
@pytest.fixture(name="test_user")
def fixture_test_user(HttpApiAuth: RAGFlowHttpApiAuth) -> dict[str, Any]:
"""Create a temporary user for update tests."""
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com" unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict = { payload: dict[str, str] = {
"nickname": "test_user_original", "nickname": "test_user_original",
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
res: dict = create_user(HttpApiAuth, payload)
res: dict[str, Any] = create_user(HttpApiAuth, payload)
assert res["code"] == 0, f"Failed to create test user: {res}" assert res["code"] == 0, f"Failed to create test user: {res}"
return { return {
"user_id": res["data"]["id"], "user_id": res["data"]["id"],
"email": unique_email, "email": unique_email,
@ -61,24 +79,34 @@ def test_user(HttpApiAuth: RAGFlowHttpApiAuth) -> dict:
} }
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.p1 @pytest.mark.p1
class TestAuthorization: class TestAuthorization:
"""Tests for authentication behavior during user updates."""
@pytest.mark.parametrize( @pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message", ("invalid_auth", "expected_code", "expected_message"),
[ [
# Note: @login_required is commented out, so endpoint works without auth # Endpoint works without auth (decorator commented out)
# Testing with None auth should succeed (code 0) if endpoint doesn't require auth
(None, 0, ""), (None, 0, ""),
# Invalid token should also work if auth is not required
(RAGFlowHttpApiAuth(INVALID_API_TOKEN), 0, ""), (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 0, ""),
], ],
) )
def test_invalid_auth(self, invalid_auth, expected_code, expected_message, test_user): def test_invalid_auth(
payload: dict = { self,
invalid_auth: RAGFlowHttpApiAuth | None,
expected_code: int,
expected_message: str,
test_user: dict[str, Any],
) -> None:
payload: dict[str, Any] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"nickname": "updated_nickname", "nickname": "updated_nickname",
} }
res: dict = update_user(invalid_auth, payload) res: dict[str, Any] = update_user(invalid_auth, payload)
assert res["code"] == expected_code, res assert res["code"] == expected_code, res
if expected_message: if expected_message:
assert expected_message in res["message"] assert expected_message in res["message"]
@ -86,82 +114,93 @@ class TestAuthorization:
@pytest.mark.usefixtures("clear_users") @pytest.mark.usefixtures("clear_users")
class TestUserUpdate: class TestUserUpdate:
"""Comprehensive tests for user update API."""
@pytest.mark.p1 @pytest.mark.p1
def test_update_with_user_id(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_with_user_id(
"""Test updating user by user_id""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
payload: dict = { ) -> None:
payload: dict[str, Any] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"nickname": "updated_nickname", "nickname": "updated_nickname",
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"]["nickname"] == "updated_nickname" assert res["data"]["nickname"] == "updated_nickname"
assert res["data"]["email"] == test_user["email"] assert res["data"]["email"] == test_user["email"]
assert "updated successfully" in res["message"].lower() assert "updated successfully" in res["message"].lower()
@pytest.mark.p1 @pytest.mark.p1
def test_update_with_email(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_with_email(
"""Test updating user by email""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
payload: dict = { ) -> None:
payload: dict[str, Any] = {
"email": test_user["email"], "email": test_user["email"],
"nickname": "updated_nickname_email", "nickname": "updated_nickname_email",
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"]["nickname"] == "updated_nickname_email" assert res["data"]["nickname"] == "updated_nickname_email"
assert res["data"]["email"] == test_user["email"] assert res["data"]["email"] == test_user["email"]
@pytest.mark.p1 @pytest.mark.p1
def test_update_missing_identifier(self, HttpApiAuth: RAGFlowHttpApiAuth) -> None: def test_update_missing_identifier(
"""Test update without user_id or email""" self, HttpApiAuth: RAGFlowHttpApiAuth
payload: dict = { ) -> None:
"nickname": "updated_nickname", """Test update without user_id or email."""
} payload: dict[str, str] = {"nickname": "updated_nickname"}
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 101 # ARGUMENT_ERROR assert res["code"] == 101
assert "Either user_id or email must be provided" in res["message"] assert "Either user_id or email must be provided" in res["message"]
@pytest.mark.p1 @pytest.mark.p1
def test_update_user_not_found_by_id(self, HttpApiAuth: RAGFlowHttpApiAuth) -> None: def test_update_user_not_found_by_id(
"""Test update with non-existent user_id""" self, HttpApiAuth: RAGFlowHttpApiAuth
payload: dict = { ) -> None:
payload: dict[str, str] = {
"user_id": "non_existent_user_id_12345", "user_id": "non_existent_user_id_12345",
"nickname": "updated_nickname", "nickname": "updated_nickname",
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 102 # DATA_ERROR assert res["code"] == 102
assert "User not found" in res["message"] assert "User not found" in res["message"]
@pytest.mark.p1 @pytest.mark.p1
def test_update_user_not_found_by_email(self, HttpApiAuth: RAGFlowHttpApiAuth) -> None: def test_update_user_not_found_by_email(
"""Test update with non-existent email""" self, HttpApiAuth: RAGFlowHttpApiAuth
payload: dict = { ) -> None:
payload: dict[str, str] = {
"email": "nonexistent@example.com", "email": "nonexistent@example.com",
"nickname": "updated_nickname", "nickname": "updated_nickname",
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 102 # DATA_ERROR assert res["code"] == 102
assert "not found" in res["message"] assert "not found" in res["message"]
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"nickname, expected_code, expected_message", ("nickname", "expected_code", "expected_message"),
[ [
("valid_nickname", 0, ""), ("valid_nickname", 0, ""),
("user123", 0, ""), ("user123", 0, ""),
("user_name", 0, ""), ("user_name", 0, ""),
("User Name", 0, ""), ("User Name", 0, ""),
("", 0, ""), # Empty nickname is accepted ("", 0, ""), # Empty nickname accepted
], ],
) )
def test_update_nickname( def test_update_nickname(
self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict, nickname: str, expected_code: int, expected_message: str self,
HttpApiAuth: RAGFlowHttpApiAuth,
test_user: dict[str, Any],
nickname: str,
expected_code: int,
expected_message: str,
) -> None: ) -> None:
payload: dict = { payload: dict[str, str] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"nickname": nickname, "nickname": nickname,
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == expected_code, res assert res["code"] == expected_code, res
if expected_code == 0: if expected_code == 0:
assert res["data"]["nickname"] == nickname assert res["data"]["nickname"] == nickname
@ -169,31 +208,33 @@ class TestUserUpdate:
assert expected_message in res["message"] assert expected_message in res["message"]
@pytest.mark.p1 @pytest.mark.p1
def test_update_password(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_password(
"""Test updating user password""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
) -> None:
new_password: str = "new_password_456" new_password: str = "new_password_456"
payload: dict = { payload: dict[str, str] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"password": encrypt_password(new_password), "password": encrypt_password(new_password),
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
assert "updated successfully" in res["message"].lower() assert "updated successfully" in res["message"].lower()
@pytest.mark.p1 @pytest.mark.p1
def test_update_password_invalid_encryption(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_password_invalid_encryption(
"""Test updating password with invalid encryption""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
payload: dict = { ) -> None:
payload: dict[str, str] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"password": "plain_text_password", # Not encrypted "password": "plain_text_password",
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 500 assert res["code"] == 500
assert "Fail to decrypt password" in res["message"] assert "Fail to decrypt password" in res["message"]
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"new_email, expected_code, expected_message", ("new_email", "expected_code", "expected_message"),
[ [
("valid@example.com", 0, ""), ("valid@example.com", 0, ""),
("user.name@example.com", 0, ""), ("user.name@example.com", 0, ""),
@ -206,16 +247,20 @@ class TestUserUpdate:
], ],
) )
def test_update_email( def test_update_email(
self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict, new_email: str, expected_code: int, expected_message: str self,
HttpApiAuth: RAGFlowHttpApiAuth,
test_user: dict[str, Any],
new_email: str,
expected_code: int,
expected_message: str,
) -> None: ) -> None:
if "@" in new_email and expected_code == 0: if "@" in new_email and expected_code == 0:
# Use unique email to avoid conflicts
new_email = f"test_{uuid.uuid4().hex[:8]}@example.com" new_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict = { payload: dict[str, str] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"new_email": new_email, "new_email": new_email,
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == expected_code, res assert res["code"] == expected_code, res
if expected_code == 0: if expected_code == 0:
assert res["data"]["email"] == new_email assert res["data"]["email"] == new_email
@ -223,159 +268,149 @@ class TestUserUpdate:
assert expected_message in res["message"] assert expected_message in res["message"]
@pytest.mark.p1 @pytest.mark.p1
def test_update_email_duplicate(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_email_duplicate(
"""Test updating email to an already used email""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
# Create another user ) -> None:
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com" unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
create_payload: dict = { create_payload: dict[str, str] = {
"nickname": "another_user", "nickname": "another_user",
"email": unique_email, "email": unique_email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
create_res: dict = create_user(HttpApiAuth, create_payload) create_res: dict[str, Any] = create_user(HttpApiAuth, create_payload)
assert create_res["code"] == 0 assert create_res["code"] == 0
# Try to update test_user's email to the same email update_payload: dict[str, str] = {
update_payload: dict = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"new_email": unique_email, "new_email": unique_email,
} }
res: dict = update_user(HttpApiAuth, update_payload) res: dict[str, Any] = update_user(HttpApiAuth, update_payload)
assert res["code"] == 103 # OPERATING_ERROR assert res["code"] == 103
assert "already in use" in res["message"] assert "already in use" in res["message"]
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"is_superuser, expected_value", ("is_superuser", "expected_value"),
[ [(True, True), (False, False)],
(True, True),
(False, False),
],
) )
def test_update_is_superuser( def test_update_is_superuser(
self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict, is_superuser: bool, expected_value: bool self,
HttpApiAuth: RAGFlowHttpApiAuth,
test_user: dict[str, Any],
is_superuser: bool,
expected_value: bool,
) -> None: ) -> None:
payload: dict = { payload: dict[str, Any] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"is_superuser": is_superuser, "is_superuser": is_superuser,
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"]["is_superuser"] == expected_value assert res["data"]["is_superuser"] is expected_value
@pytest.mark.p1 @pytest.mark.p1
def test_update_multiple_fields(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_multiple_fields(
"""Test updating multiple fields at once""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
) -> None:
new_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com" new_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict = { payload: dict[str, Any] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"nickname": "updated_multiple", "nickname": "updated_multiple",
"new_email": new_email, "new_email": new_email,
"is_superuser": True, "is_superuser": True,
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"]["nickname"] == "updated_multiple" assert res["data"]["nickname"] == "updated_multiple"
assert res["data"]["email"] == new_email assert res["data"]["email"] == new_email
assert res["data"]["is_superuser"] is True assert res["data"]["is_superuser"] is True
@pytest.mark.p1 @pytest.mark.p1
def test_update_no_fields(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_no_fields(
"""Test update with no fields to update""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
payload: dict = { ) -> None:
"user_id": test_user["user_id"], payload: dict[str, str] = {"user_id": test_user["user_id"]}
} res: dict[str, Any] = update_user(HttpApiAuth, payload)
res: dict = update_user(HttpApiAuth, payload) assert res["code"] == 101
assert res["code"] == 101 # ARGUMENT_ERROR
assert "No valid fields to update" in res["message"] assert "No valid fields to update" in res["message"]
@pytest.mark.p1 @pytest.mark.p1
def test_update_email_using_email_field_when_user_id_provided( def test_update_email_using_email_field_when_user_id_provided(
self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
) -> None: ) -> None:
"""Test that when user_id is provided, 'email' field can be used as new_email"""
new_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com" new_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
payload: dict = { payload: dict[str, str] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"email": new_email, # When user_id is provided, email is treated as new_email "email": new_email,
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"]["email"] == new_email assert res["data"]["email"] == new_email
@pytest.mark.p2 @pytest.mark.p2
def test_update_response_structure(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_response_structure(
"""Test that update response has correct structure""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
payload: dict = { ) -> None:
payload: dict[str, Any] = {
"user_id": test_user["user_id"], "user_id": test_user["user_id"],
"nickname": "response_test", "nickname": "response_test",
} }
res: dict = update_user(HttpApiAuth, payload) res: dict[str, Any] = update_user(HttpApiAuth, payload)
assert res["code"] == 0 assert res["code"] == 0
assert "data" in res assert set(("id", "email", "nickname")) <= res["data"].keys()
assert "id" in res["data"]
assert "email" in res["data"]
assert "nickname" in res["data"]
assert res["data"]["nickname"] == "response_test" assert res["data"]["nickname"] == "response_test"
assert "updated successfully" in res["message"].lower() assert "updated successfully" in res["message"].lower()
@pytest.mark.p2 @pytest.mark.p2
def test_concurrent_updates(self, HttpApiAuth: RAGFlowHttpApiAuth) -> None: def test_concurrent_updates(
"""Test concurrent updates to different users""" self, HttpApiAuth: RAGFlowHttpApiAuth
# Create multiple users ) -> None:
users: list = [] """Test concurrent updates to different users."""
users: list[dict[str, Any]] = []
for i in range(5): for i in range(5):
unique_email: str = f"test_{uuid.uuid4().hex[:8]}@example.com" email: str = f"test_{uuid.uuid4().hex[:8]}@example.com"
create_payload: dict = { create_payload: dict[str, str] = {
"nickname": f"user_{i}", "nickname": f"user_{i}",
"email": unique_email, "email": email,
"password": encrypt_password("test123"), "password": encrypt_password("test123"),
} }
create_res: dict = create_user(HttpApiAuth, create_payload) create_res: dict[str, Any] = create_user(
HttpApiAuth, create_payload
)
assert create_res["code"] == 0 assert create_res["code"] == 0
users.append(create_res["data"]) users.append(create_res["data"])
# Update all users concurrently
with ThreadPoolExecutor(max_workers=5) as executor: with ThreadPoolExecutor(max_workers=5) as executor:
futures: list = [] futures: list[Future[dict[str, Any]]] = [
for i, user in enumerate(users): executor.submit(
payload: dict = { update_user,
"user_id": user["id"], HttpApiAuth,
"nickname": f"updated_user_{i}", {
} "user_id": u["id"],
futures.append(executor.submit(update_user, HttpApiAuth, payload)) "nickname": f"updated_user_{i}",
},
)
for i, u in enumerate(users)
]
responses: list = list(as_completed(futures)) for future in as_completed(futures):
assert len(responses) == 5 res: dict[str, Any] = future.result()
assert all(future.result()["code"] == 0 for future in futures) assert res["code"] == 0
@pytest.mark.p3 @pytest.mark.p3
def test_update_same_user_multiple_times(self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict) -> None: def test_update_same_user_multiple_times(
"""Test updating the same user multiple times""" self, HttpApiAuth: RAGFlowHttpApiAuth, test_user: dict[str, Any]
# First update ) -> None:
payload1: dict = { """Test repeated updates on the same user."""
"user_id": test_user["user_id"], for nickname in (
"nickname": "first_update", "first_update",
} "second_update",
res1: dict = update_user(HttpApiAuth, payload1) "third_update",
assert res1["code"] == 0 ):
assert res1["data"]["nickname"] == "first_update" payload: dict[str, str] = {
"user_id": test_user["user_id"],
# Second update "nickname": nickname,
payload2: dict = { }
"user_id": test_user["user_id"], res: dict[str, Any] = update_user(HttpApiAuth, payload)
"nickname": "second_update", assert res["code"] == 0
} assert res["data"]["nickname"] == nickname
res2: dict = update_user(HttpApiAuth, payload2)
assert res2["code"] == 0
assert res2["data"]["nickname"] == "second_update"
# Third update
payload3: dict = {
"user_id": test_user["user_id"],
"nickname": "third_update",
}
res3: dict = update_user(HttpApiAuth, payload3)
assert res3["code"] == 0
assert res3["data"]["nickname"] == "third_update"