fix: correct AsyncPG parameter passing in PostgreSQL migration to prevent data corruption

Why this change is needed:
The migration code at line 2351 was passing a dictionary (row_dict) as parameters
to a SQL query that used positional placeholders ($1, $2, etc.). AsyncPG strictly
requires positional parameters to be passed as a list/tuple of values in the exact
order matching the placeholders. Using a dictionary would cause parameter mismatches
and migration failures, potentially corrupting migrated data or causing the entire
migration to fail silently.

How it solves it:
- Extract values from row_dict in the exact order defined by the columns list
- Pass values as separate positional arguments using *values unpacking
- Added clear comments explaining AsyncPG's requirements
- Updated comment from "named parameters" to "positional parameters" for accuracy

Impact:
- Migration now correctly maps values to SQL placeholders
- Prevents data corruption during legacy table migration
- Ensures reliable data transfer from old to new table schemas
- All PostgreSQL migration tests pass (6/6)

Testing:
- Verified with `uv run pytest tests/test_postgres_migration.py -v` - all tests pass
- Pre-commit hooks pass (ruff-format, ruff)
- Tested parameter ordering logic matches AsyncPG requirements
This commit is contained in:
BukeLy 2025-11-20 01:59:34 +08:00
parent 7d0c356702
commit 982b63c9be

View file

@ -2335,10 +2335,10 @@ class PGVectorStorage(BaseVectorStorage):
# Insert batch into new table
for row in rows:
# Get column names and values as dictionary (execute expects dict)
# Get column names and values as dictionary
row_dict = dict(row)
# Build insert query with named parameters
# Build insert query with positional parameters
columns = list(row_dict.keys())
columns_str = ", ".join(columns)
placeholders = ", ".join([f"${i+1}" for i in range(len(columns))])
@ -2348,7 +2348,9 @@ class PGVectorStorage(BaseVectorStorage):
ON CONFLICT DO NOTHING
"""
await db.execute(insert_query, row_dict)
# AsyncPG requires positional parameters as a list in order
values = [row_dict[col] for col in columns]
await db.execute(insert_query, *values)
migrated_count += len(rows)
logger.info(