diff --git a/alembic/versions/ab7e313804ae_permission_system_rework.py b/alembic/versions/ab7e313804ae_permission_system_rework.py index bd69b9b41..d83f946a6 100644 --- a/alembic/versions/ab7e313804ae_permission_system_rework.py +++ b/alembic/versions/ab7e313804ae_permission_system_rework.py @@ -144,44 +144,58 @@ def _create_data_permission(conn, user_id, data_id, permission_name): ) +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + def upgrade() -> None: conn = op.get_bind() + insp = sa.inspect(conn) - # Recreate ACLs table with default permissions set to datasets instead of documents - op.drop_table("acls") + dataset_id_column = _get_column(insp, "acls", "dataset_id") + if not dataset_id_column: + # Recreate ACLs table with default permissions set to datasets instead of documents + op.drop_table("acls") - acls_table = op.create_table( - "acls", - sa.Column("id", UUID, primary_key=True, default=uuid4), - sa.Column( - "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ), - sa.Column( - "updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc) - ), - sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")), - sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")), - sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")), - ) + acls_table = op.create_table( + "acls", + sa.Column("id", UUID, primary_key=True, default=uuid4), + sa.Column( + "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + onupdate=lambda: datetime.now(timezone.utc), + ), + sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")), + sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")), + sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")), + ) - # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table - # definition or load what is in the database - dataset_table = _define_dataset_table() - datasets = conn.execute(sa.select(dataset_table)).fetchall() + # Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table + # definition or load what is in the database + dataset_table = _define_dataset_table() + datasets = conn.execute(sa.select(dataset_table)).fetchall() - if not datasets: - return + if not datasets: + return - acl_list = [] + acl_list = [] - for dataset in datasets: - acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read")) - acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write")) - acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share")) - acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete")) + for dataset in datasets: + acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read")) + acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write")) + acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share")) + acl_list.append( + _create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete") + ) - if acl_list: - op.bulk_insert(acls_table, acl_list) + if acl_list: + op.bulk_insert(acls_table, acl_list) def downgrade() -> None: