This commit is contained in:
Bukely_ 2025-12-12 10:18:32 +08:00 committed by GitHub
commit 4e5351de63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 5500 additions and 162 deletions

190
.github/workflows/e2e-tests.yml vendored Normal file
View file

@ -0,0 +1,190 @@
name: E2E Tests (Real Databases)
on:
workflow_dispatch: # Manual trigger only for E2E tests
pull_request:
branches: [ main, dev ]
paths:
- 'lightrag/kg/postgres_impl.py'
- 'lightrag/kg/qdrant_impl.py'
- 'tests/test_e2e_*.py'
jobs:
e2e-postgres:
name: E2E PostgreSQL Tests
runs-on: ubuntu-latest
services:
postgres:
image: ankane/pgvector:latest
env:
POSTGRES_USER: lightrag
POSTGRES_PASSWORD: lightrag_test_password
POSTGRES_DB: lightrag_test
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U lightrag"
--health-interval 10s
--health-timeout 5s
--health-retries 5
strategy:
matrix:
python-version: ['3.10', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip packages
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-e2e-${{ hashFiles('**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-e2e-
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[api]"
pip install pytest pytest-asyncio asyncpg numpy qdrant-client
- name: Wait for PostgreSQL
run: |
timeout 30 bash -c 'until pg_isready -h localhost -p 5432 -U lightrag; do sleep 1; done'
- name: Setup pgvector extension
env:
PGPASSWORD: lightrag_test_password
run: |
psql -h localhost -U lightrag -d lightrag_test -c "CREATE EXTENSION IF NOT EXISTS vector;"
psql -h localhost -U lightrag -d lightrag_test -c "SELECT extname, extversion FROM pg_extension WHERE extname = 'vector';"
- name: Run PostgreSQL E2E tests
env:
POSTGRES_HOST: localhost
POSTGRES_PORT: 5432
POSTGRES_USER: lightrag
POSTGRES_PASSWORD: lightrag_test_password
POSTGRES_DATABASE: lightrag_test
run: |
pytest tests/test_e2e_multi_instance.py -k "postgres" -v --tb=short -s
timeout-minutes: 20
- name: Upload PostgreSQL test results
if: always()
uses: actions/upload-artifact@v4
with:
name: e2e-postgres-results-py${{ matrix.python-version }}
path: |
.pytest_cache/
test-results.xml
retention-days: 7
e2e-qdrant:
name: E2E Qdrant Tests
runs-on: ubuntu-latest
services:
qdrant:
image: qdrant/qdrant:latest
ports:
- 6333:6333
- 6334:6334
strategy:
matrix:
python-version: ['3.10', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip packages
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-e2e-${{ hashFiles('**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-e2e-
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[api]"
pip install pytest pytest-asyncio qdrant-client numpy
- name: Wait for Qdrant
run: |
echo "Waiting for Qdrant to be ready..."
for i in {1..60}; do
if curl -s http://localhost:6333 > /dev/null 2>&1; then
echo "Qdrant is ready!"
break
fi
echo "Attempt $i/60: Qdrant not ready yet, waiting..."
sleep 1
done
# Final check
if ! curl -s http://localhost:6333 > /dev/null 2>&1; then
echo "ERROR: Qdrant failed to start after 60 seconds"
exit 1
fi
- name: Verify Qdrant connection
run: |
echo "Verifying Qdrant API..."
curl -X GET "http://localhost:6333/collections" -H "Content-Type: application/json"
echo ""
echo "Qdrant is accessible and ready for testing"
- name: Run Qdrant E2E tests
env:
QDRANT_URL: http://localhost:6333
QDRANT_API_KEY: ""
run: |
pytest tests/test_e2e_multi_instance.py -k "qdrant" -v --tb=short -s
timeout-minutes: 15
- name: Upload Qdrant test results
if: always()
uses: actions/upload-artifact@v4
with:
name: e2e-qdrant-results-py${{ matrix.python-version }}
path: |
.pytest_cache/
test-results.xml
retention-days: 7
e2e-summary:
name: E2E Test Summary
runs-on: ubuntu-latest
needs: [e2e-postgres, e2e-qdrant]
if: always()
steps:
- name: Check test results
run: |
echo "## E2E Test Summary" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "### PostgreSQL E2E Tests" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ needs.e2e-postgres.result }}" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "### Qdrant E2E Tests" >> $GITHUB_STEP_SUMMARY
echo "Status: ${{ needs.e2e-qdrant.result }}" >> $GITHUB_STEP_SUMMARY
- name: Fail if any test failed
if: needs.e2e-postgres.result != 'success' || needs.e2e-qdrant.result != 'success'
run: exit 1

74
.github/workflows/feature-tests.yml vendored Normal file
View file

@ -0,0 +1,74 @@
name: Feature Branch Tests
on:
workflow_dispatch: # Allow manual trigger
push:
branches:
- 'feature/**'
pull_request:
branches: [ main, dev ]
jobs:
migration-tests:
name: Vector Storage Migration Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip packages
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt', '**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[api]"
pip install pytest pytest-asyncio
- name: Run Qdrant migration tests
run: |
pytest tests/test_qdrant_migration.py -v --tb=short
continue-on-error: false
- name: Run PostgreSQL migration tests
run: |
pytest tests/test_postgres_migration.py -v --tb=short
continue-on-error: false
- name: Run all unit tests (if exists)
run: |
# Run EmbeddingFunc tests
pytest tests/ -k "embedding" -v --tb=short || true
continue-on-error: true
- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
with:
name: migration-test-results-py${{ matrix.python-version }}
path: |
.pytest_cache/
test-results.xml
retention-days: 7
- name: Test Summary
if: always()
run: |
echo "## Test Summary" >> $GITHUB_STEP_SUMMARY
echo "- Python: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
echo "- Branch: ${{ github.ref_name }}" >> $GITHUB_STEP_SUMMARY
echo "- Commit: ${{ github.sha }}" >> $GITHUB_STEP_SUMMARY

View file

@ -0,0 +1,271 @@
"""
Multi-Model Vector Storage Isolation Demo
This example demonstrates LightRAG's automatic model isolation feature for vector storage.
When using different embedding models, LightRAG automatically creates separate collections/tables,
preventing dimension mismatches and data pollution.
Key Features:
- Automatic model suffix generation: {model_name}_{dim}d
- Seamless migration from legacy (no-suffix) to new (with-suffix) collections
- Support for multiple workspaces with different embedding models
Requirements:
- OpenAI API key (or any OpenAI-compatible API)
- Qdrant or PostgreSQL for vector storage (optional, defaults to NanoVectorDB)
"""
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
from lightrag.utils import EmbeddingFunc
# Set your API key
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"
async def scenario_1_new_workspace_with_explicit_model():
"""
Scenario 1: Creating a new workspace with explicit model name
Result: Creates collection/table with name like:
- Qdrant: lightrag_vdb_chunks_text_embedding_3_large_3072d
- PostgreSQL: LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d
"""
print("\n" + "=" * 80)
print("Scenario 1: New Workspace with Explicit Model Name")
print("=" * 80)
# Define custom embedding function with explicit model name
async def my_embedding_func(texts: list[str]):
return await openai_embed(texts, model="text-embedding-3-large")
# Create EmbeddingFunc with model_name specified
embedding_func = EmbeddingFunc(
embedding_dim=3072,
func=my_embedding_func,
model_name="text-embedding-3-large", # Explicit model name
)
rag = LightRAG(
working_dir="./workspace_large_model",
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func,
)
await rag.initialize_storages()
# Insert sample data
await rag.ainsert("LightRAG supports automatic model isolation for vector storage.")
# Query
result = await rag.aquery(
"What does LightRAG support?", param=QueryParam(mode="hybrid")
)
print(f"\nQuery Result: {result[:200]}...")
print("\n✅ Collection/table created with suffix: text_embedding_3_large_3072d")
await rag.close()
async def scenario_2_legacy_migration():
"""
Scenario 2: Upgrading from legacy version (without model_name)
If you previously used LightRAG without specifying model_name,
the first run with model_name will automatically migrate your data.
Result: Data is migrated from:
- Old: lightrag_vdb_chunks (no suffix)
- New: lightrag_vdb_chunks_text_embedding_ada_002_1536d (with suffix)
"""
print("\n" + "=" * 80)
print("Scenario 2: Automatic Migration from Legacy Format")
print("=" * 80)
# Step 1: Simulate legacy workspace (no model_name)
print("\n[Step 1] Creating legacy workspace without model_name...")
async def legacy_embedding_func(texts: list[str]):
return await openai_embed(texts, model="text-embedding-ada-002")
# Legacy: No model_name specified
legacy_embedding = EmbeddingFunc(
embedding_dim=1536,
func=legacy_embedding_func,
# model_name not specified → uses "unknown" as fallback
)
rag_legacy = LightRAG(
working_dir="./workspace_legacy",
llm_model_func=gpt_4o_mini_complete,
embedding_func=legacy_embedding,
)
await rag_legacy.initialize_storages()
await rag_legacy.ainsert("Legacy data without model isolation.")
await rag_legacy.close()
print("✅ Legacy workspace created with suffix: unknown_1536d")
# Step 2: Upgrade to new version with model_name
print("\n[Step 2] Upgrading to new version with explicit model_name...")
# New: With model_name specified
new_embedding = EmbeddingFunc(
embedding_dim=1536,
func=legacy_embedding_func,
model_name="text-embedding-ada-002", # Now explicitly specified
)
rag_new = LightRAG(
working_dir="./workspace_legacy", # Same working directory
llm_model_func=gpt_4o_mini_complete,
embedding_func=new_embedding,
)
# On first initialization, LightRAG will:
# 1. Detect legacy collection exists
# 2. Automatically migrate data to new collection with model suffix
# 3. Legacy collection remains but can be deleted after verification
await rag_new.initialize_storages()
# Verify data is still accessible
result = await rag_new.aquery(
"What is the legacy data?", param=QueryParam(mode="hybrid")
)
print(f"\nQuery Result: {result[:200] if result else 'No results'}...")
print("\n✅ Data migrated to: text_embedding_ada_002_1536d")
print(" Legacy collection can be manually deleted after verification")
await rag_new.close()
async def scenario_3_multiple_models_coexistence():
"""
Scenario 3: Multiple workspaces with different embedding models
Different embedding models create completely isolated collections/tables,
allowing safe coexistence without dimension conflicts or data pollution.
Result:
- Workspace A: lightrag_vdb_chunks_bge_small_768d
- Workspace B: lightrag_vdb_chunks_bge_large_1024d
"""
print("\n" + "=" * 80)
print("Scenario 3: Multiple Models Coexistence")
print("=" * 80)
# Workspace A: Small embedding model (768 dimensions)
print("\n[Workspace A] Using bge-small model (768d)...")
async def embedding_func_small(texts: list[str]):
# Simulate small embedding model
# In real usage, replace with actual model call
return await openai_embed(texts, model="text-embedding-3-small")
embedding_a = EmbeddingFunc(
embedding_dim=1536, # text-embedding-3-small dimension
func=embedding_func_small,
model_name="text-embedding-3-small",
)
rag_a = LightRAG(
working_dir="./workspace_a",
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_a,
)
await rag_a.initialize_storages()
await rag_a.ainsert("Workspace A uses small embedding model for efficiency.")
print("✅ Workspace A created with suffix: text_embedding_3_small_1536d")
# Workspace B: Large embedding model (3072 dimensions)
print("\n[Workspace B] Using text-embedding-3-large model (3072d)...")
async def embedding_func_large(texts: list[str]):
# Simulate large embedding model
return await openai_embed(texts, model="text-embedding-3-large")
embedding_b = EmbeddingFunc(
embedding_dim=3072, # text-embedding-3-large dimension
func=embedding_func_large,
model_name="text-embedding-3-large",
)
rag_b = LightRAG(
working_dir="./workspace_b",
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_b,
)
await rag_b.initialize_storages()
await rag_b.ainsert("Workspace B uses large embedding model for better accuracy.")
print("✅ Workspace B created with suffix: text_embedding_3_large_3072d")
# Verify isolation: Query each workspace
print("\n[Verification] Querying both workspaces...")
result_a = await rag_a.aquery(
"What model does workspace use?", param=QueryParam(mode="hybrid")
)
result_b = await rag_b.aquery(
"What model does workspace use?", param=QueryParam(mode="hybrid")
)
print(f"\nWorkspace A Result: {result_a[:100] if result_a else 'No results'}...")
print(f"Workspace B Result: {result_b[:100] if result_b else 'No results'}...")
print("\n✅ Both workspaces operate independently without interference")
await rag_a.close()
await rag_b.close()
async def main():
"""
Run all scenarios to demonstrate model isolation features
"""
print("\n" + "=" * 80)
print("LightRAG Multi-Model Vector Storage Isolation Demo")
print("=" * 80)
print("\nThis demo shows how LightRAG automatically handles:")
print("1. ✅ Automatic model suffix generation")
print("2. ✅ Seamless data migration from legacy format")
print("3. ✅ Multiple embedding models coexistence")
try:
# Scenario 1: New workspace with explicit model
await scenario_1_new_workspace_with_explicit_model()
# Scenario 2: Legacy migration
await scenario_2_legacy_migration()
# Scenario 3: Multiple models coexistence
await scenario_3_multiple_models_coexistence()
print("\n" + "=" * 80)
print("✅ All scenarios completed successfully!")
print("=" * 80)
print("\n📝 Key Takeaways:")
print("- Always specify `model_name` in EmbeddingFunc for clear model tracking")
print("- LightRAG automatically migrates legacy data on first run")
print("- Different embedding models create isolated collections/tables")
print("- Collection names follow pattern: {base_name}_{model_name}_{dim}d")
print("\n📚 See the plan document for more details:")
print(" .claude/plan/PR-vector-model-isolation.md")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -220,6 +220,45 @@ class BaseVectorStorage(StorageNameSpace, ABC):
cosine_better_than_threshold: float = field(default=0.2)
meta_fields: set[str] = field(default_factory=set)
def _generate_collection_suffix(self) -> str:
"""Generates collection/table suffix from embedding_func.
Returns:
str: Suffix string, e.g. "text_embedding_3_large_3072d"
"""
# Try to get model identifier from the embedding function
# If it's a wrapped function (doesn't have get_model_identifier),
# fallback to the original embedding_func from global_config
if hasattr(self.embedding_func, "get_model_identifier"):
return self.embedding_func.get_model_identifier()
elif "embedding_func" in self.global_config:
original_embedding_func = self.global_config["embedding_func"]
if original_embedding_func is not None and hasattr(
original_embedding_func, "get_model_identifier"
):
return original_embedding_func.get_model_identifier()
else:
# Debug: log why we couldn't get model identifier
from lightrag.utils import logger
logger.debug(
f"Could not get model_identifier: embedding_func is {type(original_embedding_func)}, has method={hasattr(original_embedding_func, 'get_model_identifier') if original_embedding_func else False}"
)
# Fallback: no model identifier available
return ""
def _get_legacy_collection_name(self) -> str:
"""Get legacy collection/table name (without suffix).
Used for data migration detection.
"""
raise NotImplementedError("Subclasses must implement this method")
def _get_new_collection_name(self) -> str:
"""Get new collection/table name (with suffix)."""
raise NotImplementedError("Subclasses must implement this method")
@abstractmethod
async def query(
self, query: str, top_k: int, query_embedding: list[float] = None

View file

@ -1163,23 +1163,9 @@ class PostgreSQLDB:
except Exception as e:
logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}")
# Create vector indexs
if self.vector_index_type:
logger.info(
f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}"
)
try:
if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]:
await self._create_vector_indexes()
else:
logger.warning(
"Doesn't support this vector index type: {self.vector_index_type}. "
"Supported types: HNSW, IVFFLAT, VCHORDRQ"
)
except Exception as e:
logger.error(
f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}"
)
# NOTE: Vector index creation moved to PGVectorStorage.setup_table()
# Each vector storage instance creates its own index with correct embedding_dim
# After all tables are created, attempt to migrate timestamp fields
try:
await self._migrate_timestamp_columns()
@ -1381,64 +1367,72 @@ class PostgreSQLDB:
except Exception as e:
logger.warning(f"Failed to create index {index['name']}: {e}")
async def _create_vector_indexes(self):
vdb_tables = [
"LIGHTRAG_VDB_CHUNKS",
"LIGHTRAG_VDB_ENTITY",
"LIGHTRAG_VDB_RELATION",
]
async def _create_vector_index(self, table_name: str, embedding_dim: int):
"""
Create vector index for a specific table.
Args:
table_name: Name of the table to create index on
embedding_dim: Embedding dimension for the vector column
"""
if not self.vector_index_type:
return
create_sql = {
"HNSW": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING hnsw (content_vector vector_cosine_ops)
ON {{table_name}} USING hnsw (content_vector vector_cosine_ops)
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
""",
"IVFFLAT": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING ivfflat (content_vector vector_cosine_ops)
ON {{table_name}} USING ivfflat (content_vector vector_cosine_ops)
WITH (lists = {self.ivfflat_lists})
""",
"VCHORDRQ": f"""
CREATE INDEX {{vector_index_name}}
ON {{k}} USING vchordrq (content_vector vector_cosine_ops)
{f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''}
ON {{table_name}} USING vchordrq (content_vector vector_cosine_ops)
{f"WITH (options = $${self.vchordrq_build_options}$$)" if self.vchordrq_build_options else ""}
""",
}
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
for k in vdb_tables:
vector_index_name = (
f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
if self.vector_index_type not in create_sql:
logger.warning(
f"Unsupported vector index type: {self.vector_index_type}. "
"Supported types: HNSW, IVFFLAT, VCHORDRQ"
)
check_vector_index_sql = f"""
SELECT 1 FROM pg_indexes
WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
"""
try:
vector_index_exists = await self.query(check_vector_index_sql)
if not vector_index_exists:
# Only set vector dimension when index doesn't exist
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
await self.execute(alter_sql)
logger.debug(f"Ensured vector dimension for {k}")
logger.info(
f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
return
k = table_name
vector_index_name = f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
check_vector_index_sql = f"""
SELECT 1 FROM pg_indexes
WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
"""
try:
vector_index_exists = await self.query(check_vector_index_sql)
if not vector_index_exists:
# Only set vector dimension when index doesn't exist
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
await self.execute(alter_sql)
logger.debug(f"Ensured vector dimension for {k}")
logger.info(
f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
)
await self.execute(
create_sql[self.vector_index_type].format(
vector_index_name=vector_index_name, table_name=k
)
await self.execute(
create_sql[self.vector_index_type].format(
vector_index_name=vector_index_name, k=k
)
)
logger.info(
f"Successfully created vector index {vector_index_name} on table {k}"
)
else:
logger.info(
f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
)
except Exception as e:
logger.error(f"Failed to create vector index on table {k}, Got: {e}")
)
logger.info(
f"Successfully created vector index {vector_index_name} on table {k}"
)
else:
logger.info(
f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
)
except Exception as e:
logger.error(f"Failed to create vector index on table {k}, Got: {e}")
async def query(
self,
@ -2175,6 +2169,90 @@ class PGKVStorage(BaseKVStorage):
return {"status": "error", "message": str(e)}
async def _pg_table_exists(db: PostgreSQLDB, table_name: str) -> bool:
"""Check if a table exists in PostgreSQL database"""
query = """
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = $1
)
"""
result = await db.query(query, [table_name.lower()])
return result.get("exists", False) if result else False
async def _pg_create_table(
db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
) -> None:
"""Create a new vector table by replacing the table name in DDL template"""
if base_table not in TABLES:
raise ValueError(f"No DDL template found for table: {base_table}")
ddl_template = TABLES[base_table]["ddl"]
# Replace embedding dimension placeholder if exists
ddl = ddl_template.replace(
f"VECTOR({os.environ.get('EMBEDDING_DIM', 1024)})", f"VECTOR({embedding_dim})"
)
# Replace table name
ddl = ddl.replace(base_table, table_name)
await db.execute(ddl)
async def _pg_migrate_workspace_data(
db: PostgreSQLDB,
legacy_table_name: str,
new_table_name: str,
workspace: str,
expected_count: int,
embedding_dim: int,
) -> int:
"""Migrate workspace data from legacy table to new table"""
migrated_count = 0
offset = 0
batch_size = 500
while True:
if workspace:
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 OFFSET $2 LIMIT $3"
rows = await db.query(
select_query, [workspace, offset, batch_size], multirows=True
)
else:
select_query = f"SELECT * FROM {legacy_table_name} OFFSET $1 LIMIT $2"
rows = await db.query(select_query, [offset, batch_size], multirows=True)
if not rows:
break
for row in rows:
row_dict = dict(row)
columns = list(row_dict.keys())
columns_str = ", ".join(columns)
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
insert_query = f"""
INSERT INTO {new_table_name} ({columns_str})
VALUES ({placeholders})
ON CONFLICT (workspace, id) DO NOTHING
"""
# Rebuild dict in columns order to ensure values() matches placeholders order
# Python 3.7+ dicts maintain insertion order, and execute() uses tuple(data.values())
values = {col: row_dict[col] for col in columns}
await db.execute(insert_query, values)
migrated_count += len(rows)
workspace_info = f" for workspace '{workspace}'" if workspace else ""
logger.info(
f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
)
offset += batch_size
return migrated_count
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
@ -2190,6 +2268,412 @@ class PGVectorStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
# Generate model suffix for table isolation
self.model_suffix = self._generate_collection_suffix()
# Get base table name
base_table = namespace_to_table_name(self.namespace)
if not base_table:
raise ValueError(f"Unknown namespace: {self.namespace}")
# New table name (with suffix)
# Ensure model_suffix is not empty before appending
if self.model_suffix:
self.table_name = f"{base_table}_{self.model_suffix}"
else:
# Fallback: use base table name if model_suffix is unavailable
self.table_name = base_table
logger.warning(
f"Model suffix unavailable, using base table name '{base_table}'. "
f"Ensure embedding_func has model_name for proper model isolation."
)
# Legacy table name (without suffix, for migration)
self.legacy_table_name = base_table
logger.debug(
f"PostgreSQL table naming: "
f"new='{self.table_name}', "
f"legacy='{self.legacy_table_name}', "
f"model_suffix='{self.model_suffix}'"
)
@staticmethod
async def setup_table(
db: PostgreSQLDB,
table_name: str,
legacy_table_name: str = None,
base_table: str = None,
embedding_dim: int = None,
workspace: str = None,
):
"""
Setup PostgreSQL table with migration support from legacy tables.
This method mirrors Qdrant's setup_collection approach to maintain consistency.
Args:
db: PostgreSQLDB instance
table_name: Name of the new table
legacy_table_name: Name of the legacy table (if exists)
base_table: Base table name for DDL template lookup
embedding_dim: Embedding dimension for vector column
"""
new_table_exists = await _pg_table_exists(db, table_name)
legacy_exists = legacy_table_name and await _pg_table_exists(
db, legacy_table_name
)
# Case 1: Both new and legacy tables exist
if new_table_exists and legacy_exists:
if table_name.lower() == legacy_table_name.lower():
logger.debug(
f"PostgreSQL: Table '{table_name}' already exists (no model suffix). Skipping Case 1 cleanup."
)
return
try:
workspace_info = f" for workspace '{workspace}'" if workspace else ""
if workspace:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
count_result = await db.query(count_query, [workspace])
else:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
count_result = await db.query(count_query, [])
workspace_count = count_result.get("count", 0) if count_result else 0
if workspace_count > 0:
logger.info(
f"PostgreSQL: Found {workspace_count} records in legacy table{workspace_info}. Migrating..."
)
legacy_dim = None
try:
dim_query = """
SELECT
CASE
WHEN typname = 'vector' THEN
COALESCE(atttypmod, -1)
ELSE -1
END as vector_dim
FROM pg_attribute a
JOIN pg_type t ON a.atttypid = t.oid
WHERE a.attrelid = $1::regclass
AND a.attname = 'content_vector'
"""
dim_result = await db.query(dim_query, [legacy_table_name])
legacy_dim = (
dim_result.get("vector_dim", -1) if dim_result else -1
)
if legacy_dim <= 0:
sample_query = f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
sample_result = await db.query(sample_query, [])
if sample_result and sample_result.get("content_vector"):
vector_data = sample_result["content_vector"]
if isinstance(vector_data, (list, tuple)):
legacy_dim = len(vector_data)
elif isinstance(vector_data, str):
import json
vector_list = json.loads(vector_data)
legacy_dim = len(vector_list)
if (
legacy_dim > 0
and embedding_dim
and legacy_dim != embedding_dim
):
logger.warning(
f"PostgreSQL: Dimension mismatch - "
f"legacy table has {legacy_dim}d vectors, "
f"new embedding model expects {embedding_dim}d. "
f"Skipping migration{workspace_info}."
)
await db._create_vector_index(table_name, embedding_dim)
return
except Exception as e:
logger.warning(
f"PostgreSQL: Could not verify vector dimension: {e}. Proceeding with caution..."
)
migrated_count = await _pg_migrate_workspace_data(
db,
legacy_table_name,
table_name,
workspace,
workspace_count,
embedding_dim,
)
if workspace:
new_count_query = f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
new_count_result = await db.query(new_count_query, [workspace])
else:
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
new_count_result = await db.query(new_count_query, [])
new_count = (
new_count_result.get("count", 0) if new_count_result else 0
)
if new_count < workspace_count:
logger.warning(
f"PostgreSQL: Expected {workspace_count} records, found {new_count}{workspace_info}. "
f"Some records may have been skipped due to conflicts."
)
else:
logger.info(
f"PostgreSQL: Migration completed: {migrated_count} records migrated{workspace_info}"
)
if workspace:
delete_query = (
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
)
await db.execute(delete_query, {"workspace": workspace})
logger.info(
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
)
total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
total_count_result = await db.query(total_count_query, [])
total_count = (
total_count_result.get("count", 0) if total_count_result else 0
)
if total_count == 0:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty. Deleting..."
)
drop_query = f"DROP TABLE {legacy_table_name}"
await db.execute(drop_query, None)
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
)
else:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved "
f"({total_count} records from other workspaces remain)"
)
except Exception as e:
logger.warning(
f"PostgreSQL: Error during Case 1 migration: {e}. Vector index will still be ensured."
)
await db._create_vector_index(table_name, embedding_dim)
return
# Case 2: Only new table exists - Already migrated or newly created
if new_table_exists:
logger.debug(f"PostgreSQL: Table '{table_name}' already exists")
# Ensure vector index exists with correct embedding dimension
await db._create_vector_index(table_name, embedding_dim)
return
# Case 3: Neither exists - Create new table
if not legacy_exists:
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
await _pg_create_table(db, table_name, base_table, embedding_dim)
logger.info(f"PostgreSQL: Table '{table_name}' created successfully")
# Create vector index with correct embedding dimension
await db._create_vector_index(table_name, embedding_dim)
return
# Case 4: Only legacy exists - Migrate data
logger.info(
f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}'"
)
try:
# Get legacy table count (with workspace filtering)
if workspace:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
count_result = await db.query(count_query, [workspace])
else:
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
count_result = await db.query(count_query, [])
logger.warning(
"PostgreSQL: Migration without workspace filter - this may copy data from all workspaces!"
)
legacy_count = count_result.get("count", 0) if count_result else 0
workspace_info = f" for workspace '{workspace}'" if workspace else ""
logger.info(
f"PostgreSQL: Found {legacy_count} records in legacy table{workspace_info}"
)
if legacy_count == 0:
logger.info("PostgreSQL: Legacy table is empty, skipping migration")
await _pg_create_table(db, table_name, base_table, embedding_dim)
# Create vector index with correct embedding dimension
await db._create_vector_index(table_name, embedding_dim)
return
# Check vector dimension compatibility before migration
legacy_dim = None
try:
# Try to get vector dimension from pg_attribute metadata
dim_query = """
SELECT
CASE
WHEN typname = 'vector' THEN
COALESCE(atttypmod, -1)
ELSE -1
END as vector_dim
FROM pg_attribute a
JOIN pg_type t ON a.atttypid = t.oid
WHERE a.attrelid = $1::regclass
AND a.attname = 'content_vector'
"""
dim_result = await db.query(dim_query, [legacy_table_name])
legacy_dim = dim_result.get("vector_dim", -1) if dim_result else -1
if legacy_dim <= 0:
# Alternative: Try to detect by sampling a vector
logger.info(
"PostgreSQL: Metadata dimension check failed, trying vector sampling..."
)
sample_query = (
f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
)
sample_result = await db.query(sample_query, [])
if sample_result and sample_result.get("content_vector"):
vector_data = sample_result["content_vector"]
# pgvector returns list directly
if isinstance(vector_data, (list, tuple)):
legacy_dim = len(vector_data)
elif isinstance(vector_data, str):
import json
vector_list = json.loads(vector_data)
legacy_dim = len(vector_list)
if legacy_dim > 0 and embedding_dim and legacy_dim != embedding_dim:
logger.warning(
f"PostgreSQL: Dimension mismatch detected! "
f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
f"but new embedding model expects {embedding_dim}d. "
f"Migration skipped to prevent data loss. "
f"Legacy table preserved as '{legacy_table_name}'. "
f"Creating new empty table '{table_name}' for new data."
)
# Create new table but skip migration
await _pg_create_table(db, table_name, base_table, embedding_dim)
await db._create_vector_index(table_name, embedding_dim)
logger.info(
f"PostgreSQL: New table '{table_name}' created. "
f"To query legacy data, please use a {legacy_dim}d embedding model."
)
return
except Exception as e:
logger.warning(
f"PostgreSQL: Could not verify legacy table vector dimension: {e}. "
f"Proceeding with caution..."
)
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
await _pg_create_table(db, table_name, base_table, embedding_dim)
migrated_count = await _pg_migrate_workspace_data(
db,
legacy_table_name,
table_name,
workspace,
legacy_count,
embedding_dim,
)
logger.info("PostgreSQL: Verifying migration...")
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
new_count_result = await db.query(new_count_query, [])
new_count = new_count_result.get("count", 0) if new_count_result else 0
if new_count != legacy_count:
error_msg = (
f"PostgreSQL: Migration verification failed, "
f"expected {legacy_count} records, got {new_count} in new table"
)
logger.error(error_msg)
raise PostgreSQLMigrationError(error_msg)
logger.info(
f"PostgreSQL: Migration completed successfully: {migrated_count} records migrated"
)
logger.info(
f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
)
await db._create_vector_index(table_name, embedding_dim)
try:
if workspace:
logger.info(
f"PostgreSQL: Deleting migrated workspace '{workspace}' data from legacy table '{legacy_table_name}'..."
)
delete_query = (
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
)
await db.execute(delete_query, {"workspace": workspace})
logger.info(
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
)
remaining_query = (
f"SELECT COUNT(*) as count FROM {legacy_table_name}"
)
remaining_result = await db.query(remaining_query, [])
remaining_count = (
remaining_result.get("count", 0) if remaining_result else 0
)
if remaining_count == 0:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty, deleting..."
)
drop_query = f"DROP TABLE {legacy_table_name}"
await db.execute(drop_query, None)
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
)
else:
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved ({remaining_count} records from other workspaces remain)"
)
else:
logger.warning(
f"PostgreSQL: No workspace specified, deleting entire legacy table '{legacy_table_name}'..."
)
drop_query = f"DROP TABLE {legacy_table_name}"
await db.execute(drop_query, None)
logger.info(
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted"
)
except Exception as delete_error:
# If cleanup fails, log warning but don't fail migration
logger.warning(
f"PostgreSQL: Failed to clean up legacy table '{legacy_table_name}': {delete_error}. "
"Migration succeeded, but manual cleanup may be needed."
)
except PostgreSQLMigrationError:
# Re-raise migration errors without wrapping
raise
except Exception as e:
error_msg = f"PostgreSQL: Migration failed with error: {e}"
logger.error(error_msg)
# Mirror Qdrant behavior: no automatic rollback
# Reason: partial data can be continued by re-running migration
raise PostgreSQLMigrationError(error_msg) from e
async def initialize(self):
async with get_data_init_lock():
if self.db is None:
@ -2206,6 +2690,16 @@ class PGVectorStorage(BaseVectorStorage):
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
# Setup table (create if not exists and handle migration)
await PGVectorStorage.setup_table(
self.db,
self.table_name,
legacy_table_name=self.legacy_table_name,
base_table=self.legacy_table_name, # base_table for DDL template lookup
embedding_dim=self.embedding_func.embedding_dim,
workspace=self.workspace, # CRITICAL: Filter migration by workspace
)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
@ -2215,7 +2709,9 @@ class PGVectorStorage(BaseVectorStorage):
self, item: dict[str, Any], current_time: datetime.datetime
) -> tuple[str, dict[str, Any]]:
try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
upsert_sql = SQL_TEMPLATES["upsert_chunk"].format(
table_name=self.table_name
)
data: dict[str, Any] = {
"workspace": self.workspace,
"id": item["__id__"],
@ -2239,7 +2735,7 @@ class PGVectorStorage(BaseVectorStorage):
def _upsert_entities(
self, item: dict[str, Any], current_time: datetime.datetime
) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_entity"]
upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name)
source_id = item["source_id"]
if isinstance(source_id, str) and "<SEP>" in source_id:
chunk_ids = source_id.split("<SEP>")
@ -2262,7 +2758,9 @@ class PGVectorStorage(BaseVectorStorage):
def _upsert_relationships(
self, item: dict[str, Any], current_time: datetime.datetime
) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
upsert_sql = SQL_TEMPLATES["upsert_relationship"].format(
table_name=self.table_name
)
source_id = item["source_id"]
if isinstance(source_id, str) and "<SEP>" in source_id:
chunk_ids = source_id.split("<SEP>")
@ -2335,7 +2833,9 @@ class PGVectorStorage(BaseVectorStorage):
embedding_string = ",".join(map(str, embedding))
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
sql = SQL_TEMPLATES[self.namespace].format(
embedding_string=embedding_string, table_name=self.table_name
)
params = {
"workspace": self.workspace,
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
@ -2357,14 +2857,9 @@ class PGVectorStorage(BaseVectorStorage):
if not ids:
return
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(
f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}"
)
return
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
delete_sql = (
f"DELETE FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)"
)
try:
await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
@ -2383,8 +2878,8 @@ class PGVectorStorage(BaseVectorStorage):
entity_name: The name of the entity to delete
"""
try:
# Construct SQL to delete the entity
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
# Construct SQL to delete the entity using dynamic table name
delete_sql = f"""DELETE FROM {self.table_name}
WHERE workspace=$1 AND entity_name=$2"""
await self.db.execute(
@ -2404,7 +2899,7 @@ class PGVectorStorage(BaseVectorStorage):
"""
try:
# Delete relations where the entity is either the source or target
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
delete_sql = f"""DELETE FROM {self.table_name}
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
await self.db.execute(
@ -2427,14 +2922,7 @@ class PGVectorStorage(BaseVectorStorage):
Returns:
The vector data if found, or None if not found
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(
f"[{self.workspace}] Unknown namespace for ID lookup: {self.namespace}"
)
return None
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2"
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id=$2"
params = {"workspace": self.workspace, "id": id}
try:
@ -2460,15 +2948,8 @@ class PGVectorStorage(BaseVectorStorage):
if not ids:
return []
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(
f"[{self.workspace}] Unknown namespace for IDs lookup: {self.namespace}"
)
return []
ids_str = ",".join([f"'{id}'" for id in ids])
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})"
params = {"workspace": self.workspace}
try:
@ -2509,15 +2990,8 @@ class PGVectorStorage(BaseVectorStorage):
if not ids:
return {}
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(
f"[{self.workspace}] Unknown namespace for vector lookup: {self.namespace}"
)
return {}
ids_str = ",".join([f"'{id}'" for id in ids])
query = f"SELECT id, content_vector FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
query = f"SELECT id, content_vector FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})"
params = {"workspace": self.workspace}
try:
@ -2546,15 +3020,8 @@ class PGVectorStorage(BaseVectorStorage):
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
table_name=self.table_name
)
await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"}
@ -2593,6 +3060,9 @@ class PGDocStatusStorage(DocStatusStorage):
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
# NOTE: Table creation is handled by PostgreSQLDB.initdb() during initialization
# No need to create table here as it's already created in the TABLES dict
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
@ -3188,6 +3658,12 @@ class PGDocStatusStorage(DocStatusStorage):
return {"status": "error", "message": str(e)}
class PostgreSQLMigrationError(Exception):
"""Exception for PostgreSQL table migration errors."""
pass
class PGGraphQueryException(Exception):
"""Exception for the AGE queries."""
@ -5047,7 +5523,7 @@ SQL_TEMPLATES = {
update_time = EXCLUDED.update_time
""",
# SQL for VectorStorage
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
"upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens,
chunk_order_index, full_doc_id, content, content_vector, file_path,
create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
@ -5060,7 +5536,7 @@ SQL_TEMPLATES = {
file_path=EXCLUDED.file_path,
update_time = EXCLUDED.update_time
""",
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
"upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content,
content_vector, chunk_ids, file_path, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
ON CONFLICT (workspace,id) DO UPDATE
@ -5071,7 +5547,7 @@ SQL_TEMPLATES = {
file_path=EXCLUDED.file_path,
update_time=EXCLUDED.update_time
""",
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
"upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id,
target_id, content, content_vector, chunk_ids, file_path, create_time, update_time)
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10)
ON CONFLICT (workspace,id) DO UPDATE
@ -5087,7 +5563,7 @@ SQL_TEMPLATES = {
SELECT r.source_id AS src_id,
r.target_id AS tgt_id,
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_RELATION r
FROM {table_name} r
WHERE r.workspace = $1
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
@ -5096,7 +5572,7 @@ SQL_TEMPLATES = {
"entities": """
SELECT e.entity_name,
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_ENTITY e
FROM {table_name} e
WHERE e.workspace = $1
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
@ -5107,7 +5583,7 @@ SQL_TEMPLATES = {
c.content,
c.file_path,
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
FROM LIGHTRAG_VDB_CHUNKS c
FROM {table_name} c
WHERE c.workspace = $1
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector

View file

@ -66,6 +66,48 @@ def workspace_filter_condition(workspace: str) -> models.FieldCondition:
)
def _find_legacy_collection(
client: QdrantClient, namespace: str, workspace: str = None
) -> str | None:
"""
Find legacy collection with backward compatibility support.
This function tries multiple naming patterns to locate legacy collections
created by older versions of LightRAG:
1. {workspace}_{namespace} - Old format with workspace (pre-model-isolation) - HIGHEST PRIORITY
2. lightrag_vdb_{namespace} - Current legacy format
3. {namespace} - Old format without workspace (pre-model-isolation)
Args:
client: QdrantClient instance
namespace: Base namespace (e.g., "chunks", "entities")
workspace: Optional workspace identifier
Returns:
Collection name if found, None otherwise
"""
# Try multiple naming patterns for backward compatibility
# More specific names (with workspace) have higher priority
candidates = [
f"{workspace}_{namespace}"
if workspace
else None, # Old format with workspace - most specific
f"lightrag_vdb_{namespace}", # New legacy format
namespace, # Old format without workspace - most generic
]
for candidate in candidates:
if candidate and client.collection_exists(candidate):
logger.info(
f"Qdrant: Found legacy collection '{candidate}' "
f"(namespace={namespace}, workspace={workspace or 'none'})"
)
return candidate
return None
@final
@dataclass
class QdrantVectorDBStorage(BaseVectorStorage):
@ -85,28 +127,73 @@ class QdrantVectorDBStorage(BaseVectorStorage):
def setup_collection(
client: QdrantClient,
collection_name: str,
legacy_namespace: str = None,
namespace: str = None,
workspace: str = None,
**kwargs,
):
"""
Setup Qdrant collection with migration support from legacy collections.
This method now supports backward compatibility by automatically detecting
legacy collections created by older versions of LightRAG using multiple
naming patterns.
Args:
client: QdrantClient instance
collection_name: Name of the new collection
legacy_namespace: Name of the legacy collection (if exists)
namespace: Base namespace (e.g., "chunks", "entities")
workspace: Workspace identifier for data isolation
**kwargs: Additional arguments for collection creation (vectors_config, hnsw_config, etc.)
"""
new_collection_exists = client.collection_exists(collection_name)
legacy_exists = legacy_namespace and client.collection_exists(legacy_namespace)
# Case 1: Both new and legacy collections exist - Warning only (no migration)
# Try to find legacy collection with backward compatibility
legacy_collection = (
_find_legacy_collection(client, namespace, workspace) if namespace else None
)
legacy_exists = legacy_collection is not None
# Case 1: Both new and legacy collections exist
# This can happen if:
# 1. Previous migration failed to delete the legacy collection
# 2. User manually created both collections
# 3. No model suffix (collection_name == legacy_collection)
# Strategy: Only delete legacy if it's empty (safe cleanup) and it's not the same as new collection
if new_collection_exists and legacy_exists:
logger.warning(
f"Qdrant: Legacy collection '{legacy_namespace}' still exist. Remove it if migration is complete."
)
# CRITICAL: Check if new and legacy are the same collection
# This happens when model_suffix is empty (no model_name provided)
if collection_name == legacy_collection:
logger.debug(
f"Qdrant: Collection '{collection_name}' already exists (no model suffix). Skipping Case 1 cleanup."
)
return
try:
# Check if legacy collection is empty
legacy_count = client.count(
collection_name=legacy_collection, exact=True
).count
if legacy_count == 0:
# Legacy collection is empty, safe to delete without data loss
logger.info(
f"Qdrant: Legacy collection '{legacy_collection}' is empty. Deleting..."
)
client.delete_collection(collection_name=legacy_collection)
logger.info(
f"Qdrant: Legacy collection '{legacy_collection}' deleted successfully"
)
else:
# Legacy collection still has data - don't risk deleting it
logger.warning(
f"Qdrant: Legacy collection '{legacy_collection}' still contains {legacy_count} records. "
f"Manual intervention required to verify and delete."
)
except Exception as e:
logger.warning(
f"Qdrant: Could not check or cleanup legacy collection '{legacy_collection}': {e}. "
"You may need to delete it manually."
)
return
# Case 2: Only new collection exists - Ensure index exists
@ -149,13 +236,13 @@ class QdrantVectorDBStorage(BaseVectorStorage):
# Case 4: Only legacy exists - Migrate data
logger.info(
f"Qdrant: Migrating data from legacy collection '{legacy_namespace}'"
f"Qdrant: Migrating data from legacy collection '{legacy_collection}'"
)
try:
# Get legacy collection count
legacy_count = client.count(
collection_name=legacy_namespace, exact=True
collection_name=legacy_collection, exact=True
).count
logger.info(f"Qdrant: Found {legacy_count} records in legacy collection")
@ -173,6 +260,51 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
return
# Check vector dimension compatibility before migration
try:
legacy_info = client.get_collection(legacy_collection)
legacy_dim = legacy_info.config.params.vectors.size
# Get expected dimension from kwargs
new_dim = (
kwargs.get("vectors_config").size
if "vectors_config" in kwargs
else None
)
if new_dim and legacy_dim != new_dim:
logger.warning(
f"Qdrant: Dimension mismatch detected! "
f"Legacy collection '{legacy_collection}' has {legacy_dim}d vectors, "
f"but new embedding model expects {new_dim}d. "
f"Migration skipped to prevent data loss. "
f"Legacy collection preserved as '{legacy_collection}'. "
f"Creating new empty collection '{collection_name}' for new data."
)
# Create new collection but skip migration
client.create_collection(collection_name, **kwargs)
client.create_payload_index(
collection_name=collection_name,
field_name=WORKSPACE_ID_FIELD,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
),
)
logger.info(
f"Qdrant: New collection '{collection_name}' created. "
f"To query legacy data, please use a {legacy_dim}d embedding model."
)
return
except Exception as e:
logger.warning(
f"Qdrant: Could not verify legacy collection dimension: {e}. "
f"Proceeding with caution..."
)
# Create new collection first
logger.info(f"Qdrant: Creating new collection '{collection_name}'")
client.create_collection(collection_name, **kwargs)
@ -185,7 +317,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
while True:
# Scroll through legacy data
result = client.scroll(
collection_name=legacy_namespace,
collection_name=legacy_collection,
limit=batch_size,
offset=offset,
with_vectors=True,
@ -258,9 +390,27 @@ class QdrantVectorDBStorage(BaseVectorStorage):
),
)
logger.info(
f"Qdrant: Migration from '{legacy_namespace}' to '{collection_name}' completed successfully"
f"Qdrant: Migration from '{legacy_collection}' to '{collection_name}' completed successfully"
)
# Delete legacy collection after successful migration
# Data has been verified to match, so legacy collection is no longer needed
# and keeping it would cause Case 1 warnings on next startup
try:
logger.info(
f"Qdrant: Deleting legacy collection '{legacy_collection}'..."
)
client.delete_collection(collection_name=legacy_collection)
logger.info(
f"Qdrant: Legacy collection '{legacy_collection}' deleted successfully"
)
except Exception as delete_error:
# If deletion fails, user will see Case 1 warning on next startup
logger.warning(
f"Qdrant: Failed to delete legacy collection '{legacy_collection}': {delete_error}. "
"You may need to delete it manually."
)
except QdrantMigrationError:
# Re-raise migration errors without wrapping
raise
@ -287,19 +437,34 @@ class QdrantVectorDBStorage(BaseVectorStorage):
f"Using passed workspace parameter: '{effective_workspace}'"
)
# Get legacy namespace for data migration from old version
if effective_workspace:
self.legacy_namespace = f"{effective_workspace}_{self.namespace}"
else:
self.legacy_namespace = self.namespace
self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE
# Use a shared collection with payload-based partitioning (Qdrant's recommended approach)
# Ref: https://qdrant.tech/documentation/guides/multiple-partitions/
self.final_namespace = f"lightrag_vdb_{self.namespace}"
logger.debug(
f"Using shared collection '{self.final_namespace}' with workspace '{self.effective_workspace}' for payload-based partitioning"
# Generate model suffix
model_suffix = self._generate_collection_suffix()
# Legacy collection name (without model suffix, for migration)
# This matches the old naming scheme before model isolation was implemented
# Example: "lightrag_vdb_chunks" (without model suffix)
self.legacy_namespace = f"lightrag_vdb_{self.namespace}"
# New naming scheme with model isolation
# Example: "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
# Ensure model_suffix is not empty before appending
if model_suffix:
self.final_namespace = f"lightrag_vdb_{self.namespace}_{model_suffix}"
else:
# Fallback: use legacy namespace if model_suffix is unavailable
self.final_namespace = self.legacy_namespace
logger.warning(
f"Model suffix unavailable, using legacy collection name '{self.legacy_namespace}'. "
f"Ensure embedding_func has model_name for proper model isolation."
)
logger.info(
f"Qdrant collection naming: "
f"new='{self.final_namespace}', "
f"legacy='{self.legacy_namespace}', "
f"model_suffix='{model_suffix}'"
)
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
@ -315,6 +480,12 @@ class QdrantVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
self._initialized = False
def _get_legacy_collection_name(self) -> str:
return self.legacy_namespace
def _get_new_collection_name(self) -> str:
return self.final_namespace
async def initialize(self):
"""Initialize Qdrant collection"""
async with get_data_init_lock():
@ -338,11 +509,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
# Setup collection (create if not exists and configure indexes)
# Pass legacy_namespace and workspace for migration support
# Pass namespace and workspace for backward-compatible migration support
QdrantVectorDBStorage.setup_collection(
self._client,
self.final_namespace,
legacy_namespace=self.legacy_namespace,
namespace=self.namespace,
workspace=self.effective_workspace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
@ -354,6 +525,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
),
)
# Initialize max batch size from config
self._max_batch_size = self.global_config["embedding_batch_num"]
self._initialized = True
logger.info(
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"

View file

@ -164,16 +164,29 @@ class UnifiedLock(Generic[T]):
)
# Then acquire the main lock
if self._is_async:
await self._lock.acquire()
else:
self._lock.acquire()
if self._lock is not None:
if self._is_async:
await self._lock.acquire()
else:
self._lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})",
level="INFO",
enable_output=self._enable_logging,
)
direct_log(
f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})",
level="INFO",
enable_output=self._enable_logging,
)
else:
# CRITICAL: Raise exception instead of allowing unprotected execution
error_msg = (
f"CRITICAL: Lock '{self._name}' is None - shared data not initialized. "
f"Call initialize_share_data() before using locks!"
)
direct_log(
f"== Lock == Process {self._pid}: {error_msg}",
level="ERROR",
enable_output=True,
)
raise RuntimeError(error_msg)
return self
except Exception as e:
# If main lock acquisition fails, release the async lock if it was acquired
@ -193,19 +206,21 @@ class UnifiedLock(Generic[T]):
async def __aexit__(self, exc_type, exc_val, exc_tb):
main_lock_released = False
async_lock_released = False
try:
# Release main lock first
if self._is_async:
self._lock.release()
else:
self._lock.release()
main_lock_released = True
if self._lock is not None:
if self._is_async:
self._lock.release()
else:
self._lock.release()
direct_log(
f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
level="INFO",
enable_output=self._enable_logging,
)
direct_log(
f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
level="INFO",
enable_output=self._enable_logging,
)
main_lock_released = True
# Then release async lock if in multiprocess mode
if not self._is_async and self._async_lock is not None:
@ -215,6 +230,7 @@ class UnifiedLock(Generic[T]):
level="DEBUG",
enable_output=self._enable_logging,
)
async_lock_released = True
except Exception as e:
direct_log(
@ -223,9 +239,10 @@ class UnifiedLock(Generic[T]):
enable_output=True,
)
# If main lock release failed but async lock hasn't been released, try to release it
# If main lock release failed but async lock hasn't been attempted yet, try to release it
if (
not main_lock_released
and not async_lock_released
and not self._is_async
and self._async_lock is not None
):

View file

@ -518,14 +518,10 @@ class LightRAG:
f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
)
# Fix global_config now
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
# Step 1: Capture embedding_func and max_token_size before applying decorator
# (decorator strips dataclass attributes, and asdict() converts EmbeddingFunc to dict)
original_embedding_func = self.embedding_func
embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size
@ -534,6 +530,14 @@ class LightRAG:
)
self.embedding_token_limit = embedding_max_token_size
# Fix global_config now
global_config = asdict(self)
# Restore original EmbeddingFunc object (asdict converts it to dict)
global_config["embedding_func"] = original_embedding_func
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Step 2: Apply priority wrapper decorator
self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async,

View file

@ -425,6 +425,19 @@ class EmbeddingFunc:
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
model_name: str | None = None
def get_model_identifier(self) -> str:
"""Generates model identifier for collection/table suffix.
Returns:
str: Format "{model_name}_{dim}d", e.g. "text_embedding_3_large_3072d"
If model_name is not specified, returns "unknown_{dim}d"
"""
model_part = self.model_name if self.model_name else "unknown"
# Clean model name: remove special chars, convert to lower, replace - with _
safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_part.lower())
return f"{safe_model_name}_{self.embedding_dim}d"
async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True

View file

@ -0,0 +1,55 @@
import pytest
from lightrag.base import BaseVectorStorage
from lightrag.utils import EmbeddingFunc
def test_base_vector_storage_integrity():
# Just checking if we can import and inspect the class
assert hasattr(BaseVectorStorage, "_generate_collection_suffix")
assert hasattr(BaseVectorStorage, "_get_legacy_collection_name")
assert hasattr(BaseVectorStorage, "_get_new_collection_name")
# Verify methods raise NotImplementedError
class ConcreteStorage(BaseVectorStorage):
async def query(self, *args, **kwargs):
pass
async def upsert(self, *args, **kwargs):
pass
async def delete_entity(self, *args, **kwargs):
pass
async def delete_entity_relation(self, *args, **kwargs):
pass
async def get_by_id(self, *args, **kwargs):
pass
async def get_by_ids(self, *args, **kwargs):
pass
async def delete(self, *args, **kwargs):
pass
async def get_vectors_by_ids(self, *args, **kwargs):
pass
async def index_done_callback(self):
pass
async def drop(self):
pass
func = EmbeddingFunc(embedding_dim=128, func=lambda x: x)
storage = ConcreteStorage(
namespace="test", workspace="test", global_config={}, embedding_func=func
)
assert storage._generate_collection_suffix() == "unknown_128d"
with pytest.raises(NotImplementedError):
storage._get_legacy_collection_name()
with pytest.raises(NotImplementedError):
storage._get_new_collection_name()

View file

@ -0,0 +1,316 @@
"""
Tests for dimension mismatch handling during migration.
This test module verifies that both PostgreSQL and Qdrant storage backends
properly detect and handle vector dimension mismatches when migrating from
legacy collections/tables to new ones with different embedding models.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
from lightrag.kg.postgres_impl import PGVectorStorage
# Note: Tests should use proper table names that have DDL templates
# Valid base tables: LIGHTRAG_VDB_CHUNKS, LIGHTRAG_VDB_ENTITIES, LIGHTRAG_VDB_RELATIONSHIPS,
# LIGHTRAG_DOC_CHUNKS, LIGHTRAG_DOC_FULL_DOCS, LIGHTRAG_DOC_TEXT_CHUNKS
class TestQdrantDimensionMismatch:
"""Test suite for Qdrant dimension mismatch handling."""
def test_qdrant_dimension_mismatch_skip_migration(self):
"""
Test that Qdrant skips migration when dimensions don't match.
Scenario: Legacy collection has 1536d vectors, new model expects 3072d.
Expected: Migration skipped, new empty collection created, legacy preserved.
"""
from qdrant_client import models
# Setup mock client
client = MagicMock()
# Mock legacy collection with 1536d vectors
legacy_collection_info = MagicMock()
legacy_collection_info.config.params.vectors.size = 1536
# Setup collection existence checks
def collection_exists_side_effect(name):
if name == "lightrag_chunks": # legacy
return True
elif name == "lightrag_chunks_model_3072d": # new
return False
return False
client.collection_exists.side_effect = collection_exists_side_effect
client.get_collection.return_value = legacy_collection_info
client.count.return_value.count = 100 # Legacy has data
# Call setup_collection with 3072d (different from legacy 1536d)
QdrantVectorDBStorage.setup_collection(
client,
"lightrag_chunks_model_3072d",
namespace="chunks",
workspace="test",
vectors_config=models.VectorParams(
size=3072, distance=models.Distance.COSINE
),
)
# Verify new collection was created
client.create_collection.assert_called_once()
# Verify migration was NOT attempted (no scroll/upsert calls)
client.scroll.assert_not_called()
client.upsert.assert_not_called()
def test_qdrant_dimension_match_proceed_migration(self):
"""
Test that Qdrant proceeds with migration when dimensions match.
Scenario: Legacy collection has 1536d vectors, new model also expects 1536d.
Expected: Migration proceeds normally.
"""
from qdrant_client import models
client = MagicMock()
# Mock legacy collection with 1536d vectors (matching new)
legacy_collection_info = MagicMock()
legacy_collection_info.config.params.vectors.size = 1536
def collection_exists_side_effect(name):
if name == "lightrag_chunks": # legacy
return True
elif name == "lightrag_chunks_model_1536d": # new
return False
return False
client.collection_exists.side_effect = collection_exists_side_effect
client.get_collection.return_value = legacy_collection_info
client.count.return_value.count = 100 # Legacy has data
# Mock scroll to return sample data
sample_point = MagicMock()
sample_point.id = "test_id"
sample_point.vector = [0.1] * 1536
sample_point.payload = {"id": "test"}
client.scroll.return_value = ([sample_point], None)
# Mock _find_legacy_collection to return the legacy collection name
with patch(
"lightrag.kg.qdrant_impl._find_legacy_collection",
return_value="lightrag_chunks",
):
# Call setup_collection with matching 1536d
QdrantVectorDBStorage.setup_collection(
client,
"lightrag_chunks_model_1536d",
namespace="chunks",
workspace="test",
vectors_config=models.VectorParams(
size=1536, distance=models.Distance.COSINE
),
)
# Verify migration WAS attempted
client.create_collection.assert_called_once()
client.scroll.assert_called()
client.upsert.assert_called()
class TestPostgresDimensionMismatch:
"""Test suite for PostgreSQL dimension mismatch handling."""
@pytest.mark.asyncio
async def test_postgres_dimension_mismatch_skip_migration_metadata(self):
"""
Test that PostgreSQL skips migration when dimensions don't match (via metadata).
Scenario: Legacy table has 1536d vectors (detected via pg_attribute),
new model expects 3072d.
Expected: Migration skipped, new empty table created, legacy preserved.
"""
# Setup mock database
db = AsyncMock()
# Mock table existence and dimension checks
async def query_side_effect(query, params, **kwargs):
if "information_schema.tables" in query:
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
return {"exists": True}
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
return {"exists": False}
elif "COUNT(*)" in query:
return {"count": 100} # Legacy has data
elif "pg_attribute" in query:
return {"vector_dim": 1536} # Legacy has 1536d vectors
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
# Call setup_table with 3072d (different from legacy 1536d)
await PGVectorStorage.setup_table(
db,
"LIGHTRAG_DOC_CHUNKS_model_3072d",
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
base_table="LIGHTRAG_DOC_CHUNKS",
embedding_dim=3072,
workspace="test",
)
# Verify migration was NOT attempted (no INSERT calls)
# Note: _pg_create_table is mocked, so we check INSERT calls to verify migration was skipped
insert_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "INSERT INTO" in call[0][0]
]
assert (
len(insert_calls) == 0
), "Migration should be skipped due to dimension mismatch"
@pytest.mark.asyncio
async def test_postgres_dimension_mismatch_skip_migration_sampling(self):
"""
Test that PostgreSQL skips migration when dimensions don't match (via sampling).
Scenario: Legacy table dimension detection fails via metadata,
falls back to vector sampling, detects 1536d vs expected 3072d.
Expected: Migration skipped, new empty table created, legacy preserved.
"""
db = AsyncMock()
# Mock table existence and dimension checks
async def query_side_effect(query, params, **kwargs):
if "information_schema.tables" in query:
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
return {"exists": True}
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
return {"exists": False}
elif "COUNT(*)" in query:
return {"count": 100} # Legacy has data
elif "pg_attribute" in query:
return {"vector_dim": -1} # Metadata check fails
elif "SELECT content_vector FROM" in query:
# Return sample vector with 1536 dimensions
return {"content_vector": [0.1] * 1536}
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
# Call setup_table with 3072d (different from legacy 1536d)
await PGVectorStorage.setup_table(
db,
"LIGHTRAG_DOC_CHUNKS_model_3072d",
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
base_table="LIGHTRAG_DOC_CHUNKS",
embedding_dim=3072,
workspace="test",
)
# Verify new table was created
create_table_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "CREATE TABLE" in call[0][0]
]
assert len(create_table_calls) > 0, "New table should be created"
# Verify migration was NOT attempted
insert_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "INSERT INTO" in call[0][0]
]
assert len(insert_calls) == 0, "Migration should be skipped"
@pytest.mark.asyncio
async def test_postgres_dimension_match_proceed_migration(self):
"""
Test that PostgreSQL proceeds with migration when dimensions match.
Scenario: Legacy table has 1536d vectors, new model also expects 1536d.
Expected: Migration proceeds normally.
"""
db = AsyncMock()
async def query_side_effect(query, params, **kwargs):
multirows = kwargs.get("multirows", False)
if "information_schema.tables" in query:
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
return {"exists": True}
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new
return {"exists": False}
elif "COUNT(*)" in query:
return {"count": 100} # Legacy has data
elif "pg_attribute" in query:
return {"vector_dim": 1536} # Legacy has matching 1536d
elif "SELECT * FROM" in query and multirows:
# Return sample data for migration (first batch)
# Handle workspace filtering: params = [workspace, offset, limit]
if "WHERE workspace" in query:
offset = params[1] if len(params) > 1 else 0
else:
offset = params[0] if params else 0
if offset == 0: # First batch
return [
{
"id": "test1",
"content_vector": [0.1] * 1536,
"workspace": "test",
},
{
"id": "test2",
"content_vector": [0.2] * 1536,
"workspace": "test",
},
]
else: # offset > 0
return [] # No more data
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
# Mock _pg_table_exists
async def mock_table_exists(db_inst, name):
if name == "LIGHTRAG_DOC_CHUNKS": # legacy exists
return True
elif name == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new doesn't exist
return False
return False
with patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=mock_table_exists,
):
# Call setup_table with matching 1536d
await PGVectorStorage.setup_table(
db,
"LIGHTRAG_DOC_CHUNKS_model_1536d",
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
base_table="LIGHTRAG_DOC_CHUNKS",
embedding_dim=1536,
workspace="test",
)
# Verify migration WAS attempted (INSERT calls made)
insert_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "INSERT INTO" in call[0][0]
]
assert (
len(insert_calls) > 0
), "Migration should proceed with matching dimensions"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,31 @@
from lightrag.utils import EmbeddingFunc
def dummy_func(*args, **kwargs):
pass
def test_embedding_func_with_model_name():
func = EmbeddingFunc(
embedding_dim=1536, func=dummy_func, model_name="text-embedding-ada-002"
)
assert func.get_model_identifier() == "text_embedding_ada_002_1536d"
def test_embedding_func_without_model_name():
func = EmbeddingFunc(embedding_dim=768, func=dummy_func)
assert func.get_model_identifier() == "unknown_768d"
def test_model_name_sanitization():
func = EmbeddingFunc(
embedding_dim=1024,
func=dummy_func,
model_name="models/text-embedding-004", # Contains special chars
)
assert func.get_model_identifier() == "models_text_embedding_004_1024d"
def test_model_name_with_uppercase():
func = EmbeddingFunc(embedding_dim=512, func=dummy_func, model_name="My-Model-V1")
assert func.get_model_identifier() == "my_model_v1_512d"

View file

@ -0,0 +1,213 @@
"""
Tests for safety when model suffix is absent (no model_name provided).
This test module verifies that the system correctly handles the case when
no model_name is provided, preventing accidental deletion of the only table/collection
on restart.
Critical Bug: When model_suffix is empty, table_name == legacy_table_name.
On second startup, Case 1 logic would delete the only table/collection thinking
it's "legacy", causing all subsequent operations to fail.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
from lightrag.kg.postgres_impl import PGVectorStorage
class TestNoModelSuffixSafety:
"""Test suite for preventing data loss when model_suffix is absent."""
def test_qdrant_no_suffix_second_startup(self):
"""
Test Qdrant doesn't delete collection on second startup when no model_name.
Scenario:
1. First startup: Creates collection without suffix
2. Collection is empty
3. Second startup: Should NOT delete the collection
Bug: Without fix, Case 1 would delete the only collection.
"""
from qdrant_client import models
client = MagicMock()
# Simulate second startup: collection already exists and is empty
# IMPORTANT: Without suffix, collection_name == legacy collection name
collection_name = "lightrag_vdb_chunks" # No suffix, same as legacy
# Both exist (they're the same collection)
client.collection_exists.return_value = True
# Collection is empty
client.count.return_value.count = 0
# Call setup_collection
# This should detect that new == legacy and skip deletion
QdrantVectorDBStorage.setup_collection(
client,
collection_name,
namespace="chunks",
workspace=None,
vectors_config=models.VectorParams(
size=1536, distance=models.Distance.COSINE
),
)
# CRITICAL: Collection should NOT be deleted
client.delete_collection.assert_not_called()
# Verify we returned early (skipped Case 1 cleanup)
# The collection_exists was checked, but we didn't proceed to count
# because we detected same name
assert client.collection_exists.call_count >= 1
@pytest.mark.asyncio
async def test_postgres_no_suffix_second_startup(self):
"""
Test PostgreSQL doesn't delete table on second startup when no model_name.
Scenario:
1. First startup: Creates table without suffix
2. Table is empty
3. Second startup: Should NOT delete the table
Bug: Without fix, Case 1 would delete the only table.
"""
db = AsyncMock()
# Simulate second startup: table already exists and is empty
# IMPORTANT: table_name and legacy_table_name are THE SAME
table_name = "LIGHTRAG_VDB_CHUNKS" # No suffix
legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Same as new
# Setup mock responses
async def table_exists_side_effect(db_instance, name):
# Both tables exist (they're the same)
return True
# Mock _pg_table_exists function
with patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=table_exists_side_effect,
):
# Call setup_table
# This should detect that new == legacy and skip deletion
await PGVectorStorage.setup_table(
db,
table_name,
legacy_table_name=legacy_table_name,
base_table="LIGHTRAG_VDB_CHUNKS",
embedding_dim=1536,
)
# CRITICAL: Table should NOT be deleted (no DROP TABLE)
drop_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
assert (
len(drop_calls) == 0
), "Should not drop table when new and legacy are the same"
# Also should not try to count (we returned early)
count_calls = [
call
for call in db.query.call_args_list
if call[0][0] and "COUNT(*)" in call[0][0]
]
assert (
len(count_calls) == 0
), "Should not check count when new and legacy are the same"
def test_qdrant_with_suffix_case1_still_works(self):
"""
Test that Case 1 cleanup still works when there IS a suffix.
This ensures our fix doesn't break the normal Case 1 scenario.
"""
from qdrant_client import models
client = MagicMock()
# Different names (normal case)
collection_name = "lightrag_vdb_chunks_ada_002_1536d" # With suffix
legacy_collection = "lightrag_vdb_chunks" # Without suffix
# Setup: both exist
def collection_exists_side_effect(name):
return name in [collection_name, legacy_collection]
client.collection_exists.side_effect = collection_exists_side_effect
# Legacy is empty
client.count.return_value.count = 0
# Call setup_collection
QdrantVectorDBStorage.setup_collection(
client,
collection_name,
namespace="chunks",
workspace=None,
vectors_config=models.VectorParams(
size=1536, distance=models.Distance.COSINE
),
)
# SHOULD delete legacy (normal Case 1 behavior)
client.delete_collection.assert_called_once_with(
collection_name=legacy_collection
)
@pytest.mark.asyncio
async def test_postgres_with_suffix_case1_still_works(self):
"""
Test that Case 1 cleanup still works when there IS a suffix.
This ensures our fix doesn't break the normal Case 1 scenario.
"""
db = AsyncMock()
# Different names (normal case)
table_name = "LIGHTRAG_VDB_CHUNKS_ADA_002_1536D" # With suffix
legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Without suffix
# Setup mock responses
async def table_exists_side_effect(db_instance, name):
# Both tables exist
return True
# Mock empty table
async def query_side_effect(sql, params, **kwargs):
if "COUNT(*)" in sql:
return {"count": 0}
return {}
db.query.side_effect = query_side_effect
# Mock _pg_table_exists function
with patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=table_exists_side_effect,
):
# Call setup_table
await PGVectorStorage.setup_table(
db,
table_name,
legacy_table_name=legacy_table_name,
base_table="LIGHTRAG_VDB_CHUNKS",
embedding_dim=1536,
)
# SHOULD delete legacy (normal Case 1 behavior)
drop_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
assert len(drop_calls) == 1, "Should drop legacy table in normal Case 1"
assert legacy_table_name in drop_calls[0][0][0]

View file

@ -0,0 +1,805 @@
import pytest
from unittest.mock import patch, AsyncMock
import numpy as np
from lightrag.utils import EmbeddingFunc
from lightrag.kg.postgres_impl import (
PGVectorStorage,
)
from lightrag.namespace import NameSpace
# Mock PostgreSQLDB
@pytest.fixture
def mock_pg_db():
"""Mock PostgreSQL database connection"""
db = AsyncMock()
db.workspace = "test_workspace"
# Mock query responses with multirows support
async def mock_query(sql, params=None, multirows=False, **kwargs):
# Default return value
if multirows:
return [] # Return empty list for multirows
return {"exists": False, "count": 0}
# Mock for execute that mimics PostgreSQLDB.execute() behavior
async def mock_execute(sql, data=None, **kwargs):
"""
Mock that mimics PostgreSQLDB.execute() behavior:
- Accepts data as dict[str, Any] | None (second parameter)
- Internally converts dict.values() to tuple for AsyncPG
"""
# Mimic real execute() which accepts dict and converts to tuple
if data is not None and not isinstance(data, dict):
raise TypeError(
f"PostgreSQLDB.execute() expects data as dict, got {type(data).__name__}"
)
return None
db.query = AsyncMock(side_effect=mock_query)
db.execute = AsyncMock(side_effect=mock_execute)
return db
# Mock get_data_init_lock to avoid async lock issues in tests
@pytest.fixture(autouse=True)
def mock_data_init_lock():
with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock:
mock_lock_ctx = AsyncMock()
mock_lock.return_value = mock_lock_ctx
yield mock_lock
# Mock ClientManager
@pytest.fixture
def mock_client_manager(mock_pg_db):
with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager:
mock_manager.get_client = AsyncMock(return_value=mock_pg_db)
mock_manager.release_client = AsyncMock()
yield mock_manager
# Mock Embedding function
@pytest.fixture
def mock_embedding_func():
async def embed_func(texts, **kwargs):
return np.array([[0.1] * 768 for _ in texts])
func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model")
return func
@pytest.mark.asyncio
async def test_postgres_table_naming(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""Test if table name is correctly generated with model suffix"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Verify table name contains model suffix
expected_suffix = "test_model_768d"
assert expected_suffix in storage.table_name
assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}"
# Verify legacy table name
assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS"
@pytest.mark.asyncio
async def test_postgres_migration_trigger(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""Test if migration logic is triggered correctly"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Setup mocks for migration scenario
# 1. New table does not exist, legacy table exists
async def mock_table_exists(db, table_name):
return table_name == storage.legacy_table_name
# 2. Legacy table has 100 records
mock_rows = [
{"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"}
for i in range(100)
]
async def mock_query(sql, params=None, multirows=False, **kwargs):
if "COUNT(*)" in sql:
return {"count": 100}
elif multirows and "SELECT *" in sql:
# Mock batch fetch for migration
# Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit]
if "WHERE workspace" in sql:
# With workspace filter: params[0]=workspace, params[1]=offset, params[2]=limit
offset = params[1] if len(params) > 1 else 0
limit = params[2] if len(params) > 2 else 500
else:
# No workspace filter: params[0]=offset, params[1]=limit
offset = params[0] if params else 0
limit = params[1] if len(params) > 1 else 500
start = offset
end = min(offset + limit, len(mock_rows))
return mock_rows[start:end]
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query)
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
),
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()),
):
# Initialize storage (should trigger migration)
await storage.initialize()
# Verify migration was executed
# Check that execute was called for inserting rows
assert mock_pg_db.execute.call_count > 0
@pytest.mark.asyncio
async def test_postgres_no_migration_needed(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""Test scenario where new table already exists (no migration needed)"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Mock: new table already exists
async def mock_table_exists(db, table_name):
return table_name == storage.table_name
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
),
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create,
):
await storage.initialize()
# Verify no table creation was attempted
mock_create.assert_not_called()
@pytest.mark.asyncio
async def test_scenario_1_new_workspace_creation(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Scenario 1: New workspace creation
Expected behavior:
- No legacy table exists
- Directly create new table with model suffix
- No migration needed
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=3072,
func=mock_embedding_func.func,
model_name="text-embedding-3-large",
)
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="new_workspace",
)
# Mock: neither table exists
async def mock_table_exists(db, table_name):
return False
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
),
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create,
):
await storage.initialize()
# Verify table name format
assert "text_embedding_3_large_3072d" in storage.table_name
# Verify new table creation was called
mock_create.assert_called_once()
call_args = mock_create.call_args
assert (
call_args[0][1] == storage.table_name
) # table_name is second positional arg
@pytest.mark.asyncio
async def test_scenario_2_legacy_upgrade_migration(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Scenario 2: Upgrade from legacy version
Expected behavior:
- Legacy table exists (without model suffix)
- New table doesn't exist
- Automatically migrate data to new table with suffix
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="text-embedding-ada-002",
)
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="legacy_workspace",
)
# Mock: only legacy table exists
async def mock_table_exists(db, table_name):
return table_name == storage.legacy_table_name
# Mock: legacy table has 50 records
mock_rows = [
{
"id": f"legacy_id_{i}",
"content": f"legacy_content_{i}",
"workspace": "legacy_workspace",
}
for i in range(50)
]
# Track which queries have been made for proper response
query_history = []
async def mock_query(sql, params=None, multirows=False, **kwargs):
query_history.append(sql)
if "COUNT(*)" in sql:
# Determine table type:
# - Legacy: contains base name but NOT model suffix
# - New: contains model suffix (e.g., text_embedding_ada_002_1536d)
sql_upper = sql.upper()
base_name = storage.legacy_table_name.upper()
# Check if this is querying the new table (has model suffix)
has_model_suffix = any(
suffix in sql_upper
for suffix in ["TEXT_EMBEDDING", "_1536D", "_768D", "_1024D", "_3072D"]
)
is_legacy_table = base_name in sql_upper and not has_model_suffix
is_new_table = has_model_suffix
has_workspace_filter = "WHERE workspace" in sql
if is_legacy_table and has_workspace_filter:
# Count for legacy table with workspace filter (before migration)
return {"count": 50}
elif is_legacy_table and not has_workspace_filter:
# Total count for legacy table (after deletion, checking remaining)
return {"count": 0}
elif is_new_table:
# Count for new table (verification after migration)
return {"count": 50}
else:
# Fallback
return {"count": 0}
elif multirows and "SELECT *" in sql:
# Mock batch fetch for migration
# Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit]
if "WHERE workspace" in sql:
# With workspace filter: params[0]=workspace, params[1]=offset, params[2]=limit
offset = params[1] if len(params) > 1 else 0
limit = params[2] if len(params) > 2 else 500
else:
# No workspace filter: params[0]=offset, params[1]=limit
offset = params[0] if params else 0
limit = params[1] if len(params) > 1 else 500
start = offset
end = min(offset + limit, len(mock_rows))
return mock_rows[start:end]
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query)
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
),
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create,
):
await storage.initialize()
# Verify table name contains ada-002
assert "text_embedding_ada_002_1536d" in storage.table_name
# Verify migration was executed
assert mock_pg_db.execute.call_count >= 50 # At least one execute per row
mock_create.assert_called_once()
# Verify legacy table was automatically deleted after successful migration
# This prevents Case 1 warnings on next startup
delete_calls = [
call
for call in mock_pg_db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
assert (
len(delete_calls) >= 1
), "Legacy table should be deleted after successful migration"
# Check if legacy table was dropped
dropped_table = storage.legacy_table_name
assert any(
dropped_table in str(call) for call in delete_calls
), f"Expected to drop '{dropped_table}'"
@pytest.mark.asyncio
async def test_scenario_3_multi_model_coexistence(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Scenario 3: Multiple embedding models coexist
Expected behavior:
- Different embedding models create separate tables
- Tables are isolated by model suffix
- No interference between different models
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
# Workspace A: uses bge-small (768d)
embedding_func_a = EmbeddingFunc(
embedding_dim=768, func=mock_embedding_func.func, model_name="bge-small"
)
storage_a = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func_a,
workspace="workspace_a",
)
# Workspace B: uses bge-large (1024d)
async def embed_func_b(texts, **kwargs):
return np.array([[0.1] * 1024 for _ in texts])
embedding_func_b = EmbeddingFunc(
embedding_dim=1024, func=embed_func_b, model_name="bge-large"
)
storage_b = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func_b,
workspace="workspace_b",
)
# Verify different table names
assert storage_a.table_name != storage_b.table_name
assert "bge_small_768d" in storage_a.table_name
assert "bge_large_1024d" in storage_b.table_name
# Mock: both tables don't exist yet
async def mock_table_exists(db, table_name):
return False
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
),
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create,
):
# Initialize both storages
await storage_a.initialize()
await storage_b.initialize()
# Verify two separate tables were created
assert mock_create.call_count == 2
# Verify table names are different
call_args_list = mock_create.call_args_list
table_names = [call[0][1] for call in call_args_list] # Second positional arg
assert len(set(table_names)) == 2 # Two unique table names
assert storage_a.table_name in table_names
assert storage_b.table_name in table_names
@pytest.mark.asyncio
async def test_case1_empty_legacy_auto_cleanup(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Case 1a: Both new and legacy tables exist, but legacy is EMPTY
Expected: Automatically delete empty legacy table (safe cleanup)
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="test-model",
)
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="test_ws",
)
# Mock: Both tables exist
async def mock_table_exists(db, table_name):
return True # Both new and legacy exist
# Mock: Legacy table is empty (0 records)
async def mock_query(sql, params=None, multirows=False, **kwargs):
if "COUNT(*)" in sql:
if storage.legacy_table_name in sql:
return {"count": 0} # Empty legacy table
else:
return {"count": 100} # New table has data
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query)
with patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
):
await storage.initialize()
# Verify: Empty legacy table should be automatically cleaned up
# Empty tables are safe to delete without data loss risk
delete_calls = [
call
for call in mock_pg_db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
assert len(delete_calls) >= 1, "Empty legacy table should be auto-deleted"
# Check if legacy table was dropped
dropped_table = storage.legacy_table_name
assert any(
dropped_table in str(call) for call in delete_calls
), f"Expected to drop empty legacy table '{dropped_table}'"
print(
f"✅ Case 1a: Empty legacy table '{dropped_table}' auto-deleted successfully"
)
@pytest.mark.asyncio
async def test_case1_nonempty_legacy_warning(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Case 1b: Both new and legacy tables exist, and legacy HAS DATA
Expected: Log warning, do not delete legacy (preserve data)
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="test-model",
)
storage = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="test_ws",
)
# Mock: Both tables exist
async def mock_table_exists(db, table_name):
return True # Both new and legacy exist
# Mock: Legacy table has data (50 records)
async def mock_query(sql, params=None, multirows=False, **kwargs):
if "COUNT(*)" in sql:
if storage.legacy_table_name in sql:
return {"count": 50} # Legacy has data
else:
return {"count": 100} # New table has data
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query)
with patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
):
await storage.initialize()
# Verify: Legacy table with data should be preserved
# We never auto-delete tables that contain data to prevent accidental data loss
delete_calls = [
call
for call in mock_pg_db.execute.call_args_list
if call[0][0] and "DROP TABLE" in call[0][0]
]
# Check if legacy table was deleted (it should not be)
dropped_table = storage.legacy_table_name
legacy_deleted = any(dropped_table in str(call) for call in delete_calls)
assert not legacy_deleted, "Legacy table with data should NOT be auto-deleted"
print(
f"✅ Case 1b: Legacy table '{dropped_table}' with data preserved (warning only)"
)
@pytest.mark.asyncio
async def test_case1_sequential_workspace_migration(
mock_client_manager, mock_pg_db, mock_embedding_func
):
"""
Case 1c: Sequential workspace migration (Multi-tenant scenario)
Critical bug fix verification:
Timeline:
1. Legacy table has workspace_a (3 records) + workspace_b (3 records)
2. Workspace A initializes first Case 4 (only legacy exists) migrates A's data
3. Workspace B initializes later Case 1 (both tables exist) should migrate B's data
4. Verify workspace B's data is correctly migrated to new table
5. Verify legacy table is cleaned up after both workspaces migrate
This test verifies the fix where Case 1 now checks and migrates current
workspace's data instead of just checking if legacy table is empty globally.
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
embedding_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="test-model",
)
# Mock data: Legacy table has 6 records total (3 from workspace_a, 3 from workspace_b)
mock_rows_a = [
{"id": f"a_{i}", "content": f"A content {i}", "workspace": "workspace_a"}
for i in range(3)
]
mock_rows_b = [
{"id": f"b_{i}", "content": f"B content {i}", "workspace": "workspace_b"}
for i in range(3)
]
# Track migration state
migration_state = {"new_table_exists": False, "workspace_a_migrated": False}
# Step 1: Simulate workspace_a initialization (Case 4)
# CRITICAL: Set db.workspace to workspace_a
mock_pg_db.workspace = "workspace_a"
storage_a = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="workspace_a",
)
# Mock table_exists for workspace_a
async def mock_table_exists_a(db, table_name):
if table_name == storage_a.legacy_table_name:
return True
if table_name == storage_a.table_name:
return migration_state["new_table_exists"]
return False
# Track inserted records count for verification
inserted_count = {"workspace_a": 0}
# Mock execute to track inserts
async def mock_execute_a(sql, data=None, **kwargs):
if sql and "INSERT INTO" in sql.upper():
inserted_count["workspace_a"] += 1
return None
# Mock query for workspace_a (Case 4)
async def mock_query_a(sql, params=None, multirows=False, **kwargs):
sql_upper = sql.upper()
base_name = storage_a.legacy_table_name.upper()
if "COUNT(*)" in sql:
has_model_suffix = "TEST_MODEL_1536D" in sql_upper
is_legacy = base_name in sql_upper and not has_model_suffix
has_workspace_filter = "WHERE workspace" in sql
if is_legacy and has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_a":
# After migration starts, pretend legacy is empty for this workspace
return {"count": 3 - inserted_count["workspace_a"]}
elif workspace == "workspace_b":
return {"count": 3}
elif is_legacy and not has_workspace_filter:
# Global count in legacy table
remaining = 6 - inserted_count["workspace_a"]
return {"count": remaining}
elif has_model_suffix:
# New table count (for verification)
return {"count": inserted_count["workspace_a"]}
elif multirows and "SELECT *" in sql:
if "WHERE workspace" in sql:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_a":
offset = params[1] if len(params) > 1 else 0
limit = params[2] if len(params) > 2 else 500
return mock_rows_a[offset : offset + limit]
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query_a)
mock_pg_db.execute = AsyncMock(side_effect=mock_execute_a)
# Initialize workspace_a (Case 4)
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=mock_table_exists_a,
),
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()),
):
await storage_a.initialize()
migration_state["new_table_exists"] = True
migration_state["workspace_a_migrated"] = True
print("✅ Step 1: Workspace A initialized (Case 4)")
assert mock_pg_db.execute.call_count >= 3
print(f"✅ Step 1: {mock_pg_db.execute.call_count} execute calls")
# Step 2: Simulate workspace_b initialization (Case 1)
# CRITICAL: Set db.workspace to workspace_b
mock_pg_db.workspace = "workspace_b"
storage_b = PGVectorStorage(
namespace=NameSpace.VECTOR_STORE_CHUNKS,
global_config=config,
embedding_func=embedding_func,
workspace="workspace_b",
)
mock_pg_db.reset_mock()
migration_state["workspace_b_migrated"] = False
# Mock table_exists for workspace_b (both exist)
async def mock_table_exists_b(db, table_name):
return True
# Track inserted records count for workspace_b
inserted_count["workspace_b"] = 0
# Mock execute for workspace_b to track inserts
async def mock_execute_b(sql, data=None, **kwargs):
if sql and "INSERT INTO" in sql.upper():
inserted_count["workspace_b"] += 1
return None
# Mock query for workspace_b (Case 1)
async def mock_query_b(sql, params=None, multirows=False, **kwargs):
sql_upper = sql.upper()
base_name = storage_b.legacy_table_name.upper()
if "COUNT(*)" in sql:
has_model_suffix = "TEST_MODEL_1536D" in sql_upper
is_legacy = base_name in sql_upper and not has_model_suffix
has_workspace_filter = "WHERE workspace" in sql
if is_legacy and has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_b":
# After migration starts, pretend legacy is empty for this workspace
return {"count": 3 - inserted_count["workspace_b"]}
elif workspace == "workspace_a":
return {"count": 0} # Already migrated
elif is_legacy and not has_workspace_filter:
# Global count: only workspace_b data remains
return {"count": 3 - inserted_count["workspace_b"]}
elif has_model_suffix:
# New table total count (workspace_a: 3 + workspace_b: inserted)
if has_workspace_filter:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_b":
return {"count": inserted_count["workspace_b"]}
elif workspace == "workspace_a":
return {"count": 3}
else:
# Total count in new table (for verification)
return {"count": 3 + inserted_count["workspace_b"]}
elif multirows and "SELECT *" in sql:
if "WHERE workspace" in sql:
workspace = params[0] if params and len(params) > 0 else None
if workspace == "workspace_b":
offset = params[1] if len(params) > 1 else 0
limit = params[2] if len(params) > 2 else 500
return mock_rows_b[offset : offset + limit]
return {}
mock_pg_db.query = AsyncMock(side_effect=mock_query_b)
mock_pg_db.execute = AsyncMock(side_effect=mock_execute_b)
# Initialize workspace_b (Case 1)
with patch(
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists_b
):
await storage_b.initialize()
migration_state["workspace_b_migrated"] = True
print("✅ Step 2: Workspace B initialized (Case 1)")
# Verify workspace_b migration happened
execute_calls = mock_pg_db.execute.call_args_list
insert_calls = [
call for call in execute_calls if call[0][0] and "INSERT INTO" in call[0][0]
]
assert len(insert_calls) >= 3, f"Expected >= 3 inserts, got {len(insert_calls)}"
print(f"✅ Step 2: {len(insert_calls)} insert calls")
# Verify DELETE and DROP TABLE
delete_calls = [
call
for call in execute_calls
if call[0][0]
and "DELETE FROM" in call[0][0]
and "WHERE workspace" in call[0][0]
]
assert len(delete_calls) >= 1, "Expected DELETE workspace_b data"
print("✅ Step 2: DELETE workspace_b from legacy")
drop_calls = [
call for call in execute_calls if call[0][0] and "DROP TABLE" in call[0][0]
]
assert len(drop_calls) >= 1, "Expected DROP TABLE"
print("✅ Step 2: Legacy table dropped")
print("\n🎉 Case 1c: Sequential workspace migration verified!")

View file

@ -0,0 +1,522 @@
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
import numpy as np
from lightrag.utils import EmbeddingFunc
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
# Mock QdrantClient
@pytest.fixture
def mock_qdrant_client():
with patch("lightrag.kg.qdrant_impl.QdrantClient") as mock_client_cls:
client = mock_client_cls.return_value
client.collection_exists.return_value = False
client.count.return_value.count = 0
# Mock payload schema and vector config for get_collection
collection_info = MagicMock()
collection_info.payload_schema = {}
# Mock vector dimension to match mock_embedding_func (768d)
collection_info.config.params.vectors.size = 768
client.get_collection.return_value = collection_info
yield client
# Mock get_data_init_lock to avoid async lock issues in tests
@pytest.fixture(autouse=True)
def mock_data_init_lock():
with patch("lightrag.kg.qdrant_impl.get_data_init_lock") as mock_lock:
mock_lock_ctx = AsyncMock()
mock_lock.return_value = mock_lock_ctx
yield mock_lock
# Mock Embedding function
@pytest.fixture
def mock_embedding_func():
async def embed_func(texts, **kwargs):
return np.array([[0.1] * 768 for _ in texts])
func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test-model")
return func
@pytest.mark.asyncio
async def test_qdrant_collection_naming(mock_qdrant_client, mock_embedding_func):
"""Test if collection name is correctly generated with model suffix"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Verify collection name contains model suffix
expected_suffix = "test_model_768d"
assert expected_suffix in storage.final_namespace
assert storage.final_namespace == f"lightrag_vdb_chunks_{expected_suffix}"
# Verify legacy namespace (should not include workspace, just the base collection name)
assert storage.legacy_namespace == "lightrag_vdb_chunks"
@pytest.mark.asyncio
async def test_qdrant_migration_trigger(mock_qdrant_client, mock_embedding_func):
"""Test if migration logic is triggered correctly"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# Setup mocks for migration scenario
# 1. New collection does not exist
mock_qdrant_client.collection_exists.side_effect = (
lambda name: name == storage.legacy_namespace
)
# 2. Legacy collection exists and has data
mock_qdrant_client.count.return_value.count = 100
# 3. Mock scroll for data migration
mock_point = MagicMock()
mock_point.id = "old_id"
mock_point.vector = [0.1] * 768
mock_point.payload = {"content": "test"}
# First call returns points, second call returns empty (end of scroll)
mock_qdrant_client.scroll.side_effect = [([mock_point], "next_offset"), ([], None)]
# Initialize storage (triggers migration)
await storage.initialize()
# Verify migration steps
# 1. Legacy count checked
mock_qdrant_client.count.assert_any_call(
collection_name=storage.legacy_namespace, exact=True
)
# 2. New collection created
mock_qdrant_client.create_collection.assert_called()
# 3. Data scrolled from legacy
assert mock_qdrant_client.scroll.call_count >= 1
call_args = mock_qdrant_client.scroll.call_args_list[0]
assert call_args.kwargs["collection_name"] == storage.legacy_namespace
assert call_args.kwargs["limit"] == 500
# 4. Data upserted to new
mock_qdrant_client.upsert.assert_called()
# 5. Payload index created
mock_qdrant_client.create_payload_index.assert_called()
@pytest.mark.asyncio
async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_func):
"""Test scenario where new collection already exists"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
# New collection exists and Legacy exists (warning case)
# or New collection exists and Legacy does not exist (normal case)
# Mocking case where both exist to test logic flow but without migration
# Logic in code:
# Case 1: Both exist -> Warning only
# Case 2: Only new exists -> Ensure index
# Let's test Case 2: Only new collection exists
mock_qdrant_client.collection_exists.side_effect = (
lambda name: name == storage.final_namespace
)
# Initialize
await storage.initialize()
# Should check index but NOT migrate
# In Qdrant implementation, Case 2 calls get_collection
mock_qdrant_client.get_collection.assert_called_with(storage.final_namespace)
mock_qdrant_client.scroll.assert_not_called()
# ============================================================================
# Tests for scenarios described in design document (Lines 606-649)
# ============================================================================
@pytest.mark.asyncio
async def test_scenario_1_new_workspace_creation(
mock_qdrant_client, mock_embedding_func
):
"""
场景1新建workspace
预期直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d
"""
# Use a large embedding model
large_model_func = EmbeddingFunc(
embedding_dim=3072,
func=mock_embedding_func.func,
model_name="text-embedding-3-large",
)
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=large_model_func,
workspace="test_new",
)
# Case 3: Neither legacy nor new collection exists
mock_qdrant_client.collection_exists.return_value = False
# Initialize storage
await storage.initialize()
# Verify: Should create new collection with model suffix
expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d"
assert storage.final_namespace == expected_collection
# Verify create_collection was called with correct name
create_calls = [
call for call in mock_qdrant_client.create_collection.call_args_list
]
assert len(create_calls) > 0
assert (
create_calls[0][0][0] == expected_collection
or create_calls[0].kwargs.get("collection_name") == expected_collection
)
# Verify no migration was attempted
mock_qdrant_client.scroll.assert_not_called()
print(
f"✅ Scenario 1: New workspace created with collection '{expected_collection}'"
)
@pytest.mark.asyncio
async def test_scenario_2_legacy_upgrade_migration(
mock_qdrant_client, mock_embedding_func
):
"""
场景2从旧版本升级
已存在lightrag_vdb_chunks无后缀
预期自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d
"""
# Use ada-002 model
ada_func = EmbeddingFunc(
embedding_dim=1536,
func=mock_embedding_func.func,
model_name="text-embedding-ada-002",
)
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=ada_func,
workspace="test_legacy",
)
legacy_collection = storage.legacy_namespace
new_collection = storage.final_namespace
# Case 4: Only legacy collection exists
mock_qdrant_client.collection_exists.side_effect = (
lambda name: name == legacy_collection
)
# Mock legacy collection info with 1536d vectors
legacy_collection_info = MagicMock()
legacy_collection_info.payload_schema = {}
legacy_collection_info.config.params.vectors.size = 1536
mock_qdrant_client.get_collection.return_value = legacy_collection_info
# Mock legacy data
mock_qdrant_client.count.return_value.count = 150
# Mock scroll results (simulate migration in batches)
mock_points = []
for i in range(10):
point = MagicMock()
point.id = f"legacy-{i}"
point.vector = [0.1] * 1536
point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"}
mock_points.append(point)
# First batch returns points, second batch returns empty
mock_qdrant_client.scroll.side_effect = [(mock_points, "offset1"), ([], None)]
# Initialize (triggers migration)
await storage.initialize()
# Verify: New collection should be created
expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
assert storage.final_namespace == expected_new_collection
# Verify migration steps
# 1. Check legacy count
mock_qdrant_client.count.assert_any_call(
collection_name=legacy_collection, exact=True
)
# 2. Create new collection
mock_qdrant_client.create_collection.assert_called()
# 3. Scroll legacy data
scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list]
assert len(scroll_calls) >= 1
assert scroll_calls[0].kwargs["collection_name"] == legacy_collection
# 4. Upsert to new collection
upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list]
assert len(upsert_calls) >= 1
assert upsert_calls[0].kwargs["collection_name"] == new_collection
# 5. Verify legacy collection was automatically deleted after successful migration
# This prevents Case 1 warnings on next startup
delete_calls = [
call for call in mock_qdrant_client.delete_collection.call_args_list
]
assert (
len(delete_calls) >= 1
), "Legacy collection should be deleted after successful migration"
# Check if legacy_collection was passed to delete_collection
deleted_collection = (
delete_calls[0][0][0]
if delete_calls[0][0]
else delete_calls[0].kwargs.get("collection_name")
)
assert (
deleted_collection == legacy_collection
), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'"
print(
f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}' and legacy collection deleted"
)
@pytest.mark.asyncio
async def test_scenario_3_multi_model_coexistence(mock_qdrant_client):
"""
场景3多模型并存
预期两个独立的collection互不干扰
"""
# Model A: bge-small with 768d
async def embed_func_a(texts, **kwargs):
return np.array([[0.1] * 768 for _ in texts])
model_a_func = EmbeddingFunc(
embedding_dim=768, func=embed_func_a, model_name="bge-small"
)
# Model B: bge-large with 1024d
async def embed_func_b(texts, **kwargs):
return np.array([[0.2] * 1024 for _ in texts])
model_b_func = EmbeddingFunc(
embedding_dim=1024, func=embed_func_b, model_name="bge-large"
)
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
# Create storage for workspace A with model A
storage_a = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=model_a_func,
workspace="workspace_a",
)
# Create storage for workspace B with model B
storage_b = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=model_b_func,
workspace="workspace_b",
)
# Verify: Collection names are different
assert storage_a.final_namespace != storage_b.final_namespace
# Verify: Model A collection
expected_collection_a = "lightrag_vdb_chunks_bge_small_768d"
assert storage_a.final_namespace == expected_collection_a
# Verify: Model B collection
expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d"
assert storage_b.final_namespace == expected_collection_b
# Verify: Different embedding dimensions are preserved
assert storage_a.embedding_func.embedding_dim == 768
assert storage_b.embedding_func.embedding_dim == 1024
print("✅ Scenario 3: Multi-model coexistence verified")
print(f" - Workspace A: {expected_collection_a} (768d)")
print(f" - Workspace B: {expected_collection_b} (1024d)")
print(" - Collections are independent")
@pytest.mark.asyncio
async def test_case1_empty_legacy_auto_cleanup(mock_qdrant_client, mock_embedding_func):
"""
Case 1a: 新旧collection都存在且旧库为空
预期自动删除旧库
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
legacy_collection = storage.legacy_namespace
new_collection = storage.final_namespace
# Mock: Both collections exist
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
legacy_collection,
new_collection,
]
# Mock: Legacy collection is empty (0 records)
def count_mock(collection_name, exact=True):
mock_result = MagicMock()
if collection_name == legacy_collection:
mock_result.count = 0 # Empty legacy collection
else:
mock_result.count = 100 # New collection has data
return mock_result
mock_qdrant_client.count.side_effect = count_mock
# Mock get_collection for Case 2 check
collection_info = MagicMock()
collection_info.payload_schema = {"workspace_id": True}
mock_qdrant_client.get_collection.return_value = collection_info
# Initialize storage
await storage.initialize()
# Verify: Empty legacy collection should be automatically cleaned up
# Empty collections are safe to delete without data loss risk
delete_calls = [
call for call in mock_qdrant_client.delete_collection.call_args_list
]
assert len(delete_calls) >= 1, "Empty legacy collection should be auto-deleted"
deleted_collection = (
delete_calls[0][0][0]
if delete_calls[0][0]
else delete_calls[0].kwargs.get("collection_name")
)
assert (
deleted_collection == legacy_collection
), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'"
print(
f"✅ Case 1a: Empty legacy collection '{legacy_collection}' auto-deleted successfully"
)
@pytest.mark.asyncio
async def test_case1_nonempty_legacy_warning(mock_qdrant_client, mock_embedding_func):
"""
Case 1b: 新旧collection都存在且旧库有数据
预期警告但不删除
"""
config = {
"embedding_batch_num": 10,
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
}
storage = QdrantVectorDBStorage(
namespace="chunks",
global_config=config,
embedding_func=mock_embedding_func,
workspace="test_ws",
)
legacy_collection = storage.legacy_namespace
new_collection = storage.final_namespace
# Mock: Both collections exist
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
legacy_collection,
new_collection,
]
# Mock: Legacy collection has data (50 records)
def count_mock(collection_name, exact=True):
mock_result = MagicMock()
if collection_name == legacy_collection:
mock_result.count = 50 # Legacy has data
else:
mock_result.count = 100 # New collection has data
return mock_result
mock_qdrant_client.count.side_effect = count_mock
# Mock get_collection for Case 2 check
collection_info = MagicMock()
collection_info.payload_schema = {"workspace_id": True}
mock_qdrant_client.get_collection.return_value = collection_info
# Initialize storage
await storage.initialize()
# Verify: Legacy collection with data should be preserved
# We never auto-delete collections that contain data to prevent accidental data loss
delete_calls = [
call for call in mock_qdrant_client.delete_collection.call_args_list
]
# Check if legacy collection was deleted (it should not be)
legacy_deleted = any(
(call[0][0] if call[0] else call.kwargs.get("collection_name"))
== legacy_collection
for call in delete_calls
)
assert not legacy_deleted, "Legacy collection with data should NOT be auto-deleted"
print(
f"✅ Case 1b: Legacy collection '{legacy_collection}' with data preserved (warning only)"
)

View file

@ -0,0 +1,191 @@
"""
Tests for UnifiedLock safety when lock is None.
This test module verifies that UnifiedLock raises RuntimeError instead of
allowing unprotected execution when the underlying lock is None, preventing
false security and potential race conditions.
Critical Bug 1: When self._lock is None, __aenter__ used to log WARNING but
still return successfully, allowing critical sections to run without lock
protection, causing race conditions and data corruption.
Critical Bug 2: In __aexit__, when async_lock.release() fails, the error
recovery logic would attempt to release it again, causing double-release issues.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from lightrag.kg.shared_storage import UnifiedLock
class TestUnifiedLockSafety:
"""Test suite for UnifiedLock None safety checks."""
@pytest.mark.asyncio
async def test_unified_lock_raises_on_none_async(self):
"""
Test that UnifiedLock raises RuntimeError when lock is None (async mode).
Scenario: Attempt to use UnifiedLock before initialize_share_data() is called.
Expected: RuntimeError raised, preventing unprotected critical section execution.
"""
lock = UnifiedLock(
lock=None, is_async=True, name="test_async_lock", enable_logging=False
)
with pytest.raises(
RuntimeError, match="shared data not initialized|Lock.*is None"
):
async with lock:
# This code should NEVER execute
pytest.fail(
"Code inside lock context should not execute when lock is None"
)
@pytest.mark.asyncio
async def test_unified_lock_raises_on_none_sync(self):
"""
Test that UnifiedLock raises RuntimeError when lock is None (sync mode).
Scenario: Attempt to use UnifiedLock with None lock in sync mode.
Expected: RuntimeError raised with clear error message.
"""
lock = UnifiedLock(
lock=None, is_async=False, name="test_sync_lock", enable_logging=False
)
with pytest.raises(
RuntimeError, match="shared data not initialized|Lock.*is None"
):
async with lock:
# This code should NEVER execute
pytest.fail(
"Code inside lock context should not execute when lock is None"
)
@pytest.mark.asyncio
async def test_error_message_clarity(self):
"""
Test that the error message clearly indicates the problem and solution.
Scenario: Lock is None and user tries to acquire it.
Expected: Error message mentions 'shared data not initialized' and
'initialize_share_data()'.
"""
lock = UnifiedLock(
lock=None,
is_async=True,
name="test_error_message",
enable_logging=False,
)
with pytest.raises(RuntimeError) as exc_info:
async with lock:
pass
error_message = str(exc_info.value)
# Verify error message contains helpful information
assert (
"shared data not initialized" in error_message.lower()
or "lock" in error_message.lower()
)
assert "initialize_share_data" in error_message or "None" in error_message
@pytest.mark.asyncio
async def test_aexit_no_double_release_on_async_lock_failure(self):
"""
Test that __aexit__ doesn't attempt to release async_lock twice when it fails.
Scenario: async_lock.release() fails during normal release.
Expected: Recovery logic should NOT attempt to release async_lock again,
preventing double-release issues.
This tests Bug 2 fix: async_lock_released tracking prevents double release.
"""
# Create mock locks
main_lock = MagicMock()
main_lock.acquire = MagicMock()
main_lock.release = MagicMock()
async_lock = AsyncMock()
async_lock.acquire = AsyncMock()
# Make async_lock.release() fail
release_call_count = 0
def mock_release_fail():
nonlocal release_call_count
release_call_count += 1
raise RuntimeError("Async lock release failed")
async_lock.release = MagicMock(side_effect=mock_release_fail)
# Create UnifiedLock with both locks (sync mode with async_lock)
lock = UnifiedLock(
lock=main_lock,
is_async=False,
name="test_double_release",
enable_logging=False,
)
lock._async_lock = async_lock
# Try to use the lock - should fail during __aexit__
try:
async with lock:
pass
except RuntimeError as e:
# Should get the async lock release error
assert "Async lock release failed" in str(e)
# Verify async_lock.release() was called only ONCE, not twice
assert (
release_call_count == 1
), f"async_lock.release() should be called only once, but was called {release_call_count} times"
# Main lock should have been released successfully
main_lock.release.assert_called_once()
@pytest.mark.asyncio
async def test_aexit_recovery_on_main_lock_failure(self):
"""
Test that __aexit__ recovery logic works when main lock release fails.
Scenario: main_lock.release() fails before async_lock is attempted.
Expected: Recovery logic should attempt to release async_lock to prevent
resource leaks.
This verifies the recovery logic still works correctly with async_lock_released tracking.
"""
# Create mock locks
main_lock = MagicMock()
main_lock.acquire = MagicMock()
# Make main_lock.release() fail
def mock_main_release_fail():
raise RuntimeError("Main lock release failed")
main_lock.release = MagicMock(side_effect=mock_main_release_fail)
async_lock = AsyncMock()
async_lock.acquire = AsyncMock()
async_lock.release = MagicMock()
# Create UnifiedLock with both locks (sync mode with async_lock)
lock = UnifiedLock(
lock=main_lock, is_async=False, name="test_recovery", enable_logging=False
)
lock._async_lock = async_lock
# Try to use the lock - should fail during __aexit__
try:
async with lock:
pass
except RuntimeError as e:
# Should get the main lock release error
assert "Main lock release failed" in str(e)
# Main lock release should have been attempted
main_lock.release.assert_called_once()
# Recovery logic should have attempted to release async_lock
async_lock.release.assert_called_once()

View file

@ -0,0 +1,308 @@
"""
Tests for workspace isolation during PostgreSQL migration.
This test module verifies that setup_table() properly filters migration data
by workspace, preventing cross-workspace data leakage during legacy table migration.
Critical Bug: Migration copied ALL records from legacy table regardless of workspace,
causing workspace A to receive workspace B's data, violating multi-tenant isolation.
"""
import pytest
from unittest.mock import AsyncMock
from lightrag.kg.postgres_impl import PGVectorStorage
class TestWorkspaceMigrationIsolation:
"""Test suite for workspace-scoped migration in PostgreSQL."""
@pytest.mark.asyncio
async def test_migration_filters_by_workspace(self):
"""
Test that migration only copies data from the specified workspace.
Scenario: Legacy table contains data from multiple workspaces.
Migrate only workspace_a's data to new table.
Expected: New table contains only workspace_a data, workspace_b data excluded.
"""
db = AsyncMock()
# Mock table existence checks
async def table_exists_side_effect(db_instance, name):
if name == "lightrag_doc_chunks": # legacy
return True
elif name == "lightrag_doc_chunks_model_1536d": # new
return False
return False
# Mock query responses
async def query_side_effect(sql, params, **kwargs):
multirows = kwargs.get("multirows", False)
# Table existence check
if "information_schema.tables" in sql:
if params[0] == "lightrag_doc_chunks":
return {"exists": True}
elif params[0] == "lightrag_doc_chunks_model_1536d":
return {"exists": False}
# Count query with workspace filter (legacy table)
elif "COUNT(*)" in sql and "WHERE workspace" in sql:
if params[0] == "workspace_a":
return {"count": 2} # workspace_a has 2 records
elif params[0] == "workspace_b":
return {"count": 3} # workspace_b has 3 records
return {"count": 0}
# Count query for new table (verification)
elif "COUNT(*)" in sql and "lightrag_doc_chunks_model_1536d" in sql:
return {"count": 2} # Verification: 2 records migrated
# Count query for legacy table (no filter)
elif "COUNT(*)" in sql and "lightrag_doc_chunks" in sql:
return {"count": 5} # Total records in legacy
# Dimension check
elif "pg_attribute" in sql:
return {"vector_dim": 1536}
# SELECT with workspace filter
elif "SELECT * FROM" in sql and "WHERE workspace" in sql and multirows:
workspace = params[0]
if workspace == "workspace_a" and params[1] == 0: # offset = 0
# Return only workspace_a data
return [
{
"id": "a1",
"workspace": "workspace_a",
"content": "content_a1",
"content_vector": [0.1] * 1536,
},
{
"id": "a2",
"workspace": "workspace_a",
"content": "content_a2",
"content_vector": [0.2] * 1536,
},
]
else:
return [] # No more data
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
# Mock _pg_table_exists and _pg_create_table
from unittest.mock import patch
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=table_exists_side_effect,
),
patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()),
):
# Migrate for workspace_a only
await PGVectorStorage.setup_table(
db,
"lightrag_doc_chunks_model_1536d",
legacy_table_name="lightrag_doc_chunks",
base_table="lightrag_doc_chunks",
embedding_dim=1536,
workspace="workspace_a", # CRITICAL: Only migrate workspace_a
)
# Verify workspace filter was used in queries
count_calls = [
call
for call in db.query.call_args_list
if call[0][0]
and "COUNT(*)" in call[0][0]
and "WHERE workspace" in call[0][0]
]
assert len(count_calls) > 0, "Count query should use workspace filter"
assert (
count_calls[0][0][1][0] == "workspace_a"
), "Count should filter by workspace_a"
select_calls = [
call
for call in db.query.call_args_list
if call[0][0]
and "SELECT * FROM" in call[0][0]
and "WHERE workspace" in call[0][0]
]
assert len(select_calls) > 0, "Select query should use workspace filter"
assert (
select_calls[0][0][1][0] == "workspace_a"
), "Select should filter by workspace_a"
# Verify INSERT was called (migration happened)
insert_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "INSERT INTO" in call[0][0]
]
assert len(insert_calls) == 2, "Should insert 2 records from workspace_a"
@pytest.mark.asyncio
async def test_migration_without_workspace_warns(self):
"""
Test that migration without workspace parameter logs a warning.
Scenario: setup_table called without workspace parameter.
Expected: Warning logged about potential cross-workspace data copying.
"""
db = AsyncMock()
async def table_exists_side_effect(db_instance, name):
if name == "lightrag_doc_chunks":
return True
elif name == "lightrag_doc_chunks_model_1536d":
return False
return False
async def query_side_effect(sql, params, **kwargs):
if "information_schema.tables" in sql:
return {"exists": params[0] == "lightrag_doc_chunks"}
elif "COUNT(*)" in sql:
return {"count": 5} # 5 records total
elif "pg_attribute" in sql:
return {"vector_dim": 1536}
elif "SELECT * FROM" in sql and kwargs.get("multirows"):
if params[0] == 0: # offset = 0
return [
{
"id": "1",
"workspace": "workspace_a",
"content_vector": [0.1] * 1536,
},
{
"id": "2",
"workspace": "workspace_b",
"content_vector": [0.2] * 1536,
},
]
else:
return []
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
from unittest.mock import patch
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=table_exists_side_effect,
),
patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()),
):
# Migrate WITHOUT workspace parameter (dangerous!)
await PGVectorStorage.setup_table(
db,
"lightrag_doc_chunks_model_1536d",
legacy_table_name="lightrag_doc_chunks",
base_table="lightrag_doc_chunks",
embedding_dim=1536,
workspace=None, # No workspace filter!
)
# Verify queries do NOT use workspace filter
count_calls = [
call
for call in db.query.call_args_list
if call[0][0] and "COUNT(*)" in call[0][0]
]
assert len(count_calls) > 0, "Count query should be executed"
# Check that workspace filter was NOT used
has_workspace_filter = any(
"WHERE workspace" in call[0][0] for call in count_calls
)
assert (
not has_workspace_filter
), "Count should NOT filter by workspace when workspace=None"
@pytest.mark.asyncio
async def test_no_cross_workspace_contamination(self):
"""
Test that workspace B's migration doesn't include workspace A's data.
Scenario: Two separate migrations for workspace_a and workspace_b.
Expected: Each workspace only gets its own data.
"""
db = AsyncMock()
# Track which workspace is being queried
queried_workspace = None
async def table_exists_side_effect(db_instance, name):
return "lightrag_doc_chunks" in name and "model" not in name
async def query_side_effect(sql, params, **kwargs):
nonlocal queried_workspace
multirows = kwargs.get("multirows", False)
if "information_schema.tables" in sql:
return {"exists": "lightrag_doc_chunks" in params[0]}
elif "COUNT(*)" in sql and "WHERE workspace" in sql:
queried_workspace = params[0]
return {"count": 1}
elif "COUNT(*)" in sql and "lightrag_doc_chunks_model_1536d" in sql:
return {"count": 1} # Verification count
elif "pg_attribute" in sql:
return {"vector_dim": 1536}
elif "SELECT * FROM" in sql and "WHERE workspace" in sql and multirows:
workspace = params[0]
if params[1] == 0: # offset = 0
# Return data ONLY for the queried workspace
return [
{
"id": f"{workspace}_1",
"workspace": workspace,
"content": f"content_{workspace}",
"content_vector": [0.1] * 1536,
}
]
else:
return []
return {}
db.query.side_effect = query_side_effect
db.execute = AsyncMock()
db._create_vector_index = AsyncMock()
from unittest.mock import patch
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=table_exists_side_effect,
),
patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()),
):
# Migrate workspace_b
await PGVectorStorage.setup_table(
db,
"lightrag_doc_chunks_model_1536d",
legacy_table_name="lightrag_doc_chunks",
base_table="lightrag_doc_chunks",
embedding_dim=1536,
workspace="workspace_b",
)
# Verify only workspace_b was queried
assert queried_workspace == "workspace_b", "Should only query workspace_b"
# Verify INSERT contains workspace_b data only
insert_calls = [
call
for call in db.execute.call_args_list
if call[0][0] and "INSERT INTO" in call[0][0]
]
assert len(insert_calls) > 0, "Should have INSERT calls"