feat: Add ability to filter search of datasets by dataset permissions

This commit is contained in:
Igor Ilic 2025-05-21 02:02:34 +02:00
parent c383253195
commit b11236f592
9 changed files with 77 additions and 24 deletions

View file

@ -1,4 +1,5 @@
from uuid import UUID from uuid import UUID
from typing import Optional
from datetime import datetime from datetime import datetime
from fastapi import Depends, APIRouter from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -11,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
class SearchPayloadDTO(InDTO): class SearchPayloadDTO(InDTO):
search_type: SearchType search_type: SearchType
datasets: Optional[list[str]] = None
query: str query: str
@ -39,7 +41,10 @@ def get_search_router() -> APIRouter:
try: try:
results = await cognee_search( results = await cognee_search(
query_text=payload.query, query_type=payload.search_type, user=user query_text=payload.query,
query_type=payload.search_type,
user=user,
datasets=payload.datasets,
) )
return results return results

View file

@ -114,17 +114,20 @@ async def permissions_search(
query = await log_query(query_text, query_type.value, user.id) query = await log_query(query_text, query_type.value, user.id)
# Find all datasets user has read access for # Find all datasets user has read access for
# TODO: get_all_user_permission_datasets needs to be expanded to handle roles and tenants
user_read_access_datasets = await get_all_user_permission_datasets(user, "read") user_read_access_datasets = await get_all_user_permission_datasets(user, "read")
# if datasets are provided to search filter out non provided datasets # if specific datasets are provided to search filter out non provided datasets
# TODO: Make sure dataset comparison is between objects of same type,
# user_read_access_datasets will be the Dataset objects and datasets will be strings
if datasets: if datasets:
search_datasets = [dataset for dataset in user_read_access_datasets if dataset in datasets] search_datasets = [
dataset for dataset in user_read_access_datasets if dataset.name in datasets
]
else: else:
search_datasets = user_read_access_datasets search_datasets = user_read_access_datasets
# TODO: If there are no datasets the user has access to do we raise an error? How do we handle informing him?
if not search_datasets:
pass
# Set context for database for each dataset user has access for # Set context for database for each dataset user has access for
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k): async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k):
await set_database_global_context_variables(dataset.id, user) await set_database_global_context_variables(dataset.id, user)

View file

@ -20,10 +20,18 @@ class CustomJWTStrategy(JWTStrategy):
user = await get_user(user.id) user = await get_user(user.id)
if user.tenant: if user.tenant:
data = {"user_id": str(user.id), "tenant_id": str(user.tenant.id), "roles": user.roles} data = {
"user_id": str(user.id),
"tenant_id": str(user.tenant.id),
"roles": [role.name for role in user.roles],
}
else: else:
# The default tenant is None # The default tenant is None
data = {"user_id": str(user.id), "tenant_id": None, "roles": user.roles} data = {
"user_id": str(user.id),
"tenant_id": None,
"roles": [role.name for role in user.roles],
}
return generate_jwt(data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm) return generate_jwt(data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm)

View file

@ -2,6 +2,8 @@ from .check_permission_on_dataset import check_permission_on_dataset
from .give_permission_on_dataset import give_permission_on_dataset from .give_permission_on_dataset import give_permission_on_dataset
from .get_document_ids_for_user import get_document_ids_for_user from .get_document_ids_for_user import get_document_ids_for_user
from .get_principal_datasets import get_principal_datasets from .get_principal_datasets import get_principal_datasets
from .get_role import get_role
from .get_tenant import get_tenant
from .get_all_user_permission_datasets import get_all_user_permission_datasets from .get_all_user_permission_datasets import get_all_user_permission_datasets
from .give_default_permission_to_tenant import give_default_permission_to_tenant from .give_default_permission_to_tenant import give_default_permission_to_tenant
from .give_default_permission_to_role import give_default_permission_to_role from .give_default_permission_to_role import give_default_permission_to_role

View file

@ -1,12 +1,9 @@
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from cognee.infrastructure.databases.relational import get_relational_engine
from ...models.User import User from ...models.User import User
from cognee.modules.data.models.Dataset import Dataset from cognee.modules.data.models.Dataset import Dataset
from cognee.modules.users.permissions.methods import get_principal_datasets from cognee.modules.users.permissions.methods import get_principal_datasets
from cognee.modules.users.permissions.methods import get_role, get_tenant
logger = get_logger() logger = get_logger()
@ -15,13 +12,22 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
datasets = list() datasets = list()
# Get all datasets User has explicit access to # Get all datasets User has explicit access to
datasets.extend(await get_principal_datasets(user, permission_type)) datasets.extend(await get_principal_datasets(user, permission_type))
# Get all datasets Users roles have access to
# TODO: Expand to get all user role accessible datasets
for role in user.roles:
datasets.extend(await get_principal_datasets(role, permission_type))
# Get all datasets Users tenant allows access for
# TODO: Expand to get all user tenant accessible datasets
if user.tenant_id: if user.tenant_id:
datasets.extend(await get_principal_datasets(user.tenant, permission_type)) # Get all datasets all tenants have access to
# TODO: Make sure result does not contain duplicate datasets tenant = await get_tenant(user.tenant_id)
return datasets datasets.extend(await get_principal_datasets(tenant, permission_type))
# Get all datasets Users roles have access to
for role_name in user.roles:
# TODO: user.roles in pydantic is mapped to Role objects, but in our backend it's used by role name only
# Make user.roles uniform in usage across cognee lib + backend
role = await get_role(user.tenant_id, role_name)
datasets.extend(await get_principal_datasets(role, permission_type))
# Deduplicate datasets with same ID
unique = {}
for dataset in datasets:
# If the dataset id key already exists, leave the dictionary unchanged.
unique.setdefault(dataset.id, dataset)
return list(unique.values())

View file

@ -1,4 +1,3 @@
from cognee.shared.logging_utils import get_logger
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
@ -8,8 +7,6 @@ from ...models.Principal import Principal
from cognee.modules.data.models.Dataset import Dataset from cognee.modules.data.models.Dataset import Dataset
from ...models.ACL import ACL from ...models.ACL import ACL
logger = get_logger()
async def get_principal_datasets(principal: Principal, permission_type: str) -> list[Dataset]: async def get_principal_datasets(principal: Principal, permission_type: str) -> list[Dataset]:
db_engine = get_relational_engine() db_engine = get_relational_engine()

View file

@ -0,0 +1,17 @@
from sqlalchemy import select
from uuid import UUID
from cognee.infrastructure.databases.relational import get_relational_engine
from ...models.Role import Role
async def get_role(tenant_id: UUID, role_name: str):
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
result = await session.execute(
select(Role).where(Role.name == role_name).where(Role.tenant_id == tenant_id)
)
role = result.unique().scalar_one()
return role

View file

@ -0,0 +1,14 @@
from sqlalchemy import select
from uuid import UUID
from cognee.infrastructure.databases.relational import get_relational_engine
from ...models.Tenant import Tenant
async def get_tenant(tenant_id: UUID):
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.unique().scalar_one()
return tenant

View file

@ -20,6 +20,7 @@ from cognee.modules.users.models import (
async def add_user_to_role(user_id: UUID, role_id: UUID): async def add_user_to_role(user_id: UUID, role_id: UUID):
db_engine = get_relational_engine() db_engine = get_relational_engine()
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
# TODO: Add check to verify role tenant and user tenant are the same before adding user to role
user = (await session.execute(select(User).where(User.id == user_id))).scalars().first() user = (await session.execute(select(User).where(User.id == user_id))).scalars().first()
role = (await session.execute(select(Role).where(Role.id == role_id))).scalars().first() role = (await session.execute(select(Role).where(Role.id == role_id))).scalars().first()