[OND211-2329]: Added create department API and tests.
This commit is contained in:
parent
043b06a24d
commit
00b476783c
3 changed files with 511 additions and 5 deletions
|
|
@ -13,19 +13,33 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask import Response, request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from api.apps import smtp_mail_server
|
||||
from api.db import UserTenantRole
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import UserTenant
|
||||
from api.db.services.user_service import UserTenantService, UserService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import get_init_tenant_llm
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import (
|
||||
TenantService,
|
||||
UserService,
|
||||
UserTenantService,
|
||||
)
|
||||
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import delta_seconds
|
||||
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result
|
||||
from api.utils.api_utils import (
|
||||
get_data_error_result,
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
)
|
||||
from api.utils.web_utils import send_invite_email
|
||||
from common import settings
|
||||
|
||||
|
|
@ -119,6 +133,234 @@ def rm(tenant_id, user_id):
|
|||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/create", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("name")
|
||||
def create_team() -> Response:
|
||||
"""
|
||||
Create a new team (tenant). Requires authentication - any registered user can create a team.
|
||||
|
||||
---
|
||||
tags:
|
||||
- Team
|
||||
security:
|
||||
- ApiKeyAuth: []
|
||||
parameters:
|
||||
- in: body
|
||||
name: body
|
||||
description: Team creation details.
|
||||
required: true
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Team name.
|
||||
user_id:
|
||||
type: string
|
||||
description: User ID to set as team owner (optional, defaults to
|
||||
current authenticated user).
|
||||
llm_id:
|
||||
type: string
|
||||
description: LLM model ID (optional, defaults to system default).
|
||||
embd_id:
|
||||
type: string
|
||||
description: Embedding model ID (optional, defaults to system default).
|
||||
asr_id:
|
||||
type: string
|
||||
description: ASR model ID (optional, defaults to system default).
|
||||
parser_ids:
|
||||
type: string
|
||||
description: Document parser IDs (optional, defaults to system default).
|
||||
img2txt_id:
|
||||
type: string
|
||||
description: Image-to-text model ID (optional, defaults to system default).
|
||||
rerank_id:
|
||||
type: string
|
||||
description: Rerank model ID (optional, defaults to system default).
|
||||
credit:
|
||||
type: integer
|
||||
description: Initial credit amount (optional, defaults to 512).
|
||||
responses:
|
||||
200:
|
||||
description: Team created successfully.
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: object
|
||||
description: Created team information.
|
||||
message:
|
||||
type: string
|
||||
description: Success message.
|
||||
401:
|
||||
description: Unauthorized - authentication required.
|
||||
schema:
|
||||
type: object
|
||||
400:
|
||||
description: Invalid request or user not found.
|
||||
schema:
|
||||
type: object
|
||||
500:
|
||||
description: Server error during team creation.
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
# Explicitly check authentication status
|
||||
if not current_user.is_authenticated:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="Unauthorized",
|
||||
code=RetCode.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if request.json is None:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="Request body is required!",
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
)
|
||||
|
||||
req: Dict[str, Any] = request.json
|
||||
team_name: str = req.get("name", "").strip()
|
||||
user_id: Optional[str] = req.get("user_id")
|
||||
|
||||
# Optional configuration parameters (use defaults from settings if not provided)
|
||||
llm_id: Optional[str] = req.get("llm_id")
|
||||
embd_id: Optional[str] = req.get("embd_id")
|
||||
asr_id: Optional[str] = req.get("asr_id")
|
||||
parser_ids: Optional[str] = req.get("parser_ids")
|
||||
img2txt_id: Optional[str] = req.get("img2txt_id")
|
||||
rerank_id: Optional[str] = req.get("rerank_id")
|
||||
credit: Optional[int] = req.get("credit")
|
||||
|
||||
# Validate team name
|
||||
if not team_name:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="Team name is required!",
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
)
|
||||
|
||||
if len(team_name) > 100:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="Team name must be 100 characters or less!",
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
)
|
||||
|
||||
# Determine user_id (use provided or current_user as default)
|
||||
owner_user_id: Optional[str] = user_id
|
||||
if not owner_user_id:
|
||||
# Use current authenticated user as default
|
||||
owner_user_id = current_user.id
|
||||
|
||||
# Verify user exists
|
||||
user: Optional[Any] = UserService.filter_by_id(owner_user_id)
|
||||
if not user:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message=f"User with ID {owner_user_id} not found!",
|
||||
code=RetCode.DATA_ERROR,
|
||||
)
|
||||
|
||||
# Generate tenant ID
|
||||
tenant_id: str = get_uuid()
|
||||
|
||||
# Create tenant with optional parameters (use defaults from settings if not provided)
|
||||
tenant: Dict[str, Any] = {
|
||||
"id": tenant_id,
|
||||
"name": team_name,
|
||||
"llm_id": llm_id if llm_id is not None else settings.CHAT_MDL,
|
||||
"embd_id": embd_id if embd_id is not None else settings.EMBEDDING_MDL,
|
||||
"asr_id": asr_id if asr_id is not None else settings.ASR_MDL,
|
||||
"parser_ids": parser_ids if parser_ids is not None else settings.PARSERS,
|
||||
"img2txt_id": img2txt_id if img2txt_id is not None else settings.IMAGE2TEXT_MDL,
|
||||
"rerank_id": rerank_id if rerank_id is not None else settings.RERANK_MDL,
|
||||
"credit": credit if credit is not None else 512,
|
||||
"status": StatusEnum.VALID.value,
|
||||
}
|
||||
|
||||
# Create user-tenant relationship
|
||||
usr_tenant: Dict[str, Any] = {
|
||||
"id": get_uuid(),
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": owner_user_id,
|
||||
"invited_by": owner_user_id,
|
||||
"role": UserTenantRole.OWNER,
|
||||
"status": StatusEnum.VALID.value,
|
||||
}
|
||||
|
||||
# Create root file folder
|
||||
file_id: str = get_uuid()
|
||||
file: Dict[str, Any] = {
|
||||
"id": file_id,
|
||||
"parent_id": file_id,
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": owner_user_id,
|
||||
"name": "/",
|
||||
"type": FileType.FOLDER.value,
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
|
||||
try:
|
||||
# Get tenant LLM configurations
|
||||
tenant_llm: list[Dict[str, Any]] = get_init_tenant_llm(tenant_id)
|
||||
|
||||
# Insert all records
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
FileService.insert(file)
|
||||
|
||||
# Return created team info
|
||||
team_data: Dict[str, Any] = {
|
||||
"id": tenant_id,
|
||||
"name": team_name,
|
||||
"owner_id": owner_user_id,
|
||||
"llm_id": tenant["llm_id"],
|
||||
"embd_id": tenant["embd_id"],
|
||||
}
|
||||
|
||||
return get_json_result(
|
||||
data=team_data,
|
||||
message=f"Team '{team_name}' created successfully!",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
# Rollback on error
|
||||
try:
|
||||
TenantService.delete_by_id(tenant_id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
UserTenantService.filter_delete(
|
||||
[
|
||||
UserTenant.tenant_id == tenant_id,
|
||||
UserTenant.user_id == owner_user_id,
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
TenantLLMService.delete_by_tenant_id(tenant_id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
FileService.delete_by_id(file_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message=f"Team creation failure, error: {str(e)}",
|
||||
code=RetCode.EXCEPTION_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def tenant_list():
|
||||
|
|
|
|||
|
|
@ -275,3 +275,13 @@ def delete_user(auth, payload=None, *, headers=HEADERS):
|
|||
url = f"{HOST_ADDRESS}{USER_API_URL}/delete"
|
||||
res = requests.delete(url=url, headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
# TEAM MANAGEMENT
|
||||
TEAM_API_URL = f"/{VERSION}/tenant"
|
||||
|
||||
|
||||
def create_team(auth, payload=None, *, headers=HEADERS):
|
||||
url = f"{HOST_ADDRESS}{TEAM_API_URL}/create"
|
||||
res = requests.post(url=url, headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,254 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from common import create_team
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test Classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
"""Tests for authentication behavior during team creation."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invalid_auth", "expected_code", "expected_message"),
|
||||
[
|
||||
# Endpoint now requires @login_required (JWT token auth)
|
||||
(None, 401, "Unauthorized"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "Unauthorized"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(
|
||||
self,
|
||||
invalid_auth: RAGFlowWebApiAuth | None,
|
||||
expected_code: int,
|
||||
expected_message: str,
|
||||
WebApiAuth: RAGFlowWebApiAuth,
|
||||
) -> None:
|
||||
"""Test team creation with invalid or missing authentication."""
|
||||
# Try to create team with invalid auth
|
||||
team_payload: dict[str, str] = {
|
||||
"name": "Test Team Auth",
|
||||
}
|
||||
res: dict[str, Any] = create_team(invalid_auth, team_payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_message:
|
||||
assert expected_message in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestTeamCreate:
|
||||
"""Comprehensive tests for team creation API."""
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_team_with_name_and_user_id(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team with name and user_id."""
|
||||
# Create team (user_id is optional, defaults to current authenticated user)
|
||||
team_name: str = f"Test Team {uuid.uuid4().hex[:8]}"
|
||||
team_payload: dict[str, str] = {
|
||||
"name": team_name,
|
||||
}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
assert res["code"] == 0, res
|
||||
assert "data" in res
|
||||
assert res["data"]["name"] == team_name
|
||||
assert "owner_id" in res["data"]
|
||||
assert "id" in res["data"]
|
||||
assert "deleted successfully" not in res["message"].lower()
|
||||
assert "created successfully" in res["message"].lower()
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_team_missing_name(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team without name."""
|
||||
# Try to create team without name
|
||||
team_payload: dict[str, str] = {}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
assert res["code"] == 101
|
||||
assert "name" in res["message"].lower() or "required" in res[
|
||||
"message"
|
||||
].lower()
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_team_empty_name(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team with empty name."""
|
||||
# Try to create team with empty name
|
||||
team_payload: dict[str, str] = {"name": ""}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
assert res["code"] == 101
|
||||
assert "name" in res["message"].lower() or "required" in res[
|
||||
"message"
|
||||
].lower()
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_team_name_too_long(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team with name exceeding 100 characters."""
|
||||
# Try to create team with name too long
|
||||
long_name: str = "A" * 101
|
||||
team_payload: dict[str, str] = {"name": long_name}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
assert res["code"] == 101
|
||||
assert "100" in res["message"] or "length" in res["message"].lower()
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_team_invalid_user_id(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team with non-existent user_id."""
|
||||
team_payload: dict[str, str] = {
|
||||
"name": "Test Team Invalid User",
|
||||
"user_id": "non_existent_user_id_12345",
|
||||
}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
assert res["code"] == 102
|
||||
assert "not found" in res["message"].lower()
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_team_missing_user_id(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team without user_id (should use current authenticated user)."""
|
||||
team_payload: dict[str, str] = {"name": "Test Team No User"}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
# Should succeed since user_id defaults to current authenticated user
|
||||
assert res["code"] == 0
|
||||
assert "data" in res
|
||||
assert "owner_id" in res["data"]
|
||||
assert "created successfully" in res["message"].lower()
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_team_response_structure(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test that team creation returns the expected response structure."""
|
||||
# Create team
|
||||
team_name: str = f"Test Team Structure {uuid.uuid4().hex[:8]}"
|
||||
team_payload: dict[str, str] = {
|
||||
"name": team_name,
|
||||
}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
assert res["code"] == 0
|
||||
assert "data" in res
|
||||
assert isinstance(res["data"], dict)
|
||||
assert "id" in res["data"]
|
||||
assert "name" in res["data"]
|
||||
assert "owner_id" in res["data"]
|
||||
assert res["data"]["name"] == team_name
|
||||
assert "message" in res
|
||||
assert "created successfully" in res["message"].lower()
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_create_multiple_teams_same_user(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating multiple teams for the same user."""
|
||||
# Create first team
|
||||
team_name_1: str = f"Team 1 {uuid.uuid4().hex[:8]}"
|
||||
team_payload_1: dict[str, str] = {
|
||||
"name": team_name_1,
|
||||
}
|
||||
res1: dict[str, Any] = create_team(WebApiAuth, team_payload_1)
|
||||
assert res1["code"] == 0, res1
|
||||
team_id_1: str = res1["data"]["id"]
|
||||
|
||||
# Create second team
|
||||
team_name_2: str = f"Team 2 {uuid.uuid4().hex[:8]}"
|
||||
team_payload_2: dict[str, str] = {
|
||||
"name": team_name_2,
|
||||
}
|
||||
res2: dict[str, Any] = create_team(WebApiAuth, team_payload_2)
|
||||
assert res2["code"] == 0, res2
|
||||
team_id_2: str = res2["data"]["id"]
|
||||
|
||||
# Verify teams are different
|
||||
assert team_id_1 != team_id_2
|
||||
assert res1["data"]["name"] == team_name_1
|
||||
assert res2["data"]["name"] == team_name_2
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_team_with_whitespace_name(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team with whitespace-only name."""
|
||||
# Try to create team with whitespace-only name
|
||||
team_payload: dict[str, str] = {
|
||||
"name": " ",
|
||||
}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
# Should fail validation
|
||||
assert res["code"] == 101
|
||||
assert "name" in res["message"].lower() or "required" in res[
|
||||
"message"
|
||||
].lower()
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_team_special_characters_in_name(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team with special characters in name."""
|
||||
# Create team with special characters
|
||||
team_name: str = f"Team-{uuid.uuid4().hex[:8]}_Test!"
|
||||
team_payload: dict[str, str] = {
|
||||
"name": team_name,
|
||||
}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
# Should succeed if special chars are allowed
|
||||
assert res["code"] in (0, 101)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_team_empty_payload(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team with empty payload."""
|
||||
team_payload: dict[str, Any] = {}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
assert res["code"] == 101
|
||||
assert "required" in res["message"].lower() or "name" in res[
|
||||
"message"
|
||||
].lower()
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_create_team_unicode_name(
|
||||
self, WebApiAuth: RAGFlowWebApiAuth
|
||||
) -> None:
|
||||
"""Test creating a team with unicode characters in name."""
|
||||
# Create team with unicode name
|
||||
team_name: str = f"团队{uuid.uuid4().hex[:8]}"
|
||||
team_payload: dict[str, str] = {
|
||||
"name": team_name,
|
||||
}
|
||||
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
|
||||
# Should succeed if unicode is supported
|
||||
assert res["code"] in (0, 101)
|
||||
|
||||
Loading…
Add table
Reference in a new issue