diff --git a/cognee/get_token.py b/cognee/get_token.py index f4426b180..62879477a 100644 --- a/cognee/get_token.py +++ b/cognee/get_token.py @@ -5,10 +5,10 @@ import datetime SECRET_KEY = os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret") -def create_jwt(user_id: str, tenant: str, roles: list[str]): +def create_jwt(user_id: str, tenant_id: str, roles: list[str]): payload = { "user_id": user_id, - "tenant_id": tenant, + "tenant_id": tenant_id, "roles": roles, "exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1), # 1 hour expiry } @@ -17,5 +17,7 @@ def create_jwt(user_id: str, tenant: str, roles: list[str]): if __name__ == "__main__": # Example token generation - token = create_jwt("6763554c-91bd-432c-aba8-d42cd72ed659", "tenant_456", ["admin"]) + token = create_jwt( + "6763554c-91bd-432c-aba8-d42cd72ed659", "4523544d-82bd-432c-aca7-d42cd72ed651", ["admin"] + ) print(token) diff --git a/cognee/modules/data/methods/create_dataset.py b/cognee/modules/data/methods/create_dataset.py index be4ea8792..906c03c24 100644 --- a/cognee/modules/data/methods/create_dataset.py +++ b/cognee/modules/data/methods/create_dataset.py @@ -16,7 +16,10 @@ async def create_dataset(dataset_name: str, owner_id: UUID, session: AsyncSessio ).first() if dataset is None: - dataset = Dataset(id=uuid5(NAMESPACE_OID, dataset_name), name=dataset_name, data=[]) + # Dataset id should be generated based on dataset_name and owner_id so multiple users can use the same dataset_name + dataset = Dataset( + id=uuid5(NAMESPACE_OID, f"{dataset_name}{str(owner_id)}"), name=dataset_name, data=[] + ) dataset.owner_id = owner_id session.add(dataset) diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py index 178fcc788..ae7825202 100644 --- a/cognee/modules/users/methods/get_authenticated_user.py +++ b/cognee/modules/users/methods/get_authenticated_user.py @@ -5,6 +5,8 @@ from fastapi import HTTPException, Header import os import jwt +from uuid import UUID + fastapi_users = get_fastapi_users() @@ -19,10 +21,18 @@ async def get_authenticated_user(authorization: str = Header(...)) -> SimpleName token, os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret"), algorithms=["HS256"] ) - # SimpleNamespace lets us access dictionary elements like attributes - auth_data = SimpleNamespace( - id=payload["user_id"], tenant_id=payload["tenant_id"], roles=payload["roles"] - ) + if payload["tenant_id"]: + # SimpleNamespace lets us access dictionary elements like attributes + auth_data = SimpleNamespace( + id=UUID(payload["user_id"]), + tenant_id=UUID(payload["tenant_id"]), + roles=payload["roles"], + ) + else: + auth_data = SimpleNamespace( + id=UUID(payload["user_id"]), tenant_id=None, roles=payload["roles"] + ) + return auth_data except jwt.ExpiredSignatureError: diff --git a/cognee/modules/users/methods/get_default_user.py b/cognee/modules/users/methods/get_default_user.py index 36daa0d93..bfaa82acc 100644 --- a/cognee/modules/users/methods/get_default_user.py +++ b/cognee/modules/users/methods/get_default_user.py @@ -1,8 +1,7 @@ from types import SimpleNamespace - from sqlalchemy.orm import selectinload from sqlalchemy.future import select -from cognee.modules.users.models import User, Tenant +from cognee.modules.users.models import User from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.methods.create_default_user import create_default_user diff --git a/cognee/modules/users/permissions/methods/check_permission_on_documents.py b/cognee/modules/users/permissions/methods/check_permission_on_documents.py index 262af17bc..d1b6f866b 100644 --- a/cognee/modules/users/permissions/methods/check_permission_on_documents.py +++ b/cognee/modules/users/permissions/methods/check_permission_on_documents.py @@ -13,7 +13,9 @@ logger = logging.getLogger(__name__) async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]): - user_roles_ids = [role.id for role in user.roles] + # TODO: Enable user role permissions again. Temporarily disabled during rework. + # user_roles_ids = [role.id for role in user.roles] + user_roles_ids = [] db_engine = get_relational_engine()