[OND211-2329]: Added create department API and tests.

This commit is contained in:
Hetavi Shah 2025-11-12 19:01:27 +05:30
parent 043b06a24d
commit 00b476783c
3 changed files with 511 additions and 5 deletions

View file

@ -13,19 +13,33 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
from typing import Any, Dict, Optional
from flask import request from flask import Response, request
from flask_login import login_required, current_user from flask_login import current_user, login_required
from api.apps import smtp_mail_server from api.apps import smtp_mail_server
from api.db import UserTenantRole from api.db import FileType, UserTenantRole
from api.db.db_models import UserTenant from api.db.db_models import UserTenant
from api.db.services.user_service import UserTenantService, UserService from api.db.services.file_service import FileService
from api.db.services.llm_service import get_init_tenant_llm
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import (
TenantService,
UserService,
UserTenantService,
)
from common.constants import RetCode, StatusEnum from common.constants import RetCode, StatusEnum
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.time_utils import delta_seconds from common.time_utils import delta_seconds
from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result from api.utils.api_utils import (
get_data_error_result,
get_json_result,
server_error_response,
validate_request,
)
from api.utils.web_utils import send_invite_email from api.utils.web_utils import send_invite_email
from common import settings from common import settings
@ -119,6 +133,234 @@ def rm(tenant_id, user_id):
return server_error_response(e) return server_error_response(e)
@manager.route("/create", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name")
def create_team() -> Response:
"""
Create a new team (tenant). Requires authentication - any registered user can create a team.
---
tags:
- Team
security:
- ApiKeyAuth: []
parameters:
- in: body
name: body
description: Team creation details.
required: true
schema:
type: object
required:
- name
properties:
name:
type: string
description: Team name.
user_id:
type: string
description: User ID to set as team owner (optional, defaults to
current authenticated user).
llm_id:
type: string
description: LLM model ID (optional, defaults to system default).
embd_id:
type: string
description: Embedding model ID (optional, defaults to system default).
asr_id:
type: string
description: ASR model ID (optional, defaults to system default).
parser_ids:
type: string
description: Document parser IDs (optional, defaults to system default).
img2txt_id:
type: string
description: Image-to-text model ID (optional, defaults to system default).
rerank_id:
type: string
description: Rerank model ID (optional, defaults to system default).
credit:
type: integer
description: Initial credit amount (optional, defaults to 512).
responses:
200:
description: Team created successfully.
schema:
type: object
properties:
data:
type: object
description: Created team information.
message:
type: string
description: Success message.
401:
description: Unauthorized - authentication required.
schema:
type: object
400:
description: Invalid request or user not found.
schema:
type: object
500:
description: Server error during team creation.
schema:
type: object
"""
# Explicitly check authentication status
if not current_user.is_authenticated:
return get_json_result(
data=False,
message="Unauthorized",
code=RetCode.UNAUTHORIZED,
)
if request.json is None:
return get_json_result(
data=False,
message="Request body is required!",
code=RetCode.ARGUMENT_ERROR,
)
req: Dict[str, Any] = request.json
team_name: str = req.get("name", "").strip()
user_id: Optional[str] = req.get("user_id")
# Optional configuration parameters (use defaults from settings if not provided)
llm_id: Optional[str] = req.get("llm_id")
embd_id: Optional[str] = req.get("embd_id")
asr_id: Optional[str] = req.get("asr_id")
parser_ids: Optional[str] = req.get("parser_ids")
img2txt_id: Optional[str] = req.get("img2txt_id")
rerank_id: Optional[str] = req.get("rerank_id")
credit: Optional[int] = req.get("credit")
# Validate team name
if not team_name:
return get_json_result(
data=False,
message="Team name is required!",
code=RetCode.ARGUMENT_ERROR,
)
if len(team_name) > 100:
return get_json_result(
data=False,
message="Team name must be 100 characters or less!",
code=RetCode.ARGUMENT_ERROR,
)
# Determine user_id (use provided or current_user as default)
owner_user_id: Optional[str] = user_id
if not owner_user_id:
# Use current authenticated user as default
owner_user_id = current_user.id
# Verify user exists
user: Optional[Any] = UserService.filter_by_id(owner_user_id)
if not user:
return get_json_result(
data=False,
message=f"User with ID {owner_user_id} not found!",
code=RetCode.DATA_ERROR,
)
# Generate tenant ID
tenant_id: str = get_uuid()
# Create tenant with optional parameters (use defaults from settings if not provided)
tenant: Dict[str, Any] = {
"id": tenant_id,
"name": team_name,
"llm_id": llm_id if llm_id is not None else settings.CHAT_MDL,
"embd_id": embd_id if embd_id is not None else settings.EMBEDDING_MDL,
"asr_id": asr_id if asr_id is not None else settings.ASR_MDL,
"parser_ids": parser_ids if parser_ids is not None else settings.PARSERS,
"img2txt_id": img2txt_id if img2txt_id is not None else settings.IMAGE2TEXT_MDL,
"rerank_id": rerank_id if rerank_id is not None else settings.RERANK_MDL,
"credit": credit if credit is not None else 512,
"status": StatusEnum.VALID.value,
}
# Create user-tenant relationship
usr_tenant: Dict[str, Any] = {
"id": get_uuid(),
"tenant_id": tenant_id,
"user_id": owner_user_id,
"invited_by": owner_user_id,
"role": UserTenantRole.OWNER,
"status": StatusEnum.VALID.value,
}
# Create root file folder
file_id: str = get_uuid()
file: Dict[str, Any] = {
"id": file_id,
"parent_id": file_id,
"tenant_id": tenant_id,
"created_by": owner_user_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
try:
# Get tenant LLM configurations
tenant_llm: list[Dict[str, Any]] = get_init_tenant_llm(tenant_id)
# Insert all records
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
FileService.insert(file)
# Return created team info
team_data: Dict[str, Any] = {
"id": tenant_id,
"name": team_name,
"owner_id": owner_user_id,
"llm_id": tenant["llm_id"],
"embd_id": tenant["embd_id"],
}
return get_json_result(
data=team_data,
message=f"Team '{team_name}' created successfully!",
)
except Exception as e:
logging.exception(e)
# Rollback on error
try:
TenantService.delete_by_id(tenant_id)
except Exception:
pass
try:
UserTenantService.filter_delete(
[
UserTenant.tenant_id == tenant_id,
UserTenant.user_id == owner_user_id,
]
)
except Exception:
pass
try:
TenantLLMService.delete_by_tenant_id(tenant_id)
except Exception:
pass
try:
FileService.delete_by_id(file_id)
except Exception:
pass
return get_json_result(
data=False,
message=f"Team creation failure, error: {str(e)}",
code=RetCode.EXCEPTION_ERROR,
)
@manager.route("/list", methods=["GET"]) # noqa: F821 @manager.route("/list", methods=["GET"]) # noqa: F821
@login_required @login_required
def tenant_list(): def tenant_list():

View file

@ -275,3 +275,13 @@ def delete_user(auth, payload=None, *, headers=HEADERS):
url = f"{HOST_ADDRESS}{USER_API_URL}/delete" url = f"{HOST_ADDRESS}{USER_API_URL}/delete"
res = requests.delete(url=url, headers=headers, auth=auth, json=payload) res = requests.delete(url=url, headers=headers, auth=auth, json=payload)
return res.json() return res.json()
# TEAM MANAGEMENT
TEAM_API_URL = f"/{VERSION}/tenant"
def create_team(auth, payload=None, *, headers=HEADERS):
url = f"{HOST_ADDRESS}{TEAM_API_URL}/create"
res = requests.post(url=url, headers=headers, auth=auth, json=payload)
return res.json()

View file

@ -0,0 +1,254 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import uuid
from typing import Any
import pytest
from common import create_team
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
# ---------------------------------------------------------------------------
# Test Classes
# ---------------------------------------------------------------------------
@pytest.mark.p1
class TestAuthorization:
"""Tests for authentication behavior during team creation."""
@pytest.mark.parametrize(
("invalid_auth", "expected_code", "expected_message"),
[
# Endpoint now requires @login_required (JWT token auth)
(None, 401, "Unauthorized"),
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "Unauthorized"),
],
)
def test_invalid_auth(
self,
invalid_auth: RAGFlowWebApiAuth | None,
expected_code: int,
expected_message: str,
WebApiAuth: RAGFlowWebApiAuth,
) -> None:
"""Test team creation with invalid or missing authentication."""
# Try to create team with invalid auth
team_payload: dict[str, str] = {
"name": "Test Team Auth",
}
res: dict[str, Any] = create_team(invalid_auth, team_payload)
assert res["code"] == expected_code, res
if expected_message:
assert expected_message in res["message"]
@pytest.mark.p1
class TestTeamCreate:
"""Comprehensive tests for team creation API."""
@pytest.mark.p1
def test_create_team_with_name_and_user_id(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team with name and user_id."""
# Create team (user_id is optional, defaults to current authenticated user)
team_name: str = f"Test Team {uuid.uuid4().hex[:8]}"
team_payload: dict[str, str] = {
"name": team_name,
}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
assert res["code"] == 0, res
assert "data" in res
assert res["data"]["name"] == team_name
assert "owner_id" in res["data"]
assert "id" in res["data"]
assert "deleted successfully" not in res["message"].lower()
assert "created successfully" in res["message"].lower()
@pytest.mark.p1
def test_create_team_missing_name(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team without name."""
# Try to create team without name
team_payload: dict[str, str] = {}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
assert res["code"] == 101
assert "name" in res["message"].lower() or "required" in res[
"message"
].lower()
@pytest.mark.p1
def test_create_team_empty_name(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team with empty name."""
# Try to create team with empty name
team_payload: dict[str, str] = {"name": ""}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
assert res["code"] == 101
assert "name" in res["message"].lower() or "required" in res[
"message"
].lower()
@pytest.mark.p1
def test_create_team_name_too_long(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team with name exceeding 100 characters."""
# Try to create team with name too long
long_name: str = "A" * 101
team_payload: dict[str, str] = {"name": long_name}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
assert res["code"] == 101
assert "100" in res["message"] or "length" in res["message"].lower()
@pytest.mark.p1
def test_create_team_invalid_user_id(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team with non-existent user_id."""
team_payload: dict[str, str] = {
"name": "Test Team Invalid User",
"user_id": "non_existent_user_id_12345",
}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
assert res["code"] == 102
assert "not found" in res["message"].lower()
@pytest.mark.p1
def test_create_team_missing_user_id(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team without user_id (should use current authenticated user)."""
team_payload: dict[str, str] = {"name": "Test Team No User"}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
# Should succeed since user_id defaults to current authenticated user
assert res["code"] == 0
assert "data" in res
assert "owner_id" in res["data"]
assert "created successfully" in res["message"].lower()
@pytest.mark.p1
def test_create_team_response_structure(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test that team creation returns the expected response structure."""
# Create team
team_name: str = f"Test Team Structure {uuid.uuid4().hex[:8]}"
team_payload: dict[str, str] = {
"name": team_name,
}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
assert res["code"] == 0
assert "data" in res
assert isinstance(res["data"], dict)
assert "id" in res["data"]
assert "name" in res["data"]
assert "owner_id" in res["data"]
assert res["data"]["name"] == team_name
assert "message" in res
assert "created successfully" in res["message"].lower()
@pytest.mark.p1
def test_create_multiple_teams_same_user(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating multiple teams for the same user."""
# Create first team
team_name_1: str = f"Team 1 {uuid.uuid4().hex[:8]}"
team_payload_1: dict[str, str] = {
"name": team_name_1,
}
res1: dict[str, Any] = create_team(WebApiAuth, team_payload_1)
assert res1["code"] == 0, res1
team_id_1: str = res1["data"]["id"]
# Create second team
team_name_2: str = f"Team 2 {uuid.uuid4().hex[:8]}"
team_payload_2: dict[str, str] = {
"name": team_name_2,
}
res2: dict[str, Any] = create_team(WebApiAuth, team_payload_2)
assert res2["code"] == 0, res2
team_id_2: str = res2["data"]["id"]
# Verify teams are different
assert team_id_1 != team_id_2
assert res1["data"]["name"] == team_name_1
assert res2["data"]["name"] == team_name_2
@pytest.mark.p2
def test_create_team_with_whitespace_name(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team with whitespace-only name."""
# Try to create team with whitespace-only name
team_payload: dict[str, str] = {
"name": " ",
}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
# Should fail validation
assert res["code"] == 101
assert "name" in res["message"].lower() or "required" in res[
"message"
].lower()
@pytest.mark.p2
def test_create_team_special_characters_in_name(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team with special characters in name."""
# Create team with special characters
team_name: str = f"Team-{uuid.uuid4().hex[:8]}_Test!"
team_payload: dict[str, str] = {
"name": team_name,
}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
# Should succeed if special chars are allowed
assert res["code"] in (0, 101)
@pytest.mark.p2
def test_create_team_empty_payload(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team with empty payload."""
team_payload: dict[str, Any] = {}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
assert res["code"] == 101
assert "required" in res["message"].lower() or "name" in res[
"message"
].lower()
@pytest.mark.p3
def test_create_team_unicode_name(
self, WebApiAuth: RAGFlowWebApiAuth
) -> None:
"""Test creating a team with unicode characters in name."""
# Create team with unicode name
team_name: str = f"团队{uuid.uuid4().hex[:8]}"
team_payload: dict[str, str] = {
"name": team_name,
}
res: dict[str, Any] = create_team(WebApiAuth, team_payload)
# Should succeed if unicode is supported
assert res["code"] in (0, 101)