Merge cf68cdfe3a into 9562a974d2
This commit is contained in:
commit
4e5351de63
18 changed files with 5500 additions and 162 deletions
190
.github/workflows/e2e-tests.yml
vendored
Normal file
190
.github/workflows/e2e-tests.yml
vendored
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
name: E2E Tests (Real Databases)
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # Manual trigger only for E2E tests
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, dev ]
|
||||||
|
paths:
|
||||||
|
- 'lightrag/kg/postgres_impl.py'
|
||||||
|
- 'lightrag/kg/qdrant_impl.py'
|
||||||
|
- 'tests/test_e2e_*.py'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
e2e-postgres:
|
||||||
|
name: E2E PostgreSQL Tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: ankane/pgvector:latest
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: lightrag
|
||||||
|
POSTGRES_PASSWORD: lightrag_test_password
|
||||||
|
POSTGRES_DB: lightrag_test
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd "pg_isready -U lightrag"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ['3.10', '3.12']
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Cache pip packages
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-e2e-${{ hashFiles('**/pyproject.toml') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pip-e2e-
|
||||||
|
${{ runner.os }}-pip-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -e ".[api]"
|
||||||
|
pip install pytest pytest-asyncio asyncpg numpy qdrant-client
|
||||||
|
|
||||||
|
- name: Wait for PostgreSQL
|
||||||
|
run: |
|
||||||
|
timeout 30 bash -c 'until pg_isready -h localhost -p 5432 -U lightrag; do sleep 1; done'
|
||||||
|
|
||||||
|
- name: Setup pgvector extension
|
||||||
|
env:
|
||||||
|
PGPASSWORD: lightrag_test_password
|
||||||
|
run: |
|
||||||
|
psql -h localhost -U lightrag -d lightrag_test -c "CREATE EXTENSION IF NOT EXISTS vector;"
|
||||||
|
psql -h localhost -U lightrag -d lightrag_test -c "SELECT extname, extversion FROM pg_extension WHERE extname = 'vector';"
|
||||||
|
|
||||||
|
- name: Run PostgreSQL E2E tests
|
||||||
|
env:
|
||||||
|
POSTGRES_HOST: localhost
|
||||||
|
POSTGRES_PORT: 5432
|
||||||
|
POSTGRES_USER: lightrag
|
||||||
|
POSTGRES_PASSWORD: lightrag_test_password
|
||||||
|
POSTGRES_DATABASE: lightrag_test
|
||||||
|
run: |
|
||||||
|
pytest tests/test_e2e_multi_instance.py -k "postgres" -v --tb=short -s
|
||||||
|
timeout-minutes: 20
|
||||||
|
|
||||||
|
- name: Upload PostgreSQL test results
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: e2e-postgres-results-py${{ matrix.python-version }}
|
||||||
|
path: |
|
||||||
|
.pytest_cache/
|
||||||
|
test-results.xml
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
e2e-qdrant:
|
||||||
|
name: E2E Qdrant Tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
services:
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
ports:
|
||||||
|
- 6333:6333
|
||||||
|
- 6334:6334
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ['3.10', '3.12']
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Cache pip packages
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-e2e-${{ hashFiles('**/pyproject.toml') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pip-e2e-
|
||||||
|
${{ runner.os }}-pip-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -e ".[api]"
|
||||||
|
pip install pytest pytest-asyncio qdrant-client numpy
|
||||||
|
|
||||||
|
- name: Wait for Qdrant
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Qdrant to be ready..."
|
||||||
|
for i in {1..60}; do
|
||||||
|
if curl -s http://localhost:6333 > /dev/null 2>&1; then
|
||||||
|
echo "Qdrant is ready!"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
echo "Attempt $i/60: Qdrant not ready yet, waiting..."
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
# Final check
|
||||||
|
if ! curl -s http://localhost:6333 > /dev/null 2>&1; then
|
||||||
|
echo "ERROR: Qdrant failed to start after 60 seconds"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Verify Qdrant connection
|
||||||
|
run: |
|
||||||
|
echo "Verifying Qdrant API..."
|
||||||
|
curl -X GET "http://localhost:6333/collections" -H "Content-Type: application/json"
|
||||||
|
echo ""
|
||||||
|
echo "Qdrant is accessible and ready for testing"
|
||||||
|
|
||||||
|
- name: Run Qdrant E2E tests
|
||||||
|
env:
|
||||||
|
QDRANT_URL: http://localhost:6333
|
||||||
|
QDRANT_API_KEY: ""
|
||||||
|
run: |
|
||||||
|
pytest tests/test_e2e_multi_instance.py -k "qdrant" -v --tb=short -s
|
||||||
|
timeout-minutes: 15
|
||||||
|
|
||||||
|
- name: Upload Qdrant test results
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: e2e-qdrant-results-py${{ matrix.python-version }}
|
||||||
|
path: |
|
||||||
|
.pytest_cache/
|
||||||
|
test-results.xml
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
e2e-summary:
|
||||||
|
name: E2E Test Summary
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [e2e-postgres, e2e-qdrant]
|
||||||
|
if: always()
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check test results
|
||||||
|
run: |
|
||||||
|
echo "## E2E Test Summary" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "### PostgreSQL E2E Tests" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "Status: ${{ needs.e2e-postgres.result }}" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "### Qdrant E2E Tests" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "Status: ${{ needs.e2e-qdrant.result }}" >> $GITHUB_STEP_SUMMARY
|
||||||
|
|
||||||
|
- name: Fail if any test failed
|
||||||
|
if: needs.e2e-postgres.result != 'success' || needs.e2e-qdrant.result != 'success'
|
||||||
|
run: exit 1
|
||||||
74
.github/workflows/feature-tests.yml
vendored
Normal file
74
.github/workflows/feature-tests.yml
vendored
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
name: Feature Branch Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # Allow manual trigger
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- 'feature/**'
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, dev ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
migration-tests:
|
||||||
|
name: Vector Storage Migration Tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ['3.10', '3.11', '3.12']
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Cache pip packages
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt', '**/pyproject.toml') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pip-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -e ".[api]"
|
||||||
|
pip install pytest pytest-asyncio
|
||||||
|
|
||||||
|
- name: Run Qdrant migration tests
|
||||||
|
run: |
|
||||||
|
pytest tests/test_qdrant_migration.py -v --tb=short
|
||||||
|
continue-on-error: false
|
||||||
|
|
||||||
|
- name: Run PostgreSQL migration tests
|
||||||
|
run: |
|
||||||
|
pytest tests/test_postgres_migration.py -v --tb=short
|
||||||
|
continue-on-error: false
|
||||||
|
|
||||||
|
- name: Run all unit tests (if exists)
|
||||||
|
run: |
|
||||||
|
# Run EmbeddingFunc tests
|
||||||
|
pytest tests/ -k "embedding" -v --tb=short || true
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Upload test results
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: migration-test-results-py${{ matrix.python-version }}
|
||||||
|
path: |
|
||||||
|
.pytest_cache/
|
||||||
|
test-results.xml
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
- name: Test Summary
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
echo "## Test Summary" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "- Python: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "- Branch: ${{ github.ref_name }}" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "- Commit: ${{ github.sha }}" >> $GITHUB_STEP_SUMMARY
|
||||||
271
examples/multi_model_demo.py
Normal file
271
examples/multi_model_demo.py
Normal file
|
|
@ -0,0 +1,271 @@
|
||||||
|
"""
|
||||||
|
Multi-Model Vector Storage Isolation Demo
|
||||||
|
|
||||||
|
This example demonstrates LightRAG's automatic model isolation feature for vector storage.
|
||||||
|
When using different embedding models, LightRAG automatically creates separate collections/tables,
|
||||||
|
preventing dimension mismatches and data pollution.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Automatic model suffix generation: {model_name}_{dim}d
|
||||||
|
- Seamless migration from legacy (no-suffix) to new (with-suffix) collections
|
||||||
|
- Support for multiple workspaces with different embedding models
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- OpenAI API key (or any OpenAI-compatible API)
|
||||||
|
- Qdrant or PostgreSQL for vector storage (optional, defaults to NanoVectorDB)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
# Set your API key
|
||||||
|
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"
|
||||||
|
|
||||||
|
|
||||||
|
async def scenario_1_new_workspace_with_explicit_model():
|
||||||
|
"""
|
||||||
|
Scenario 1: Creating a new workspace with explicit model name
|
||||||
|
|
||||||
|
Result: Creates collection/table with name like:
|
||||||
|
- Qdrant: lightrag_vdb_chunks_text_embedding_3_large_3072d
|
||||||
|
- PostgreSQL: LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Scenario 1: New Workspace with Explicit Model Name")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Define custom embedding function with explicit model name
|
||||||
|
async def my_embedding_func(texts: list[str]):
|
||||||
|
return await openai_embed(texts, model="text-embedding-3-large")
|
||||||
|
|
||||||
|
# Create EmbeddingFunc with model_name specified
|
||||||
|
embedding_func = EmbeddingFunc(
|
||||||
|
embedding_dim=3072,
|
||||||
|
func=my_embedding_func,
|
||||||
|
model_name="text-embedding-3-large", # Explicit model name
|
||||||
|
)
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir="./workspace_large_model",
|
||||||
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag.initialize_storages()
|
||||||
|
|
||||||
|
# Insert sample data
|
||||||
|
await rag.ainsert("LightRAG supports automatic model isolation for vector storage.")
|
||||||
|
|
||||||
|
# Query
|
||||||
|
result = await rag.aquery(
|
||||||
|
"What does LightRAG support?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nQuery Result: {result[:200]}...")
|
||||||
|
print("\n✅ Collection/table created with suffix: text_embedding_3_large_3072d")
|
||||||
|
|
||||||
|
await rag.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def scenario_2_legacy_migration():
|
||||||
|
"""
|
||||||
|
Scenario 2: Upgrading from legacy version (without model_name)
|
||||||
|
|
||||||
|
If you previously used LightRAG without specifying model_name,
|
||||||
|
the first run with model_name will automatically migrate your data.
|
||||||
|
|
||||||
|
Result: Data is migrated from:
|
||||||
|
- Old: lightrag_vdb_chunks (no suffix)
|
||||||
|
- New: lightrag_vdb_chunks_text_embedding_ada_002_1536d (with suffix)
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Scenario 2: Automatic Migration from Legacy Format")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Step 1: Simulate legacy workspace (no model_name)
|
||||||
|
print("\n[Step 1] Creating legacy workspace without model_name...")
|
||||||
|
|
||||||
|
async def legacy_embedding_func(texts: list[str]):
|
||||||
|
return await openai_embed(texts, model="text-embedding-ada-002")
|
||||||
|
|
||||||
|
# Legacy: No model_name specified
|
||||||
|
legacy_embedding = EmbeddingFunc(
|
||||||
|
embedding_dim=1536,
|
||||||
|
func=legacy_embedding_func,
|
||||||
|
# model_name not specified → uses "unknown" as fallback
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_legacy = LightRAG(
|
||||||
|
working_dir="./workspace_legacy",
|
||||||
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
embedding_func=legacy_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag_legacy.initialize_storages()
|
||||||
|
await rag_legacy.ainsert("Legacy data without model isolation.")
|
||||||
|
await rag_legacy.close()
|
||||||
|
|
||||||
|
print("✅ Legacy workspace created with suffix: unknown_1536d")
|
||||||
|
|
||||||
|
# Step 2: Upgrade to new version with model_name
|
||||||
|
print("\n[Step 2] Upgrading to new version with explicit model_name...")
|
||||||
|
|
||||||
|
# New: With model_name specified
|
||||||
|
new_embedding = EmbeddingFunc(
|
||||||
|
embedding_dim=1536,
|
||||||
|
func=legacy_embedding_func,
|
||||||
|
model_name="text-embedding-ada-002", # Now explicitly specified
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_new = LightRAG(
|
||||||
|
working_dir="./workspace_legacy", # Same working directory
|
||||||
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
embedding_func=new_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
# On first initialization, LightRAG will:
|
||||||
|
# 1. Detect legacy collection exists
|
||||||
|
# 2. Automatically migrate data to new collection with model suffix
|
||||||
|
# 3. Legacy collection remains but can be deleted after verification
|
||||||
|
await rag_new.initialize_storages()
|
||||||
|
|
||||||
|
# Verify data is still accessible
|
||||||
|
result = await rag_new.aquery(
|
||||||
|
"What is the legacy data?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nQuery Result: {result[:200] if result else 'No results'}...")
|
||||||
|
print("\n✅ Data migrated to: text_embedding_ada_002_1536d")
|
||||||
|
print("ℹ️ Legacy collection can be manually deleted after verification")
|
||||||
|
|
||||||
|
await rag_new.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def scenario_3_multiple_models_coexistence():
|
||||||
|
"""
|
||||||
|
Scenario 3: Multiple workspaces with different embedding models
|
||||||
|
|
||||||
|
Different embedding models create completely isolated collections/tables,
|
||||||
|
allowing safe coexistence without dimension conflicts or data pollution.
|
||||||
|
|
||||||
|
Result:
|
||||||
|
- Workspace A: lightrag_vdb_chunks_bge_small_768d
|
||||||
|
- Workspace B: lightrag_vdb_chunks_bge_large_1024d
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Scenario 3: Multiple Models Coexistence")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Workspace A: Small embedding model (768 dimensions)
|
||||||
|
print("\n[Workspace A] Using bge-small model (768d)...")
|
||||||
|
|
||||||
|
async def embedding_func_small(texts: list[str]):
|
||||||
|
# Simulate small embedding model
|
||||||
|
# In real usage, replace with actual model call
|
||||||
|
return await openai_embed(texts, model="text-embedding-3-small")
|
||||||
|
|
||||||
|
embedding_a = EmbeddingFunc(
|
||||||
|
embedding_dim=1536, # text-embedding-3-small dimension
|
||||||
|
func=embedding_func_small,
|
||||||
|
model_name="text-embedding-3-small",
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_a = LightRAG(
|
||||||
|
working_dir="./workspace_a",
|
||||||
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
embedding_func=embedding_a,
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag_a.initialize_storages()
|
||||||
|
await rag_a.ainsert("Workspace A uses small embedding model for efficiency.")
|
||||||
|
|
||||||
|
print("✅ Workspace A created with suffix: text_embedding_3_small_1536d")
|
||||||
|
|
||||||
|
# Workspace B: Large embedding model (3072 dimensions)
|
||||||
|
print("\n[Workspace B] Using text-embedding-3-large model (3072d)...")
|
||||||
|
|
||||||
|
async def embedding_func_large(texts: list[str]):
|
||||||
|
# Simulate large embedding model
|
||||||
|
return await openai_embed(texts, model="text-embedding-3-large")
|
||||||
|
|
||||||
|
embedding_b = EmbeddingFunc(
|
||||||
|
embedding_dim=3072, # text-embedding-3-large dimension
|
||||||
|
func=embedding_func_large,
|
||||||
|
model_name="text-embedding-3-large",
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_b = LightRAG(
|
||||||
|
working_dir="./workspace_b",
|
||||||
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
embedding_func=embedding_b,
|
||||||
|
)
|
||||||
|
|
||||||
|
await rag_b.initialize_storages()
|
||||||
|
await rag_b.ainsert("Workspace B uses large embedding model for better accuracy.")
|
||||||
|
|
||||||
|
print("✅ Workspace B created with suffix: text_embedding_3_large_3072d")
|
||||||
|
|
||||||
|
# Verify isolation: Query each workspace
|
||||||
|
print("\n[Verification] Querying both workspaces...")
|
||||||
|
|
||||||
|
result_a = await rag_a.aquery(
|
||||||
|
"What model does workspace use?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
result_b = await rag_b.aquery(
|
||||||
|
"What model does workspace use?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nWorkspace A Result: {result_a[:100] if result_a else 'No results'}...")
|
||||||
|
print(f"Workspace B Result: {result_b[:100] if result_b else 'No results'}...")
|
||||||
|
|
||||||
|
print("\n✅ Both workspaces operate independently without interference")
|
||||||
|
|
||||||
|
await rag_a.close()
|
||||||
|
await rag_b.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""
|
||||||
|
Run all scenarios to demonstrate model isolation features
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("LightRAG Multi-Model Vector Storage Isolation Demo")
|
||||||
|
print("=" * 80)
|
||||||
|
print("\nThis demo shows how LightRAG automatically handles:")
|
||||||
|
print("1. ✅ Automatic model suffix generation")
|
||||||
|
print("2. ✅ Seamless data migration from legacy format")
|
||||||
|
print("3. ✅ Multiple embedding models coexistence")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Scenario 1: New workspace with explicit model
|
||||||
|
await scenario_1_new_workspace_with_explicit_model()
|
||||||
|
|
||||||
|
# Scenario 2: Legacy migration
|
||||||
|
await scenario_2_legacy_migration()
|
||||||
|
|
||||||
|
# Scenario 3: Multiple models coexistence
|
||||||
|
await scenario_3_multiple_models_coexistence()
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("✅ All scenarios completed successfully!")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
print("\n📝 Key Takeaways:")
|
||||||
|
print("- Always specify `model_name` in EmbeddingFunc for clear model tracking")
|
||||||
|
print("- LightRAG automatically migrates legacy data on first run")
|
||||||
|
print("- Different embedding models create isolated collections/tables")
|
||||||
|
print("- Collection names follow pattern: {base_name}_{model_name}_{dim}d")
|
||||||
|
print("\n📚 See the plan document for more details:")
|
||||||
|
print(" .claude/plan/PR-vector-model-isolation.md")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Error: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -220,6 +220,45 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
||||||
cosine_better_than_threshold: float = field(default=0.2)
|
cosine_better_than_threshold: float = field(default=0.2)
|
||||||
meta_fields: set[str] = field(default_factory=set)
|
meta_fields: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
|
def _generate_collection_suffix(self) -> str:
|
||||||
|
"""Generates collection/table suffix from embedding_func.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Suffix string, e.g. "text_embedding_3_large_3072d"
|
||||||
|
"""
|
||||||
|
# Try to get model identifier from the embedding function
|
||||||
|
# If it's a wrapped function (doesn't have get_model_identifier),
|
||||||
|
# fallback to the original embedding_func from global_config
|
||||||
|
if hasattr(self.embedding_func, "get_model_identifier"):
|
||||||
|
return self.embedding_func.get_model_identifier()
|
||||||
|
elif "embedding_func" in self.global_config:
|
||||||
|
original_embedding_func = self.global_config["embedding_func"]
|
||||||
|
if original_embedding_func is not None and hasattr(
|
||||||
|
original_embedding_func, "get_model_identifier"
|
||||||
|
):
|
||||||
|
return original_embedding_func.get_model_identifier()
|
||||||
|
else:
|
||||||
|
# Debug: log why we couldn't get model identifier
|
||||||
|
from lightrag.utils import logger
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Could not get model_identifier: embedding_func is {type(original_embedding_func)}, has method={hasattr(original_embedding_func, 'get_model_identifier') if original_embedding_func else False}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: no model identifier available
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _get_legacy_collection_name(self) -> str:
|
||||||
|
"""Get legacy collection/table name (without suffix).
|
||||||
|
|
||||||
|
Used for data migration detection.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method")
|
||||||
|
|
||||||
|
def _get_new_collection_name(self) -> str:
|
||||||
|
"""Get new collection/table name (with suffix)."""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(
|
async def query(
|
||||||
self, query: str, top_k: int, query_embedding: list[float] = None
|
self, query: str, top_k: int, query_embedding: list[float] = None
|
||||||
|
|
|
||||||
|
|
@ -1163,23 +1163,9 @@ class PostgreSQLDB:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}")
|
logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}")
|
||||||
|
|
||||||
# Create vector indexs
|
# NOTE: Vector index creation moved to PGVectorStorage.setup_table()
|
||||||
if self.vector_index_type:
|
# Each vector storage instance creates its own index with correct embedding_dim
|
||||||
logger.info(
|
|
||||||
f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]:
|
|
||||||
await self._create_vector_indexes()
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Doesn't support this vector index type: {self.vector_index_type}. "
|
|
||||||
"Supported types: HNSW, IVFFLAT, VCHORDRQ"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}"
|
|
||||||
)
|
|
||||||
# After all tables are created, attempt to migrate timestamp fields
|
# After all tables are created, attempt to migrate timestamp fields
|
||||||
try:
|
try:
|
||||||
await self._migrate_timestamp_columns()
|
await self._migrate_timestamp_columns()
|
||||||
|
|
@ -1381,64 +1367,72 @@ class PostgreSQLDB:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to create index {index['name']}: {e}")
|
logger.warning(f"Failed to create index {index['name']}: {e}")
|
||||||
|
|
||||||
async def _create_vector_indexes(self):
|
async def _create_vector_index(self, table_name: str, embedding_dim: int):
|
||||||
vdb_tables = [
|
"""
|
||||||
"LIGHTRAG_VDB_CHUNKS",
|
Create vector index for a specific table.
|
||||||
"LIGHTRAG_VDB_ENTITY",
|
|
||||||
"LIGHTRAG_VDB_RELATION",
|
Args:
|
||||||
]
|
table_name: Name of the table to create index on
|
||||||
|
embedding_dim: Embedding dimension for the vector column
|
||||||
|
"""
|
||||||
|
if not self.vector_index_type:
|
||||||
|
return
|
||||||
|
|
||||||
create_sql = {
|
create_sql = {
|
||||||
"HNSW": f"""
|
"HNSW": f"""
|
||||||
CREATE INDEX {{vector_index_name}}
|
CREATE INDEX {{vector_index_name}}
|
||||||
ON {{k}} USING hnsw (content_vector vector_cosine_ops)
|
ON {{table_name}} USING hnsw (content_vector vector_cosine_ops)
|
||||||
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
|
WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef})
|
||||||
""",
|
""",
|
||||||
"IVFFLAT": f"""
|
"IVFFLAT": f"""
|
||||||
CREATE INDEX {{vector_index_name}}
|
CREATE INDEX {{vector_index_name}}
|
||||||
ON {{k}} USING ivfflat (content_vector vector_cosine_ops)
|
ON {{table_name}} USING ivfflat (content_vector vector_cosine_ops)
|
||||||
WITH (lists = {self.ivfflat_lists})
|
WITH (lists = {self.ivfflat_lists})
|
||||||
""",
|
""",
|
||||||
"VCHORDRQ": f"""
|
"VCHORDRQ": f"""
|
||||||
CREATE INDEX {{vector_index_name}}
|
CREATE INDEX {{vector_index_name}}
|
||||||
ON {{k}} USING vchordrq (content_vector vector_cosine_ops)
|
ON {{table_name}} USING vchordrq (content_vector vector_cosine_ops)
|
||||||
{f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''}
|
{f"WITH (options = $${self.vchordrq_build_options}$$)" if self.vchordrq_build_options else ""}
|
||||||
""",
|
""",
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024))
|
if self.vector_index_type not in create_sql:
|
||||||
for k in vdb_tables:
|
logger.warning(
|
||||||
vector_index_name = (
|
f"Unsupported vector index type: {self.vector_index_type}. "
|
||||||
f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
|
"Supported types: HNSW, IVFFLAT, VCHORDRQ"
|
||||||
)
|
)
|
||||||
check_vector_index_sql = f"""
|
return
|
||||||
SELECT 1 FROM pg_indexes
|
|
||||||
WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
|
k = table_name
|
||||||
"""
|
vector_index_name = f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine"
|
||||||
try:
|
check_vector_index_sql = f"""
|
||||||
vector_index_exists = await self.query(check_vector_index_sql)
|
SELECT 1 FROM pg_indexes
|
||||||
if not vector_index_exists:
|
WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}'
|
||||||
# Only set vector dimension when index doesn't exist
|
"""
|
||||||
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
|
try:
|
||||||
await self.execute(alter_sql)
|
vector_index_exists = await self.query(check_vector_index_sql)
|
||||||
logger.debug(f"Ensured vector dimension for {k}")
|
if not vector_index_exists:
|
||||||
logger.info(
|
# Only set vector dimension when index doesn't exist
|
||||||
f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
|
alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})"
|
||||||
|
await self.execute(alter_sql)
|
||||||
|
logger.debug(f"Ensured vector dimension for {k}")
|
||||||
|
logger.info(
|
||||||
|
f"Creating {self.vector_index_type} index {vector_index_name} on table {k}"
|
||||||
|
)
|
||||||
|
await self.execute(
|
||||||
|
create_sql[self.vector_index_type].format(
|
||||||
|
vector_index_name=vector_index_name, table_name=k
|
||||||
)
|
)
|
||||||
await self.execute(
|
)
|
||||||
create_sql[self.vector_index_type].format(
|
logger.info(
|
||||||
vector_index_name=vector_index_name, k=k
|
f"Successfully created vector index {vector_index_name} on table {k}"
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully created vector index {vector_index_name} on table {k}"
|
f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
|
||||||
)
|
)
|
||||||
else:
|
except Exception as e:
|
||||||
logger.info(
|
logger.error(f"Failed to create vector index on table {k}, Got: {e}")
|
||||||
f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create vector index on table {k}, Got: {e}")
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
|
|
@ -2175,6 +2169,90 @@ class PGKVStorage(BaseKVStorage):
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
async def _pg_table_exists(db: PostgreSQLDB, table_name: str) -> bool:
|
||||||
|
"""Check if a table exists in PostgreSQL database"""
|
||||||
|
query = """
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT FROM information_schema.tables
|
||||||
|
WHERE table_name = $1
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
result = await db.query(query, [table_name.lower()])
|
||||||
|
return result.get("exists", False) if result else False
|
||||||
|
|
||||||
|
|
||||||
|
async def _pg_create_table(
|
||||||
|
db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
|
||||||
|
) -> None:
|
||||||
|
"""Create a new vector table by replacing the table name in DDL template"""
|
||||||
|
if base_table not in TABLES:
|
||||||
|
raise ValueError(f"No DDL template found for table: {base_table}")
|
||||||
|
|
||||||
|
ddl_template = TABLES[base_table]["ddl"]
|
||||||
|
|
||||||
|
# Replace embedding dimension placeholder if exists
|
||||||
|
ddl = ddl_template.replace(
|
||||||
|
f"VECTOR({os.environ.get('EMBEDDING_DIM', 1024)})", f"VECTOR({embedding_dim})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace table name
|
||||||
|
ddl = ddl.replace(base_table, table_name)
|
||||||
|
|
||||||
|
await db.execute(ddl)
|
||||||
|
|
||||||
|
|
||||||
|
async def _pg_migrate_workspace_data(
|
||||||
|
db: PostgreSQLDB,
|
||||||
|
legacy_table_name: str,
|
||||||
|
new_table_name: str,
|
||||||
|
workspace: str,
|
||||||
|
expected_count: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
) -> int:
|
||||||
|
"""Migrate workspace data from legacy table to new table"""
|
||||||
|
migrated_count = 0
|
||||||
|
offset = 0
|
||||||
|
batch_size = 500
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if workspace:
|
||||||
|
select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 OFFSET $2 LIMIT $3"
|
||||||
|
rows = await db.query(
|
||||||
|
select_query, [workspace, offset, batch_size], multirows=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
select_query = f"SELECT * FROM {legacy_table_name} OFFSET $1 LIMIT $2"
|
||||||
|
rows = await db.query(select_query, [offset, batch_size], multirows=True)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
row_dict = dict(row)
|
||||||
|
columns = list(row_dict.keys())
|
||||||
|
columns_str = ", ".join(columns)
|
||||||
|
placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))])
|
||||||
|
insert_query = f"""
|
||||||
|
INSERT INTO {new_table_name} ({columns_str})
|
||||||
|
VALUES ({placeholders})
|
||||||
|
ON CONFLICT (workspace, id) DO NOTHING
|
||||||
|
"""
|
||||||
|
# Rebuild dict in columns order to ensure values() matches placeholders order
|
||||||
|
# Python 3.7+ dicts maintain insertion order, and execute() uses tuple(data.values())
|
||||||
|
values = {col: row_dict[col] for col in columns}
|
||||||
|
await db.execute(insert_query, values)
|
||||||
|
|
||||||
|
migrated_count += len(rows)
|
||||||
|
workspace_info = f" for workspace '{workspace}'" if workspace else ""
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}"
|
||||||
|
)
|
||||||
|
|
||||||
|
offset += batch_size
|
||||||
|
|
||||||
|
return migrated_count
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGVectorStorage(BaseVectorStorage):
|
class PGVectorStorage(BaseVectorStorage):
|
||||||
|
|
@ -2190,6 +2268,412 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
|
# Generate model suffix for table isolation
|
||||||
|
self.model_suffix = self._generate_collection_suffix()
|
||||||
|
|
||||||
|
# Get base table name
|
||||||
|
base_table = namespace_to_table_name(self.namespace)
|
||||||
|
if not base_table:
|
||||||
|
raise ValueError(f"Unknown namespace: {self.namespace}")
|
||||||
|
|
||||||
|
# New table name (with suffix)
|
||||||
|
# Ensure model_suffix is not empty before appending
|
||||||
|
if self.model_suffix:
|
||||||
|
self.table_name = f"{base_table}_{self.model_suffix}"
|
||||||
|
else:
|
||||||
|
# Fallback: use base table name if model_suffix is unavailable
|
||||||
|
self.table_name = base_table
|
||||||
|
logger.warning(
|
||||||
|
f"Model suffix unavailable, using base table name '{base_table}'. "
|
||||||
|
f"Ensure embedding_func has model_name for proper model isolation."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Legacy table name (without suffix, for migration)
|
||||||
|
self.legacy_table_name = base_table
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"PostgreSQL table naming: "
|
||||||
|
f"new='{self.table_name}', "
|
||||||
|
f"legacy='{self.legacy_table_name}', "
|
||||||
|
f"model_suffix='{self.model_suffix}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def setup_table(
|
||||||
|
db: PostgreSQLDB,
|
||||||
|
table_name: str,
|
||||||
|
legacy_table_name: str = None,
|
||||||
|
base_table: str = None,
|
||||||
|
embedding_dim: int = None,
|
||||||
|
workspace: str = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup PostgreSQL table with migration support from legacy tables.
|
||||||
|
|
||||||
|
This method mirrors Qdrant's setup_collection approach to maintain consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: PostgreSQLDB instance
|
||||||
|
table_name: Name of the new table
|
||||||
|
legacy_table_name: Name of the legacy table (if exists)
|
||||||
|
base_table: Base table name for DDL template lookup
|
||||||
|
embedding_dim: Embedding dimension for vector column
|
||||||
|
"""
|
||||||
|
new_table_exists = await _pg_table_exists(db, table_name)
|
||||||
|
legacy_exists = legacy_table_name and await _pg_table_exists(
|
||||||
|
db, legacy_table_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 1: Both new and legacy tables exist
|
||||||
|
if new_table_exists and legacy_exists:
|
||||||
|
if table_name.lower() == legacy_table_name.lower():
|
||||||
|
logger.debug(
|
||||||
|
f"PostgreSQL: Table '{table_name}' already exists (no model suffix). Skipping Case 1 cleanup."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace_info = f" for workspace '{workspace}'" if workspace else ""
|
||||||
|
|
||||||
|
if workspace:
|
||||||
|
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
|
||||||
|
count_result = await db.query(count_query, [workspace])
|
||||||
|
else:
|
||||||
|
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||||
|
count_result = await db.query(count_query, [])
|
||||||
|
|
||||||
|
workspace_count = count_result.get("count", 0) if count_result else 0
|
||||||
|
|
||||||
|
if workspace_count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Found {workspace_count} records in legacy table{workspace_info}. Migrating..."
|
||||||
|
)
|
||||||
|
|
||||||
|
legacy_dim = None
|
||||||
|
try:
|
||||||
|
dim_query = """
|
||||||
|
SELECT
|
||||||
|
CASE
|
||||||
|
WHEN typname = 'vector' THEN
|
||||||
|
COALESCE(atttypmod, -1)
|
||||||
|
ELSE -1
|
||||||
|
END as vector_dim
|
||||||
|
FROM pg_attribute a
|
||||||
|
JOIN pg_type t ON a.atttypid = t.oid
|
||||||
|
WHERE a.attrelid = $1::regclass
|
||||||
|
AND a.attname = 'content_vector'
|
||||||
|
"""
|
||||||
|
dim_result = await db.query(dim_query, [legacy_table_name])
|
||||||
|
legacy_dim = (
|
||||||
|
dim_result.get("vector_dim", -1) if dim_result else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
if legacy_dim <= 0:
|
||||||
|
sample_query = f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
|
||||||
|
sample_result = await db.query(sample_query, [])
|
||||||
|
if sample_result and sample_result.get("content_vector"):
|
||||||
|
vector_data = sample_result["content_vector"]
|
||||||
|
if isinstance(vector_data, (list, tuple)):
|
||||||
|
legacy_dim = len(vector_data)
|
||||||
|
elif isinstance(vector_data, str):
|
||||||
|
import json
|
||||||
|
|
||||||
|
vector_list = json.loads(vector_data)
|
||||||
|
legacy_dim = len(vector_list)
|
||||||
|
|
||||||
|
if (
|
||||||
|
legacy_dim > 0
|
||||||
|
and embedding_dim
|
||||||
|
and legacy_dim != embedding_dim
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL: Dimension mismatch - "
|
||||||
|
f"legacy table has {legacy_dim}d vectors, "
|
||||||
|
f"new embedding model expects {embedding_dim}d. "
|
||||||
|
f"Skipping migration{workspace_info}."
|
||||||
|
)
|
||||||
|
await db._create_vector_index(table_name, embedding_dim)
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL: Could not verify vector dimension: {e}. Proceeding with caution..."
|
||||||
|
)
|
||||||
|
|
||||||
|
migrated_count = await _pg_migrate_workspace_data(
|
||||||
|
db,
|
||||||
|
legacy_table_name,
|
||||||
|
table_name,
|
||||||
|
workspace,
|
||||||
|
workspace_count,
|
||||||
|
embedding_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if workspace:
|
||||||
|
new_count_query = f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1"
|
||||||
|
new_count_result = await db.query(new_count_query, [workspace])
|
||||||
|
else:
|
||||||
|
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
|
||||||
|
new_count_result = await db.query(new_count_query, [])
|
||||||
|
|
||||||
|
new_count = (
|
||||||
|
new_count_result.get("count", 0) if new_count_result else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_count < workspace_count:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL: Expected {workspace_count} records, found {new_count}{workspace_info}. "
|
||||||
|
f"Some records may have been skipped due to conflicts."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Migration completed: {migrated_count} records migrated{workspace_info}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if workspace:
|
||||||
|
delete_query = (
|
||||||
|
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
|
||||||
|
)
|
||||||
|
await db.execute(delete_query, {"workspace": workspace})
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||||
|
total_count_result = await db.query(total_count_query, [])
|
||||||
|
total_count = (
|
||||||
|
total_count_result.get("count", 0) if total_count_result else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_count == 0:
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty. Deleting..."
|
||||||
|
)
|
||||||
|
drop_query = f"DROP TABLE {legacy_table_name}"
|
||||||
|
await db.execute(drop_query, None)
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved "
|
||||||
|
f"({total_count} records from other workspaces remain)"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL: Error during Case 1 migration: {e}. Vector index will still be ensured."
|
||||||
|
)
|
||||||
|
|
||||||
|
await db._create_vector_index(table_name, embedding_dim)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case 2: Only new table exists - Already migrated or newly created
|
||||||
|
if new_table_exists:
|
||||||
|
logger.debug(f"PostgreSQL: Table '{table_name}' already exists")
|
||||||
|
# Ensure vector index exists with correct embedding dimension
|
||||||
|
await db._create_vector_index(table_name, embedding_dim)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case 3: Neither exists - Create new table
|
||||||
|
if not legacy_exists:
|
||||||
|
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
|
||||||
|
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||||
|
logger.info(f"PostgreSQL: Table '{table_name}' created successfully")
|
||||||
|
# Create vector index with correct embedding dimension
|
||||||
|
await db._create_vector_index(table_name, embedding_dim)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case 4: Only legacy exists - Migrate data
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get legacy table count (with workspace filtering)
|
||||||
|
if workspace:
|
||||||
|
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1"
|
||||||
|
count_result = await db.query(count_query, [workspace])
|
||||||
|
else:
|
||||||
|
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||||
|
count_result = await db.query(count_query, [])
|
||||||
|
logger.warning(
|
||||||
|
"PostgreSQL: Migration without workspace filter - this may copy data from all workspaces!"
|
||||||
|
)
|
||||||
|
|
||||||
|
legacy_count = count_result.get("count", 0) if count_result else 0
|
||||||
|
workspace_info = f" for workspace '{workspace}'" if workspace else ""
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Found {legacy_count} records in legacy table{workspace_info}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if legacy_count == 0:
|
||||||
|
logger.info("PostgreSQL: Legacy table is empty, skipping migration")
|
||||||
|
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||||
|
# Create vector index with correct embedding dimension
|
||||||
|
await db._create_vector_index(table_name, embedding_dim)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check vector dimension compatibility before migration
|
||||||
|
legacy_dim = None
|
||||||
|
try:
|
||||||
|
# Try to get vector dimension from pg_attribute metadata
|
||||||
|
dim_query = """
|
||||||
|
SELECT
|
||||||
|
CASE
|
||||||
|
WHEN typname = 'vector' THEN
|
||||||
|
COALESCE(atttypmod, -1)
|
||||||
|
ELSE -1
|
||||||
|
END as vector_dim
|
||||||
|
FROM pg_attribute a
|
||||||
|
JOIN pg_type t ON a.atttypid = t.oid
|
||||||
|
WHERE a.attrelid = $1::regclass
|
||||||
|
AND a.attname = 'content_vector'
|
||||||
|
"""
|
||||||
|
dim_result = await db.query(dim_query, [legacy_table_name])
|
||||||
|
legacy_dim = dim_result.get("vector_dim", -1) if dim_result else -1
|
||||||
|
|
||||||
|
if legacy_dim <= 0:
|
||||||
|
# Alternative: Try to detect by sampling a vector
|
||||||
|
logger.info(
|
||||||
|
"PostgreSQL: Metadata dimension check failed, trying vector sampling..."
|
||||||
|
)
|
||||||
|
sample_query = (
|
||||||
|
f"SELECT content_vector FROM {legacy_table_name} LIMIT 1"
|
||||||
|
)
|
||||||
|
sample_result = await db.query(sample_query, [])
|
||||||
|
if sample_result and sample_result.get("content_vector"):
|
||||||
|
vector_data = sample_result["content_vector"]
|
||||||
|
# pgvector returns list directly
|
||||||
|
if isinstance(vector_data, (list, tuple)):
|
||||||
|
legacy_dim = len(vector_data)
|
||||||
|
elif isinstance(vector_data, str):
|
||||||
|
import json
|
||||||
|
|
||||||
|
vector_list = json.loads(vector_data)
|
||||||
|
legacy_dim = len(vector_list)
|
||||||
|
|
||||||
|
if legacy_dim > 0 and embedding_dim and legacy_dim != embedding_dim:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL: Dimension mismatch detected! "
|
||||||
|
f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, "
|
||||||
|
f"but new embedding model expects {embedding_dim}d. "
|
||||||
|
f"Migration skipped to prevent data loss. "
|
||||||
|
f"Legacy table preserved as '{legacy_table_name}'. "
|
||||||
|
f"Creating new empty table '{table_name}' for new data."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new table but skip migration
|
||||||
|
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||||
|
await db._create_vector_index(table_name, embedding_dim)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: New table '{table_name}' created. "
|
||||||
|
f"To query legacy data, please use a {legacy_dim}d embedding model."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL: Could not verify legacy table vector dimension: {e}. "
|
||||||
|
f"Proceeding with caution..."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
|
||||||
|
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||||
|
|
||||||
|
migrated_count = await _pg_migrate_workspace_data(
|
||||||
|
db,
|
||||||
|
legacy_table_name,
|
||||||
|
table_name,
|
||||||
|
workspace,
|
||||||
|
legacy_count,
|
||||||
|
embedding_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("PostgreSQL: Verifying migration...")
|
||||||
|
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
|
||||||
|
new_count_result = await db.query(new_count_query, [])
|
||||||
|
new_count = new_count_result.get("count", 0) if new_count_result else 0
|
||||||
|
|
||||||
|
if new_count != legacy_count:
|
||||||
|
error_msg = (
|
||||||
|
f"PostgreSQL: Migration verification failed, "
|
||||||
|
f"expected {legacy_count} records, got {new_count} in new table"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise PostgreSQLMigrationError(error_msg)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Migration completed successfully: {migrated_count} records migrated"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
await db._create_vector_index(table_name, embedding_dim)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if workspace:
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Deleting migrated workspace '{workspace}' data from legacy table '{legacy_table_name}'..."
|
||||||
|
)
|
||||||
|
delete_query = (
|
||||||
|
f"DELETE FROM {legacy_table_name} WHERE workspace = $1"
|
||||||
|
)
|
||||||
|
await db.execute(delete_query, {"workspace": workspace})
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table"
|
||||||
|
)
|
||||||
|
|
||||||
|
remaining_query = (
|
||||||
|
f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||||
|
)
|
||||||
|
remaining_result = await db.query(remaining_query, [])
|
||||||
|
remaining_count = (
|
||||||
|
remaining_result.get("count", 0) if remaining_result else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if remaining_count == 0:
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Legacy table '{legacy_table_name}' is empty, deleting..."
|
||||||
|
)
|
||||||
|
drop_query = f"DROP TABLE {legacy_table_name}"
|
||||||
|
await db.execute(drop_query, None)
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Legacy table '{legacy_table_name}' preserved ({remaining_count} records from other workspaces remain)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL: No workspace specified, deleting entire legacy table '{legacy_table_name}'..."
|
||||||
|
)
|
||||||
|
drop_query = f"DROP TABLE {legacy_table_name}"
|
||||||
|
await db.execute(drop_query, None)
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL: Legacy table '{legacy_table_name}' deleted"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as delete_error:
|
||||||
|
# If cleanup fails, log warning but don't fail migration
|
||||||
|
logger.warning(
|
||||||
|
f"PostgreSQL: Failed to clean up legacy table '{legacy_table_name}': {delete_error}. "
|
||||||
|
"Migration succeeded, but manual cleanup may be needed."
|
||||||
|
)
|
||||||
|
|
||||||
|
except PostgreSQLMigrationError:
|
||||||
|
# Re-raise migration errors without wrapping
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"PostgreSQL: Migration failed with error: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
# Mirror Qdrant behavior: no automatic rollback
|
||||||
|
# Reason: partial data can be continued by re-running migration
|
||||||
|
raise PostgreSQLMigrationError(error_msg) from e
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
async with get_data_init_lock():
|
async with get_data_init_lock():
|
||||||
if self.db is None:
|
if self.db is None:
|
||||||
|
|
@ -2206,6 +2690,16 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
# Use "default" for compatibility (lowest priority)
|
# Use "default" for compatibility (lowest priority)
|
||||||
self.workspace = "default"
|
self.workspace = "default"
|
||||||
|
|
||||||
|
# Setup table (create if not exists and handle migration)
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
self.db,
|
||||||
|
self.table_name,
|
||||||
|
legacy_table_name=self.legacy_table_name,
|
||||||
|
base_table=self.legacy_table_name, # base_table for DDL template lookup
|
||||||
|
embedding_dim=self.embedding_func.embedding_dim,
|
||||||
|
workspace=self.workspace, # CRITICAL: Filter migration by workspace
|
||||||
|
)
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
|
|
@ -2215,7 +2709,9 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
self, item: dict[str, Any], current_time: datetime.datetime
|
self, item: dict[str, Any], current_time: datetime.datetime
|
||||||
) -> tuple[str, dict[str, Any]]:
|
) -> tuple[str, dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
|
upsert_sql = SQL_TEMPLATES["upsert_chunk"].format(
|
||||||
|
table_name=self.table_name
|
||||||
|
)
|
||||||
data: dict[str, Any] = {
|
data: dict[str, Any] = {
|
||||||
"workspace": self.workspace,
|
"workspace": self.workspace,
|
||||||
"id": item["__id__"],
|
"id": item["__id__"],
|
||||||
|
|
@ -2239,7 +2735,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
def _upsert_entities(
|
def _upsert_entities(
|
||||||
self, item: dict[str, Any], current_time: datetime.datetime
|
self, item: dict[str, Any], current_time: datetime.datetime
|
||||||
) -> tuple[str, dict[str, Any]]:
|
) -> tuple[str, dict[str, Any]]:
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_entity"]
|
upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name)
|
||||||
source_id = item["source_id"]
|
source_id = item["source_id"]
|
||||||
if isinstance(source_id, str) and "<SEP>" in source_id:
|
if isinstance(source_id, str) and "<SEP>" in source_id:
|
||||||
chunk_ids = source_id.split("<SEP>")
|
chunk_ids = source_id.split("<SEP>")
|
||||||
|
|
@ -2262,7 +2758,9 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
def _upsert_relationships(
|
def _upsert_relationships(
|
||||||
self, item: dict[str, Any], current_time: datetime.datetime
|
self, item: dict[str, Any], current_time: datetime.datetime
|
||||||
) -> tuple[str, dict[str, Any]]:
|
) -> tuple[str, dict[str, Any]]:
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
upsert_sql = SQL_TEMPLATES["upsert_relationship"].format(
|
||||||
|
table_name=self.table_name
|
||||||
|
)
|
||||||
source_id = item["source_id"]
|
source_id = item["source_id"]
|
||||||
if isinstance(source_id, str) and "<SEP>" in source_id:
|
if isinstance(source_id, str) and "<SEP>" in source_id:
|
||||||
chunk_ids = source_id.split("<SEP>")
|
chunk_ids = source_id.split("<SEP>")
|
||||||
|
|
@ -2335,7 +2833,9 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
|
|
||||||
embedding_string = ",".join(map(str, embedding))
|
embedding_string = ",".join(map(str, embedding))
|
||||||
|
|
||||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
sql = SQL_TEMPLATES[self.namespace].format(
|
||||||
|
embedding_string=embedding_string, table_name=self.table_name
|
||||||
|
)
|
||||||
params = {
|
params = {
|
||||||
"workspace": self.workspace,
|
"workspace": self.workspace,
|
||||||
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
|
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
|
||||||
|
|
@ -2357,14 +2857,9 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
if not ids:
|
if not ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
delete_sql = (
|
||||||
if not table_name:
|
f"DELETE FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)"
|
||||||
logger.error(
|
)
|
||||||
f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
|
await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
|
||||||
|
|
@ -2383,8 +2878,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
entity_name: The name of the entity to delete
|
entity_name: The name of the entity to delete
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Construct SQL to delete the entity
|
# Construct SQL to delete the entity using dynamic table name
|
||||||
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
|
delete_sql = f"""DELETE FROM {self.table_name}
|
||||||
WHERE workspace=$1 AND entity_name=$2"""
|
WHERE workspace=$1 AND entity_name=$2"""
|
||||||
|
|
||||||
await self.db.execute(
|
await self.db.execute(
|
||||||
|
|
@ -2404,7 +2899,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Delete relations where the entity is either the source or target
|
# Delete relations where the entity is either the source or target
|
||||||
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
|
delete_sql = f"""DELETE FROM {self.table_name}
|
||||||
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
|
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
|
||||||
|
|
||||||
await self.db.execute(
|
await self.db.execute(
|
||||||
|
|
@ -2427,14 +2922,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
Returns:
|
Returns:
|
||||||
The vector data if found, or None if not found
|
The vector data if found, or None if not found
|
||||||
"""
|
"""
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id=$2"
|
||||||
if not table_name:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Unknown namespace for ID lookup: {self.namespace}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2"
|
|
||||||
params = {"workspace": self.workspace, "id": id}
|
params = {"workspace": self.workspace, "id": id}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -2460,15 +2948,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
if not ids:
|
if not ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
|
||||||
if not table_name:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Unknown namespace for IDs lookup: {self.namespace}"
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
ids_str = ",".join([f"'{id}'" for id in ids])
|
ids_str = ",".join([f"'{id}'" for id in ids])
|
||||||
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
|
query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})"
|
||||||
params = {"workspace": self.workspace}
|
params = {"workspace": self.workspace}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -2509,15 +2990,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
if not ids:
|
if not ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
|
||||||
if not table_name:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Unknown namespace for vector lookup: {self.namespace}"
|
|
||||||
)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
ids_str = ",".join([f"'{id}'" for id in ids])
|
ids_str = ",".join([f"'{id}'" for id in ids])
|
||||||
query = f"SELECT id, content_vector FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})"
|
query = f"SELECT id, content_vector FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})"
|
||||||
params = {"workspace": self.workspace}
|
params = {"workspace": self.workspace}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -2546,15 +3020,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
try:
|
try:
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
|
||||||
if not table_name:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Unknown namespace: {self.namespace}",
|
|
||||||
}
|
|
||||||
|
|
||||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||||
table_name=table_name
|
table_name=self.table_name
|
||||||
)
|
)
|
||||||
await self.db.execute(drop_sql, {"workspace": self.workspace})
|
await self.db.execute(drop_sql, {"workspace": self.workspace})
|
||||||
return {"status": "success", "message": "data dropped"}
|
return {"status": "success", "message": "data dropped"}
|
||||||
|
|
@ -2593,6 +3060,9 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
# Use "default" for compatibility (lowest priority)
|
# Use "default" for compatibility (lowest priority)
|
||||||
self.workspace = "default"
|
self.workspace = "default"
|
||||||
|
|
||||||
|
# NOTE: Table creation is handled by PostgreSQLDB.initdb() during initialization
|
||||||
|
# No need to create table here as it's already created in the TABLES dict
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
if self.db is not None:
|
||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
|
|
@ -3188,6 +3658,12 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
class PostgreSQLMigrationError(Exception):
|
||||||
|
"""Exception for PostgreSQL table migration errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PGGraphQueryException(Exception):
|
class PGGraphQueryException(Exception):
|
||||||
"""Exception for the AGE queries."""
|
"""Exception for the AGE queries."""
|
||||||
|
|
||||||
|
|
@ -5047,7 +5523,7 @@ SQL_TEMPLATES = {
|
||||||
update_time = EXCLUDED.update_time
|
update_time = EXCLUDED.update_time
|
||||||
""",
|
""",
|
||||||
# SQL for VectorStorage
|
# SQL for VectorStorage
|
||||||
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
"upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens,
|
||||||
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
||||||
create_time, update_time)
|
create_time, update_time)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
|
|
@ -5060,7 +5536,7 @@ SQL_TEMPLATES = {
|
||||||
file_path=EXCLUDED.file_path,
|
file_path=EXCLUDED.file_path,
|
||||||
update_time = EXCLUDED.update_time
|
update_time = EXCLUDED.update_time
|
||||||
""",
|
""",
|
||||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
"upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content,
|
||||||
content_vector, chunk_ids, file_path, create_time, update_time)
|
content_vector, chunk_ids, file_path, create_time, update_time)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
|
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
|
|
@ -5071,7 +5547,7 @@ SQL_TEMPLATES = {
|
||||||
file_path=EXCLUDED.file_path,
|
file_path=EXCLUDED.file_path,
|
||||||
update_time=EXCLUDED.update_time
|
update_time=EXCLUDED.update_time
|
||||||
""",
|
""",
|
||||||
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
"upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id,
|
||||||
target_id, content, content_vector, chunk_ids, file_path, create_time, update_time)
|
target_id, content, content_vector, chunk_ids, file_path, create_time, update_time)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10)
|
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10)
|
||||||
ON CONFLICT (workspace,id) DO UPDATE
|
ON CONFLICT (workspace,id) DO UPDATE
|
||||||
|
|
@ -5087,7 +5563,7 @@ SQL_TEMPLATES = {
|
||||||
SELECT r.source_id AS src_id,
|
SELECT r.source_id AS src_id,
|
||||||
r.target_id AS tgt_id,
|
r.target_id AS tgt_id,
|
||||||
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
|
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
|
||||||
FROM LIGHTRAG_VDB_RELATION r
|
FROM {table_name} r
|
||||||
WHERE r.workspace = $1
|
WHERE r.workspace = $1
|
||||||
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
|
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||||
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
|
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
|
||||||
|
|
@ -5096,7 +5572,7 @@ SQL_TEMPLATES = {
|
||||||
"entities": """
|
"entities": """
|
||||||
SELECT e.entity_name,
|
SELECT e.entity_name,
|
||||||
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
|
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
|
||||||
FROM LIGHTRAG_VDB_ENTITY e
|
FROM {table_name} e
|
||||||
WHERE e.workspace = $1
|
WHERE e.workspace = $1
|
||||||
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
|
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||||
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
|
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
|
||||||
|
|
@ -5107,7 +5583,7 @@ SQL_TEMPLATES = {
|
||||||
c.content,
|
c.content,
|
||||||
c.file_path,
|
c.file_path,
|
||||||
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
|
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
|
||||||
FROM LIGHTRAG_VDB_CHUNKS c
|
FROM {table_name} c
|
||||||
WHERE c.workspace = $1
|
WHERE c.workspace = $1
|
||||||
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
|
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||||
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
|
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,48 @@ def workspace_filter_condition(workspace: str) -> models.FieldCondition:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_legacy_collection(
|
||||||
|
client: QdrantClient, namespace: str, workspace: str = None
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Find legacy collection with backward compatibility support.
|
||||||
|
|
||||||
|
This function tries multiple naming patterns to locate legacy collections
|
||||||
|
created by older versions of LightRAG:
|
||||||
|
|
||||||
|
1. {workspace}_{namespace} - Old format with workspace (pre-model-isolation) - HIGHEST PRIORITY
|
||||||
|
2. lightrag_vdb_{namespace} - Current legacy format
|
||||||
|
3. {namespace} - Old format without workspace (pre-model-isolation)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: QdrantClient instance
|
||||||
|
namespace: Base namespace (e.g., "chunks", "entities")
|
||||||
|
workspace: Optional workspace identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Collection name if found, None otherwise
|
||||||
|
"""
|
||||||
|
# Try multiple naming patterns for backward compatibility
|
||||||
|
# More specific names (with workspace) have higher priority
|
||||||
|
candidates = [
|
||||||
|
f"{workspace}_{namespace}"
|
||||||
|
if workspace
|
||||||
|
else None, # Old format with workspace - most specific
|
||||||
|
f"lightrag_vdb_{namespace}", # New legacy format
|
||||||
|
namespace, # Old format without workspace - most generic
|
||||||
|
]
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate and client.collection_exists(candidate):
|
||||||
|
logger.info(
|
||||||
|
f"Qdrant: Found legacy collection '{candidate}' "
|
||||||
|
f"(namespace={namespace}, workspace={workspace or 'none'})"
|
||||||
|
)
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
|
|
@ -85,28 +127,73 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
def setup_collection(
|
def setup_collection(
|
||||||
client: QdrantClient,
|
client: QdrantClient,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
legacy_namespace: str = None,
|
namespace: str = None,
|
||||||
workspace: str = None,
|
workspace: str = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Setup Qdrant collection with migration support from legacy collections.
|
Setup Qdrant collection with migration support from legacy collections.
|
||||||
|
|
||||||
|
This method now supports backward compatibility by automatically detecting
|
||||||
|
legacy collections created by older versions of LightRAG using multiple
|
||||||
|
naming patterns.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: QdrantClient instance
|
client: QdrantClient instance
|
||||||
collection_name: Name of the new collection
|
collection_name: Name of the new collection
|
||||||
legacy_namespace: Name of the legacy collection (if exists)
|
namespace: Base namespace (e.g., "chunks", "entities")
|
||||||
workspace: Workspace identifier for data isolation
|
workspace: Workspace identifier for data isolation
|
||||||
**kwargs: Additional arguments for collection creation (vectors_config, hnsw_config, etc.)
|
**kwargs: Additional arguments for collection creation (vectors_config, hnsw_config, etc.)
|
||||||
"""
|
"""
|
||||||
new_collection_exists = client.collection_exists(collection_name)
|
new_collection_exists = client.collection_exists(collection_name)
|
||||||
legacy_exists = legacy_namespace and client.collection_exists(legacy_namespace)
|
|
||||||
|
|
||||||
# Case 1: Both new and legacy collections exist - Warning only (no migration)
|
# Try to find legacy collection with backward compatibility
|
||||||
|
legacy_collection = (
|
||||||
|
_find_legacy_collection(client, namespace, workspace) if namespace else None
|
||||||
|
)
|
||||||
|
legacy_exists = legacy_collection is not None
|
||||||
|
|
||||||
|
# Case 1: Both new and legacy collections exist
|
||||||
|
# This can happen if:
|
||||||
|
# 1. Previous migration failed to delete the legacy collection
|
||||||
|
# 2. User manually created both collections
|
||||||
|
# 3. No model suffix (collection_name == legacy_collection)
|
||||||
|
# Strategy: Only delete legacy if it's empty (safe cleanup) and it's not the same as new collection
|
||||||
if new_collection_exists and legacy_exists:
|
if new_collection_exists and legacy_exists:
|
||||||
logger.warning(
|
# CRITICAL: Check if new and legacy are the same collection
|
||||||
f"Qdrant: Legacy collection '{legacy_namespace}' still exist. Remove it if migration is complete."
|
# This happens when model_suffix is empty (no model_name provided)
|
||||||
)
|
if collection_name == legacy_collection:
|
||||||
|
logger.debug(
|
||||||
|
f"Qdrant: Collection '{collection_name}' already exists (no model suffix). Skipping Case 1 cleanup."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if legacy collection is empty
|
||||||
|
legacy_count = client.count(
|
||||||
|
collection_name=legacy_collection, exact=True
|
||||||
|
).count
|
||||||
|
|
||||||
|
if legacy_count == 0:
|
||||||
|
# Legacy collection is empty, safe to delete without data loss
|
||||||
|
logger.info(
|
||||||
|
f"Qdrant: Legacy collection '{legacy_collection}' is empty. Deleting..."
|
||||||
|
)
|
||||||
|
client.delete_collection(collection_name=legacy_collection)
|
||||||
|
logger.info(
|
||||||
|
f"Qdrant: Legacy collection '{legacy_collection}' deleted successfully"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Legacy collection still has data - don't risk deleting it
|
||||||
|
logger.warning(
|
||||||
|
f"Qdrant: Legacy collection '{legacy_collection}' still contains {legacy_count} records. "
|
||||||
|
f"Manual intervention required to verify and delete."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Qdrant: Could not check or cleanup legacy collection '{legacy_collection}': {e}. "
|
||||||
|
"You may need to delete it manually."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Case 2: Only new collection exists - Ensure index exists
|
# Case 2: Only new collection exists - Ensure index exists
|
||||||
|
|
@ -149,13 +236,13 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
|
|
||||||
# Case 4: Only legacy exists - Migrate data
|
# Case 4: Only legacy exists - Migrate data
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Qdrant: Migrating data from legacy collection '{legacy_namespace}'"
|
f"Qdrant: Migrating data from legacy collection '{legacy_collection}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get legacy collection count
|
# Get legacy collection count
|
||||||
legacy_count = client.count(
|
legacy_count = client.count(
|
||||||
collection_name=legacy_namespace, exact=True
|
collection_name=legacy_collection, exact=True
|
||||||
).count
|
).count
|
||||||
logger.info(f"Qdrant: Found {legacy_count} records in legacy collection")
|
logger.info(f"Qdrant: Found {legacy_count} records in legacy collection")
|
||||||
|
|
||||||
|
|
@ -173,6 +260,51 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check vector dimension compatibility before migration
|
||||||
|
try:
|
||||||
|
legacy_info = client.get_collection(legacy_collection)
|
||||||
|
legacy_dim = legacy_info.config.params.vectors.size
|
||||||
|
|
||||||
|
# Get expected dimension from kwargs
|
||||||
|
new_dim = (
|
||||||
|
kwargs.get("vectors_config").size
|
||||||
|
if "vectors_config" in kwargs
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_dim and legacy_dim != new_dim:
|
||||||
|
logger.warning(
|
||||||
|
f"Qdrant: Dimension mismatch detected! "
|
||||||
|
f"Legacy collection '{legacy_collection}' has {legacy_dim}d vectors, "
|
||||||
|
f"but new embedding model expects {new_dim}d. "
|
||||||
|
f"Migration skipped to prevent data loss. "
|
||||||
|
f"Legacy collection preserved as '{legacy_collection}'. "
|
||||||
|
f"Creating new empty collection '{collection_name}' for new data."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new collection but skip migration
|
||||||
|
client.create_collection(collection_name, **kwargs)
|
||||||
|
client.create_payload_index(
|
||||||
|
collection_name=collection_name,
|
||||||
|
field_name=WORKSPACE_ID_FIELD,
|
||||||
|
field_schema=models.KeywordIndexParams(
|
||||||
|
type=models.KeywordIndexType.KEYWORD,
|
||||||
|
is_tenant=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Qdrant: New collection '{collection_name}' created. "
|
||||||
|
f"To query legacy data, please use a {legacy_dim}d embedding model."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Qdrant: Could not verify legacy collection dimension: {e}. "
|
||||||
|
f"Proceeding with caution..."
|
||||||
|
)
|
||||||
|
|
||||||
# Create new collection first
|
# Create new collection first
|
||||||
logger.info(f"Qdrant: Creating new collection '{collection_name}'")
|
logger.info(f"Qdrant: Creating new collection '{collection_name}'")
|
||||||
client.create_collection(collection_name, **kwargs)
|
client.create_collection(collection_name, **kwargs)
|
||||||
|
|
@ -185,7 +317,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
while True:
|
while True:
|
||||||
# Scroll through legacy data
|
# Scroll through legacy data
|
||||||
result = client.scroll(
|
result = client.scroll(
|
||||||
collection_name=legacy_namespace,
|
collection_name=legacy_collection,
|
||||||
limit=batch_size,
|
limit=batch_size,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
with_vectors=True,
|
with_vectors=True,
|
||||||
|
|
@ -258,9 +390,27 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Qdrant: Migration from '{legacy_namespace}' to '{collection_name}' completed successfully"
|
f"Qdrant: Migration from '{legacy_collection}' to '{collection_name}' completed successfully"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Delete legacy collection after successful migration
|
||||||
|
# Data has been verified to match, so legacy collection is no longer needed
|
||||||
|
# and keeping it would cause Case 1 warnings on next startup
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Qdrant: Deleting legacy collection '{legacy_collection}'..."
|
||||||
|
)
|
||||||
|
client.delete_collection(collection_name=legacy_collection)
|
||||||
|
logger.info(
|
||||||
|
f"Qdrant: Legacy collection '{legacy_collection}' deleted successfully"
|
||||||
|
)
|
||||||
|
except Exception as delete_error:
|
||||||
|
# If deletion fails, user will see Case 1 warning on next startup
|
||||||
|
logger.warning(
|
||||||
|
f"Qdrant: Failed to delete legacy collection '{legacy_collection}': {delete_error}. "
|
||||||
|
"You may need to delete it manually."
|
||||||
|
)
|
||||||
|
|
||||||
except QdrantMigrationError:
|
except QdrantMigrationError:
|
||||||
# Re-raise migration errors without wrapping
|
# Re-raise migration errors without wrapping
|
||||||
raise
|
raise
|
||||||
|
|
@ -287,19 +437,34 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
f"Using passed workspace parameter: '{effective_workspace}'"
|
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get legacy namespace for data migration from old version
|
|
||||||
if effective_workspace:
|
|
||||||
self.legacy_namespace = f"{effective_workspace}_{self.namespace}"
|
|
||||||
else:
|
|
||||||
self.legacy_namespace = self.namespace
|
|
||||||
|
|
||||||
self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE
|
self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE
|
||||||
|
|
||||||
# Use a shared collection with payload-based partitioning (Qdrant's recommended approach)
|
# Generate model suffix
|
||||||
# Ref: https://qdrant.tech/documentation/guides/multiple-partitions/
|
model_suffix = self._generate_collection_suffix()
|
||||||
self.final_namespace = f"lightrag_vdb_{self.namespace}"
|
|
||||||
logger.debug(
|
# Legacy collection name (without model suffix, for migration)
|
||||||
f"Using shared collection '{self.final_namespace}' with workspace '{self.effective_workspace}' for payload-based partitioning"
|
# This matches the old naming scheme before model isolation was implemented
|
||||||
|
# Example: "lightrag_vdb_chunks" (without model suffix)
|
||||||
|
self.legacy_namespace = f"lightrag_vdb_{self.namespace}"
|
||||||
|
|
||||||
|
# New naming scheme with model isolation
|
||||||
|
# Example: "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
|
||||||
|
# Ensure model_suffix is not empty before appending
|
||||||
|
if model_suffix:
|
||||||
|
self.final_namespace = f"lightrag_vdb_{self.namespace}_{model_suffix}"
|
||||||
|
else:
|
||||||
|
# Fallback: use legacy namespace if model_suffix is unavailable
|
||||||
|
self.final_namespace = self.legacy_namespace
|
||||||
|
logger.warning(
|
||||||
|
f"Model suffix unavailable, using legacy collection name '{self.legacy_namespace}'. "
|
||||||
|
f"Ensure embedding_func has model_name for proper model isolation."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Qdrant collection naming: "
|
||||||
|
f"new='{self.final_namespace}', "
|
||||||
|
f"legacy='{self.legacy_namespace}', "
|
||||||
|
f"model_suffix='{model_suffix}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
|
@ -315,6 +480,12 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
|
def _get_legacy_collection_name(self) -> str:
|
||||||
|
return self.legacy_namespace
|
||||||
|
|
||||||
|
def _get_new_collection_name(self) -> str:
|
||||||
|
return self.final_namespace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize Qdrant collection"""
|
"""Initialize Qdrant collection"""
|
||||||
async with get_data_init_lock():
|
async with get_data_init_lock():
|
||||||
|
|
@ -338,11 +509,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup collection (create if not exists and configure indexes)
|
# Setup collection (create if not exists and configure indexes)
|
||||||
# Pass legacy_namespace and workspace for migration support
|
# Pass namespace and workspace for backward-compatible migration support
|
||||||
QdrantVectorDBStorage.setup_collection(
|
QdrantVectorDBStorage.setup_collection(
|
||||||
self._client,
|
self._client,
|
||||||
self.final_namespace,
|
self.final_namespace,
|
||||||
legacy_namespace=self.legacy_namespace,
|
namespace=self.namespace,
|
||||||
workspace=self.effective_workspace,
|
workspace=self.effective_workspace,
|
||||||
vectors_config=models.VectorParams(
|
vectors_config=models.VectorParams(
|
||||||
size=self.embedding_func.embedding_dim,
|
size=self.embedding_func.embedding_dim,
|
||||||
|
|
@ -354,6 +525,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize max batch size from config
|
||||||
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
|
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
|
||||||
|
|
|
||||||
|
|
@ -164,16 +164,29 @@ class UnifiedLock(Generic[T]):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Then acquire the main lock
|
# Then acquire the main lock
|
||||||
if self._is_async:
|
if self._lock is not None:
|
||||||
await self._lock.acquire()
|
if self._is_async:
|
||||||
else:
|
await self._lock.acquire()
|
||||||
self._lock.acquire()
|
else:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
direct_log(
|
direct_log(
|
||||||
f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})",
|
f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})",
|
||||||
level="INFO",
|
level="INFO",
|
||||||
enable_output=self._enable_logging,
|
enable_output=self._enable_logging,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# CRITICAL: Raise exception instead of allowing unprotected execution
|
||||||
|
error_msg = (
|
||||||
|
f"CRITICAL: Lock '{self._name}' is None - shared data not initialized. "
|
||||||
|
f"Call initialize_share_data() before using locks!"
|
||||||
|
)
|
||||||
|
direct_log(
|
||||||
|
f"== Lock == Process {self._pid}: {error_msg}",
|
||||||
|
level="ERROR",
|
||||||
|
enable_output=True,
|
||||||
|
)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
return self
|
return self
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If main lock acquisition fails, release the async lock if it was acquired
|
# If main lock acquisition fails, release the async lock if it was acquired
|
||||||
|
|
@ -193,19 +206,21 @@ class UnifiedLock(Generic[T]):
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
main_lock_released = False
|
main_lock_released = False
|
||||||
|
async_lock_released = False
|
||||||
try:
|
try:
|
||||||
# Release main lock first
|
# Release main lock first
|
||||||
if self._is_async:
|
if self._lock is not None:
|
||||||
self._lock.release()
|
if self._is_async:
|
||||||
else:
|
self._lock.release()
|
||||||
self._lock.release()
|
else:
|
||||||
main_lock_released = True
|
self._lock.release()
|
||||||
|
|
||||||
direct_log(
|
direct_log(
|
||||||
f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
|
f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})",
|
||||||
level="INFO",
|
level="INFO",
|
||||||
enable_output=self._enable_logging,
|
enable_output=self._enable_logging,
|
||||||
)
|
)
|
||||||
|
main_lock_released = True
|
||||||
|
|
||||||
# Then release async lock if in multiprocess mode
|
# Then release async lock if in multiprocess mode
|
||||||
if not self._is_async and self._async_lock is not None:
|
if not self._is_async and self._async_lock is not None:
|
||||||
|
|
@ -215,6 +230,7 @@ class UnifiedLock(Generic[T]):
|
||||||
level="DEBUG",
|
level="DEBUG",
|
||||||
enable_output=self._enable_logging,
|
enable_output=self._enable_logging,
|
||||||
)
|
)
|
||||||
|
async_lock_released = True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
direct_log(
|
direct_log(
|
||||||
|
|
@ -223,9 +239,10 @@ class UnifiedLock(Generic[T]):
|
||||||
enable_output=True,
|
enable_output=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If main lock release failed but async lock hasn't been released, try to release it
|
# If main lock release failed but async lock hasn't been attempted yet, try to release it
|
||||||
if (
|
if (
|
||||||
not main_lock_released
|
not main_lock_released
|
||||||
|
and not async_lock_released
|
||||||
and not self._is_async
|
and not self._is_async
|
||||||
and self._async_lock is not None
|
and self._async_lock is not None
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -518,14 +518,10 @@ class LightRAG:
|
||||||
f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
|
f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fix global_config now
|
|
||||||
global_config = asdict(self)
|
|
||||||
|
|
||||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
|
||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
|
||||||
|
|
||||||
# Init Embedding
|
# Init Embedding
|
||||||
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
|
# Step 1: Capture embedding_func and max_token_size before applying decorator
|
||||||
|
# (decorator strips dataclass attributes, and asdict() converts EmbeddingFunc to dict)
|
||||||
|
original_embedding_func = self.embedding_func
|
||||||
embedding_max_token_size = None
|
embedding_max_token_size = None
|
||||||
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
|
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
|
||||||
embedding_max_token_size = self.embedding_func.max_token_size
|
embedding_max_token_size = self.embedding_func.max_token_size
|
||||||
|
|
@ -534,6 +530,14 @@ class LightRAG:
|
||||||
)
|
)
|
||||||
self.embedding_token_limit = embedding_max_token_size
|
self.embedding_token_limit = embedding_max_token_size
|
||||||
|
|
||||||
|
# Fix global_config now
|
||||||
|
global_config = asdict(self)
|
||||||
|
# Restore original EmbeddingFunc object (asdict converts it to dict)
|
||||||
|
global_config["embedding_func"] = original_embedding_func
|
||||||
|
|
||||||
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||||
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
# Step 2: Apply priority wrapper decorator
|
# Step 2: Apply priority wrapper decorator
|
||||||
self.embedding_func = priority_limit_async_func_call(
|
self.embedding_func = priority_limit_async_func_call(
|
||||||
self.embedding_func_max_async,
|
self.embedding_func_max_async,
|
||||||
|
|
|
||||||
|
|
@ -425,6 +425,19 @@ class EmbeddingFunc:
|
||||||
send_dimensions: bool = (
|
send_dimensions: bool = (
|
||||||
False # Control whether to send embedding_dim to the function
|
False # Control whether to send embedding_dim to the function
|
||||||
)
|
)
|
||||||
|
model_name: str | None = None
|
||||||
|
|
||||||
|
def get_model_identifier(self) -> str:
|
||||||
|
"""Generates model identifier for collection/table suffix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Format "{model_name}_{dim}d", e.g. "text_embedding_3_large_3072d"
|
||||||
|
If model_name is not specified, returns "unknown_{dim}d"
|
||||||
|
"""
|
||||||
|
model_part = self.model_name if self.model_name else "unknown"
|
||||||
|
# Clean model name: remove special chars, convert to lower, replace - with _
|
||||||
|
safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_part.lower())
|
||||||
|
return f"{safe_model_name}_{self.embedding_dim}d"
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
# Only inject embedding_dim when send_dimensions is True
|
# Only inject embedding_dim when send_dimensions is True
|
||||||
|
|
|
||||||
55
tests/test_base_storage_integrity.py
Normal file
55
tests/test_base_storage_integrity.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
import pytest
|
||||||
|
from lightrag.base import BaseVectorStorage
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_vector_storage_integrity():
|
||||||
|
# Just checking if we can import and inspect the class
|
||||||
|
assert hasattr(BaseVectorStorage, "_generate_collection_suffix")
|
||||||
|
assert hasattr(BaseVectorStorage, "_get_legacy_collection_name")
|
||||||
|
assert hasattr(BaseVectorStorage, "_get_new_collection_name")
|
||||||
|
|
||||||
|
# Verify methods raise NotImplementedError
|
||||||
|
class ConcreteStorage(BaseVectorStorage):
|
||||||
|
async def query(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def upsert(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_entity(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_entity_relation(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_by_id(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_by_ids(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_vectors_by_ids(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def index_done_callback(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def drop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
func = EmbeddingFunc(embedding_dim=128, func=lambda x: x)
|
||||||
|
storage = ConcreteStorage(
|
||||||
|
namespace="test", workspace="test", global_config={}, embedding_func=func
|
||||||
|
)
|
||||||
|
|
||||||
|
assert storage._generate_collection_suffix() == "unknown_128d"
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
storage._get_legacy_collection_name()
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
storage._get_new_collection_name()
|
||||||
316
tests/test_dimension_mismatch.py
Normal file
316
tests/test_dimension_mismatch.py
Normal file
|
|
@ -0,0 +1,316 @@
|
||||||
|
"""
|
||||||
|
Tests for dimension mismatch handling during migration.
|
||||||
|
|
||||||
|
This test module verifies that both PostgreSQL and Qdrant storage backends
|
||||||
|
properly detect and handle vector dimension mismatches when migrating from
|
||||||
|
legacy collections/tables to new ones with different embedding models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
|
||||||
|
from lightrag.kg.postgres_impl import PGVectorStorage
|
||||||
|
|
||||||
|
|
||||||
|
# Note: Tests should use proper table names that have DDL templates
|
||||||
|
# Valid base tables: LIGHTRAG_VDB_CHUNKS, LIGHTRAG_VDB_ENTITIES, LIGHTRAG_VDB_RELATIONSHIPS,
|
||||||
|
# LIGHTRAG_DOC_CHUNKS, LIGHTRAG_DOC_FULL_DOCS, LIGHTRAG_DOC_TEXT_CHUNKS
|
||||||
|
|
||||||
|
|
||||||
|
class TestQdrantDimensionMismatch:
|
||||||
|
"""Test suite for Qdrant dimension mismatch handling."""
|
||||||
|
|
||||||
|
def test_qdrant_dimension_mismatch_skip_migration(self):
|
||||||
|
"""
|
||||||
|
Test that Qdrant skips migration when dimensions don't match.
|
||||||
|
|
||||||
|
Scenario: Legacy collection has 1536d vectors, new model expects 3072d.
|
||||||
|
Expected: Migration skipped, new empty collection created, legacy preserved.
|
||||||
|
"""
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
# Setup mock client
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Mock legacy collection with 1536d vectors
|
||||||
|
legacy_collection_info = MagicMock()
|
||||||
|
legacy_collection_info.config.params.vectors.size = 1536
|
||||||
|
|
||||||
|
# Setup collection existence checks
|
||||||
|
def collection_exists_side_effect(name):
|
||||||
|
if name == "lightrag_chunks": # legacy
|
||||||
|
return True
|
||||||
|
elif name == "lightrag_chunks_model_3072d": # new
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
client.collection_exists.side_effect = collection_exists_side_effect
|
||||||
|
client.get_collection.return_value = legacy_collection_info
|
||||||
|
client.count.return_value.count = 100 # Legacy has data
|
||||||
|
|
||||||
|
# Call setup_collection with 3072d (different from legacy 1536d)
|
||||||
|
QdrantVectorDBStorage.setup_collection(
|
||||||
|
client,
|
||||||
|
"lightrag_chunks_model_3072d",
|
||||||
|
namespace="chunks",
|
||||||
|
workspace="test",
|
||||||
|
vectors_config=models.VectorParams(
|
||||||
|
size=3072, distance=models.Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify new collection was created
|
||||||
|
client.create_collection.assert_called_once()
|
||||||
|
|
||||||
|
# Verify migration was NOT attempted (no scroll/upsert calls)
|
||||||
|
client.scroll.assert_not_called()
|
||||||
|
client.upsert.assert_not_called()
|
||||||
|
|
||||||
|
def test_qdrant_dimension_match_proceed_migration(self):
|
||||||
|
"""
|
||||||
|
Test that Qdrant proceeds with migration when dimensions match.
|
||||||
|
|
||||||
|
Scenario: Legacy collection has 1536d vectors, new model also expects 1536d.
|
||||||
|
Expected: Migration proceeds normally.
|
||||||
|
"""
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Mock legacy collection with 1536d vectors (matching new)
|
||||||
|
legacy_collection_info = MagicMock()
|
||||||
|
legacy_collection_info.config.params.vectors.size = 1536
|
||||||
|
|
||||||
|
def collection_exists_side_effect(name):
|
||||||
|
if name == "lightrag_chunks": # legacy
|
||||||
|
return True
|
||||||
|
elif name == "lightrag_chunks_model_1536d": # new
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
client.collection_exists.side_effect = collection_exists_side_effect
|
||||||
|
client.get_collection.return_value = legacy_collection_info
|
||||||
|
client.count.return_value.count = 100 # Legacy has data
|
||||||
|
|
||||||
|
# Mock scroll to return sample data
|
||||||
|
sample_point = MagicMock()
|
||||||
|
sample_point.id = "test_id"
|
||||||
|
sample_point.vector = [0.1] * 1536
|
||||||
|
sample_point.payload = {"id": "test"}
|
||||||
|
client.scroll.return_value = ([sample_point], None)
|
||||||
|
|
||||||
|
# Mock _find_legacy_collection to return the legacy collection name
|
||||||
|
with patch(
|
||||||
|
"lightrag.kg.qdrant_impl._find_legacy_collection",
|
||||||
|
return_value="lightrag_chunks",
|
||||||
|
):
|
||||||
|
# Call setup_collection with matching 1536d
|
||||||
|
QdrantVectorDBStorage.setup_collection(
|
||||||
|
client,
|
||||||
|
"lightrag_chunks_model_1536d",
|
||||||
|
namespace="chunks",
|
||||||
|
workspace="test",
|
||||||
|
vectors_config=models.VectorParams(
|
||||||
|
size=1536, distance=models.Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify migration WAS attempted
|
||||||
|
client.create_collection.assert_called_once()
|
||||||
|
client.scroll.assert_called()
|
||||||
|
client.upsert.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostgresDimensionMismatch:
|
||||||
|
"""Test suite for PostgreSQL dimension mismatch handling."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_dimension_mismatch_skip_migration_metadata(self):
|
||||||
|
"""
|
||||||
|
Test that PostgreSQL skips migration when dimensions don't match (via metadata).
|
||||||
|
|
||||||
|
Scenario: Legacy table has 1536d vectors (detected via pg_attribute),
|
||||||
|
new model expects 3072d.
|
||||||
|
Expected: Migration skipped, new empty table created, legacy preserved.
|
||||||
|
"""
|
||||||
|
# Setup mock database
|
||||||
|
db = AsyncMock()
|
||||||
|
|
||||||
|
# Mock table existence and dimension checks
|
||||||
|
async def query_side_effect(query, params, **kwargs):
|
||||||
|
if "information_schema.tables" in query:
|
||||||
|
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
|
||||||
|
return {"exists": True}
|
||||||
|
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
|
||||||
|
return {"exists": False}
|
||||||
|
elif "COUNT(*)" in query:
|
||||||
|
return {"count": 100} # Legacy has data
|
||||||
|
elif "pg_attribute" in query:
|
||||||
|
return {"vector_dim": 1536} # Legacy has 1536d vectors
|
||||||
|
return {}
|
||||||
|
|
||||||
|
db.query.side_effect = query_side_effect
|
||||||
|
db.execute = AsyncMock()
|
||||||
|
db._create_vector_index = AsyncMock()
|
||||||
|
|
||||||
|
# Call setup_table with 3072d (different from legacy 1536d)
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
db,
|
||||||
|
"LIGHTRAG_DOC_CHUNKS_model_3072d",
|
||||||
|
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
|
||||||
|
base_table="LIGHTRAG_DOC_CHUNKS",
|
||||||
|
embedding_dim=3072,
|
||||||
|
workspace="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify migration was NOT attempted (no INSERT calls)
|
||||||
|
# Note: _pg_create_table is mocked, so we check INSERT calls to verify migration was skipped
|
||||||
|
insert_calls = [
|
||||||
|
call
|
||||||
|
for call in db.execute.call_args_list
|
||||||
|
if call[0][0] and "INSERT INTO" in call[0][0]
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(insert_calls) == 0
|
||||||
|
), "Migration should be skipped due to dimension mismatch"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_dimension_mismatch_skip_migration_sampling(self):
|
||||||
|
"""
|
||||||
|
Test that PostgreSQL skips migration when dimensions don't match (via sampling).
|
||||||
|
|
||||||
|
Scenario: Legacy table dimension detection fails via metadata,
|
||||||
|
falls back to vector sampling, detects 1536d vs expected 3072d.
|
||||||
|
Expected: Migration skipped, new empty table created, legacy preserved.
|
||||||
|
"""
|
||||||
|
db = AsyncMock()
|
||||||
|
|
||||||
|
# Mock table existence and dimension checks
|
||||||
|
async def query_side_effect(query, params, **kwargs):
|
||||||
|
if "information_schema.tables" in query:
|
||||||
|
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
|
||||||
|
return {"exists": True}
|
||||||
|
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new
|
||||||
|
return {"exists": False}
|
||||||
|
elif "COUNT(*)" in query:
|
||||||
|
return {"count": 100} # Legacy has data
|
||||||
|
elif "pg_attribute" in query:
|
||||||
|
return {"vector_dim": -1} # Metadata check fails
|
||||||
|
elif "SELECT content_vector FROM" in query:
|
||||||
|
# Return sample vector with 1536 dimensions
|
||||||
|
return {"content_vector": [0.1] * 1536}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
db.query.side_effect = query_side_effect
|
||||||
|
db.execute = AsyncMock()
|
||||||
|
db._create_vector_index = AsyncMock()
|
||||||
|
|
||||||
|
# Call setup_table with 3072d (different from legacy 1536d)
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
db,
|
||||||
|
"LIGHTRAG_DOC_CHUNKS_model_3072d",
|
||||||
|
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
|
||||||
|
base_table="LIGHTRAG_DOC_CHUNKS",
|
||||||
|
embedding_dim=3072,
|
||||||
|
workspace="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify new table was created
|
||||||
|
create_table_calls = [
|
||||||
|
call
|
||||||
|
for call in db.execute.call_args_list
|
||||||
|
if call[0][0] and "CREATE TABLE" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(create_table_calls) > 0, "New table should be created"
|
||||||
|
|
||||||
|
# Verify migration was NOT attempted
|
||||||
|
insert_calls = [
|
||||||
|
call
|
||||||
|
for call in db.execute.call_args_list
|
||||||
|
if call[0][0] and "INSERT INTO" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(insert_calls) == 0, "Migration should be skipped"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_dimension_match_proceed_migration(self):
|
||||||
|
"""
|
||||||
|
Test that PostgreSQL proceeds with migration when dimensions match.
|
||||||
|
|
||||||
|
Scenario: Legacy table has 1536d vectors, new model also expects 1536d.
|
||||||
|
Expected: Migration proceeds normally.
|
||||||
|
"""
|
||||||
|
db = AsyncMock()
|
||||||
|
|
||||||
|
async def query_side_effect(query, params, **kwargs):
|
||||||
|
multirows = kwargs.get("multirows", False)
|
||||||
|
|
||||||
|
if "information_schema.tables" in query:
|
||||||
|
if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy
|
||||||
|
return {"exists": True}
|
||||||
|
elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new
|
||||||
|
return {"exists": False}
|
||||||
|
elif "COUNT(*)" in query:
|
||||||
|
return {"count": 100} # Legacy has data
|
||||||
|
elif "pg_attribute" in query:
|
||||||
|
return {"vector_dim": 1536} # Legacy has matching 1536d
|
||||||
|
elif "SELECT * FROM" in query and multirows:
|
||||||
|
# Return sample data for migration (first batch)
|
||||||
|
# Handle workspace filtering: params = [workspace, offset, limit]
|
||||||
|
if "WHERE workspace" in query:
|
||||||
|
offset = params[1] if len(params) > 1 else 0
|
||||||
|
else:
|
||||||
|
offset = params[0] if params else 0
|
||||||
|
|
||||||
|
if offset == 0: # First batch
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": "test1",
|
||||||
|
"content_vector": [0.1] * 1536,
|
||||||
|
"workspace": "test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "test2",
|
||||||
|
"content_vector": [0.2] * 1536,
|
||||||
|
"workspace": "test",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else: # offset > 0
|
||||||
|
return [] # No more data
|
||||||
|
return {}
|
||||||
|
|
||||||
|
db.query.side_effect = query_side_effect
|
||||||
|
db.execute = AsyncMock()
|
||||||
|
db._create_vector_index = AsyncMock()
|
||||||
|
|
||||||
|
# Mock _pg_table_exists
|
||||||
|
async def mock_table_exists(db_inst, name):
|
||||||
|
if name == "LIGHTRAG_DOC_CHUNKS": # legacy exists
|
||||||
|
return True
|
||||||
|
elif name == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new doesn't exist
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||||
|
side_effect=mock_table_exists,
|
||||||
|
):
|
||||||
|
# Call setup_table with matching 1536d
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
db,
|
||||||
|
"LIGHTRAG_DOC_CHUNKS_model_1536d",
|
||||||
|
legacy_table_name="LIGHTRAG_DOC_CHUNKS",
|
||||||
|
base_table="LIGHTRAG_DOC_CHUNKS",
|
||||||
|
embedding_dim=1536,
|
||||||
|
workspace="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify migration WAS attempted (INSERT calls made)
|
||||||
|
insert_calls = [
|
||||||
|
call
|
||||||
|
for call in db.execute.call_args_list
|
||||||
|
if call[0][0] and "INSERT INTO" in call[0][0]
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(insert_calls) > 0
|
||||||
|
), "Migration should proceed with matching dimensions"
|
||||||
1639
tests/test_e2e_multi_instance.py
Normal file
1639
tests/test_e2e_multi_instance.py
Normal file
File diff suppressed because it is too large
Load diff
31
tests/test_embedding_func.py
Normal file
31
tests/test_embedding_func.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_func(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_func_with_model_name():
|
||||||
|
func = EmbeddingFunc(
|
||||||
|
embedding_dim=1536, func=dummy_func, model_name="text-embedding-ada-002"
|
||||||
|
)
|
||||||
|
assert func.get_model_identifier() == "text_embedding_ada_002_1536d"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_func_without_model_name():
|
||||||
|
func = EmbeddingFunc(embedding_dim=768, func=dummy_func)
|
||||||
|
assert func.get_model_identifier() == "unknown_768d"
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_name_sanitization():
|
||||||
|
func = EmbeddingFunc(
|
||||||
|
embedding_dim=1024,
|
||||||
|
func=dummy_func,
|
||||||
|
model_name="models/text-embedding-004", # Contains special chars
|
||||||
|
)
|
||||||
|
assert func.get_model_identifier() == "models_text_embedding_004_1024d"
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_name_with_uppercase():
|
||||||
|
func = EmbeddingFunc(embedding_dim=512, func=dummy_func, model_name="My-Model-V1")
|
||||||
|
assert func.get_model_identifier() == "my_model_v1_512d"
|
||||||
213
tests/test_no_model_suffix_safety.py
Normal file
213
tests/test_no_model_suffix_safety.py
Normal file
|
|
@ -0,0 +1,213 @@
|
||||||
|
"""
|
||||||
|
Tests for safety when model suffix is absent (no model_name provided).
|
||||||
|
|
||||||
|
This test module verifies that the system correctly handles the case when
|
||||||
|
no model_name is provided, preventing accidental deletion of the only table/collection
|
||||||
|
on restart.
|
||||||
|
|
||||||
|
Critical Bug: When model_suffix is empty, table_name == legacy_table_name.
|
||||||
|
On second startup, Case 1 logic would delete the only table/collection thinking
|
||||||
|
it's "legacy", causing all subsequent operations to fail.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
|
||||||
|
from lightrag.kg.postgres_impl import PGVectorStorage
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoModelSuffixSafety:
|
||||||
|
"""Test suite for preventing data loss when model_suffix is absent."""
|
||||||
|
|
||||||
|
def test_qdrant_no_suffix_second_startup(self):
|
||||||
|
"""
|
||||||
|
Test Qdrant doesn't delete collection on second startup when no model_name.
|
||||||
|
|
||||||
|
Scenario:
|
||||||
|
1. First startup: Creates collection without suffix
|
||||||
|
2. Collection is empty
|
||||||
|
3. Second startup: Should NOT delete the collection
|
||||||
|
|
||||||
|
Bug: Without fix, Case 1 would delete the only collection.
|
||||||
|
"""
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Simulate second startup: collection already exists and is empty
|
||||||
|
# IMPORTANT: Without suffix, collection_name == legacy collection name
|
||||||
|
collection_name = "lightrag_vdb_chunks" # No suffix, same as legacy
|
||||||
|
|
||||||
|
# Both exist (they're the same collection)
|
||||||
|
client.collection_exists.return_value = True
|
||||||
|
|
||||||
|
# Collection is empty
|
||||||
|
client.count.return_value.count = 0
|
||||||
|
|
||||||
|
# Call setup_collection
|
||||||
|
# This should detect that new == legacy and skip deletion
|
||||||
|
QdrantVectorDBStorage.setup_collection(
|
||||||
|
client,
|
||||||
|
collection_name,
|
||||||
|
namespace="chunks",
|
||||||
|
workspace=None,
|
||||||
|
vectors_config=models.VectorParams(
|
||||||
|
size=1536, distance=models.Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# CRITICAL: Collection should NOT be deleted
|
||||||
|
client.delete_collection.assert_not_called()
|
||||||
|
|
||||||
|
# Verify we returned early (skipped Case 1 cleanup)
|
||||||
|
# The collection_exists was checked, but we didn't proceed to count
|
||||||
|
# because we detected same name
|
||||||
|
assert client.collection_exists.call_count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_no_suffix_second_startup(self):
|
||||||
|
"""
|
||||||
|
Test PostgreSQL doesn't delete table on second startup when no model_name.
|
||||||
|
|
||||||
|
Scenario:
|
||||||
|
1. First startup: Creates table without suffix
|
||||||
|
2. Table is empty
|
||||||
|
3. Second startup: Should NOT delete the table
|
||||||
|
|
||||||
|
Bug: Without fix, Case 1 would delete the only table.
|
||||||
|
"""
|
||||||
|
db = AsyncMock()
|
||||||
|
|
||||||
|
# Simulate second startup: table already exists and is empty
|
||||||
|
# IMPORTANT: table_name and legacy_table_name are THE SAME
|
||||||
|
table_name = "LIGHTRAG_VDB_CHUNKS" # No suffix
|
||||||
|
legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Same as new
|
||||||
|
|
||||||
|
# Setup mock responses
|
||||||
|
async def table_exists_side_effect(db_instance, name):
|
||||||
|
# Both tables exist (they're the same)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Mock _pg_table_exists function
|
||||||
|
with patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||||
|
side_effect=table_exists_side_effect,
|
||||||
|
):
|
||||||
|
# Call setup_table
|
||||||
|
# This should detect that new == legacy and skip deletion
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
db,
|
||||||
|
table_name,
|
||||||
|
legacy_table_name=legacy_table_name,
|
||||||
|
base_table="LIGHTRAG_VDB_CHUNKS",
|
||||||
|
embedding_dim=1536,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CRITICAL: Table should NOT be deleted (no DROP TABLE)
|
||||||
|
drop_calls = [
|
||||||
|
call
|
||||||
|
for call in db.execute.call_args_list
|
||||||
|
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(drop_calls) == 0
|
||||||
|
), "Should not drop table when new and legacy are the same"
|
||||||
|
|
||||||
|
# Also should not try to count (we returned early)
|
||||||
|
count_calls = [
|
||||||
|
call
|
||||||
|
for call in db.query.call_args_list
|
||||||
|
if call[0][0] and "COUNT(*)" in call[0][0]
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(count_calls) == 0
|
||||||
|
), "Should not check count when new and legacy are the same"
|
||||||
|
|
||||||
|
def test_qdrant_with_suffix_case1_still_works(self):
|
||||||
|
"""
|
||||||
|
Test that Case 1 cleanup still works when there IS a suffix.
|
||||||
|
|
||||||
|
This ensures our fix doesn't break the normal Case 1 scenario.
|
||||||
|
"""
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Different names (normal case)
|
||||||
|
collection_name = "lightrag_vdb_chunks_ada_002_1536d" # With suffix
|
||||||
|
legacy_collection = "lightrag_vdb_chunks" # Without suffix
|
||||||
|
|
||||||
|
# Setup: both exist
|
||||||
|
def collection_exists_side_effect(name):
|
||||||
|
return name in [collection_name, legacy_collection]
|
||||||
|
|
||||||
|
client.collection_exists.side_effect = collection_exists_side_effect
|
||||||
|
|
||||||
|
# Legacy is empty
|
||||||
|
client.count.return_value.count = 0
|
||||||
|
|
||||||
|
# Call setup_collection
|
||||||
|
QdrantVectorDBStorage.setup_collection(
|
||||||
|
client,
|
||||||
|
collection_name,
|
||||||
|
namespace="chunks",
|
||||||
|
workspace=None,
|
||||||
|
vectors_config=models.VectorParams(
|
||||||
|
size=1536, distance=models.Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# SHOULD delete legacy (normal Case 1 behavior)
|
||||||
|
client.delete_collection.assert_called_once_with(
|
||||||
|
collection_name=legacy_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_with_suffix_case1_still_works(self):
|
||||||
|
"""
|
||||||
|
Test that Case 1 cleanup still works when there IS a suffix.
|
||||||
|
|
||||||
|
This ensures our fix doesn't break the normal Case 1 scenario.
|
||||||
|
"""
|
||||||
|
db = AsyncMock()
|
||||||
|
|
||||||
|
# Different names (normal case)
|
||||||
|
table_name = "LIGHTRAG_VDB_CHUNKS_ADA_002_1536D" # With suffix
|
||||||
|
legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Without suffix
|
||||||
|
|
||||||
|
# Setup mock responses
|
||||||
|
async def table_exists_side_effect(db_instance, name):
|
||||||
|
# Both tables exist
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Mock empty table
|
||||||
|
async def query_side_effect(sql, params, **kwargs):
|
||||||
|
if "COUNT(*)" in sql:
|
||||||
|
return {"count": 0}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
db.query.side_effect = query_side_effect
|
||||||
|
|
||||||
|
# Mock _pg_table_exists function
|
||||||
|
with patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||||
|
side_effect=table_exists_side_effect,
|
||||||
|
):
|
||||||
|
# Call setup_table
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
db,
|
||||||
|
table_name,
|
||||||
|
legacy_table_name=legacy_table_name,
|
||||||
|
base_table="LIGHTRAG_VDB_CHUNKS",
|
||||||
|
embedding_dim=1536,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SHOULD delete legacy (normal Case 1 behavior)
|
||||||
|
drop_calls = [
|
||||||
|
call
|
||||||
|
for call in db.execute.call_args_list
|
||||||
|
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(drop_calls) == 1, "Should drop legacy table in normal Case 1"
|
||||||
|
assert legacy_table_name in drop_calls[0][0][0]
|
||||||
805
tests/test_postgres_migration.py
Normal file
805
tests/test_postgres_migration.py
Normal file
|
|
@ -0,0 +1,805 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, AsyncMock
|
||||||
|
import numpy as np
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from lightrag.kg.postgres_impl import (
|
||||||
|
PGVectorStorage,
|
||||||
|
)
|
||||||
|
from lightrag.namespace import NameSpace
|
||||||
|
|
||||||
|
|
||||||
|
# Mock PostgreSQLDB
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pg_db():
|
||||||
|
"""Mock PostgreSQL database connection"""
|
||||||
|
db = AsyncMock()
|
||||||
|
db.workspace = "test_workspace"
|
||||||
|
|
||||||
|
# Mock query responses with multirows support
|
||||||
|
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||||
|
# Default return value
|
||||||
|
if multirows:
|
||||||
|
return [] # Return empty list for multirows
|
||||||
|
return {"exists": False, "count": 0}
|
||||||
|
|
||||||
|
# Mock for execute that mimics PostgreSQLDB.execute() behavior
|
||||||
|
async def mock_execute(sql, data=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Mock that mimics PostgreSQLDB.execute() behavior:
|
||||||
|
- Accepts data as dict[str, Any] | None (second parameter)
|
||||||
|
- Internally converts dict.values() to tuple for AsyncPG
|
||||||
|
"""
|
||||||
|
# Mimic real execute() which accepts dict and converts to tuple
|
||||||
|
if data is not None and not isinstance(data, dict):
|
||||||
|
raise TypeError(
|
||||||
|
f"PostgreSQLDB.execute() expects data as dict, got {type(data).__name__}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
db.query = AsyncMock(side_effect=mock_query)
|
||||||
|
db.execute = AsyncMock(side_effect=mock_execute)
|
||||||
|
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
# Mock get_data_init_lock to avoid async lock issues in tests
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_data_init_lock():
|
||||||
|
with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock:
|
||||||
|
mock_lock_ctx = AsyncMock()
|
||||||
|
mock_lock.return_value = mock_lock_ctx
|
||||||
|
yield mock_lock
|
||||||
|
|
||||||
|
|
||||||
|
# Mock ClientManager
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_client_manager(mock_pg_db):
|
||||||
|
with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager:
|
||||||
|
mock_manager.get_client = AsyncMock(return_value=mock_pg_db)
|
||||||
|
mock_manager.release_client = AsyncMock()
|
||||||
|
yield mock_manager
|
||||||
|
|
||||||
|
|
||||||
|
# Mock Embedding function
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedding_func():
|
||||||
|
async def embed_func(texts, **kwargs):
|
||||||
|
return np.array([[0.1] * 768 for _ in texts])
|
||||||
|
|
||||||
|
func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model")
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_table_naming(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""Test if table name is correctly generated with model suffix"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=mock_embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify table name contains model suffix
|
||||||
|
expected_suffix = "test_model_768d"
|
||||||
|
assert expected_suffix in storage.table_name
|
||||||
|
assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}"
|
||||||
|
|
||||||
|
# Verify legacy table name
|
||||||
|
assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_migration_trigger(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""Test if migration logic is triggered correctly"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=mock_embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup mocks for migration scenario
|
||||||
|
# 1. New table does not exist, legacy table exists
|
||||||
|
async def mock_table_exists(db, table_name):
|
||||||
|
return table_name == storage.legacy_table_name
|
||||||
|
|
||||||
|
# 2. Legacy table has 100 records
|
||||||
|
mock_rows = [
|
||||||
|
{"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"}
|
||||||
|
for i in range(100)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||||
|
if "COUNT(*)" in sql:
|
||||||
|
return {"count": 100}
|
||||||
|
elif multirows and "SELECT *" in sql:
|
||||||
|
# Mock batch fetch for migration
|
||||||
|
# Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit]
|
||||||
|
if "WHERE workspace" in sql:
|
||||||
|
# With workspace filter: params[0]=workspace, params[1]=offset, params[2]=limit
|
||||||
|
offset = params[1] if len(params) > 1 else 0
|
||||||
|
limit = params[2] if len(params) > 2 else 500
|
||||||
|
else:
|
||||||
|
# No workspace filter: params[0]=offset, params[1]=limit
|
||||||
|
offset = params[0] if params else 0
|
||||||
|
limit = params[1] if len(params) > 1 else 500
|
||||||
|
start = offset
|
||||||
|
end = min(offset + limit, len(mock_rows))
|
||||||
|
return mock_rows[start:end]
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()),
|
||||||
|
):
|
||||||
|
# Initialize storage (should trigger migration)
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify migration was executed
|
||||||
|
# Check that execute was called for inserting rows
|
||||||
|
assert mock_pg_db.execute.call_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_no_migration_needed(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""Test scenario where new table already exists (no migration needed)"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=mock_embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock: new table already exists
|
||||||
|
async def mock_table_exists(db, table_name):
|
||||||
|
return table_name == storage.table_name
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create,
|
||||||
|
):
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify no table creation was attempted
|
||||||
|
mock_create.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scenario_1_new_workspace_creation(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Scenario 1: New workspace creation
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- No legacy table exists
|
||||||
|
- Directly create new table with model suffix
|
||||||
|
- No migration needed
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_func = EmbeddingFunc(
|
||||||
|
embedding_dim=3072,
|
||||||
|
func=mock_embedding_func.func,
|
||||||
|
model_name="text-embedding-3-large",
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
workspace="new_workspace",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock: neither table exists
|
||||||
|
async def mock_table_exists(db, table_name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create,
|
||||||
|
):
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify table name format
|
||||||
|
assert "text_embedding_3_large_3072d" in storage.table_name
|
||||||
|
|
||||||
|
# Verify new table creation was called
|
||||||
|
mock_create.assert_called_once()
|
||||||
|
call_args = mock_create.call_args
|
||||||
|
assert (
|
||||||
|
call_args[0][1] == storage.table_name
|
||||||
|
) # table_name is second positional arg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scenario_2_legacy_upgrade_migration(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Scenario 2: Upgrade from legacy version
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- Legacy table exists (without model suffix)
|
||||||
|
- New table doesn't exist
|
||||||
|
- Automatically migrate data to new table with suffix
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_func = EmbeddingFunc(
|
||||||
|
embedding_dim=1536,
|
||||||
|
func=mock_embedding_func.func,
|
||||||
|
model_name="text-embedding-ada-002",
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
workspace="legacy_workspace",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock: only legacy table exists
|
||||||
|
async def mock_table_exists(db, table_name):
|
||||||
|
return table_name == storage.legacy_table_name
|
||||||
|
|
||||||
|
# Mock: legacy table has 50 records
|
||||||
|
mock_rows = [
|
||||||
|
{
|
||||||
|
"id": f"legacy_id_{i}",
|
||||||
|
"content": f"legacy_content_{i}",
|
||||||
|
"workspace": "legacy_workspace",
|
||||||
|
}
|
||||||
|
for i in range(50)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Track which queries have been made for proper response
|
||||||
|
query_history = []
|
||||||
|
|
||||||
|
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||||
|
query_history.append(sql)
|
||||||
|
|
||||||
|
if "COUNT(*)" in sql:
|
||||||
|
# Determine table type:
|
||||||
|
# - Legacy: contains base name but NOT model suffix
|
||||||
|
# - New: contains model suffix (e.g., text_embedding_ada_002_1536d)
|
||||||
|
sql_upper = sql.upper()
|
||||||
|
base_name = storage.legacy_table_name.upper()
|
||||||
|
|
||||||
|
# Check if this is querying the new table (has model suffix)
|
||||||
|
has_model_suffix = any(
|
||||||
|
suffix in sql_upper
|
||||||
|
for suffix in ["TEXT_EMBEDDING", "_1536D", "_768D", "_1024D", "_3072D"]
|
||||||
|
)
|
||||||
|
|
||||||
|
is_legacy_table = base_name in sql_upper and not has_model_suffix
|
||||||
|
is_new_table = has_model_suffix
|
||||||
|
has_workspace_filter = "WHERE workspace" in sql
|
||||||
|
|
||||||
|
if is_legacy_table and has_workspace_filter:
|
||||||
|
# Count for legacy table with workspace filter (before migration)
|
||||||
|
return {"count": 50}
|
||||||
|
elif is_legacy_table and not has_workspace_filter:
|
||||||
|
# Total count for legacy table (after deletion, checking remaining)
|
||||||
|
return {"count": 0}
|
||||||
|
elif is_new_table:
|
||||||
|
# Count for new table (verification after migration)
|
||||||
|
return {"count": 50}
|
||||||
|
else:
|
||||||
|
# Fallback
|
||||||
|
return {"count": 0}
|
||||||
|
elif multirows and "SELECT *" in sql:
|
||||||
|
# Mock batch fetch for migration
|
||||||
|
# Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit]
|
||||||
|
if "WHERE workspace" in sql:
|
||||||
|
# With workspace filter: params[0]=workspace, params[1]=offset, params[2]=limit
|
||||||
|
offset = params[1] if len(params) > 1 else 0
|
||||||
|
limit = params[2] if len(params) > 2 else 500
|
||||||
|
else:
|
||||||
|
# No workspace filter: params[0]=offset, params[1]=limit
|
||||||
|
offset = params[0] if params else 0
|
||||||
|
limit = params[1] if len(params) > 1 else 500
|
||||||
|
start = offset
|
||||||
|
end = min(offset + limit, len(mock_rows))
|
||||||
|
return mock_rows[start:end]
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create,
|
||||||
|
):
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify table name contains ada-002
|
||||||
|
assert "text_embedding_ada_002_1536d" in storage.table_name
|
||||||
|
|
||||||
|
# Verify migration was executed
|
||||||
|
assert mock_pg_db.execute.call_count >= 50 # At least one execute per row
|
||||||
|
mock_create.assert_called_once()
|
||||||
|
|
||||||
|
# Verify legacy table was automatically deleted after successful migration
|
||||||
|
# This prevents Case 1 warnings on next startup
|
||||||
|
delete_calls = [
|
||||||
|
call
|
||||||
|
for call in mock_pg_db.execute.call_args_list
|
||||||
|
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(delete_calls) >= 1
|
||||||
|
), "Legacy table should be deleted after successful migration"
|
||||||
|
# Check if legacy table was dropped
|
||||||
|
dropped_table = storage.legacy_table_name
|
||||||
|
assert any(
|
||||||
|
dropped_table in str(call) for call in delete_calls
|
||||||
|
), f"Expected to drop '{dropped_table}'"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scenario_3_multi_model_coexistence(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Scenario 3: Multiple embedding models coexist
|
||||||
|
|
||||||
|
Expected behavior:
|
||||||
|
- Different embedding models create separate tables
|
||||||
|
- Tables are isolated by model suffix
|
||||||
|
- No interference between different models
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Workspace A: uses bge-small (768d)
|
||||||
|
embedding_func_a = EmbeddingFunc(
|
||||||
|
embedding_dim=768, func=mock_embedding_func.func, model_name="bge-small"
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_a = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=embedding_func_a,
|
||||||
|
workspace="workspace_a",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Workspace B: uses bge-large (1024d)
|
||||||
|
async def embed_func_b(texts, **kwargs):
|
||||||
|
return np.array([[0.1] * 1024 for _ in texts])
|
||||||
|
|
||||||
|
embedding_func_b = EmbeddingFunc(
|
||||||
|
embedding_dim=1024, func=embed_func_b, model_name="bge-large"
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_b = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=embedding_func_b,
|
||||||
|
workspace="workspace_b",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify different table names
|
||||||
|
assert storage_a.table_name != storage_b.table_name
|
||||||
|
assert "bge_small_768d" in storage_a.table_name
|
||||||
|
assert "bge_large_1024d" in storage_b.table_name
|
||||||
|
|
||||||
|
# Mock: both tables don't exist yet
|
||||||
|
async def mock_table_exists(db, table_name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create,
|
||||||
|
):
|
||||||
|
# Initialize both storages
|
||||||
|
await storage_a.initialize()
|
||||||
|
await storage_b.initialize()
|
||||||
|
|
||||||
|
# Verify two separate tables were created
|
||||||
|
assert mock_create.call_count == 2
|
||||||
|
|
||||||
|
# Verify table names are different
|
||||||
|
call_args_list = mock_create.call_args_list
|
||||||
|
table_names = [call[0][1] for call in call_args_list] # Second positional arg
|
||||||
|
assert len(set(table_names)) == 2 # Two unique table names
|
||||||
|
assert storage_a.table_name in table_names
|
||||||
|
assert storage_b.table_name in table_names
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_case1_empty_legacy_auto_cleanup(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Case 1a: Both new and legacy tables exist, but legacy is EMPTY
|
||||||
|
Expected: Automatically delete empty legacy table (safe cleanup)
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_func = EmbeddingFunc(
|
||||||
|
embedding_dim=1536,
|
||||||
|
func=mock_embedding_func.func,
|
||||||
|
model_name="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock: Both tables exist
|
||||||
|
async def mock_table_exists(db, table_name):
|
||||||
|
return True # Both new and legacy exist
|
||||||
|
|
||||||
|
# Mock: Legacy table is empty (0 records)
|
||||||
|
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||||
|
if "COUNT(*)" in sql:
|
||||||
|
if storage.legacy_table_name in sql:
|
||||||
|
return {"count": 0} # Empty legacy table
|
||||||
|
else:
|
||||||
|
return {"count": 100} # New table has data
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||||
|
):
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify: Empty legacy table should be automatically cleaned up
|
||||||
|
# Empty tables are safe to delete without data loss risk
|
||||||
|
delete_calls = [
|
||||||
|
call
|
||||||
|
for call in mock_pg_db.execute.call_args_list
|
||||||
|
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(delete_calls) >= 1, "Empty legacy table should be auto-deleted"
|
||||||
|
# Check if legacy table was dropped
|
||||||
|
dropped_table = storage.legacy_table_name
|
||||||
|
assert any(
|
||||||
|
dropped_table in str(call) for call in delete_calls
|
||||||
|
), f"Expected to drop empty legacy table '{dropped_table}'"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"✅ Case 1a: Empty legacy table '{dropped_table}' auto-deleted successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_case1_nonempty_legacy_warning(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Case 1b: Both new and legacy tables exist, and legacy HAS DATA
|
||||||
|
Expected: Log warning, do not delete legacy (preserve data)
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_func = EmbeddingFunc(
|
||||||
|
embedding_dim=1536,
|
||||||
|
func=mock_embedding_func.func,
|
||||||
|
model_name="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock: Both tables exist
|
||||||
|
async def mock_table_exists(db, table_name):
|
||||||
|
return True # Both new and legacy exist
|
||||||
|
|
||||||
|
# Mock: Legacy table has data (50 records)
|
||||||
|
async def mock_query(sql, params=None, multirows=False, **kwargs):
|
||||||
|
if "COUNT(*)" in sql:
|
||||||
|
if storage.legacy_table_name in sql:
|
||||||
|
return {"count": 50} # Legacy has data
|
||||||
|
else:
|
||||||
|
return {"count": 100} # New table has data
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists
|
||||||
|
):
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify: Legacy table with data should be preserved
|
||||||
|
# We never auto-delete tables that contain data to prevent accidental data loss
|
||||||
|
delete_calls = [
|
||||||
|
call
|
||||||
|
for call in mock_pg_db.execute.call_args_list
|
||||||
|
if call[0][0] and "DROP TABLE" in call[0][0]
|
||||||
|
]
|
||||||
|
# Check if legacy table was deleted (it should not be)
|
||||||
|
dropped_table = storage.legacy_table_name
|
||||||
|
legacy_deleted = any(dropped_table in str(call) for call in delete_calls)
|
||||||
|
assert not legacy_deleted, "Legacy table with data should NOT be auto-deleted"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"✅ Case 1b: Legacy table '{dropped_table}' with data preserved (warning only)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_case1_sequential_workspace_migration(
|
||||||
|
mock_client_manager, mock_pg_db, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Case 1c: Sequential workspace migration (Multi-tenant scenario)
|
||||||
|
|
||||||
|
Critical bug fix verification:
|
||||||
|
Timeline:
|
||||||
|
1. Legacy table has workspace_a (3 records) + workspace_b (3 records)
|
||||||
|
2. Workspace A initializes first → Case 4 (only legacy exists) → migrates A's data
|
||||||
|
3. Workspace B initializes later → Case 1 (both tables exist) → should migrate B's data
|
||||||
|
4. Verify workspace B's data is correctly migrated to new table
|
||||||
|
5. Verify legacy table is cleaned up after both workspaces migrate
|
||||||
|
|
||||||
|
This test verifies the fix where Case 1 now checks and migrates current
|
||||||
|
workspace's data instead of just checking if legacy table is empty globally.
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_func = EmbeddingFunc(
|
||||||
|
embedding_dim=1536,
|
||||||
|
func=mock_embedding_func.func,
|
||||||
|
model_name="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock data: Legacy table has 6 records total (3 from workspace_a, 3 from workspace_b)
|
||||||
|
mock_rows_a = [
|
||||||
|
{"id": f"a_{i}", "content": f"A content {i}", "workspace": "workspace_a"}
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
mock_rows_b = [
|
||||||
|
{"id": f"b_{i}", "content": f"B content {i}", "workspace": "workspace_b"}
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Track migration state
|
||||||
|
migration_state = {"new_table_exists": False, "workspace_a_migrated": False}
|
||||||
|
|
||||||
|
# Step 1: Simulate workspace_a initialization (Case 4)
|
||||||
|
# CRITICAL: Set db.workspace to workspace_a
|
||||||
|
mock_pg_db.workspace = "workspace_a"
|
||||||
|
|
||||||
|
storage_a = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
workspace="workspace_a",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock table_exists for workspace_a
|
||||||
|
async def mock_table_exists_a(db, table_name):
|
||||||
|
if table_name == storage_a.legacy_table_name:
|
||||||
|
return True
|
||||||
|
if table_name == storage_a.table_name:
|
||||||
|
return migration_state["new_table_exists"]
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Track inserted records count for verification
|
||||||
|
inserted_count = {"workspace_a": 0}
|
||||||
|
|
||||||
|
# Mock execute to track inserts
|
||||||
|
async def mock_execute_a(sql, data=None, **kwargs):
|
||||||
|
if sql and "INSERT INTO" in sql.upper():
|
||||||
|
inserted_count["workspace_a"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mock query for workspace_a (Case 4)
|
||||||
|
async def mock_query_a(sql, params=None, multirows=False, **kwargs):
|
||||||
|
sql_upper = sql.upper()
|
||||||
|
base_name = storage_a.legacy_table_name.upper()
|
||||||
|
|
||||||
|
if "COUNT(*)" in sql:
|
||||||
|
has_model_suffix = "TEST_MODEL_1536D" in sql_upper
|
||||||
|
is_legacy = base_name in sql_upper and not has_model_suffix
|
||||||
|
has_workspace_filter = "WHERE workspace" in sql
|
||||||
|
|
||||||
|
if is_legacy and has_workspace_filter:
|
||||||
|
workspace = params[0] if params and len(params) > 0 else None
|
||||||
|
if workspace == "workspace_a":
|
||||||
|
# After migration starts, pretend legacy is empty for this workspace
|
||||||
|
return {"count": 3 - inserted_count["workspace_a"]}
|
||||||
|
elif workspace == "workspace_b":
|
||||||
|
return {"count": 3}
|
||||||
|
elif is_legacy and not has_workspace_filter:
|
||||||
|
# Global count in legacy table
|
||||||
|
remaining = 6 - inserted_count["workspace_a"]
|
||||||
|
return {"count": remaining}
|
||||||
|
elif has_model_suffix:
|
||||||
|
# New table count (for verification)
|
||||||
|
return {"count": inserted_count["workspace_a"]}
|
||||||
|
elif multirows and "SELECT *" in sql:
|
||||||
|
if "WHERE workspace" in sql:
|
||||||
|
workspace = params[0] if params and len(params) > 0 else None
|
||||||
|
if workspace == "workspace_a":
|
||||||
|
offset = params[1] if len(params) > 1 else 0
|
||||||
|
limit = params[2] if len(params) > 2 else 500
|
||||||
|
return mock_rows_a[offset : offset + limit]
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_pg_db.query = AsyncMock(side_effect=mock_query_a)
|
||||||
|
mock_pg_db.execute = AsyncMock(side_effect=mock_execute_a)
|
||||||
|
|
||||||
|
# Initialize workspace_a (Case 4)
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||||
|
side_effect=mock_table_exists_a,
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()),
|
||||||
|
):
|
||||||
|
await storage_a.initialize()
|
||||||
|
migration_state["new_table_exists"] = True
|
||||||
|
migration_state["workspace_a_migrated"] = True
|
||||||
|
|
||||||
|
print("✅ Step 1: Workspace A initialized (Case 4)")
|
||||||
|
assert mock_pg_db.execute.call_count >= 3
|
||||||
|
print(f"✅ Step 1: {mock_pg_db.execute.call_count} execute calls")
|
||||||
|
|
||||||
|
# Step 2: Simulate workspace_b initialization (Case 1)
|
||||||
|
# CRITICAL: Set db.workspace to workspace_b
|
||||||
|
mock_pg_db.workspace = "workspace_b"
|
||||||
|
|
||||||
|
storage_b = PGVectorStorage(
|
||||||
|
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=embedding_func,
|
||||||
|
workspace="workspace_b",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_pg_db.reset_mock()
|
||||||
|
migration_state["workspace_b_migrated"] = False
|
||||||
|
|
||||||
|
# Mock table_exists for workspace_b (both exist)
|
||||||
|
async def mock_table_exists_b(db, table_name):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Track inserted records count for workspace_b
|
||||||
|
inserted_count["workspace_b"] = 0
|
||||||
|
|
||||||
|
# Mock execute for workspace_b to track inserts
|
||||||
|
async def mock_execute_b(sql, data=None, **kwargs):
|
||||||
|
if sql and "INSERT INTO" in sql.upper():
|
||||||
|
inserted_count["workspace_b"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mock query for workspace_b (Case 1)
|
||||||
|
async def mock_query_b(sql, params=None, multirows=False, **kwargs):
|
||||||
|
sql_upper = sql.upper()
|
||||||
|
base_name = storage_b.legacy_table_name.upper()
|
||||||
|
|
||||||
|
if "COUNT(*)" in sql:
|
||||||
|
has_model_suffix = "TEST_MODEL_1536D" in sql_upper
|
||||||
|
is_legacy = base_name in sql_upper and not has_model_suffix
|
||||||
|
has_workspace_filter = "WHERE workspace" in sql
|
||||||
|
|
||||||
|
if is_legacy and has_workspace_filter:
|
||||||
|
workspace = params[0] if params and len(params) > 0 else None
|
||||||
|
if workspace == "workspace_b":
|
||||||
|
# After migration starts, pretend legacy is empty for this workspace
|
||||||
|
return {"count": 3 - inserted_count["workspace_b"]}
|
||||||
|
elif workspace == "workspace_a":
|
||||||
|
return {"count": 0} # Already migrated
|
||||||
|
elif is_legacy and not has_workspace_filter:
|
||||||
|
# Global count: only workspace_b data remains
|
||||||
|
return {"count": 3 - inserted_count["workspace_b"]}
|
||||||
|
elif has_model_suffix:
|
||||||
|
# New table total count (workspace_a: 3 + workspace_b: inserted)
|
||||||
|
if has_workspace_filter:
|
||||||
|
workspace = params[0] if params and len(params) > 0 else None
|
||||||
|
if workspace == "workspace_b":
|
||||||
|
return {"count": inserted_count["workspace_b"]}
|
||||||
|
elif workspace == "workspace_a":
|
||||||
|
return {"count": 3}
|
||||||
|
else:
|
||||||
|
# Total count in new table (for verification)
|
||||||
|
return {"count": 3 + inserted_count["workspace_b"]}
|
||||||
|
elif multirows and "SELECT *" in sql:
|
||||||
|
if "WHERE workspace" in sql:
|
||||||
|
workspace = params[0] if params and len(params) > 0 else None
|
||||||
|
if workspace == "workspace_b":
|
||||||
|
offset = params[1] if len(params) > 1 else 0
|
||||||
|
limit = params[2] if len(params) > 2 else 500
|
||||||
|
return mock_rows_b[offset : offset + limit]
|
||||||
|
return {}
|
||||||
|
|
||||||
|
mock_pg_db.query = AsyncMock(side_effect=mock_query_b)
|
||||||
|
mock_pg_db.execute = AsyncMock(side_effect=mock_execute_b)
|
||||||
|
|
||||||
|
# Initialize workspace_b (Case 1)
|
||||||
|
with patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists_b
|
||||||
|
):
|
||||||
|
await storage_b.initialize()
|
||||||
|
migration_state["workspace_b_migrated"] = True
|
||||||
|
|
||||||
|
print("✅ Step 2: Workspace B initialized (Case 1)")
|
||||||
|
|
||||||
|
# Verify workspace_b migration happened
|
||||||
|
execute_calls = mock_pg_db.execute.call_args_list
|
||||||
|
insert_calls = [
|
||||||
|
call for call in execute_calls if call[0][0] and "INSERT INTO" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(insert_calls) >= 3, f"Expected >= 3 inserts, got {len(insert_calls)}"
|
||||||
|
print(f"✅ Step 2: {len(insert_calls)} insert calls")
|
||||||
|
|
||||||
|
# Verify DELETE and DROP TABLE
|
||||||
|
delete_calls = [
|
||||||
|
call
|
||||||
|
for call in execute_calls
|
||||||
|
if call[0][0]
|
||||||
|
and "DELETE FROM" in call[0][0]
|
||||||
|
and "WHERE workspace" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(delete_calls) >= 1, "Expected DELETE workspace_b data"
|
||||||
|
print("✅ Step 2: DELETE workspace_b from legacy")
|
||||||
|
|
||||||
|
drop_calls = [
|
||||||
|
call for call in execute_calls if call[0][0] and "DROP TABLE" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(drop_calls) >= 1, "Expected DROP TABLE"
|
||||||
|
print("✅ Step 2: Legacy table dropped")
|
||||||
|
|
||||||
|
print("\n🎉 Case 1c: Sequential workspace migration verified!")
|
||||||
522
tests/test_qdrant_migration.py
Normal file
522
tests/test_qdrant_migration.py
Normal file
|
|
@ -0,0 +1,522 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
import numpy as np
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from lightrag.kg.qdrant_impl import QdrantVectorDBStorage
|
||||||
|
|
||||||
|
|
||||||
|
# Mock QdrantClient
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_qdrant_client():
|
||||||
|
with patch("lightrag.kg.qdrant_impl.QdrantClient") as mock_client_cls:
|
||||||
|
client = mock_client_cls.return_value
|
||||||
|
client.collection_exists.return_value = False
|
||||||
|
client.count.return_value.count = 0
|
||||||
|
# Mock payload schema and vector config for get_collection
|
||||||
|
collection_info = MagicMock()
|
||||||
|
collection_info.payload_schema = {}
|
||||||
|
# Mock vector dimension to match mock_embedding_func (768d)
|
||||||
|
collection_info.config.params.vectors.size = 768
|
||||||
|
client.get_collection.return_value = collection_info
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
# Mock get_data_init_lock to avoid async lock issues in tests
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_data_init_lock():
|
||||||
|
with patch("lightrag.kg.qdrant_impl.get_data_init_lock") as mock_lock:
|
||||||
|
mock_lock_ctx = AsyncMock()
|
||||||
|
mock_lock.return_value = mock_lock_ctx
|
||||||
|
yield mock_lock
|
||||||
|
|
||||||
|
|
||||||
|
# Mock Embedding function
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedding_func():
|
||||||
|
async def embed_func(texts, **kwargs):
|
||||||
|
return np.array([[0.1] * 768 for _ in texts])
|
||||||
|
|
||||||
|
func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test-model")
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qdrant_collection_naming(mock_qdrant_client, mock_embedding_func):
|
||||||
|
"""Test if collection name is correctly generated with model suffix"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=mock_embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify collection name contains model suffix
|
||||||
|
expected_suffix = "test_model_768d"
|
||||||
|
assert expected_suffix in storage.final_namespace
|
||||||
|
assert storage.final_namespace == f"lightrag_vdb_chunks_{expected_suffix}"
|
||||||
|
|
||||||
|
# Verify legacy namespace (should not include workspace, just the base collection name)
|
||||||
|
assert storage.legacy_namespace == "lightrag_vdb_chunks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qdrant_migration_trigger(mock_qdrant_client, mock_embedding_func):
|
||||||
|
"""Test if migration logic is triggered correctly"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=mock_embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup mocks for migration scenario
|
||||||
|
# 1. New collection does not exist
|
||||||
|
mock_qdrant_client.collection_exists.side_effect = (
|
||||||
|
lambda name: name == storage.legacy_namespace
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Legacy collection exists and has data
|
||||||
|
mock_qdrant_client.count.return_value.count = 100
|
||||||
|
|
||||||
|
# 3. Mock scroll for data migration
|
||||||
|
|
||||||
|
mock_point = MagicMock()
|
||||||
|
mock_point.id = "old_id"
|
||||||
|
mock_point.vector = [0.1] * 768
|
||||||
|
mock_point.payload = {"content": "test"}
|
||||||
|
|
||||||
|
# First call returns points, second call returns empty (end of scroll)
|
||||||
|
mock_qdrant_client.scroll.side_effect = [([mock_point], "next_offset"), ([], None)]
|
||||||
|
|
||||||
|
# Initialize storage (triggers migration)
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify migration steps
|
||||||
|
# 1. Legacy count checked
|
||||||
|
mock_qdrant_client.count.assert_any_call(
|
||||||
|
collection_name=storage.legacy_namespace, exact=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. New collection created
|
||||||
|
mock_qdrant_client.create_collection.assert_called()
|
||||||
|
|
||||||
|
# 3. Data scrolled from legacy
|
||||||
|
assert mock_qdrant_client.scroll.call_count >= 1
|
||||||
|
call_args = mock_qdrant_client.scroll.call_args_list[0]
|
||||||
|
assert call_args.kwargs["collection_name"] == storage.legacy_namespace
|
||||||
|
assert call_args.kwargs["limit"] == 500
|
||||||
|
|
||||||
|
# 4. Data upserted to new
|
||||||
|
mock_qdrant_client.upsert.assert_called()
|
||||||
|
|
||||||
|
# 5. Payload index created
|
||||||
|
mock_qdrant_client.create_payload_index.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_func):
|
||||||
|
"""Test scenario where new collection already exists"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=mock_embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
# New collection exists and Legacy exists (warning case)
|
||||||
|
# or New collection exists and Legacy does not exist (normal case)
|
||||||
|
# Mocking case where both exist to test logic flow but without migration
|
||||||
|
|
||||||
|
# Logic in code:
|
||||||
|
# Case 1: Both exist -> Warning only
|
||||||
|
# Case 2: Only new exists -> Ensure index
|
||||||
|
|
||||||
|
# Let's test Case 2: Only new collection exists
|
||||||
|
mock_qdrant_client.collection_exists.side_effect = (
|
||||||
|
lambda name: name == storage.final_namespace
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Should check index but NOT migrate
|
||||||
|
# In Qdrant implementation, Case 2 calls get_collection
|
||||||
|
mock_qdrant_client.get_collection.assert_called_with(storage.final_namespace)
|
||||||
|
mock_qdrant_client.scroll.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for scenarios described in design document (Lines 606-649)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scenario_1_new_workspace_creation(
|
||||||
|
mock_qdrant_client, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
场景1:新建workspace
|
||||||
|
预期:直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d
|
||||||
|
"""
|
||||||
|
# Use a large embedding model
|
||||||
|
large_model_func = EmbeddingFunc(
|
||||||
|
embedding_dim=3072,
|
||||||
|
func=mock_embedding_func.func,
|
||||||
|
model_name="text-embedding-3-large",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=large_model_func,
|
||||||
|
workspace="test_new",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 3: Neither legacy nor new collection exists
|
||||||
|
mock_qdrant_client.collection_exists.return_value = False
|
||||||
|
|
||||||
|
# Initialize storage
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify: Should create new collection with model suffix
|
||||||
|
expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d"
|
||||||
|
assert storage.final_namespace == expected_collection
|
||||||
|
|
||||||
|
# Verify create_collection was called with correct name
|
||||||
|
create_calls = [
|
||||||
|
call for call in mock_qdrant_client.create_collection.call_args_list
|
||||||
|
]
|
||||||
|
assert len(create_calls) > 0
|
||||||
|
assert (
|
||||||
|
create_calls[0][0][0] == expected_collection
|
||||||
|
or create_calls[0].kwargs.get("collection_name") == expected_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no migration was attempted
|
||||||
|
mock_qdrant_client.scroll.assert_not_called()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"✅ Scenario 1: New workspace created with collection '{expected_collection}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scenario_2_legacy_upgrade_migration(
|
||||||
|
mock_qdrant_client, mock_embedding_func
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
场景2:从旧版本升级
|
||||||
|
已存在lightrag_vdb_chunks(无后缀)
|
||||||
|
预期:自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d
|
||||||
|
"""
|
||||||
|
# Use ada-002 model
|
||||||
|
ada_func = EmbeddingFunc(
|
||||||
|
embedding_dim=1536,
|
||||||
|
func=mock_embedding_func.func,
|
||||||
|
model_name="text-embedding-ada-002",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=ada_func,
|
||||||
|
workspace="test_legacy",
|
||||||
|
)
|
||||||
|
|
||||||
|
legacy_collection = storage.legacy_namespace
|
||||||
|
new_collection = storage.final_namespace
|
||||||
|
|
||||||
|
# Case 4: Only legacy collection exists
|
||||||
|
mock_qdrant_client.collection_exists.side_effect = (
|
||||||
|
lambda name: name == legacy_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock legacy collection info with 1536d vectors
|
||||||
|
legacy_collection_info = MagicMock()
|
||||||
|
legacy_collection_info.payload_schema = {}
|
||||||
|
legacy_collection_info.config.params.vectors.size = 1536
|
||||||
|
mock_qdrant_client.get_collection.return_value = legacy_collection_info
|
||||||
|
|
||||||
|
# Mock legacy data
|
||||||
|
mock_qdrant_client.count.return_value.count = 150
|
||||||
|
|
||||||
|
# Mock scroll results (simulate migration in batches)
|
||||||
|
|
||||||
|
mock_points = []
|
||||||
|
for i in range(10):
|
||||||
|
point = MagicMock()
|
||||||
|
point.id = f"legacy-{i}"
|
||||||
|
point.vector = [0.1] * 1536
|
||||||
|
point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"}
|
||||||
|
mock_points.append(point)
|
||||||
|
|
||||||
|
# First batch returns points, second batch returns empty
|
||||||
|
mock_qdrant_client.scroll.side_effect = [(mock_points, "offset1"), ([], None)]
|
||||||
|
|
||||||
|
# Initialize (triggers migration)
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify: New collection should be created
|
||||||
|
expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
|
||||||
|
assert storage.final_namespace == expected_new_collection
|
||||||
|
|
||||||
|
# Verify migration steps
|
||||||
|
# 1. Check legacy count
|
||||||
|
mock_qdrant_client.count.assert_any_call(
|
||||||
|
collection_name=legacy_collection, exact=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Create new collection
|
||||||
|
mock_qdrant_client.create_collection.assert_called()
|
||||||
|
|
||||||
|
# 3. Scroll legacy data
|
||||||
|
scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list]
|
||||||
|
assert len(scroll_calls) >= 1
|
||||||
|
assert scroll_calls[0].kwargs["collection_name"] == legacy_collection
|
||||||
|
|
||||||
|
# 4. Upsert to new collection
|
||||||
|
upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list]
|
||||||
|
assert len(upsert_calls) >= 1
|
||||||
|
assert upsert_calls[0].kwargs["collection_name"] == new_collection
|
||||||
|
|
||||||
|
# 5. Verify legacy collection was automatically deleted after successful migration
|
||||||
|
# This prevents Case 1 warnings on next startup
|
||||||
|
delete_calls = [
|
||||||
|
call for call in mock_qdrant_client.delete_collection.call_args_list
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(delete_calls) >= 1
|
||||||
|
), "Legacy collection should be deleted after successful migration"
|
||||||
|
# Check if legacy_collection was passed to delete_collection
|
||||||
|
deleted_collection = (
|
||||||
|
delete_calls[0][0][0]
|
||||||
|
if delete_calls[0][0]
|
||||||
|
else delete_calls[0].kwargs.get("collection_name")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
deleted_collection == legacy_collection
|
||||||
|
), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}' and legacy collection deleted"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scenario_3_multi_model_coexistence(mock_qdrant_client):
|
||||||
|
"""
|
||||||
|
场景3:多模型并存
|
||||||
|
预期:两个独立的collection,互不干扰
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model A: bge-small with 768d
|
||||||
|
async def embed_func_a(texts, **kwargs):
|
||||||
|
return np.array([[0.1] * 768 for _ in texts])
|
||||||
|
|
||||||
|
model_a_func = EmbeddingFunc(
|
||||||
|
embedding_dim=768, func=embed_func_a, model_name="bge-small"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model B: bge-large with 1024d
|
||||||
|
async def embed_func_b(texts, **kwargs):
|
||||||
|
return np.array([[0.2] * 1024 for _ in texts])
|
||||||
|
|
||||||
|
model_b_func = EmbeddingFunc(
|
||||||
|
embedding_dim=1024, func=embed_func_b, model_name="bge-large"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create storage for workspace A with model A
|
||||||
|
storage_a = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=model_a_func,
|
||||||
|
workspace="workspace_a",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create storage for workspace B with model B
|
||||||
|
storage_b = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=model_b_func,
|
||||||
|
workspace="workspace_b",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: Collection names are different
|
||||||
|
assert storage_a.final_namespace != storage_b.final_namespace
|
||||||
|
|
||||||
|
# Verify: Model A collection
|
||||||
|
expected_collection_a = "lightrag_vdb_chunks_bge_small_768d"
|
||||||
|
assert storage_a.final_namespace == expected_collection_a
|
||||||
|
|
||||||
|
# Verify: Model B collection
|
||||||
|
expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d"
|
||||||
|
assert storage_b.final_namespace == expected_collection_b
|
||||||
|
|
||||||
|
# Verify: Different embedding dimensions are preserved
|
||||||
|
assert storage_a.embedding_func.embedding_dim == 768
|
||||||
|
assert storage_b.embedding_func.embedding_dim == 1024
|
||||||
|
|
||||||
|
print("✅ Scenario 3: Multi-model coexistence verified")
|
||||||
|
print(f" - Workspace A: {expected_collection_a} (768d)")
|
||||||
|
print(f" - Workspace B: {expected_collection_b} (1024d)")
|
||||||
|
print(" - Collections are independent")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_case1_empty_legacy_auto_cleanup(mock_qdrant_client, mock_embedding_func):
|
||||||
|
"""
|
||||||
|
Case 1a: 新旧collection都存在,且旧库为空
|
||||||
|
预期:自动删除旧库
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=mock_embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
legacy_collection = storage.legacy_namespace
|
||||||
|
new_collection = storage.final_namespace
|
||||||
|
|
||||||
|
# Mock: Both collections exist
|
||||||
|
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
|
||||||
|
legacy_collection,
|
||||||
|
new_collection,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock: Legacy collection is empty (0 records)
|
||||||
|
def count_mock(collection_name, exact=True):
|
||||||
|
mock_result = MagicMock()
|
||||||
|
if collection_name == legacy_collection:
|
||||||
|
mock_result.count = 0 # Empty legacy collection
|
||||||
|
else:
|
||||||
|
mock_result.count = 100 # New collection has data
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
mock_qdrant_client.count.side_effect = count_mock
|
||||||
|
|
||||||
|
# Mock get_collection for Case 2 check
|
||||||
|
collection_info = MagicMock()
|
||||||
|
collection_info.payload_schema = {"workspace_id": True}
|
||||||
|
mock_qdrant_client.get_collection.return_value = collection_info
|
||||||
|
|
||||||
|
# Initialize storage
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify: Empty legacy collection should be automatically cleaned up
|
||||||
|
# Empty collections are safe to delete without data loss risk
|
||||||
|
delete_calls = [
|
||||||
|
call for call in mock_qdrant_client.delete_collection.call_args_list
|
||||||
|
]
|
||||||
|
assert len(delete_calls) >= 1, "Empty legacy collection should be auto-deleted"
|
||||||
|
deleted_collection = (
|
||||||
|
delete_calls[0][0][0]
|
||||||
|
if delete_calls[0][0]
|
||||||
|
else delete_calls[0].kwargs.get("collection_name")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
deleted_collection == legacy_collection
|
||||||
|
), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"✅ Case 1a: Empty legacy collection '{legacy_collection}' auto-deleted successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_case1_nonempty_legacy_warning(mock_qdrant_client, mock_embedding_func):
|
||||||
|
"""
|
||||||
|
Case 1b: 新旧collection都存在,且旧库有数据
|
||||||
|
预期:警告但不删除
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"embedding_batch_num": 10,
|
||||||
|
"vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage = QdrantVectorDBStorage(
|
||||||
|
namespace="chunks",
|
||||||
|
global_config=config,
|
||||||
|
embedding_func=mock_embedding_func,
|
||||||
|
workspace="test_ws",
|
||||||
|
)
|
||||||
|
|
||||||
|
legacy_collection = storage.legacy_namespace
|
||||||
|
new_collection = storage.final_namespace
|
||||||
|
|
||||||
|
# Mock: Both collections exist
|
||||||
|
mock_qdrant_client.collection_exists.side_effect = lambda name: name in [
|
||||||
|
legacy_collection,
|
||||||
|
new_collection,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock: Legacy collection has data (50 records)
|
||||||
|
def count_mock(collection_name, exact=True):
|
||||||
|
mock_result = MagicMock()
|
||||||
|
if collection_name == legacy_collection:
|
||||||
|
mock_result.count = 50 # Legacy has data
|
||||||
|
else:
|
||||||
|
mock_result.count = 100 # New collection has data
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
mock_qdrant_client.count.side_effect = count_mock
|
||||||
|
|
||||||
|
# Mock get_collection for Case 2 check
|
||||||
|
collection_info = MagicMock()
|
||||||
|
collection_info.payload_schema = {"workspace_id": True}
|
||||||
|
mock_qdrant_client.get_collection.return_value = collection_info
|
||||||
|
|
||||||
|
# Initialize storage
|
||||||
|
await storage.initialize()
|
||||||
|
|
||||||
|
# Verify: Legacy collection with data should be preserved
|
||||||
|
# We never auto-delete collections that contain data to prevent accidental data loss
|
||||||
|
delete_calls = [
|
||||||
|
call for call in mock_qdrant_client.delete_collection.call_args_list
|
||||||
|
]
|
||||||
|
# Check if legacy collection was deleted (it should not be)
|
||||||
|
legacy_deleted = any(
|
||||||
|
(call[0][0] if call[0] else call.kwargs.get("collection_name"))
|
||||||
|
== legacy_collection
|
||||||
|
for call in delete_calls
|
||||||
|
)
|
||||||
|
assert not legacy_deleted, "Legacy collection with data should NOT be auto-deleted"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"✅ Case 1b: Legacy collection '{legacy_collection}' with data preserved (warning only)"
|
||||||
|
)
|
||||||
191
tests/test_unified_lock_safety.py
Normal file
191
tests/test_unified_lock_safety.py
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
"""
|
||||||
|
Tests for UnifiedLock safety when lock is None.
|
||||||
|
|
||||||
|
This test module verifies that UnifiedLock raises RuntimeError instead of
|
||||||
|
allowing unprotected execution when the underlying lock is None, preventing
|
||||||
|
false security and potential race conditions.
|
||||||
|
|
||||||
|
Critical Bug 1: When self._lock is None, __aenter__ used to log WARNING but
|
||||||
|
still return successfully, allowing critical sections to run without lock
|
||||||
|
protection, causing race conditions and data corruption.
|
||||||
|
|
||||||
|
Critical Bug 2: In __aexit__, when async_lock.release() fails, the error
|
||||||
|
recovery logic would attempt to release it again, causing double-release issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
|
from lightrag.kg.shared_storage import UnifiedLock
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnifiedLockSafety:
|
||||||
|
"""Test suite for UnifiedLock None safety checks."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_lock_raises_on_none_async(self):
|
||||||
|
"""
|
||||||
|
Test that UnifiedLock raises RuntimeError when lock is None (async mode).
|
||||||
|
|
||||||
|
Scenario: Attempt to use UnifiedLock before initialize_share_data() is called.
|
||||||
|
Expected: RuntimeError raised, preventing unprotected critical section execution.
|
||||||
|
"""
|
||||||
|
lock = UnifiedLock(
|
||||||
|
lock=None, is_async=True, name="test_async_lock", enable_logging=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError, match="shared data not initialized|Lock.*is None"
|
||||||
|
):
|
||||||
|
async with lock:
|
||||||
|
# This code should NEVER execute
|
||||||
|
pytest.fail(
|
||||||
|
"Code inside lock context should not execute when lock is None"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_lock_raises_on_none_sync(self):
|
||||||
|
"""
|
||||||
|
Test that UnifiedLock raises RuntimeError when lock is None (sync mode).
|
||||||
|
|
||||||
|
Scenario: Attempt to use UnifiedLock with None lock in sync mode.
|
||||||
|
Expected: RuntimeError raised with clear error message.
|
||||||
|
"""
|
||||||
|
lock = UnifiedLock(
|
||||||
|
lock=None, is_async=False, name="test_sync_lock", enable_logging=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError, match="shared data not initialized|Lock.*is None"
|
||||||
|
):
|
||||||
|
async with lock:
|
||||||
|
# This code should NEVER execute
|
||||||
|
pytest.fail(
|
||||||
|
"Code inside lock context should not execute when lock is None"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_message_clarity(self):
|
||||||
|
"""
|
||||||
|
Test that the error message clearly indicates the problem and solution.
|
||||||
|
|
||||||
|
Scenario: Lock is None and user tries to acquire it.
|
||||||
|
Expected: Error message mentions 'shared data not initialized' and
|
||||||
|
'initialize_share_data()'.
|
||||||
|
"""
|
||||||
|
lock = UnifiedLock(
|
||||||
|
lock=None,
|
||||||
|
is_async=True,
|
||||||
|
name="test_error_message",
|
||||||
|
enable_logging=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
async with lock:
|
||||||
|
pass
|
||||||
|
|
||||||
|
error_message = str(exc_info.value)
|
||||||
|
# Verify error message contains helpful information
|
||||||
|
assert (
|
||||||
|
"shared data not initialized" in error_message.lower()
|
||||||
|
or "lock" in error_message.lower()
|
||||||
|
)
|
||||||
|
assert "initialize_share_data" in error_message or "None" in error_message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aexit_no_double_release_on_async_lock_failure(self):
|
||||||
|
"""
|
||||||
|
Test that __aexit__ doesn't attempt to release async_lock twice when it fails.
|
||||||
|
|
||||||
|
Scenario: async_lock.release() fails during normal release.
|
||||||
|
Expected: Recovery logic should NOT attempt to release async_lock again,
|
||||||
|
preventing double-release issues.
|
||||||
|
|
||||||
|
This tests Bug 2 fix: async_lock_released tracking prevents double release.
|
||||||
|
"""
|
||||||
|
# Create mock locks
|
||||||
|
main_lock = MagicMock()
|
||||||
|
main_lock.acquire = MagicMock()
|
||||||
|
main_lock.release = MagicMock()
|
||||||
|
|
||||||
|
async_lock = AsyncMock()
|
||||||
|
async_lock.acquire = AsyncMock()
|
||||||
|
|
||||||
|
# Make async_lock.release() fail
|
||||||
|
release_call_count = 0
|
||||||
|
|
||||||
|
def mock_release_fail():
|
||||||
|
nonlocal release_call_count
|
||||||
|
release_call_count += 1
|
||||||
|
raise RuntimeError("Async lock release failed")
|
||||||
|
|
||||||
|
async_lock.release = MagicMock(side_effect=mock_release_fail)
|
||||||
|
|
||||||
|
# Create UnifiedLock with both locks (sync mode with async_lock)
|
||||||
|
lock = UnifiedLock(
|
||||||
|
lock=main_lock,
|
||||||
|
is_async=False,
|
||||||
|
name="test_double_release",
|
||||||
|
enable_logging=False,
|
||||||
|
)
|
||||||
|
lock._async_lock = async_lock
|
||||||
|
|
||||||
|
# Try to use the lock - should fail during __aexit__
|
||||||
|
try:
|
||||||
|
async with lock:
|
||||||
|
pass
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Should get the async lock release error
|
||||||
|
assert "Async lock release failed" in str(e)
|
||||||
|
|
||||||
|
# Verify async_lock.release() was called only ONCE, not twice
|
||||||
|
assert (
|
||||||
|
release_call_count == 1
|
||||||
|
), f"async_lock.release() should be called only once, but was called {release_call_count} times"
|
||||||
|
|
||||||
|
# Main lock should have been released successfully
|
||||||
|
main_lock.release.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aexit_recovery_on_main_lock_failure(self):
|
||||||
|
"""
|
||||||
|
Test that __aexit__ recovery logic works when main lock release fails.
|
||||||
|
|
||||||
|
Scenario: main_lock.release() fails before async_lock is attempted.
|
||||||
|
Expected: Recovery logic should attempt to release async_lock to prevent
|
||||||
|
resource leaks.
|
||||||
|
|
||||||
|
This verifies the recovery logic still works correctly with async_lock_released tracking.
|
||||||
|
"""
|
||||||
|
# Create mock locks
|
||||||
|
main_lock = MagicMock()
|
||||||
|
main_lock.acquire = MagicMock()
|
||||||
|
|
||||||
|
# Make main_lock.release() fail
|
||||||
|
def mock_main_release_fail():
|
||||||
|
raise RuntimeError("Main lock release failed")
|
||||||
|
|
||||||
|
main_lock.release = MagicMock(side_effect=mock_main_release_fail)
|
||||||
|
|
||||||
|
async_lock = AsyncMock()
|
||||||
|
async_lock.acquire = AsyncMock()
|
||||||
|
async_lock.release = MagicMock()
|
||||||
|
|
||||||
|
# Create UnifiedLock with both locks (sync mode with async_lock)
|
||||||
|
lock = UnifiedLock(
|
||||||
|
lock=main_lock, is_async=False, name="test_recovery", enable_logging=False
|
||||||
|
)
|
||||||
|
lock._async_lock = async_lock
|
||||||
|
|
||||||
|
# Try to use the lock - should fail during __aexit__
|
||||||
|
try:
|
||||||
|
async with lock:
|
||||||
|
pass
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Should get the main lock release error
|
||||||
|
assert "Main lock release failed" in str(e)
|
||||||
|
|
||||||
|
# Main lock release should have been attempted
|
||||||
|
main_lock.release.assert_called_once()
|
||||||
|
|
||||||
|
# Recovery logic should have attempted to release async_lock
|
||||||
|
async_lock.release.assert_called_once()
|
||||||
308
tests/test_workspace_migration_isolation.py
Normal file
308
tests/test_workspace_migration_isolation.py
Normal file
|
|
@ -0,0 +1,308 @@
|
||||||
|
"""
|
||||||
|
Tests for workspace isolation during PostgreSQL migration.
|
||||||
|
|
||||||
|
This test module verifies that setup_table() properly filters migration data
|
||||||
|
by workspace, preventing cross-workspace data leakage during legacy table migration.
|
||||||
|
|
||||||
|
Critical Bug: Migration copied ALL records from legacy table regardless of workspace,
|
||||||
|
causing workspace A to receive workspace B's data, violating multi-tenant isolation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from lightrag.kg.postgres_impl import PGVectorStorage
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkspaceMigrationIsolation:
|
||||||
|
"""Test suite for workspace-scoped migration in PostgreSQL."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_migration_filters_by_workspace(self):
|
||||||
|
"""
|
||||||
|
Test that migration only copies data from the specified workspace.
|
||||||
|
|
||||||
|
Scenario: Legacy table contains data from multiple workspaces.
|
||||||
|
Migrate only workspace_a's data to new table.
|
||||||
|
Expected: New table contains only workspace_a data, workspace_b data excluded.
|
||||||
|
"""
|
||||||
|
db = AsyncMock()
|
||||||
|
|
||||||
|
# Mock table existence checks
|
||||||
|
async def table_exists_side_effect(db_instance, name):
|
||||||
|
if name == "lightrag_doc_chunks": # legacy
|
||||||
|
return True
|
||||||
|
elif name == "lightrag_doc_chunks_model_1536d": # new
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Mock query responses
|
||||||
|
async def query_side_effect(sql, params, **kwargs):
|
||||||
|
multirows = kwargs.get("multirows", False)
|
||||||
|
|
||||||
|
# Table existence check
|
||||||
|
if "information_schema.tables" in sql:
|
||||||
|
if params[0] == "lightrag_doc_chunks":
|
||||||
|
return {"exists": True}
|
||||||
|
elif params[0] == "lightrag_doc_chunks_model_1536d":
|
||||||
|
return {"exists": False}
|
||||||
|
|
||||||
|
# Count query with workspace filter (legacy table)
|
||||||
|
elif "COUNT(*)" in sql and "WHERE workspace" in sql:
|
||||||
|
if params[0] == "workspace_a":
|
||||||
|
return {"count": 2} # workspace_a has 2 records
|
||||||
|
elif params[0] == "workspace_b":
|
||||||
|
return {"count": 3} # workspace_b has 3 records
|
||||||
|
return {"count": 0}
|
||||||
|
|
||||||
|
# Count query for new table (verification)
|
||||||
|
elif "COUNT(*)" in sql and "lightrag_doc_chunks_model_1536d" in sql:
|
||||||
|
return {"count": 2} # Verification: 2 records migrated
|
||||||
|
|
||||||
|
# Count query for legacy table (no filter)
|
||||||
|
elif "COUNT(*)" in sql and "lightrag_doc_chunks" in sql:
|
||||||
|
return {"count": 5} # Total records in legacy
|
||||||
|
|
||||||
|
# Dimension check
|
||||||
|
elif "pg_attribute" in sql:
|
||||||
|
return {"vector_dim": 1536}
|
||||||
|
|
||||||
|
# SELECT with workspace filter
|
||||||
|
elif "SELECT * FROM" in sql and "WHERE workspace" in sql and multirows:
|
||||||
|
workspace = params[0]
|
||||||
|
if workspace == "workspace_a" and params[1] == 0: # offset = 0
|
||||||
|
# Return only workspace_a data
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": "a1",
|
||||||
|
"workspace": "workspace_a",
|
||||||
|
"content": "content_a1",
|
||||||
|
"content_vector": [0.1] * 1536,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "a2",
|
||||||
|
"workspace": "workspace_a",
|
||||||
|
"content": "content_a2",
|
||||||
|
"content_vector": [0.2] * 1536,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [] # No more data
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
db.query.side_effect = query_side_effect
|
||||||
|
db.execute = AsyncMock()
|
||||||
|
db._create_vector_index = AsyncMock()
|
||||||
|
|
||||||
|
# Mock _pg_table_exists and _pg_create_table
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||||
|
side_effect=table_exists_side_effect,
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()),
|
||||||
|
):
|
||||||
|
# Migrate for workspace_a only
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
db,
|
||||||
|
"lightrag_doc_chunks_model_1536d",
|
||||||
|
legacy_table_name="lightrag_doc_chunks",
|
||||||
|
base_table="lightrag_doc_chunks",
|
||||||
|
embedding_dim=1536,
|
||||||
|
workspace="workspace_a", # CRITICAL: Only migrate workspace_a
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify workspace filter was used in queries
|
||||||
|
count_calls = [
|
||||||
|
call
|
||||||
|
for call in db.query.call_args_list
|
||||||
|
if call[0][0]
|
||||||
|
and "COUNT(*)" in call[0][0]
|
||||||
|
and "WHERE workspace" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(count_calls) > 0, "Count query should use workspace filter"
|
||||||
|
assert (
|
||||||
|
count_calls[0][0][1][0] == "workspace_a"
|
||||||
|
), "Count should filter by workspace_a"
|
||||||
|
|
||||||
|
select_calls = [
|
||||||
|
call
|
||||||
|
for call in db.query.call_args_list
|
||||||
|
if call[0][0]
|
||||||
|
and "SELECT * FROM" in call[0][0]
|
||||||
|
and "WHERE workspace" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(select_calls) > 0, "Select query should use workspace filter"
|
||||||
|
assert (
|
||||||
|
select_calls[0][0][1][0] == "workspace_a"
|
||||||
|
), "Select should filter by workspace_a"
|
||||||
|
|
||||||
|
# Verify INSERT was called (migration happened)
|
||||||
|
insert_calls = [
|
||||||
|
call
|
||||||
|
for call in db.execute.call_args_list
|
||||||
|
if call[0][0] and "INSERT INTO" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(insert_calls) == 2, "Should insert 2 records from workspace_a"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_migration_without_workspace_warns(self):
|
||||||
|
"""
|
||||||
|
Test that migration without workspace parameter logs a warning.
|
||||||
|
|
||||||
|
Scenario: setup_table called without workspace parameter.
|
||||||
|
Expected: Warning logged about potential cross-workspace data copying.
|
||||||
|
"""
|
||||||
|
db = AsyncMock()
|
||||||
|
|
||||||
|
async def table_exists_side_effect(db_instance, name):
|
||||||
|
if name == "lightrag_doc_chunks":
|
||||||
|
return True
|
||||||
|
elif name == "lightrag_doc_chunks_model_1536d":
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def query_side_effect(sql, params, **kwargs):
|
||||||
|
if "information_schema.tables" in sql:
|
||||||
|
return {"exists": params[0] == "lightrag_doc_chunks"}
|
||||||
|
elif "COUNT(*)" in sql:
|
||||||
|
return {"count": 5} # 5 records total
|
||||||
|
elif "pg_attribute" in sql:
|
||||||
|
return {"vector_dim": 1536}
|
||||||
|
elif "SELECT * FROM" in sql and kwargs.get("multirows"):
|
||||||
|
if params[0] == 0: # offset = 0
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"workspace": "workspace_a",
|
||||||
|
"content_vector": [0.1] * 1536,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "2",
|
||||||
|
"workspace": "workspace_b",
|
||||||
|
"content_vector": [0.2] * 1536,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
return {}
|
||||||
|
|
||||||
|
db.query.side_effect = query_side_effect
|
||||||
|
db.execute = AsyncMock()
|
||||||
|
db._create_vector_index = AsyncMock()
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||||
|
side_effect=table_exists_side_effect,
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()),
|
||||||
|
):
|
||||||
|
# Migrate WITHOUT workspace parameter (dangerous!)
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
db,
|
||||||
|
"lightrag_doc_chunks_model_1536d",
|
||||||
|
legacy_table_name="lightrag_doc_chunks",
|
||||||
|
base_table="lightrag_doc_chunks",
|
||||||
|
embedding_dim=1536,
|
||||||
|
workspace=None, # No workspace filter!
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify queries do NOT use workspace filter
|
||||||
|
count_calls = [
|
||||||
|
call
|
||||||
|
for call in db.query.call_args_list
|
||||||
|
if call[0][0] and "COUNT(*)" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(count_calls) > 0, "Count query should be executed"
|
||||||
|
# Check that workspace filter was NOT used
|
||||||
|
has_workspace_filter = any(
|
||||||
|
"WHERE workspace" in call[0][0] for call in count_calls
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
not has_workspace_filter
|
||||||
|
), "Count should NOT filter by workspace when workspace=None"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_cross_workspace_contamination(self):
|
||||||
|
"""
|
||||||
|
Test that workspace B's migration doesn't include workspace A's data.
|
||||||
|
|
||||||
|
Scenario: Two separate migrations for workspace_a and workspace_b.
|
||||||
|
Expected: Each workspace only gets its own data.
|
||||||
|
"""
|
||||||
|
db = AsyncMock()
|
||||||
|
|
||||||
|
# Track which workspace is being queried
|
||||||
|
queried_workspace = None
|
||||||
|
|
||||||
|
async def table_exists_side_effect(db_instance, name):
|
||||||
|
return "lightrag_doc_chunks" in name and "model" not in name
|
||||||
|
|
||||||
|
async def query_side_effect(sql, params, **kwargs):
|
||||||
|
nonlocal queried_workspace
|
||||||
|
multirows = kwargs.get("multirows", False)
|
||||||
|
|
||||||
|
if "information_schema.tables" in sql:
|
||||||
|
return {"exists": "lightrag_doc_chunks" in params[0]}
|
||||||
|
elif "COUNT(*)" in sql and "WHERE workspace" in sql:
|
||||||
|
queried_workspace = params[0]
|
||||||
|
return {"count": 1}
|
||||||
|
elif "COUNT(*)" in sql and "lightrag_doc_chunks_model_1536d" in sql:
|
||||||
|
return {"count": 1} # Verification count
|
||||||
|
elif "pg_attribute" in sql:
|
||||||
|
return {"vector_dim": 1536}
|
||||||
|
elif "SELECT * FROM" in sql and "WHERE workspace" in sql and multirows:
|
||||||
|
workspace = params[0]
|
||||||
|
if params[1] == 0: # offset = 0
|
||||||
|
# Return data ONLY for the queried workspace
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": f"{workspace}_1",
|
||||||
|
"workspace": workspace,
|
||||||
|
"content": f"content_{workspace}",
|
||||||
|
"content_vector": [0.1] * 1536,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
return {}
|
||||||
|
|
||||||
|
db.query.side_effect = query_side_effect
|
||||||
|
db.execute = AsyncMock()
|
||||||
|
db._create_vector_index = AsyncMock()
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||||
|
side_effect=table_exists_side_effect,
|
||||||
|
),
|
||||||
|
patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()),
|
||||||
|
):
|
||||||
|
# Migrate workspace_b
|
||||||
|
await PGVectorStorage.setup_table(
|
||||||
|
db,
|
||||||
|
"lightrag_doc_chunks_model_1536d",
|
||||||
|
legacy_table_name="lightrag_doc_chunks",
|
||||||
|
base_table="lightrag_doc_chunks",
|
||||||
|
embedding_dim=1536,
|
||||||
|
workspace="workspace_b",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify only workspace_b was queried
|
||||||
|
assert queried_workspace == "workspace_b", "Should only query workspace_b"
|
||||||
|
|
||||||
|
# Verify INSERT contains workspace_b data only
|
||||||
|
insert_calls = [
|
||||||
|
call
|
||||||
|
for call in db.execute.call_args_list
|
||||||
|
if call[0][0] and "INSERT INTO" in call[0][0]
|
||||||
|
]
|
||||||
|
assert len(insert_calls) > 0, "Should have INSERT calls"
|
||||||
Loading…
Add table
Reference in a new issue