diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 86ffaedb1..b7d206ab0 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -18,6 +18,7 @@ import logging import re import sys from functools import partial +from typing import Dict, Any, Optional import trio from quart import request, Response, make_response from agent.component import LLM @@ -44,7 +45,8 @@ from rag.nlp import search from rag.utils.redis_conn import REDIS_CONN from common import settings from api.apps import login_required, current_user - +from api.db.services.user_service import UserTenantService +from common.constants import StatusEnum @manager.route('/templates', methods=['GET']) # noqa: F821 @login_required @@ -69,11 +71,31 @@ async def rm(): @manager.route('/set', methods=['POST']) # noqa: F821 @validate_request("dsl", "title") @login_required -async def save(): - req = await request_json() +async def save() -> Any: + req: Dict[str, Any] = await request_json() if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.loads(req["dsl"]) + + # Validate shared_tenant_id if provided + shared_tenant_id: Optional[str] = req.get("shared_tenant_id") + if shared_tenant_id: + if req.get("permission") != "team": + return get_json_result( + data=False, + message="shared_tenant_id can only be set when permission is 'team'", + code=RetCode.ARGUMENT_ERROR + ) + # Verify user is a member of the shared tenant + + user_tenant = UserTenantService.filter_by_tenant_and_user_id(shared_tenant_id, current_user.id) + if not user_tenant or user_tenant.status != StatusEnum.VALID.value: + return get_json_result( + data=False, + message=f"You are not a member of the selected team", + code=RetCode.PERMISSION_ERROR + ) + cate = req.get("canvas_category", CanvasCategory.Agent) if "id" not in req: req["user_id"] = current_user.id diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 4e8015d7f..0375601ab 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -17,6 +17,7 @@ import json import logging import random import re +from typing import Dict, Any, Optional from quart import request import numpy as np @@ -42,13 +43,35 @@ from rag.utils.doc_store_conn import OrderByExpr from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD from common import settings from api.apps import login_required, current_user +from common.constants import StatusEnum @manager.route('/create', methods=['post']) # noqa: F821 @login_required @validate_request("name") -async def create(): - req = await request_json() +async def create() -> Any: + req: Dict[str, Any] = await request_json() + + # Validate shared_tenant_id if provided + shared_tenant_id: Optional[str] = req.get("shared_tenant_id") + if shared_tenant_id: + if req.get("permission") != "team": + return get_json_result( + data=False, + message="shared_tenant_id can only be set when permission is 'team'", + code=RetCode.ARGUMENT_ERROR + ) + # Verify user is a member of the shared tenant + user_tenant = UserTenantService.filter_by_tenant_and_user_id(shared_tenant_id, current_user.id) + if not user_tenant or user_tenant.status != StatusEnum.VALID.value: + return get_json_result( + data=False, + message=f"You are not a member of the selected team", + code=RetCode.PERMISSION_ERROR + ) + + e: bool + res: Any e, res = KnowledgebaseService.create_with_name( name = req.pop("name", None), tenant_id = current_user.id, diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index ff860324c..487a7e4fc 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -18,6 +18,7 @@ import logging import os import json +from typing import Dict, Any, Optional from quart import request from peewee import OperationalError from api.db.db_models import File @@ -50,11 +51,12 @@ from api.utils.validation_utils import ( from rag.nlp import search from common.constants import PAGERANK_FLD from common import settings - +from api.db.services.user_service import UserTenantService +from common.constants import StatusEnum @manager.route("/datasets", methods=["POST"]) # noqa: F821 @token_required -async def create(tenant_id): +async def create(tenant_id: str) -> Any: """ Create a new dataset. --- @@ -116,9 +118,24 @@ async def create(tenant_id): # | embedding_model| embd_id | # | chunk_method | parser_id | + req: Dict[str, Any] + err: Optional[Any] req, err = await validate_and_parse_json_request(request, CreateDatasetReq) if err is not None: return get_error_argument_result(err) + + # Validate shared_tenant_id if provided + shared_tenant_id: Optional[str] = req.get("shared_tenant_id") + if shared_tenant_id: + if req.get("permission") != "team": + return get_error_argument_result("shared_tenant_id can only be set when permission is 'team'") + # Verify user is a member of the shared tenant + + user_tenant = UserTenantService.filter_by_tenant_and_user_id(shared_tenant_id, tenant_id) + if not user_tenant or user_tenant.status != StatusEnum.VALID.value: + return get_error_permission_result(message=f"User is not a member of tenant '{shared_tenant_id}'") + + e: bool e, req = KnowledgebaseService.create_with_name( name = req.pop("name", None), tenant_id = tenant_id, @@ -130,6 +147,8 @@ async def create(tenant_id): return req # Insert embedding model(embd id) + ok: bool + t: Any ok, t = TenantService.get_by_id(tenant_id) if not ok: return get_error_permission_result(message="Tenant not found") diff --git a/api/common/check_team_permission.py b/api/common/check_team_permission.py index c8e04d34b..ff5a2e761 100644 --- a/api/common/check_team_permission.py +++ b/api/common/check_team_permission.py @@ -14,43 +14,57 @@ # limitations under the License. # +from typing import Dict, Any, Optional, List from api.db import TenantPermission from api.db.db_models import File, Knowledgebase from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.user_service import TenantService +from api.db.services.user_service import TenantService, UserTenantService +from common.constants import StatusEnum -def check_kb_team_permission(kb: dict | Knowledgebase, other: str) -> bool: +def check_kb_team_permission(kb: Dict[str, Any] | Knowledgebase, other: str) -> bool: kb = kb.to_dict() if isinstance(kb, Knowledgebase) else kb kb_tenant_id = kb["tenant_id"] + # If user owns the tenant where the KB was created, always allow if kb_tenant_id == other: return True + # If permission is not "team", deny access if kb["permission"] != TenantPermission.TEAM: return False - joined_tenants = TenantService.get_joined_tenants_by_user_id(other) + # If shared_tenant_id is specified, check if user is a member of that specific tenant + shared_tenant_id: Optional[str] = kb.get("shared_tenant_id") + if shared_tenant_id: + # Check if user is a member of the shared tenant + user_tenant = UserTenantService.filter_by_tenant_and_user_id(shared_tenant_id, other) + return user_tenant is not None and user_tenant.status == StatusEnum.VALID.value + + # Legacy behavior: if no shared_tenant_id, check if user is a member of the KB's tenant + joined_tenants: List[Dict[str, Any]] = TenantService.get_joined_tenants_by_user_id(other) return any(tenant["tenant_id"] == kb_tenant_id for tenant in joined_tenants) -def check_file_team_permission(file: dict | File, other: str) -> bool: +def check_file_team_permission(file: Dict[str, Any] | File, other: str) -> bool: file = file.to_dict() if isinstance(file, File) else file file_tenant_id = file["tenant_id"] if file_tenant_id == other: return True - file_id = file["id"] + file_id: str = file["id"] - kb_ids = [kb_info["kb_id"] for kb_info in FileService.get_kb_id_by_file_id(file_id)] + kb_ids: List[str] = [kb_info["kb_id"] for kb_info in FileService.get_kb_id_by_file_id(file_id)] for kb_id in kb_ids: + ok: bool + kb: Optional[Knowledgebase] ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: + if not ok or kb is None: continue if check_kb_team_permission(kb, other): diff --git a/api/db/db_models.py b/api/db/db_models.py index bd3feea64..17677a2c3 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -740,6 +740,7 @@ class Knowledgebase(DataBaseModel): description = TextField(null=True, help_text="KB description") embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True) permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True) + shared_tenant_id = CharField(max_length=32, null=True, help_text="Specific tenant ID to share with when permission is 'team'", index=True) created_by = CharField(max_length=32, null=False, index=True) doc_num = IntegerField(default=0, index=True) token_num = IntegerField(default=0, index=True) @@ -923,6 +924,7 @@ class UserCanvas(DataBaseModel): title = CharField(max_length=255, null=True, help_text="Canvas title") permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True) + shared_tenant_id = CharField(max_length=32, null=True, help_text="Specific tenant ID to share with when permission is 'team'", index=True) description = TextField(null=True, help_text="Canvas description") canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True) canvas_category = CharField(max_length=32, null=False, default="agent_canvas", help_text="Canvas category: agent_canvas|dataflow_canvas", index=True) @@ -1201,6 +1203,14 @@ def migrate_db(): migrate(migrator.add_column("user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True))) except Exception: pass + try: + migrate(migrator.add_column("knowledgebase", "shared_tenant_id", CharField(max_length=32, null=True, help_text="Specific tenant ID to share with when permission is 'team'", index=True))) + except Exception: + pass + try: + migrate(migrator.add_column("user_canvas", "shared_tenant_id", CharField(max_length=32, null=True, help_text="Specific tenant ID to share with when permission is 'team'", index=True))) + except Exception: + pass try: migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False))) except Exception: diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 57b4b5c2a..dd72152df 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -16,6 +16,7 @@ import json import logging import time +from typing import List, Dict, Any, Optional, Tuple from uuid import uuid4 from agent.canvas import Canvas from api.db import CanvasCategory, TenantPermission @@ -26,7 +27,11 @@ from common.misc_utils import get_uuid from api.utils.api_utils import get_data_openai import tiktoken from peewee import fn - +from api.db.services.user_service import UserTenantService +from common.constants import StatusEnum +from api.db.services.user_service import UserTenantService +from api.db import TenantPermission +from common.constants import StatusEnum class CanvasTemplateService(CommonService): model = CanvasTemplate @@ -95,7 +100,7 @@ class UserCanvasService(CommonService): @classmethod @DB.connection_context() - def get_by_canvas_id(cls, pid): + def get_by_canvas_id(cls, pid: str) -> Tuple[bool, Optional[Dict[str, Any]]]: try: fields = [ @@ -105,6 +110,7 @@ class UserCanvasService(CommonService): cls.model.dsl, cls.model.description, cls.model.permission, + cls.model.shared_tenant_id, cls.model.update_time, cls.model.user_id, cls.model.create_time, @@ -125,10 +131,17 @@ class UserCanvasService(CommonService): @classmethod @DB.connection_context() - def get_by_tenant_ids(cls, joined_tenant_ids, user_id, - page_number, items_per_page, - orderby, desc, keywords, canvas_category=None - ): + def get_by_tenant_ids( + cls, + joined_tenant_ids: List[str], + user_id: str, + page_number: Optional[int], + items_per_page: Optional[int], + orderby: str, + desc: bool, + keywords: Optional[str], + canvas_category: Optional[str] = None + ) -> Tuple[List[Dict[str, Any]], int]: fields = [ cls.model.id, cls.model.avatar, @@ -136,20 +149,41 @@ class UserCanvasService(CommonService): cls.model.dsl, cls.model.description, cls.model.permission, + cls.model.shared_tenant_id, cls.model.user_id.alias("tenant_id"), User.nickname, User.avatar.alias('tenant_avatar'), cls.model.update_time, cls.model.canvas_category, ] + # Build permission conditions: user's own canvases OR team canvases they have access to + # For team canvases: check if shared_tenant_id matches user's tenant membership, or legacy behavior (user_id in joined_tenants) + + + # Get all tenant IDs where user is a member (for checking shared_tenant_id) + user_tenant_relations: List[Any] = UserTenantService.query(user_id=user_id, status=StatusEnum.VALID.value) + user_tenant_ids: List[str] = [str(ut.tenant_id) for ut in user_tenant_relations] + + # Condition: user's own canvases OR (team permission AND (shared_tenant_id in user's tenants OR legacy: user_id in joined_tenant_ids)) + permission_condition = ( + (cls.model.user_id == user_id) | + ( + (cls.model.permission == TenantPermission.TEAM.value) & + ( + (cls.model.shared_tenant_id.in_(user_tenant_ids)) | + ((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.shared_tenant_id.is_null())) + ) + ) + ) + if keywords: agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( - (((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)), + permission_condition, (fn.LOWER(cls.model.title).contains(keywords.lower())) ) else: agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( - (((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)) + permission_condition ) if canvas_category: agents = agents.where(cls.model.canvas_category == canvas_category) @@ -165,16 +199,32 @@ class UserCanvasService(CommonService): @classmethod @DB.connection_context() - def accessible(cls, canvas_id, tenant_id): - from api.db.services.user_service import UserTenantService + def accessible(cls, canvas_id: str, tenant_id: str) -> bool: + + + e: bool + c: Optional[Dict[str, Any]] e, c = UserCanvasService.get_by_canvas_id(canvas_id) - if not e: + if not e or c is None: return False - tids = [t.tenant_id for t in UserTenantService.query(user_id=tenant_id)] - if c["user_id"] != canvas_id and c["user_id"] not in tids: + # If user owns the canvas, always allow + if c["user_id"] == tenant_id: + return True + + # If permission is not "team", deny access + if c.get("permission") != TenantPermission.TEAM.value: return False - return True + + # If shared_tenant_id is specified, check if user is a member of that specific tenant + shared_tenant_id: Optional[str] = c.get("shared_tenant_id") + if shared_tenant_id: + user_tenant = UserTenantService.filter_by_tenant_and_user_id(shared_tenant_id, tenant_id) + return user_tenant is not None and user_tenant.status == StatusEnum.VALID.value + + # Legacy behavior: check if user is a member of the canvas owner's tenant + tids: List[str] = [str(t.tenant_id) for t in UserTenantService.query(user_id=tenant_id, status=StatusEnum.VALID.value)] + return str(c["user_id"]) in tids async def completion(tenant_id, agent_id, session_id=None, **kwargs): diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 630b64feb..a6edadaf2 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -361,11 +361,21 @@ class CreateDatasetReq(Base): description: Annotated[str | None, Field(default=None, max_length=65535)] embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")] permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)] + shared_tenant_id: Annotated[str | None, Field(default=None, max_length=32, description="Specific tenant ID to share with when permission is 'team'")] chunk_method: Annotated[ Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"), ] parser_config: Annotated[ParserConfig | None, Field(default=None)] + + @field_validator("shared_tenant_id", mode="after") + @classmethod + def validate_shared_tenant_id(cls, v: str | None, info: Any) -> str | None: + """Validate that shared_tenant_id is only set when permission is 'team'.""" + permission: str = info.data.get("permission", "me") + if v is not None and permission != "team": + raise ValueError("shared_tenant_id can only be set when permission is 'team'") + return v @field_validator("avatar", mode="after") @classmethod