From b11236f592ce8acb2388e2d8da2493a7cec453e0 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 21 May 2025 02:02:34 +0200 Subject: [PATCH] feat: Add ability to filter search of datasets by dataset permissions --- .../v1/search/routers/get_search_router.py | 7 +++- cognee/modules/search/methods/search.py | 13 +++++--- .../users/authentication/get_auth_backend.py | 12 +++++-- .../users/permissions/methods/__init__.py | 2 ++ .../get_all_user_permission_datasets.py | 32 +++++++++++-------- .../methods/get_principal_datasets.py | 3 -- .../users/permissions/methods/get_role.py | 17 ++++++++++ .../users/permissions/methods/get_tenant.py | 14 ++++++++ .../users/roles/methods/add_user_to_role.py | 1 + 9 files changed, 77 insertions(+), 24 deletions(-) create mode 100644 cognee/modules/users/permissions/methods/get_role.py create mode 100644 cognee/modules/users/permissions/methods/get_tenant.py diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index cb3ef38a8..a5bb1f3da 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -1,4 +1,5 @@ from uuid import UUID +from typing import Optional from datetime import datetime from fastapi import Depends, APIRouter from fastapi.responses import JSONResponse @@ -11,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user class SearchPayloadDTO(InDTO): search_type: SearchType + datasets: Optional[list[str]] = None query: str @@ -39,7 +41,10 @@ def get_search_router() -> APIRouter: try: results = await cognee_search( - query_text=payload.query, query_type=payload.search_type, user=user + query_text=payload.query, + query_type=payload.search_type, + user=user, + datasets=payload.datasets, ) return results diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index a614bd58b..b9633dac4 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -114,17 +114,20 @@ async def permissions_search( query = await log_query(query_text, query_type.value, user.id) # Find all datasets user has read access for - # TODO: get_all_user_permission_datasets needs to be expanded to handle roles and tenants user_read_access_datasets = await get_all_user_permission_datasets(user, "read") - # if datasets are provided to search filter out non provided datasets - # TODO: Make sure dataset comparison is between objects of same type, - # user_read_access_datasets will be the Dataset objects and datasets will be strings + # if specific datasets are provided to search filter out non provided datasets if datasets: - search_datasets = [dataset for dataset in user_read_access_datasets if dataset in datasets] + search_datasets = [ + dataset for dataset in user_read_access_datasets if dataset.name in datasets + ] else: search_datasets = user_read_access_datasets + # TODO: If there are no datasets the user has access to do we raise an error? How do we handle informing him? + if not search_datasets: + pass + # Set context for database for each dataset user has access for async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k): await set_database_global_context_variables(dataset.id, user) diff --git a/cognee/modules/users/authentication/get_auth_backend.py b/cognee/modules/users/authentication/get_auth_backend.py index 9b02b84b1..fa6932f7f 100644 --- a/cognee/modules/users/authentication/get_auth_backend.py +++ b/cognee/modules/users/authentication/get_auth_backend.py @@ -20,10 +20,18 @@ class CustomJWTStrategy(JWTStrategy): user = await get_user(user.id) if user.tenant: - data = {"user_id": str(user.id), "tenant_id": str(user.tenant.id), "roles": user.roles} + data = { + "user_id": str(user.id), + "tenant_id": str(user.tenant.id), + "roles": [role.name for role in user.roles], + } else: # The default tenant is None - data = {"user_id": str(user.id), "tenant_id": None, "roles": user.roles} + data = { + "user_id": str(user.id), + "tenant_id": None, + "roles": [role.name for role in user.roles], + } return generate_jwt(data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm) diff --git a/cognee/modules/users/permissions/methods/__init__.py b/cognee/modules/users/permissions/methods/__init__.py index e0c0cf983..86faba14f 100644 --- a/cognee/modules/users/permissions/methods/__init__.py +++ b/cognee/modules/users/permissions/methods/__init__.py @@ -2,6 +2,8 @@ from .check_permission_on_dataset import check_permission_on_dataset from .give_permission_on_dataset import give_permission_on_dataset from .get_document_ids_for_user import get_document_ids_for_user from .get_principal_datasets import get_principal_datasets +from .get_role import get_role +from .get_tenant import get_tenant from .get_all_user_permission_datasets import get_all_user_permission_datasets from .give_default_permission_to_tenant import give_default_permission_to_tenant from .give_default_permission_to_role import give_default_permission_to_role diff --git a/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py b/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py index 5d67387f4..9f0ac6290 100644 --- a/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +++ b/cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py @@ -1,12 +1,9 @@ from cognee.shared.logging_utils import get_logger -from sqlalchemy import select -from sqlalchemy.orm import joinedload - -from cognee.infrastructure.databases.relational import get_relational_engine from ...models.User import User from cognee.modules.data.models.Dataset import Dataset from cognee.modules.users.permissions.methods import get_principal_datasets +from cognee.modules.users.permissions.methods import get_role, get_tenant logger = get_logger() @@ -15,13 +12,22 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) -> datasets = list() # Get all datasets User has explicit access to datasets.extend(await get_principal_datasets(user, permission_type)) - # Get all datasets Users roles have access to - # TODO: Expand to get all user role accessible datasets - for role in user.roles: - datasets.extend(await get_principal_datasets(role, permission_type)) - # Get all datasets Users tenant allows access for - # TODO: Expand to get all user tenant accessible datasets + if user.tenant_id: - datasets.extend(await get_principal_datasets(user.tenant, permission_type)) - # TODO: Make sure result does not contain duplicate datasets - return datasets + # Get all datasets all tenants have access to + tenant = await get_tenant(user.tenant_id) + datasets.extend(await get_principal_datasets(tenant, permission_type)) + # Get all datasets Users roles have access to + for role_name in user.roles: + # TODO: user.roles in pydantic is mapped to Role objects, but in our backend it's used by role name only + # Make user.roles uniform in usage across cognee lib + backend + role = await get_role(user.tenant_id, role_name) + datasets.extend(await get_principal_datasets(role, permission_type)) + + # Deduplicate datasets with same ID + unique = {} + for dataset in datasets: + # If the dataset id key already exists, leave the dictionary unchanged. + unique.setdefault(dataset.id, dataset) + + return list(unique.values()) diff --git a/cognee/modules/users/permissions/methods/get_principal_datasets.py b/cognee/modules/users/permissions/methods/get_principal_datasets.py index 6fae77a92..b2385182f 100644 --- a/cognee/modules/users/permissions/methods/get_principal_datasets.py +++ b/cognee/modules/users/permissions/methods/get_principal_datasets.py @@ -1,4 +1,3 @@ -from cognee.shared.logging_utils import get_logger from sqlalchemy import select from sqlalchemy.orm import joinedload @@ -8,8 +7,6 @@ from ...models.Principal import Principal from cognee.modules.data.models.Dataset import Dataset from ...models.ACL import ACL -logger = get_logger() - async def get_principal_datasets(principal: Principal, permission_type: str) -> list[Dataset]: db_engine = get_relational_engine() diff --git a/cognee/modules/users/permissions/methods/get_role.py b/cognee/modules/users/permissions/methods/get_role.py new file mode 100644 index 000000000..83952077a --- /dev/null +++ b/cognee/modules/users/permissions/methods/get_role.py @@ -0,0 +1,17 @@ +from sqlalchemy import select +from uuid import UUID + +from cognee.infrastructure.databases.relational import get_relational_engine + +from ...models.Role import Role + + +async def get_role(tenant_id: UUID, role_name: str): + db_engine = get_relational_engine() + + async with db_engine.get_async_session() as session: + result = await session.execute( + select(Role).where(Role.name == role_name).where(Role.tenant_id == tenant_id) + ) + role = result.unique().scalar_one() + return role diff --git a/cognee/modules/users/permissions/methods/get_tenant.py b/cognee/modules/users/permissions/methods/get_tenant.py new file mode 100644 index 000000000..b6813a344 --- /dev/null +++ b/cognee/modules/users/permissions/methods/get_tenant.py @@ -0,0 +1,14 @@ +from sqlalchemy import select +from uuid import UUID + +from cognee.infrastructure.databases.relational import get_relational_engine +from ...models.Tenant import Tenant + + +async def get_tenant(tenant_id: UUID): + db_engine = get_relational_engine() + + async with db_engine.get_async_session() as session: + result = await session.execute(select(Tenant).where(Tenant.id == tenant_id)) + tenant = result.unique().scalar_one() + return tenant diff --git a/cognee/modules/users/roles/methods/add_user_to_role.py b/cognee/modules/users/roles/methods/add_user_to_role.py index dac72174b..d3f74bf86 100644 --- a/cognee/modules/users/roles/methods/add_user_to_role.py +++ b/cognee/modules/users/roles/methods/add_user_to_role.py @@ -20,6 +20,7 @@ from cognee.modules.users.models import ( async def add_user_to_role(user_id: UUID, role_id: UUID): db_engine = get_relational_engine() async with db_engine.get_async_session() as session: + # TODO: Add check to verify role tenant and user tenant are the same before adding user to role user = (await session.execute(select(User).where(User.id == user_id))).scalars().first() role = (await session.execute(select(Role).where(Role.id == role_id))).scalars().first()