diff --git a/cognee/modules/users/methods/get_default_user.py b/cognee/modules/users/methods/get_default_user.py index fd384e50d..8ab330d4c 100644 --- a/cognee/modules/users/methods/get_default_user.py +++ b/cognee/modules/users/methods/get_default_user.py @@ -1,11 +1,12 @@ +from sqlalchemy.orm import joinedload + from cognee.modules.users.models import User from cognee.infrastructure.databases.relational import get_relational_engine from sqlalchemy.future import select async def get_default_user(session): - - stmt = select(User).where(User.email == "default_user@example.com") + stmt = select(User).options(joinedload(User.groups)).where(User.email == "default_user@example.com") result = await session.execute(stmt) user = result.scalars().first() return user \ No newline at end of file diff --git a/cognee/modules/users/models/User.py b/cognee/modules/users/models/User.py index fba04e6a0..4e792bbd5 100644 --- a/cognee/modules/users/models/User.py +++ b/cognee/modules/users/models/User.py @@ -24,7 +24,7 @@ class User(SQLAlchemyBaseUserTableUUID, Principal): from fastapi_users import schemas class UserRead(schemas.BaseUser[uuid_UUID]): - pass + groups: list[uuid_UUID] # Add groups attribute class UserCreate(schemas.BaseUserCreate): pass diff --git a/cognee/modules/users/permissions/methods/check_permissions_on_documents.py b/cognee/modules/users/permissions/methods/check_permissions_on_documents.py index 474855114..3f2724ac8 100644 --- a/cognee/modules/users/permissions/methods/check_permissions_on_documents.py +++ b/cognee/modules/users/permissions/methods/check_permissions_on_documents.py @@ -15,18 +15,18 @@ class PermissionDeniedException(Exception): async def check_permissions_on_documents(user: User, permission_type: str, document_ids: list[str], session): + + logging.info("This is the user: %s", user.__dict__) try: user_group_ids = [group.id for group in user.groups] - result = await session.execute( - select(ACL).filter( - ACL.principal_id.in_([user.id, *user_group_ids]), - ACL.permission.name == permission_type - ) + acls = await session.execute( + select(ACL) + .join(ACL.permission) + .where(ACL.principal_id.in_([user.id, *user_group_ids])) + .where(ACL.permission.has(name=permission_type)) ) - acls = result.scalars().all() - - resource_ids = [resource.resource_id for acl in acls for resource in acl.resources] + resource_ids = [resource.resource_id for acl in acls.scalars().all() for resource in acl.resources] has_permissions = all(document_id in resource_ids for document_id in document_ids) if not has_permissions: