Refactor test configuration to use pytest fixtures and CLI options
• Add pytest command-line options
• Create session-scoped fixtures
• Remove hardcoded environment vars
• Update test function signatures
• Improve configuration priority
(cherry picked from commit 1fe05df211)
This commit is contained in:
parent
97cf689dfb
commit
d011a1c0e7
2 changed files with 729 additions and 1111 deletions
|
|
@ -1,351 +1,85 @@
|
|||
"""
|
||||
Pytest configuration and fixtures for multi-tenant testing.
|
||||
Pytest configuration for LightRAG tests.
|
||||
|
||||
Provides:
|
||||
- Database fixtures for different testing modes
|
||||
- Tenant and KB context fixtures
|
||||
- Mock LLM and embedding services
|
||||
- Multi-tenant test utilities
|
||||
This file provides command-line options and fixtures for test configuration.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import asyncio
|
||||
import psycopg2
|
||||
import json
|
||||
from typing import Dict, List, Optional, Generator
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
import uuid
|
||||
|
||||
# ============================================================================
|
||||
# Environment and Mode Detection
|
||||
# ============================================================================
|
||||
|
||||
MULTITENANT_MODE = os.getenv("MULTITENANT_MODE", "demo")
|
||||
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost")
|
||||
POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", "5432"))
|
||||
POSTGRES_USER = os.getenv("POSTGRES_USER", "lightrag")
|
||||
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "lightrag_secure_password")
|
||||
POSTGRES_DATABASE = os.getenv("POSTGRES_DATABASE", "lightrag_multitenant")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Database Connection Management
|
||||
# ============================================================================
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command-line options for LightRAG tests."""
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def db_connection_string():
|
||||
"""Generate PostgreSQL connection string."""
|
||||
return f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DATABASE}"
|
||||
parser.addoption(
|
||||
"--keep-artifacts",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Keep test artifacts (temporary directories and files) after test completion for inspection",
|
||||
)
|
||||
|
||||
parser.addoption(
|
||||
"--stress-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable stress test mode with more intensive workloads",
|
||||
)
|
||||
|
||||
parser.addoption(
|
||||
"--test-workers",
|
||||
action="store",
|
||||
default=3,
|
||||
type=int,
|
||||
help="Number of parallel workers for stress tests (default: 3)",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_connection():
|
||||
"""Create persistent PostgreSQL connection for session."""
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
database=POSTGRES_DATABASE
|
||||
)
|
||||
conn.autocommit = False
|
||||
yield conn
|
||||
conn.close()
|
||||
except psycopg2.Error as e:
|
||||
pytest.skip(f"PostgreSQL not available: {e}")
|
||||
def keep_test_artifacts(request):
|
||||
"""
|
||||
Fixture to determine whether to keep test artifacts.
|
||||
|
||||
Priority: CLI option > Environment variable > Default (False)
|
||||
"""
|
||||
import os
|
||||
|
||||
@contextmanager
|
||||
def database_transaction(postgres_connection):
|
||||
"""Context manager for database transactions with rollback."""
|
||||
cursor = postgres_connection.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
postgres_connection.commit()
|
||||
except Exception as e:
|
||||
postgres_connection.rollback()
|
||||
raise e
|
||||
finally:
|
||||
cursor.close()
|
||||
# Check CLI option first
|
||||
if request.config.getoption("--keep-artifacts"):
|
||||
return True
|
||||
|
||||
# Fall back to environment variable
|
||||
return os.getenv("LIGHTRAG_KEEP_ARTIFACTS", "false").lower() == "true"
|
||||
|
||||
# ============================================================================
|
||||
# Mode-Specific Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def testing_mode():
|
||||
"""Return current testing mode."""
|
||||
return MULTITENANT_MODE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_compatibility_mode():
|
||||
"""Check if running in compatibility mode (MULTITENANT_MODE=off)."""
|
||||
return MULTITENANT_MODE == "off"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_single_tenant_mode():
|
||||
"""Check if running in single-tenant mode (MULTITENANT_MODE=on)."""
|
||||
return MULTITENANT_MODE == "on"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_demo_mode():
|
||||
"""Check if running in demo mode (MULTITENANT_MODE=demo)."""
|
||||
return MULTITENANT_MODE == "demo"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tenant and KB Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def demo_tenant_acme():
|
||||
"""Acme Corp tenant for demo mode."""
|
||||
return {
|
||||
"tenant_id": "acme-corp",
|
||||
"name": "Acme Corporation",
|
||||
"kbs": ["kb-prod", "kb-dev"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def demo_tenant_techstart():
|
||||
"""TechStart tenant for demo mode."""
|
||||
return {
|
||||
"tenant_id": "techstart",
|
||||
"name": "TechStart Inc",
|
||||
"kbs": ["kb-main", "kb-backup"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_tenant():
|
||||
"""Default tenant for compatibility and on modes."""
|
||||
return {
|
||||
"tenant_id": "default",
|
||||
"name": "Default Tenant",
|
||||
"kbs": ["default"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_1():
|
||||
"""Test tenant 1 for single-tenant mode."""
|
||||
return {
|
||||
"tenant_id": "tenant-1",
|
||||
"name": "Test Tenant 1",
|
||||
"kbs": ["kb-default", "kb-secondary", "kb-experimental"]
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Data Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_document():
|
||||
"""Sample document for testing."""
|
||||
return {
|
||||
"title": "Test Document",
|
||||
"content": "This is a test document for LightRAG multi-tenant testing.",
|
||||
"file_type": "text",
|
||||
"metadata": {
|
||||
"source": "test",
|
||||
"version": "1.0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_entity():
|
||||
"""Sample entity for testing."""
|
||||
return {
|
||||
"name": "TestEntity",
|
||||
"type": "Person",
|
||||
"description": "A test entity for multi-tenant isolation testing",
|
||||
"metadata": {
|
||||
"test": True,
|
||||
"created_by": "pytest"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_relation():
|
||||
"""Sample relation for testing."""
|
||||
return {
|
||||
"source_entity": "Entity1",
|
||||
"target_entity": "Entity2",
|
||||
"relation_type": "knows",
|
||||
"description": "Test relationship between entities",
|
||||
"weight": 0.8
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Database Query Helpers
|
||||
# ============================================================================
|
||||
|
||||
class DatabaseHelper:
|
||||
"""Helper class for database operations in tests."""
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def execute_query(self, query: str, params: tuple = ()) -> List[Dict]:
|
||||
"""Execute a SELECT query and return results."""
|
||||
with database_transaction(self.connection) as cursor:
|
||||
cursor.execute(query, params)
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
|
||||
def execute_insert(self, table: str, data: Dict) -> None:
|
||||
"""Insert a row into a table."""
|
||||
columns = ", ".join(data.keys())
|
||||
placeholders = ", ".join(["%s"] * len(data))
|
||||
query = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
|
||||
with database_transaction(self.connection) as cursor:
|
||||
cursor.execute(query, tuple(data.values()))
|
||||
|
||||
def execute_delete(self, table: str, where: Dict) -> int:
|
||||
"""Delete rows from a table."""
|
||||
where_clause = " AND ".join([f"{k} = %s" for k in where.keys()])
|
||||
query = f"DELETE FROM {table} WHERE {where_clause}"
|
||||
with database_transaction(self.connection) as cursor:
|
||||
cursor.execute(query, tuple(where.values()))
|
||||
return cursor.rowcount
|
||||
|
||||
def count_documents(self, tenant_id: str, kb_id: str) -> int:
|
||||
"""Count documents for a tenant/KB."""
|
||||
query = "SELECT COUNT(*) as count FROM documents WHERE tenant_id = %s AND kb_id = %s"
|
||||
result = self.execute_query(query, (tenant_id, kb_id))
|
||||
return result[0]["count"] if result else 0
|
||||
|
||||
def count_entities(self, tenant_id: str, kb_id: str) -> int:
|
||||
"""Count entities for a tenant/KB."""
|
||||
query = "SELECT COUNT(*) as count FROM entities WHERE tenant_id = %s AND kb_id = %s"
|
||||
result = self.execute_query(query, (tenant_id, kb_id))
|
||||
return result[0]["count"] if result else 0
|
||||
|
||||
def get_all_documents(self, tenant_id: str, kb_id: str) -> List[Dict]:
|
||||
"""Get all documents for a tenant/KB."""
|
||||
query = "SELECT * FROM documents WHERE tenant_id = %s AND kb_id = %s ORDER BY created_at DESC"
|
||||
return self.execute_query(query, (tenant_id, kb_id))
|
||||
|
||||
def get_all_entities(self, tenant_id: str, kb_id: str) -> List[Dict]:
|
||||
"""Get all entities for a tenant/KB."""
|
||||
query = "SELECT * FROM entities WHERE tenant_id = %s AND kb_id = %s ORDER BY created_at DESC"
|
||||
return self.execute_query(query, (tenant_id, kb_id))
|
||||
|
||||
def verify_tenant_isolation(self, tenant_id: str) -> bool:
|
||||
"""Verify that no cross-tenant data exists when querying this tenant."""
|
||||
# Check that all documents belong to this tenant
|
||||
query = """
|
||||
SELECT COUNT(*) as count FROM documents
|
||||
WHERE tenant_id != %s AND EXISTS (
|
||||
SELECT 1 FROM documents d2
|
||||
WHERE d2.tenant_id = %s AND d2.id = documents.id
|
||||
)
|
||||
"""
|
||||
result = self.execute_query(query, (tenant_id, tenant_id))
|
||||
return result[0]["count"] == 0 if result else True
|
||||
|
||||
def clear_tenant_data(self, tenant_id: str, kb_id: Optional[str] = None) -> None:
|
||||
"""Clear all data for a tenant/KB."""
|
||||
tables = ["document_status", "embeddings", "documents", "entities", "relations"]
|
||||
|
||||
for table in tables:
|
||||
if kb_id:
|
||||
where = {"tenant_id": tenant_id, "kb_id": kb_id}
|
||||
else:
|
||||
where = {"tenant_id": tenant_id}
|
||||
self.execute_delete(table, where)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_helper(postgres_connection):
|
||||
"""Provide database helper for tests."""
|
||||
return DatabaseHelper(postgres_connection)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mock Services
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_service():
|
||||
"""Mock LLM service for testing."""
|
||||
mock = MagicMock()
|
||||
mock.generate = MagicMock(return_value="Mock LLM response")
|
||||
mock.extract_entities = MagicMock(return_value=["Entity1", "Entity2"])
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_service():
|
||||
"""Mock embedding service for testing."""
|
||||
mock = MagicMock()
|
||||
mock.embed_text = MagicMock(return_value=[0.1] * 1024) # 1024-dim vector
|
||||
mock.embed_batch = MagicMock(return_value=[[0.1] * 1024 for _ in range(10)])
|
||||
return mock
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Async Event Loop
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
def stress_test_mode(request):
|
||||
"""
|
||||
Fixture to determine whether stress test mode is enabled.
|
||||
|
||||
Priority: CLI option > Environment variable > Default (False)
|
||||
"""
|
||||
import os
|
||||
|
||||
# Check CLI option first
|
||||
if request.config.getoption("--stress-test"):
|
||||
return True
|
||||
|
||||
# Fall back to environment variable
|
||||
return os.getenv("LIGHTRAG_STRESS_TEST", "false").lower() == "true"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Markers and Parametrization
|
||||
# ============================================================================
|
||||
@pytest.fixture(scope="session")
|
||||
def parallel_workers(request):
|
||||
"""
|
||||
Fixture to determine the number of parallel workers for stress tests.
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom pytest markers."""
|
||||
config.addinivalue_line(
|
||||
"markers", "compatibility: mark test to run only in compatibility mode"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "single_tenant: mark test to run only in single-tenant mode"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "multi_tenant: mark test to run only in demo/multi-tenant mode"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "database: mark test that requires database connection"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "isolation: mark test that verifies data isolation"
|
||||
)
|
||||
Priority: CLI option > Environment variable > Default (3)
|
||||
"""
|
||||
import os
|
||||
|
||||
# Check CLI option first
|
||||
cli_workers = request.config.getoption("--test-workers")
|
||||
if cli_workers != 3: # Non-default value provided
|
||||
return cli_workers
|
||||
|
||||
# ============================================================================
|
||||
# Test Collection Hooks
|
||||
# ============================================================================
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Skip tests based on testing mode."""
|
||||
skip_compatibility = pytest.mark.skip(reason="Not in compatibility mode")
|
||||
skip_single_tenant = pytest.mark.skip(reason="Not in single-tenant mode")
|
||||
skip_multi_tenant = pytest.mark.skip(reason="Not in multi-tenant mode")
|
||||
|
||||
for item in items:
|
||||
if "compatibility" in item.keywords and MULTITENANT_MODE != "off":
|
||||
item.add_marker(skip_compatibility)
|
||||
if "single_tenant" in item.keywords and MULTITENANT_MODE != "on":
|
||||
item.add_marker(skip_single_tenant)
|
||||
if "multi_tenant" in item.keywords and MULTITENANT_MODE != "demo":
|
||||
item.add_marker(skip_multi_tenant)
|
||||
# Fall back to environment variable
|
||||
return int(os.getenv("LIGHTRAG_TEST_WORKERS", "3"))
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue