feat: Initial multi-tenancy commit

This commit is contained in:
Igor Ilic 2025-10-17 18:09:01 +02:00
parent f0c332928d
commit a8ff50ceae
14 changed files with 82 additions and 72 deletions

View file

@ -22,8 +22,9 @@ async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -
if dataset is None: if dataset is None:
# Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name # Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name
dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user) dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user)
dataset = Dataset(id=dataset_id, name=dataset_name, data=[]) dataset = Dataset(
dataset.owner_id = owner_id id=dataset_id, name=dataset_name, data=[], owner_id=owner_id, tenant_id=user.tenant_id
)
session.add(dataset) session.add(dataset)

View file

@ -18,6 +18,7 @@ class Dataset(Base):
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
owner_id = Column(UUID, index=True) owner_id = Column(UUID, index=True)
tenant_id = Column(UUID, index=True, nullable=True)
acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan") acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan")

View file

@ -18,7 +18,6 @@ from typing import Optional
async def create_user( async def create_user(
email: str, email: str,
password: str, password: str,
tenant_id: Optional[str] = None,
is_superuser: bool = False, is_superuser: bool = False,
is_active: bool = True, is_active: bool = True,
is_verified: bool = False, is_verified: bool = False,
@ -30,33 +29,15 @@ async def create_user(
async with relational_engine.get_async_session() as session: async with relational_engine.get_async_session() as session:
async with get_user_db_context(session) as user_db: async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager: async with get_user_manager_context(user_db) as user_manager:
if tenant_id: user = await user_manager.create(
# Check if the tenant already exists UserCreate(
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id)) email=email,
tenant = result.scalars().first() password=password,
if not tenant: is_superuser=is_superuser,
raise TenantNotFoundError is_active=is_active,
is_verified=is_verified,
user = await user_manager.create(
UserCreate(
email=email,
password=password,
tenant_id=tenant.id,
is_superuser=is_superuser,
is_active=is_active,
is_verified=is_verified,
)
)
else:
user = await user_manager.create(
UserCreate(
email=email,
password=password,
is_superuser=is_superuser,
is_active=is_active,
is_verified=is_verified,
)
) )
)
if auto_login: if auto_login:
await session.refresh(user) await session.refresh(user)

View file

@ -27,12 +27,7 @@ async def get_default_user() -> SimpleNamespace:
if user is None: if user is None:
return await create_default_user() return await create_default_user()
# We return a SimpleNamespace to have the same user type as our SaaS return user
# SimpleNamespace is just a dictionary which can be accessed through attributes
auth_data = SimpleNamespace(
id=user.id, email=user.email, tenant_id=user.tenant_id, roles=[]
)
return auth_data
except Exception as error: except Exception as error:
if "principals" in str(error.args): if "principals" in str(error.args):
raise DatabaseNotCreatedError() from error raise DatabaseNotCreatedError() from error

View file

@ -14,7 +14,7 @@ async def get_user(user_id: UUID):
user = ( user = (
await session.execute( await session.execute(
select(User) select(User)
.options(selectinload(User.roles), selectinload(User.tenant)) .options(selectinload(User.roles), selectinload(User.tenants))
.where(User.id == user_id) .where(User.id == user_id)
) )
).scalar() ).scalar()

View file

@ -13,7 +13,7 @@ async def get_user_by_email(user_email: str):
user = ( user = (
await session.execute( await session.execute(
select(User) select(User)
.options(joinedload(User.roles), joinedload(User.tenant)) .options(joinedload(User.roles), joinedload(User.tenants))
.where(User.email == user_email) .where(User.email == user_email)
) )
).scalar() ).scalar()

View file

@ -1,7 +1,7 @@
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship, Mapped
from sqlalchemy import Column, String, ForeignKey, UUID from sqlalchemy import Column, String, ForeignKey, UUID
from .Principal import Principal from .Principal import Principal
from .User import User from .UserTenant import UserTenant
from .Role import Role from .Role import Role
@ -13,14 +13,13 @@ class Tenant(Principal):
owner_id = Column(UUID, index=True) owner_id = Column(UUID, index=True)
# One-to-Many relationship with User; specify the join via User.tenant_id users: Mapped[list["User"]] = relationship( # noqa: F821
users = relationship(
"User", "User",
back_populates="tenant", secondary=UserTenant.__tablename__,
foreign_keys=lambda: [User.tenant_id], back_populates="tenants",
) )
# One-to-Many relationship with Role (if needed; similar fix) # One-to-Many relationship with Role
roles = relationship( roles = relationship(
"Role", "Role",
back_populates="tenant", back_populates="tenant",

View file

@ -6,8 +6,10 @@ from sqlalchemy import ForeignKey, Column, UUID
from sqlalchemy.orm import relationship, Mapped from sqlalchemy.orm import relationship, Mapped
from .Principal import Principal from .Principal import Principal
from .UserTenant import UserTenant
from .UserRole import UserRole from .UserRole import UserRole
from .Role import Role from .Role import Role
from .Tenant import Tenant
class User(SQLAlchemyBaseUserTableUUID, Principal): class User(SQLAlchemyBaseUserTableUUID, Principal):
@ -15,7 +17,7 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True) id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True)
# Foreign key to Tenant (Many-to-One relationship) # Foreign key to current Tenant (Many-to-One relationship)
tenant_id = Column(UUID, ForeignKey("tenants.id")) tenant_id = Column(UUID, ForeignKey("tenants.id"))
# Many-to-Many Relationship with Roles # Many-to-Many Relationship with Roles
@ -25,11 +27,11 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
back_populates="users", back_populates="users",
) )
# Relationship to Tenant # Many-to-Many Relationship with Tenants user is a part of
tenant = relationship( tenants: Mapped[list["Tenant"]] = relationship(
"Tenant", "Tenant",
secondary=UserTenant.__tablename__,
back_populates="users", back_populates="users",
foreign_keys=[tenant_id],
) )
# ACL Relationship (One-to-Many) # ACL Relationship (One-to-Many)
@ -46,7 +48,6 @@ class UserRead(schemas.BaseUser[uuid_UUID]):
class UserCreate(schemas.BaseUserCreate): class UserCreate(schemas.BaseUserCreate):
tenant_id: Optional[uuid_UUID] = None
is_verified: bool = True is_verified: bool = True

View file

@ -0,0 +1,12 @@
from datetime import datetime, timezone
from sqlalchemy import Column, ForeignKey, DateTime, UUID
from cognee.infrastructure.databases.relational import Base
class UserTenant(Base):
__tablename__ = "user_tenants"
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
user_id = Column(UUID, ForeignKey("users.id"), primary_key=True)
tenant_id = Column(UUID, ForeignKey("tenants.id"), primary_key=True)

View file

@ -1,6 +1,7 @@
from .User import User from .User import User
from .Role import Role from .Role import Role
from .UserRole import UserRole from .UserRole import UserRole
from .UserTenant import UserTenant
from .DatasetDatabase import DatasetDatabase from .DatasetDatabase import DatasetDatabase
from .RoleDefaultPermissions import RoleDefaultPermissions from .RoleDefaultPermissions import RoleDefaultPermissions
from .UserDefaultPermissions import UserDefaultPermissions from .UserDefaultPermissions import UserDefaultPermissions

View file

@ -1,11 +1,8 @@
from types import SimpleNamespace
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from ...models.User import User from ...models.User import User
from cognee.modules.data.models.Dataset import Dataset from cognee.modules.data.models.Dataset import Dataset
from cognee.modules.users.permissions.methods import get_principal_datasets from cognee.modules.users.permissions.methods import get_principal_datasets
from cognee.modules.users.permissions.methods import get_role, get_tenant
logger = get_logger() logger = get_logger()
@ -25,17 +22,15 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
# Get all datasets User has explicit access to # Get all datasets User has explicit access to
datasets.extend(await get_principal_datasets(user, permission_type)) datasets.extend(await get_principal_datasets(user, permission_type))
if user.tenant_id: # Get all tenants user is a part of
# Get all datasets all tenants have access to tenants = await user.awaitable_attrs.tenants
tenant = await get_tenant(user.tenant_id)
for tenant in tenants:
# Get all datasets all tenant members have access to
datasets.extend(await get_principal_datasets(tenant, permission_type)) datasets.extend(await get_principal_datasets(tenant, permission_type))
# Get all datasets Users roles have access to # Get all datasets accessible by roles user is a part of
if isinstance(user, SimpleNamespace): roles = await user.awaitable_attrs.roles
# If simple namespace use roles defined in user
roles = user.roles
else:
roles = await user.awaitable_attrs.roles
for role in roles: for role in roles:
datasets.extend(await get_principal_datasets(role, permission_type)) datasets.extend(await get_principal_datasets(role, permission_type))
@ -45,4 +40,5 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
# If the dataset id key already exists, leave the dictionary unchanged. # If the dataset id key already exists, leave the dictionary unchanged.
unique.setdefault(dataset.id, dataset) unique.setdefault(dataset.id, dataset)
# TODO: Add filtering out of datasets that aren't currently selected tenant of user
return list(unique.values()) return list(unique.values())

View file

@ -1,8 +1,11 @@
from typing import Optional
from uuid import UUID from uuid import UUID
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy import insert
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.models.UserTenant import UserTenant
from cognee.modules.users.methods import get_user from cognee.modules.users.methods import get_user
from cognee.modules.users.permissions.methods import get_tenant from cognee.modules.users.permissions.methods import get_tenant
from cognee.modules.users.exceptions import ( from cognee.modules.users.exceptions import (
@ -12,14 +15,19 @@ from cognee.modules.users.exceptions import (
) )
async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID): async def add_user_to_tenant(
user_id: UUID, tenant_id: UUID, owner_id: UUID, set_active_tenant: Optional[bool] = True
):
""" """
Add a user with the given id to the tenant with the given id. Add a user with the given id to the tenant with the given id.
This can only be successful if the request owner with the given id is the tenant owner. This can only be successful if the request owner with the given id is the tenant owner.
If set_active_tenant is true it will automatically set the users active tenant to provided tenant.
Args: Args:
user_id: Id of the user. user_id: Id of the user.
tenant_id: Id of the tenant. tenant_id: Id of the tenant.
owner_id: Id of the request owner. owner_id: Id of the request owner.
set_active_tenant: If set_active_tenant is true it will automatically set the users active tenant to provided tenant.
Returns: Returns:
None None
@ -41,12 +49,17 @@ async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
) )
try: try:
if user.tenant_id is None: try:
# Add association directly to the association table
create_user_tenant_statement = insert(UserTenant).values(
user_id=user_id, tenant_id=tenant_id
)
await session.execute(create_user_tenant_statement)
except IntegrityError:
raise EntityAlreadyExistsError(message="User is already part of group.")
if set_active_tenant:
user.tenant_id = tenant_id user.tenant_id = tenant_id
elif user.tenant_id == tenant_id:
return
else:
raise IntegrityError
await session.merge(user) await session.merge(user)
await session.commit() await session.commit()

View file

@ -1,6 +1,8 @@
from uuid import UUID from uuid import UUID
from sqlalchemy import insert
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from cognee.modules.users.models.UserTenant import UserTenant
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.models import Tenant from cognee.modules.users.models import Tenant
@ -22,16 +24,22 @@ async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
try: try:
user = await get_user(user_id) user = await get_user(user_id)
if user.tenant_id:
raise EntityAlreadyExistsError(
message="User already has a tenant. New tenant cannot be created."
)
tenant = Tenant(name=tenant_name, owner_id=user_id) tenant = Tenant(name=tenant_name, owner_id=user_id)
session.add(tenant) session.add(tenant)
await session.flush() await session.flush()
user.tenant_id = tenant.id user.tenant_id = tenant.id
try:
# Add association directly to the association table
create_user_tenant_statement = insert(UserTenant).values(
user_id=user_id, tenant_id=tenant.id
)
await session.execute(create_user_tenant_statement)
except IntegrityError:
raise EntityAlreadyExistsError(message="User is already part of group.")
await session.merge(user) await session.merge(user)
await session.commit() await session.commit()
return tenant.id return tenant.id

View file

@ -150,7 +150,9 @@ async def main():
# To add a user to a role he must be part of the same tenant/organization # To add a user to a role he must be part of the same tenant/organization
print("\nOperation started as user_2 to add user_3 to CogneeLab tenant/organization") print("\nOperation started as user_2 to add user_3 to CogneeLab tenant/organization")
await add_user_to_tenant(user_id=user_3.id, tenant_id=tenant_id, owner_id=user_2.id) await add_user_to_tenant(
user_id=user_3.id, tenant_id=tenant_id, owner_id=user_2.id, set_active_tenant=True
)
print( print(
"\nOperation started by user_2, as tenant owner, to add user_3 to Researcher role inside the tenant/organization" "\nOperation started by user_2, as tenant owner, to add user_3 to Researcher role inside the tenant/organization"