fix: migrate workspace data in PostgreSQL Case 1 to prevent data loss

Why this change is needed:
In multi-tenant deployments, when workspace A migrates first (creating
the new model-suffixed table), subsequent workspace B initialization
enters Case 1 (both tables exist). The original Case 1 logic only
checked if the legacy table was empty globally, without checking if
the current workspace had unmigrated data. This caused workspace B's
data to remain in the legacy table while the application queried the
new table, resulting in data loss for workspace B.

How it solves the problem:
1. Extracted migration logic into _pg_migrate_workspace_data() helper
   function to avoid code duplication
2. Modified Case 1 to check if current workspace has data in legacy
   table and migrate it if found
3. Both Case 1 and Case 4 now use the same migration helper, ensuring
   consistent behavior
4. After migration, only delete the current workspace's data from
   legacy table, preserving other workspaces' data

Impact:
- Prevents data loss in multi-tenant PostgreSQL deployments
- Maintains backward compatibility with single-tenant setups
- Reduces code duplication between Case 1 and Case 4

Testing:
All PostgreSQL migration tests pass (8/8)
This commit is contained in:
BukeLy 2025-11-26 01:16:57 +08:00
parent 3b8a1e64b7
commit a8f5c9bd33

View file

@ -2201,6 +2201,56 @@ async def _pg_create_table(
await db.execute(ddl)
async def _pg_migrate_workspace_data(
db: PostgreSQLDB,
legacy_table_name: str,
new_table_name: str,
workspace: str,
expected_count: int,
embedding_dim: int,
) -> int:
"""Migrate workspace data from legacy table to new table"""
migrated_count = 0
offset = 0
batch_size = 500
while True:
if workspace:
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 OFFSET $2 LIMIT $3"
rows = await db.query(
select_query, [workspace, offset, batch_size], multirows=True
)
else:
select_query = f"SELECT * FROM {legacy_table_name} OFFSET $1 LIMIT $2"
rows = await db.query(select_query, [offset, batch_size], multirows=True)
if not rows:
break
for row in rows:
row_dict = dict(row)
columns = list(row_dict.keys())
columns_str = ", ".join(columns)
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
insert_query = f"""
INSERT INTO {new_table_name} ({columns_str})
VALUES ({placeholders})
ON CONFLICT (workspace, id) DO NOTHING
"""
values = {col: row_dict[col] for col in columns}
await db.execute(insert_query, values)
migrated_count += len(rows)
workspace_info = f" for workspace '{workspace}'" if workspace else ""
logger.info(
f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
)
offset += batch_size
return migrated_count
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
@ -2273,14 +2323,7 @@ class PGVectorStorage(BaseVectorStorage):
)
# Case 1: Both new and legacy tables exist
# This can happen if:
# 1. Previous migration failed to delete the legacy table
# 2. User manually created both tables
# 3. No model suffix (table_name == legacy_table_name)
# Strategy: Only delete legacy if it's empty (safe cleanup) and it's not the same as new table
if new_table_exists and legacy_exists:
# CRITICAL: Check if new and legacy are the same table
# This happens when model_suffix is empty (no model_name provided)
if table_name.lower() == legacy_table_name.lower():
logger.debug(
f"PostgreSQL: Table '{table_name}' already exists (no model suffix). Skipping Case 1 cleanup."
@ -2288,13 +2331,119 @@ class PGVectorStorage(BaseVectorStorage):
return
try:
# Check if legacy table is empty
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
count_result = await db.query(count_query, [])
legacy_count = count_result.get("count", 0) if count_result else 0
workspace_info = f" for workspace '{workspace}'" if workspace else ""
if legacy_count == 0:
# Legacy table is empty, safe to delete without data loss
if workspace:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
count_result = await db.query(count_query, [workspace])
else:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
count_result = await db.query(count_query, [])
workspace_count = count_result.get("count", 0) if count_result else 0
if workspace_count > 0:
logger.info(
f"PostgreSQL: Found {workspace_count} records in legacy table{workspace_info}. Migrating..."
)
legacy_dim = None
try:
dim_query = """
SELECT
CASE
WHEN typname = 'vector' THEN
COALESCE(atttypmod, -1)
ELSE -1
END as vector_dim
FROM pg_attribute a
JOIN pg_type t ON a.atttypid = t.oid
WHERE a.attrelid = $1::regclass
AND a.attname = 'content_vector'
"""
dim_result = await db.query(dim_query, [legacy_table_name])
legacy_dim = (
dim_result.get("vector_dim", -1) if dim_result else -1
)
if legacy_dim <= 0:
sample_query = f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
sample_result = await db.query(sample_query, [])
if sample_result and sample_result.get("content_vector"):
vector_data = sample_result["content_vector"]
if isinstance(vector_data, (list, tuple)):
legacy_dim = len(vector_data)
elif isinstance(vector_data, str):
import json
vector_list = json.loads(vector_data)
legacy_dim = len(vector_list)
if (
legacy_dim > 0
and embedding_dim
and legacy_dim != embedding_dim
):
logger.warning(
f"PostgreSQL: Dimension mismatch - "
f"legacy table has {legacy_dim}d vectors, "
f"new embedding model expects {embedding_dim}d. "
f"Skipping migration{workspace_info}."
)
await db._create_vector_index(table_name, embedding_dim)
return
except Exception as e:
logger.warning(
f"PostgreSQL: Could not verify vector dimension: {e}. Proceeding with caution..."
)
migrated_count = await _pg_migrate_workspace_data(
db,
legacy_table_name,
table_name,
workspace,
workspace_count,
embedding_dim,
)
if workspace:
new_count_query = f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
new_count_result = await db.query(new_count_query, [workspace])
else:
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
new_count_result = await db.query(new_count_query, [])
new_count = (
new_count_result.get("count", 0) if new_count_result else 0
)
if new_count < workspace_count:
logger.warning(
f"PostgreSQL: Expected {workspace_count} records, found {new_count}{workspace_info}. "
f"Some records may have been skipped due to conflicts."
)
else:
logger.info(
f"PostgreSQL: Migration completed: {migrated_count} records migrated{workspace_info}"
)
if workspace:
delete_query = (
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
)
await db.execute(delete_query, {"workspace": workspace})
logger.info(
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
)
total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
total_count_result = await db.query(total_count_query, [])
total_count = (
total_count_result.get("count", 0) if total_count_result else 0
)
if total_count == 0:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty. Deleting..."
)
@ -2304,18 +2453,16 @@ class PGVectorStorage(BaseVectorStorage):
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
)
else:
# Legacy table still has data - don't risk deleting it
logger.warning(
f"PostgreSQL: Legacy table '{legacy_table_name}' still contains {legacy_count} records. "
f"Manual intervention required to verify and delete."
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved "
f"({total_count} records from other workspaces remain)"
)
except Exception as e:
logger.warning(
f"PostgreSQL: Could not check or cleanup legacy table '{legacy_table_name}': {e}. "
"You may need to delete it manually."
f"PostgreSQL: Error during Case 1 migration: {e}. Vector index will still be ensured."
)
# Ensure vector index exists even if cleanup was not performed
await db._create_vector_index(table_name, embedding_dim)
return
@ -2430,61 +2577,19 @@ class PGVectorStorage(BaseVectorStorage):
f"Proceeding with caution..."
)
# Create new table first
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
await _pg_create_table(db, table_name, base_table, embedding_dim)
# Batch migration (500 records per batch, same as Qdrant)
migrated_count = 0
offset = 0
batch_size = 500 # Mirror Qdrant batch size
migrated_count = await _pg_migrate_workspace_data(
db,
legacy_table_name,
table_name,
workspace,
legacy_count,
embedding_dim,
)
while True:
# Fetch a batch of rows (with workspace filtering)
if workspace:
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 OFFSET $2 LIMIT $3"
rows = await db.query(
select_query, [workspace, offset, batch_size], multirows=True
)
else:
select_query = (
f"SELECT * FROM {legacy_table_name} OFFSET $1 LIMIT $2"
)
rows = await db.query(
select_query, [offset, batch_size], multirows=True
)
if not rows:
break
# Insert batch into new table
for row in rows:
# Get column names and values as dictionary
row_dict = dict(row)
# Build insert query with positional parameters
columns = list(row_dict.keys())
columns_str = ", ".join(columns)
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
insert_query = f"""
INSERT INTO {table_name} ({columns_str})
VALUES ({placeholders})
ON CONFLICT (workspace, id) DO NOTHING
"""
# Construct dict for execute() method
values = {col: row_dict[col] for col in columns}
await db.execute(insert_query, values)
migrated_count += len(rows)
logger.info(
f"PostgreSQL: {migrated_count}/{legacy_count} records migrated"
)
offset += batch_size
# Verify migration by comparing counts
logger.info("Verifying migration...")
logger.info("PostgreSQL: Verifying migration...")
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
new_count_result = await db.query(new_count_query, [])
new_count = new_count_result.get("count", 0) if new_count_result else 0
@ -2504,15 +2609,10 @@ class PGVectorStorage(BaseVectorStorage):
f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
)
# Create vector index after successful migration
await db._create_vector_index(table_name, embedding_dim)
# Clean up migrated data from legacy table
# CRITICAL: Only delete current workspace's data, not the entire table!
# Other workspaces may still have data in the legacy table.
try:
if workspace:
# Delete only current workspace's migrated data
logger.info(
f"PostgreSQL: Deleting migrated workspace '{workspace}' data from legacy table '{legacy_table_name}'..."
)
@ -2524,7 +2624,6 @@ class PGVectorStorage(BaseVectorStorage):
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
)
# Check if legacy table still has data from other workspaces
remaining_query = (
f"SELECT COUNT(*) as count FROM {legacy_table_name}"
)
@ -2534,7 +2633,6 @@ class PGVectorStorage(BaseVectorStorage):
)
if remaining_count == 0:
# Table is now empty, safe to drop
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty, deleting..."
)
@ -2544,12 +2642,10 @@ class PGVectorStorage(BaseVectorStorage):
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
)
else:
# Table still has data from other workspaces, preserve it
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved ({remaining_count} records from other workspaces remain)"
)
else:
# No workspace specified - delete entire table (legacy behavior for backward compatibility)
logger.warning(
f"PostgreSQL: No workspace specified, deleting entire legacy table '{legacy_table_name}'..."
)