LightRAG/tests/test_postgres_migration_params.py
BukeLy 0508ad7a15 fix: prevent offline tests from failing due to missing E2E dependencies
Why this change is needed:
Offline tests were failing with "ModuleNotFoundError: No module named 'qdrant_client'"
because test_e2e_multi_instance.py was being imported during test collection, even
though it's an E2E test that shouldn't run in offline mode. Pytest imports all test
files during collection phase regardless of marks, causing import errors for missing
E2E dependencies (qdrant_client, asyncpg, etc.).

Additionally, the test mocks for PostgreSQL migration were too permissive - they
accepted any parameter format without validation, which allowed bugs (like passing
dict instead of positional args to AsyncPG execute()) to slip through undetected.

How it solves it:
1. E2E Import Fix:
   - Use pytest.importorskip() to conditionally import qdrant_client
   - E2E tests are now skipped cleanly when dependencies are missing
   - Offline tests can collect and run without E2E dependencies

2. Stricter Test Mocks:
   - Enhanced mock_pg_db fixture to validate AsyncPG parameter format
   - Mock execute() now raises TypeError if dict/list passed as single argument
   - Ensures tests catch parameter passing bugs that would fail in production

3. Parameter Validation Test:
   - Added test_postgres_migration_params.py for explicit parameter validation
   - Verifies migration passes positional args correctly to AsyncPG
   - Provides detailed output for debugging parameter issues

Impact:
- Offline tests no longer fail due to missing E2E dependencies
- Future bugs in AsyncPG parameter passing will be caught by tests
- Better test isolation between offline and E2E test suites
- Improved test coverage for migration parameter handling

Testing:
- Verified with `pytest tests/ -m offline -v` - no import errors
- All PostgreSQL migration tests pass (6/6 unit + 1 strict validation)
- Pre-commit hooks pass (ruff-format, ruff)
2025-11-20 02:03:48 +08:00

168 lines
5.7 KiB
Python

"""
Strict test to verify PostgreSQL migration parameter passing.
This test specifically validates that the migration code passes parameters
to AsyncPG execute() in the correct format (positional args, not dict).
"""
import pytest
from unittest.mock import patch, AsyncMock
from lightrag.utils import EmbeddingFunc
from lightrag.kg.postgres_impl import PGVectorStorage
from lightrag.namespace import NameSpace
@pytest.mark.asyncio
async def test_migration_parameter_passing():
"""
Verify that migration passes positional parameters correctly to execute().
This test specifically checks that execute() is called with:
- SQL query as first argument
- Values as separate positional arguments (*values)
NOT as a dictionary or list
"""
# Track all execute calls
execute_calls = []
async def strict_execute(sql, *args, **kwargs):
"""Record all execute calls with their arguments"""
execute_calls.append(
{
"sql": sql,
"args": args, # Should be tuple of values
"kwargs": kwargs,
}
)
# Validate: if args has only one element and it's a dict/list, that's wrong
if args and len(args) == 1 and isinstance(args[0], (dict, list)):
raise TypeError(
f"BUG DETECTED: execute() called with {type(args[0]).__name__} "
"instead of positional parameters! "
f"Got: execute(sql, {args[0]!r})"
)
return None
# Create mocks
mock_db = AsyncMock()
mock_db.workspace = "test_workspace"
mock_db.execute = AsyncMock(side_effect=strict_execute)
# Mock query to simulate legacy table with data
mock_rows = [
{
"id": "row1",
"content": "content1",
"workspace": "test",
"vector": [0.1] * 1536,
},
{
"id": "row2",
"content": "content2",
"workspace": "test",
"vector": [0.2] * 1536,
},
]
async def mock_query(sql, params=None, multirows=False, **kwargs):
if "COUNT(*)" in sql:
return {"count": len(mock_rows)}
elif multirows and "SELECT *" in sql:
return mock_rows
return {}
mock_db.query = AsyncMock(side_effect=mock_query)
# Mock table existence: only legacy table exists
async def mock_table_exists(db, table_name):
return "test_model_1536d" not in table_name # Legacy exists, new doesn't
# Setup embedding function
async def embed_func(texts, **kwargs):
import numpy as np
return np.array([[0.1] * 1536 for _ in texts])
embedding_func = EmbeddingFunc(
embedding_dim=1536, func=embed_func, model_name="test-model"
)
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="test",
)
with (
patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock,
patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager,
patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
),
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()),
):
mock_lock_ctx = AsyncMock()
mock_lock.return_value = mock_lock_ctx
mock_manager.get_client = AsyncMock(return_value=mock_db)
mock_manager.release_client = AsyncMock()
# This should trigger migration
await storage.initialize()
# Verify execute was called (migration happened)
assert len(execute_calls) > 0, "Migration should have called execute()"
# Verify parameter format for INSERT statements
insert_calls = [c for c in execute_calls if "INSERT INTO" in c["sql"]]
assert len(insert_calls) > 0, "Should have INSERT statements from migration"
print(f"\n✓ Migration executed {len(insert_calls)} INSERT statements")
# Check each INSERT call
for i, call_info in enumerate(insert_calls):
args = call_info["args"]
sql = call_info["sql"]
print(f"\n INSERT #{i+1}:")
print(f" SQL: {sql[:100]}...")
print(f" Args count: {len(args)}")
print(f" Args types: {[type(arg).__name__ for arg in args]}")
# Key validation: args should be a tuple of values, not a single dict/list
if args:
# Check if first (and only) arg is a dict or list - that's the bug!
if len(args) == 1 and isinstance(args[0], (dict, list)):
pytest.fail(
f"BUG: execute() called with {type(args[0]).__name__} instead of "
f"positional parameters!\n"
f" SQL: {sql}\n"
f" Args: {args[0]}\n"
f"Expected: execute(sql, val1, val2, val3, ...)\n"
f"Got: execute(sql, {type(args[0]).__name__})"
)
# Validate all args are primitive types (not collections)
for j, arg in enumerate(args):
if isinstance(arg, (dict, list)) and not isinstance(arg, (str, bytes)):
# Exception: vector columns might be lists, that's OK
if "vector" not in sql:
pytest.fail(
f"BUG: Parameter #{j} is {type(arg).__name__}, "
f"expected primitive type"
)
print(
f"\n✅ All {len(insert_calls)} INSERT statements use correct parameter format"
)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])