diff --git a/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py b/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py index cd19d09c8..7e13898ae 100644 --- a/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py +++ b/alembic/versions/76625596c5c3_expand_dataset_database_for_multi_user.py @@ -14,7 +14,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision: str = "76625596c5c3" -down_revision: Union[str, None] = "211ab850ef3d" +down_revision: Union[str, None] = "c946955da633" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -30,40 +30,20 @@ def upgrade() -> None: conn = op.get_bind() insp = sa.inspect(conn) - data = sa.table( - "dataset_database", - sa.Column("dataset_id", sa.UUID, primary_key=True, index=True), # Critical for SQLite - sa.Column("owner_id", sa.UUID, index=True), - sa.Column("vector_database_name", sa.String(), unique=True, nullable=False), - sa.Column("graph_database_name", sa.String(), unique=True, nullable=False), - sa.Column("vector_database_provider", sa.String(), unique=False, nullable=False), - sa.Column("graph_database_provider", sa.String(), unique=False, nullable=False), - sa.Column("vector_database_url", sa.String(), unique=False, nullable=True), - sa.Column("graph_database_url", sa.String(), unique=False, nullable=True), - sa.Column("vector_database_key", sa.String(), unique=False, nullable=True), - sa.Column("graph_database_key", sa.String(), unique=False, nullable=True), - sa.Column("created_at", sa.DateTime(timezone=True)), - sa.Column("updated_at", sa.DateTime(timezone=True)), - ) - vector_database_provider_column = _get_column( insp, "dataset_database", "vector_database_provider" ) if not vector_database_provider_column: op.add_column( "dataset_database", - sa.Column("vector_database_provider", sa.String(), unique=False, nullable=False), + sa.Column( + "vector_database_provider", + sa.String(), + unique=False, + nullable=False, + server_default="lancedb", + ), ) - if op.get_context().dialect.name == "sqlite": - with op.batch_alter_table("dataset_database") as batch_op: - batch_op.execute( - data.update().values( - vector_database_provider="lancedb", - ) - ) - else: - conn = op.get_bind() - conn.execute(data.update().values(vector_database_provider="lancedb")) graph_database_provider_column = _get_column( insp, "dataset_database", "graph_database_provider" @@ -71,18 +51,14 @@ def upgrade() -> None: if not graph_database_provider_column: op.add_column( "dataset_database", - sa.Column("graph_database_provider", sa.String(), unique=False, nullable=False), + sa.Column( + "graph_database_provider", + sa.String(), + unique=False, + nullable=False, + server_default="kuzu", + ), ) - if op.get_context().dialect.name == "sqlite": - with op.batch_alter_table("dataset_database") as batch_op: - batch_op.execute( - data.update().values( - graph_database_provider="kuzu", - ) - ) - else: - conn = op.get_bind() - conn.execute(data.update().values(graph_database_provider="kuzu")) vector_database_url_column = _get_column(insp, "dataset_database", "vector_database_url") if not vector_database_url_column: