141 lines
4.5 KiB
Python
141 lines
4.5 KiB
Python
"""Multi Tenant Support
|
|
|
|
Revision ID: c946955da633
|
|
Revises: 211ab850ef3d
|
|
Create Date: 2025-11-04 18:11:09.325158
|
|
|
|
"""
|
|
|
|
from typing import Sequence, Union
|
|
from datetime import datetime, timezone
|
|
from uuid import uuid4
|
|
|
|
from alembic import op
|
|
import sqlalchemy as sa
|
|
|
|
# revision identifiers, used by Alembic.
|
|
revision: str = "c946955da633"
|
|
down_revision: Union[str, None] = "211ab850ef3d"
|
|
branch_labels: Union[str, Sequence[str], None] = None
|
|
depends_on: Union[str, Sequence[str], None] = None
|
|
|
|
|
|
def _now():
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def _define_user_table() -> sa.Table:
|
|
table = sa.Table(
|
|
"users",
|
|
sa.MetaData(),
|
|
sa.Column(
|
|
"id",
|
|
sa.UUID,
|
|
sa.ForeignKey("principals.id", ondelete="CASCADE"),
|
|
primary_key=True,
|
|
nullable=False,
|
|
),
|
|
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True),
|
|
)
|
|
return table
|
|
|
|
|
|
def _define_dataset_table() -> sa.Table:
|
|
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
|
# definition or load what is in the database
|
|
table = sa.Table(
|
|
"datasets",
|
|
sa.MetaData(),
|
|
sa.Column("id", sa.UUID, primary_key=True, default=uuid4),
|
|
sa.Column("name", sa.Text),
|
|
sa.Column(
|
|
"created_at",
|
|
sa.DateTime(timezone=True),
|
|
default=lambda: datetime.now(timezone.utc),
|
|
),
|
|
sa.Column(
|
|
"updated_at",
|
|
sa.DateTime(timezone=True),
|
|
onupdate=lambda: datetime.now(timezone.utc),
|
|
),
|
|
sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True),
|
|
sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True),
|
|
)
|
|
|
|
return table
|
|
|
|
|
|
def _get_column(inspector, table, name, schema=None):
|
|
for col in inspector.get_columns(table, schema=schema):
|
|
if col["name"] == name:
|
|
return col
|
|
return None
|
|
|
|
|
|
def upgrade() -> None:
|
|
conn = op.get_bind()
|
|
insp = sa.inspect(conn)
|
|
|
|
dataset = _define_dataset_table()
|
|
user = _define_user_table()
|
|
|
|
print(insp.get_table_names())
|
|
|
|
print(_get_column(insp, "user_tenants", "tenant_id"))
|
|
|
|
if "user_tenants" not in insp.get_table_names():
|
|
# Define table with all necessary columns including primary key
|
|
user_tenants = op.create_table(
|
|
"user_tenants",
|
|
sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True),
|
|
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True),
|
|
sa.Column(
|
|
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
|
),
|
|
)
|
|
|
|
# Get all users with their tenant_id
|
|
user_data = conn.execute(
|
|
sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None))
|
|
).fetchall()
|
|
|
|
# Insert into user_tenants table
|
|
if user_data:
|
|
op.bulk_insert(
|
|
user_tenants,
|
|
[
|
|
{"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()}
|
|
for user_id, tenant_id in user_data
|
|
],
|
|
)
|
|
|
|
tenant_id_column = _get_column(insp, "datasets", "tenant_id")
|
|
if not tenant_id_column:
|
|
op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True))
|
|
|
|
# Build subquery, select users.tenant_id for each dataset.owner_id
|
|
tenant_id_from_dataset_owner = (
|
|
sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery()
|
|
)
|
|
|
|
if op.get_context().dialect.name == "sqlite":
|
|
# If column doesn't exist create new original_extension column and update from values of extension column
|
|
with op.batch_alter_table("datasets") as batch_op:
|
|
batch_op.execute(
|
|
dataset.update().values(
|
|
tenant_id=tenant_id_from_dataset_owner,
|
|
)
|
|
)
|
|
else:
|
|
conn = op.get_bind()
|
|
conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner))
|
|
|
|
op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"])
|
|
|
|
|
|
def downgrade() -> None:
|
|
# ### commands auto generated by Alembic - please adjust! ###
|
|
op.drop_table("user_tenants")
|
|
op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets")
|
|
op.drop_column("datasets", "tenant_id")
|
|
# ### end Alembic commands ###
|