feat: Add ability to filter search of datasets by dataset permissions
This commit is contained in:
parent
c383253195
commit
b11236f592
9 changed files with 77 additions and 24 deletions
|
|
@ -1,4 +1,5 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from fastapi import Depends, APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
|
@ -11,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
|
|||
|
||||
class SearchPayloadDTO(InDTO):
|
||||
search_type: SearchType
|
||||
datasets: Optional[list[str]] = None
|
||||
query: str
|
||||
|
||||
|
||||
|
|
@ -39,7 +41,10 @@ def get_search_router() -> APIRouter:
|
|||
|
||||
try:
|
||||
results = await cognee_search(
|
||||
query_text=payload.query, query_type=payload.search_type, user=user
|
||||
query_text=payload.query,
|
||||
query_type=payload.search_type,
|
||||
user=user,
|
||||
datasets=payload.datasets,
|
||||
)
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -114,17 +114,20 @@ async def permissions_search(
|
|||
query = await log_query(query_text, query_type.value, user.id)
|
||||
|
||||
# Find all datasets user has read access for
|
||||
# TODO: get_all_user_permission_datasets needs to be expanded to handle roles and tenants
|
||||
user_read_access_datasets = await get_all_user_permission_datasets(user, "read")
|
||||
|
||||
# if datasets are provided to search filter out non provided datasets
|
||||
# TODO: Make sure dataset comparison is between objects of same type,
|
||||
# user_read_access_datasets will be the Dataset objects and datasets will be strings
|
||||
# if specific datasets are provided to search filter out non provided datasets
|
||||
if datasets:
|
||||
search_datasets = [dataset for dataset in user_read_access_datasets if dataset in datasets]
|
||||
search_datasets = [
|
||||
dataset for dataset in user_read_access_datasets if dataset.name in datasets
|
||||
]
|
||||
else:
|
||||
search_datasets = user_read_access_datasets
|
||||
|
||||
# TODO: If there are no datasets the user has access to do we raise an error? How do we handle informing him?
|
||||
if not search_datasets:
|
||||
pass
|
||||
|
||||
# Set context for database for each dataset user has access for
|
||||
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k):
|
||||
await set_database_global_context_variables(dataset.id, user)
|
||||
|
|
|
|||
|
|
@ -20,10 +20,18 @@ class CustomJWTStrategy(JWTStrategy):
|
|||
user = await get_user(user.id)
|
||||
|
||||
if user.tenant:
|
||||
data = {"user_id": str(user.id), "tenant_id": str(user.tenant.id), "roles": user.roles}
|
||||
data = {
|
||||
"user_id": str(user.id),
|
||||
"tenant_id": str(user.tenant.id),
|
||||
"roles": [role.name for role in user.roles],
|
||||
}
|
||||
else:
|
||||
# The default tenant is None
|
||||
data = {"user_id": str(user.id), "tenant_id": None, "roles": user.roles}
|
||||
data = {
|
||||
"user_id": str(user.id),
|
||||
"tenant_id": None,
|
||||
"roles": [role.name for role in user.roles],
|
||||
}
|
||||
return generate_jwt(data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from .check_permission_on_dataset import check_permission_on_dataset
|
|||
from .give_permission_on_dataset import give_permission_on_dataset
|
||||
from .get_document_ids_for_user import get_document_ids_for_user
|
||||
from .get_principal_datasets import get_principal_datasets
|
||||
from .get_role import get_role
|
||||
from .get_tenant import get_tenant
|
||||
from .get_all_user_permission_datasets import get_all_user_permission_datasets
|
||||
from .give_default_permission_to_tenant import give_default_permission_to_tenant
|
||||
from .give_default_permission_to_role import give_default_permission_to_role
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from ...models.User import User
|
||||
from cognee.modules.data.models.Dataset import Dataset
|
||||
from cognee.modules.users.permissions.methods import get_principal_datasets
|
||||
from cognee.modules.users.permissions.methods import get_role, get_tenant
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -15,13 +12,22 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
|
|||
datasets = list()
|
||||
# Get all datasets User has explicit access to
|
||||
datasets.extend(await get_principal_datasets(user, permission_type))
|
||||
# Get all datasets Users roles have access to
|
||||
# TODO: Expand to get all user role accessible datasets
|
||||
for role in user.roles:
|
||||
datasets.extend(await get_principal_datasets(role, permission_type))
|
||||
# Get all datasets Users tenant allows access for
|
||||
# TODO: Expand to get all user tenant accessible datasets
|
||||
|
||||
if user.tenant_id:
|
||||
datasets.extend(await get_principal_datasets(user.tenant, permission_type))
|
||||
# TODO: Make sure result does not contain duplicate datasets
|
||||
return datasets
|
||||
# Get all datasets all tenants have access to
|
||||
tenant = await get_tenant(user.tenant_id)
|
||||
datasets.extend(await get_principal_datasets(tenant, permission_type))
|
||||
# Get all datasets Users roles have access to
|
||||
for role_name in user.roles:
|
||||
# TODO: user.roles in pydantic is mapped to Role objects, but in our backend it's used by role name only
|
||||
# Make user.roles uniform in usage across cognee lib + backend
|
||||
role = await get_role(user.tenant_id, role_name)
|
||||
datasets.extend(await get_principal_datasets(role, permission_type))
|
||||
|
||||
# Deduplicate datasets with same ID
|
||||
unique = {}
|
||||
for dataset in datasets:
|
||||
# If the dataset id key already exists, leave the dictionary unchanged.
|
||||
unique.setdefault(dataset.id, dataset)
|
||||
|
||||
return list(unique.values())
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
|
|
@ -8,8 +7,6 @@ from ...models.Principal import Principal
|
|||
from cognee.modules.data.models.Dataset import Dataset
|
||||
from ...models.ACL import ACL
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def get_principal_datasets(principal: Principal, permission_type: str) -> list[Dataset]:
|
||||
db_engine = get_relational_engine()
|
||||
|
|
|
|||
17
cognee/modules/users/permissions/methods/get_role.py
Normal file
17
cognee/modules/users/permissions/methods/get_role.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from sqlalchemy import select
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
from ...models.Role import Role
|
||||
|
||||
|
||||
async def get_role(tenant_id: UUID, role_name: str):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(Role).where(Role.name == role_name).where(Role.tenant_id == tenant_id)
|
||||
)
|
||||
role = result.unique().scalar_one()
|
||||
return role
|
||||
14
cognee/modules/users/permissions/methods/get_tenant.py
Normal file
14
cognee/modules/users/permissions/methods/get_tenant.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from sqlalchemy import select
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ...models.Tenant import Tenant
|
||||
|
||||
|
||||
async def get_tenant(tenant_id: UUID):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = result.unique().scalar_one()
|
||||
return tenant
|
||||
|
|
@ -20,6 +20,7 @@ from cognee.modules.users.models import (
|
|||
async def add_user_to_role(user_id: UUID, role_id: UUID):
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
# TODO: Add check to verify role tenant and user tenant are the same before adding user to role
|
||||
user = (await session.execute(select(User).where(User.id == user_id))).scalars().first()
|
||||
role = (await session.execute(select(Role).where(Role.id == role_id))).scalars().first()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue