"""Multi Tenant Support Revision ID: c946955da633 Revises: 211ab850ef3d Create Date: 2025-11-04 18:11:09.325158 """ from typing import Sequence, Union from datetime import datetime, timezone from uuid import uuid4 from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision: str = "c946955da633" down_revision: Union[str, None] = "211ab850ef3d" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def _now(): return datetime.now(timezone.utc) def _define_user_table() -> sa.Table: table = sa.Table( "users", sa.MetaData(), sa.Column( "id", sa.UUID, sa.ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True, nullable=False, ), sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True), ) return table def _define_dataset_table() -> sa.Table: # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table # definition or load what is in the database table = sa.Table( "datasets", sa.MetaData(), sa.Column("id", sa.UUID, primary_key=True, default=uuid4), sa.Column("name", sa.Text), sa.Column( "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), ), sa.Column( "updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc), ), sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True), sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True), ) return table def _get_column(inspector, table, name, schema=None): for col in inspector.get_columns(table, schema=schema): if col["name"] == name: return col return None def upgrade() -> None: conn = op.get_bind() insp = sa.inspect(conn) dataset = _define_dataset_table() user = _define_user_table() if "user_tenants" not in insp.get_table_names(): # Define table with all necessary columns including primary key user_tenants = op.create_table( "user_tenants", sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True), sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True), sa.Column( "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) ), ) # Get all users with their tenant_id user_data = conn.execute( sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None)) ).fetchall() # Insert into user_tenants table if user_data: op.bulk_insert( user_tenants, [ {"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()} for user_id, tenant_id in user_data ], ) tenant_id_column = _get_column(insp, "datasets", "tenant_id") if not tenant_id_column: op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True)) # Build subquery, select users.tenant_id for each dataset.owner_id tenant_id_from_dataset_owner = ( sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery() ) if op.get_context().dialect.name == "sqlite": # If column doesn't exist create new original_extension column and update from values of extension column with op.batch_alter_table("datasets") as batch_op: batch_op.execute( dataset.update().values( tenant_id=tenant_id_from_dataset_owner, ) ) else: conn = op.get_bind() conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner)) op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"]) def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("user_tenants") op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets") op.drop_column("datasets", "tenant_id") # ### end Alembic commands ###