feat: PostgreSQL model isolation and auto-migration
Why this change is needed:
PostgreSQL vector storage needs model isolation to prevent dimension
conflicts when different workspaces use different embedding models.
Without this, the first workspace locks the vector dimension for all
subsequent workspaces, causing failures.
How it solves it:
- Implements dynamic table naming with model suffix: {table}_{model}_{dim}d
- Adds setup_table() method mirroring Qdrant's approach for consistency
- Implements 4-branch migration logic: both exist -> warn, only new -> use,
neither -> create, only legacy -> migrate
- Batch migration: 500 records/batch (same as Qdrant)
- No automatic rollback to support idempotent re-runs
Impact:
- PostgreSQL tables now isolated by embedding model and dimension
- Automatic data migration from legacy tables on startup
- Backward compatible: model_name=None defaults to "unknown"
- All SQL operations use dynamic table names
Testing:
- 6 new tests for PostgreSQL migration (100% pass)
- Tests cover: naming, migration trigger, scenarios 1-3
- 3 additional scenario tests added for Qdrant completeness
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
df5aacb545
commit
ad68624d02
4 changed files with 798 additions and 23 deletions
|
|
@ -2175,6 +2175,38 @@ class PGKVStorage(BaseKVStorage):
|
|||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
async def _pg_table_exists(db: PostgreSQLDB, table_name: str) -> bool:
|
||||
"""Check if a table exists in PostgreSQL database"""
|
||||
query = """
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = $1
|
||||
)
|
||||
"""
|
||||
result = await db.query(query, [table_name.lower()])
|
||||
return result.get("exists", False) if result else False
|
||||
|
||||
|
||||
async def _pg_create_table(
|
||||
db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int
|
||||
) -> None:
|
||||
"""Create a new vector table by replacing the table name in DDL template"""
|
||||
if base_table not in TABLES:
|
||||
raise ValueError(f"No DDL template found for table: {base_table}")
|
||||
|
||||
ddl_template = TABLES[base_table]["ddl"]
|
||||
|
||||
# Replace embedding dimension placeholder if exists
|
||||
ddl = ddl_template.replace(
|
||||
f"VECTOR({os.environ.get('EMBEDDING_DIM', 1024)})", f"VECTOR({embedding_dim})"
|
||||
)
|
||||
|
||||
# Replace table name
|
||||
ddl = ddl.replace(base_table, table_name)
|
||||
|
||||
await db.execute(ddl)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
|
|
@ -2190,6 +2222,163 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
# Generate model suffix for table isolation
|
||||
self.model_suffix = self._generate_collection_suffix()
|
||||
|
||||
# Get base table name
|
||||
base_table = namespace_to_table_name(self.namespace)
|
||||
if not base_table:
|
||||
raise ValueError(f"Unknown namespace: {self.namespace}")
|
||||
|
||||
# New table name (with suffix)
|
||||
self.table_name = f"{base_table}_{self.model_suffix}"
|
||||
|
||||
# Legacy table name (without suffix, for migration)
|
||||
self.legacy_table_name = base_table
|
||||
|
||||
logger.debug(
|
||||
f"PostgreSQL table naming: "
|
||||
f"new='{self.table_name}', "
|
||||
f"legacy='{self.legacy_table_name}', "
|
||||
f"model_suffix='{self.model_suffix}'"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def setup_table(
|
||||
db: PostgreSQLDB,
|
||||
table_name: str,
|
||||
legacy_table_name: str = None,
|
||||
base_table: str = None,
|
||||
embedding_dim: int = None,
|
||||
):
|
||||
"""
|
||||
Setup PostgreSQL table with migration support from legacy tables.
|
||||
|
||||
This method mirrors Qdrant's setup_collection approach to maintain consistency.
|
||||
|
||||
Args:
|
||||
db: PostgreSQLDB instance
|
||||
table_name: Name of the new table
|
||||
legacy_table_name: Name of the legacy table (if exists)
|
||||
base_table: Base table name for DDL template lookup
|
||||
embedding_dim: Embedding dimension for vector column
|
||||
"""
|
||||
new_table_exists = await _pg_table_exists(db, table_name)
|
||||
legacy_exists = legacy_table_name and await _pg_table_exists(
|
||||
db, legacy_table_name
|
||||
)
|
||||
|
||||
# Case 1: Both new and legacy tables exist - Warning only (no migration)
|
||||
if new_table_exists and legacy_exists:
|
||||
logger.warning(
|
||||
f"PostgreSQL: Legacy table '{legacy_table_name}' still exists. "
|
||||
f"Remove it if migration is complete."
|
||||
)
|
||||
return
|
||||
|
||||
# Case 2: Only new table exists - Already migrated or newly created
|
||||
if new_table_exists:
|
||||
logger.debug(f"PostgreSQL: Table '{table_name}' already exists")
|
||||
return
|
||||
|
||||
# Case 3: Neither exists - Create new table
|
||||
if not legacy_exists:
|
||||
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
logger.info(f"PostgreSQL: Table '{table_name}' created successfully")
|
||||
return
|
||||
|
||||
# Case 4: Only legacy exists - Migrate data
|
||||
logger.info(
|
||||
f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get legacy table count
|
||||
count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}"
|
||||
count_result = await db.query(count_query, [])
|
||||
legacy_count = count_result.get("count", 0) if count_result else 0
|
||||
logger.info(f"PostgreSQL: Found {legacy_count} records in legacy table")
|
||||
|
||||
if legacy_count == 0:
|
||||
logger.info("PostgreSQL: Legacy table is empty, skipping migration")
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
return
|
||||
|
||||
# Create new table first
|
||||
logger.info(f"PostgreSQL: Creating new table '{table_name}'")
|
||||
await _pg_create_table(db, table_name, base_table, embedding_dim)
|
||||
|
||||
# Batch migration (500 records per batch, same as Qdrant)
|
||||
migrated_count = 0
|
||||
offset = 0
|
||||
batch_size = 500 # Mirror Qdrant batch size
|
||||
|
||||
while True:
|
||||
# Fetch a batch of rows
|
||||
select_query = (
|
||||
f"SELECT * FROM {legacy_table_name} OFFSET $1 LIMIT $2"
|
||||
)
|
||||
rows = await db.fetch(select_query, [offset, batch_size])
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
# Insert batch into new table
|
||||
for row in rows:
|
||||
# Get column names and values
|
||||
columns = list(row.keys())
|
||||
values = list(row.values())
|
||||
|
||||
# Build insert query
|
||||
placeholders = ", ".join([f"${i+1}" for i in range(len(columns))])
|
||||
columns_str = ", ".join(columns)
|
||||
insert_query = f"""
|
||||
INSERT INTO {table_name} ({columns_str})
|
||||
VALUES ({placeholders})
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
|
||||
await db.execute(insert_query, values)
|
||||
|
||||
migrated_count += len(rows)
|
||||
logger.info(
|
||||
f"PostgreSQL: {migrated_count}/{legacy_count} records migrated"
|
||||
)
|
||||
|
||||
offset += batch_size
|
||||
|
||||
# Verify migration by comparing counts
|
||||
logger.info("Verifying migration...")
|
||||
new_count_query = f"SELECT COUNT(*) as count FROM {table_name}"
|
||||
new_count_result = await db.query(new_count_query, [])
|
||||
new_count = new_count_result.get("count", 0) if new_count_result else 0
|
||||
|
||||
if new_count != legacy_count:
|
||||
error_msg = (
|
||||
f"PostgreSQL: Migration verification failed, "
|
||||
f"expected {legacy_count} records, got {new_count} in new table"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise PostgreSQLMigrationError(error_msg)
|
||||
|
||||
logger.info(
|
||||
f"PostgreSQL: Migration completed successfully: {migrated_count} records migrated"
|
||||
)
|
||||
logger.info(
|
||||
f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully"
|
||||
)
|
||||
|
||||
except PostgreSQLMigrationError:
|
||||
# Re-raise migration errors without wrapping
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"PostgreSQL: Migration failed with error: {e}"
|
||||
logger.error(error_msg)
|
||||
# Mirror Qdrant behavior: no automatic rollback
|
||||
# Reason: partial data can be continued by re-running migration
|
||||
raise PostgreSQLMigrationError(error_msg) from e
|
||||
|
||||
async def initialize(self):
|
||||
async with get_data_init_lock():
|
||||
if self.db is None:
|
||||
|
|
@ -2206,6 +2395,15 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
# Use "default" for compatibility (lowest priority)
|
||||
self.workspace = "default"
|
||||
|
||||
# Setup table (create if not exists and handle migration)
|
||||
await PGVectorStorage.setup_table(
|
||||
self.db,
|
||||
self.table_name,
|
||||
legacy_table_name=self.legacy_table_name,
|
||||
base_table=self.legacy_table_name, # base_table for DDL template lookup
|
||||
embedding_dim=self.embedding_func.embedding_dim,
|
||||
)
|
||||
|
||||
async def finalize(self):
|
||||
if self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
|
|
@ -2215,7 +2413,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
self, item: dict[str, Any], current_time: datetime.datetime
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
try:
|
||||
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
|
||||
upsert_sql = SQL_TEMPLATES["upsert_chunk"].format(
|
||||
table_name=self.table_name
|
||||
)
|
||||
data: dict[str, Any] = {
|
||||
"workspace": self.workspace,
|
||||
"id": item["__id__"],
|
||||
|
|
@ -2239,7 +2439,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
def _upsert_entities(
|
||||
self, item: dict[str, Any], current_time: datetime.datetime
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
upsert_sql = SQL_TEMPLATES["upsert_entity"]
|
||||
upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name)
|
||||
source_id = item["source_id"]
|
||||
if isinstance(source_id, str) and "<SEP>" in source_id:
|
||||
chunk_ids = source_id.split("<SEP>")
|
||||
|
|
@ -2262,7 +2462,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
def _upsert_relationships(
|
||||
self, item: dict[str, Any], current_time: datetime.datetime
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
||||
upsert_sql = SQL_TEMPLATES["upsert_relationship"].format(
|
||||
table_name=self.table_name
|
||||
)
|
||||
source_id = item["source_id"]
|
||||
if isinstance(source_id, str) and "<SEP>" in source_id:
|
||||
chunk_ids = source_id.split("<SEP>")
|
||||
|
|
@ -2335,7 +2537,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
|
||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
||||
sql = SQL_TEMPLATES[self.namespace].format(
|
||||
embedding_string=embedding_string, table_name=self.table_name
|
||||
)
|
||||
params = {
|
||||
"workspace": self.workspace,
|
||||
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
|
||||
|
|
@ -2357,14 +2561,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
if not ids:
|
||||
return
|
||||
|
||||
table_name = namespace_to_table_name(self.namespace)
|
||||
if not table_name:
|
||||
logger.error(
|
||||
f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}"
|
||||
)
|
||||
return
|
||||
|
||||
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
||||
delete_sql = f"DELETE FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)"
|
||||
|
||||
try:
|
||||
await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids})
|
||||
|
|
@ -2383,8 +2580,8 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
entity_name: The name of the entity to delete
|
||||
"""
|
||||
try:
|
||||
# Construct SQL to delete the entity
|
||||
delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY
|
||||
# Construct SQL to delete the entity using dynamic table name
|
||||
delete_sql = f"""DELETE FROM {self.table_name}
|
||||
WHERE workspace=$1 AND entity_name=$2"""
|
||||
|
||||
await self.db.execute(
|
||||
|
|
@ -2404,7 +2601,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|||
"""
|
||||
try:
|
||||
# Delete relations where the entity is either the source or target
|
||||
delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION
|
||||
delete_sql = f"""DELETE FROM {self.table_name}
|
||||
WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)"""
|
||||
|
||||
await self.db.execute(
|
||||
|
|
@ -3188,6 +3385,11 @@ class PGDocStatusStorage(DocStatusStorage):
|
|||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
class PostgreSQLMigrationError(Exception):
|
||||
"""Exception for PostgreSQL table migration errors."""
|
||||
pass
|
||||
|
||||
|
||||
class PGGraphQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
|
||||
|
|
@ -5047,7 +5249,7 @@ SQL_TEMPLATES = {
|
|||
update_time = EXCLUDED.update_time
|
||||
""",
|
||||
# SQL for VectorStorage
|
||||
"upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens,
|
||||
"upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens,
|
||||
chunk_order_index, full_doc_id, content, content_vector, file_path,
|
||||
create_time, update_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
|
|
@ -5060,7 +5262,7 @@ SQL_TEMPLATES = {
|
|||
file_path=EXCLUDED.file_path,
|
||||
update_time = EXCLUDED.update_time
|
||||
""",
|
||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
||||
"upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content,
|
||||
content_vector, chunk_ids, file_path, create_time, update_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
|
|
@ -5071,7 +5273,7 @@ SQL_TEMPLATES = {
|
|||
file_path=EXCLUDED.file_path,
|
||||
update_time=EXCLUDED.update_time
|
||||
""",
|
||||
"upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
|
||||
"upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id,
|
||||
target_id, content, content_vector, chunk_ids, file_path, create_time, update_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
|
|
@ -5087,7 +5289,7 @@ SQL_TEMPLATES = {
|
|||
SELECT r.source_id AS src_id,
|
||||
r.target_id AS tgt_id,
|
||||
EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_RELATION r
|
||||
FROM {table_name} r
|
||||
WHERE r.workspace = $1
|
||||
AND r.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
ORDER BY r.content_vector <=> '[{embedding_string}]'::vector
|
||||
|
|
@ -5096,7 +5298,7 @@ SQL_TEMPLATES = {
|
|||
"entities": """
|
||||
SELECT e.entity_name,
|
||||
EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_ENTITY e
|
||||
FROM {table_name} e
|
||||
WHERE e.workspace = $1
|
||||
AND e.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
ORDER BY e.content_vector <=> '[{embedding_string}]'::vector
|
||||
|
|
@ -5107,7 +5309,7 @@ SQL_TEMPLATES = {
|
|||
c.content,
|
||||
c.file_path,
|
||||
EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at
|
||||
FROM LIGHTRAG_VDB_CHUNKS c
|
||||
FROM {table_name} c
|
||||
WHERE c.workspace = $1
|
||||
AND c.content_vector <=> '[{embedding_string}]'::vector < $2
|
||||
ORDER BY c.content_vector <=> '[{embedding_string}]'::vector
|
||||
|
|
|
|||
366
tests/test_postgres_migration.py
Normal file
366
tests/test_postgres_migration.py
Normal file
|
|
@ -0,0 +1,366 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock, call
|
||||
import numpy as np
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.kg.postgres_impl import (
|
||||
PGVectorStorage,
|
||||
_pg_table_exists,
|
||||
_pg_create_table,
|
||||
PostgreSQLMigrationError,
|
||||
)
|
||||
from lightrag.namespace import NameSpace
|
||||
|
||||
|
||||
# Mock PostgreSQLDB
|
||||
@pytest.fixture
|
||||
def mock_pg_db():
|
||||
"""Mock PostgreSQL database connection"""
|
||||
db = AsyncMock()
|
||||
db.workspace = "test_workspace"
|
||||
|
||||
# Mock query responses
|
||||
db.query = AsyncMock(return_value={"exists": False, "count": 0})
|
||||
db.execute = AsyncMock()
|
||||
db.fetch = AsyncMock(return_value=[])
|
||||
|
||||
return db
|
||||
|
||||
|
||||
# Mock get_data_init_lock to avoid async lock issues in tests
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_data_init_lock():
|
||||
with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock:
|
||||
mock_lock_ctx = AsyncMock()
|
||||
mock_lock.return_value = mock_lock_ctx
|
||||
yield mock_lock
|
||||
|
||||
|
||||
# Mock ClientManager
|
||||
@pytest.fixture
|
||||
def mock_client_manager(mock_pg_db):
|
||||
with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager:
|
||||
mock_manager.get_client = AsyncMock(return_value=mock_pg_db)
|
||||
mock_manager.release_client = AsyncMock()
|
||||
yield mock_manager
|
||||
|
||||
|
||||
# Mock Embedding function
|
||||
@pytest.fixture
|
||||
def mock_embedding_func():
|
||||
async def embed_func(texts, **kwargs):
|
||||
return np.array([[0.1] * 768 for _ in texts])
|
||||
|
||||
func = EmbeddingFunc(
|
||||
embedding_dim=768,
|
||||
func=embed_func,
|
||||
model_name="test_model"
|
||||
)
|
||||
return func
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_table_naming(mock_client_manager, mock_pg_db, mock_embedding_func):
|
||||
"""Test if table name is correctly generated with model suffix"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws"
|
||||
)
|
||||
|
||||
# Verify table name contains model suffix
|
||||
expected_suffix = "test_model_768d"
|
||||
assert expected_suffix in storage.table_name
|
||||
assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}"
|
||||
|
||||
# Verify legacy table name
|
||||
assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_migration_trigger(mock_client_manager, mock_pg_db, mock_embedding_func):
|
||||
"""Test if migration logic is triggered correctly"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws"
|
||||
)
|
||||
|
||||
# Setup mocks for migration scenario
|
||||
# 1. New table does not exist, legacy table exists
|
||||
async def mock_table_exists(db, table_name):
|
||||
return table_name == storage.legacy_table_name
|
||||
|
||||
# 2. Legacy table has 100 records
|
||||
async def mock_query(sql, params):
|
||||
if "COUNT(*)" in sql:
|
||||
return {"count": 100}
|
||||
return {}
|
||||
|
||||
# 3. Mock fetch for batch migration
|
||||
mock_rows = [
|
||||
{"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
async def mock_fetch(sql, params):
|
||||
offset = params[0] if params else 0
|
||||
limit = params[1] if len(params) > 1 else 500
|
||||
start = offset
|
||||
end = min(offset + limit, len(mock_rows))
|
||||
return mock_rows[start:end]
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||
mock_pg_db.fetch = AsyncMock(side_effect=mock_fetch)
|
||||
|
||||
with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \
|
||||
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()):
|
||||
|
||||
# Initialize storage (should trigger migration)
|
||||
await storage.initialize()
|
||||
|
||||
# Verify migration was executed
|
||||
# Check that execute was called for inserting rows
|
||||
assert mock_pg_db.execute.call_count > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_no_migration_needed(mock_client_manager, mock_pg_db, mock_embedding_func):
|
||||
"""Test scenario where new table already exists (no migration needed)"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=mock_embedding_func,
|
||||
workspace="test_ws"
|
||||
)
|
||||
|
||||
# Mock: new table already exists
|
||||
async def mock_table_exists(db, table_name):
|
||||
return table_name == storage.table_name
|
||||
|
||||
with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \
|
||||
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create:
|
||||
|
||||
await storage.initialize()
|
||||
|
||||
# Verify no table creation was attempted
|
||||
mock_create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_1_new_workspace_creation(mock_client_manager, mock_pg_db, mock_embedding_func):
|
||||
"""
|
||||
Scenario 1: New workspace creation
|
||||
|
||||
Expected behavior:
|
||||
- No legacy table exists
|
||||
- Directly create new table with model suffix
|
||||
- No migration needed
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=3072,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="text-embedding-3-large"
|
||||
)
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func,
|
||||
workspace="new_workspace"
|
||||
)
|
||||
|
||||
# Mock: neither table exists
|
||||
async def mock_table_exists(db, table_name):
|
||||
return False
|
||||
|
||||
with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \
|
||||
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create:
|
||||
|
||||
await storage.initialize()
|
||||
|
||||
# Verify table name format
|
||||
assert "text_embedding_3_large_3072d" in storage.table_name
|
||||
|
||||
# Verify new table creation was called
|
||||
mock_create.assert_called_once()
|
||||
call_args = mock_create.call_args
|
||||
assert call_args[0][1] == storage.table_name # table_name is second positional arg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_2_legacy_upgrade_migration(mock_client_manager, mock_pg_db, mock_embedding_func):
|
||||
"""
|
||||
Scenario 2: Upgrade from legacy version
|
||||
|
||||
Expected behavior:
|
||||
- Legacy table exists (without model suffix)
|
||||
- New table doesn't exist
|
||||
- Automatically migrate data to new table with suffix
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="text-embedding-ada-002"
|
||||
)
|
||||
|
||||
storage = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func,
|
||||
workspace="legacy_workspace"
|
||||
)
|
||||
|
||||
# Mock: only legacy table exists
|
||||
async def mock_table_exists(db, table_name):
|
||||
return table_name == storage.legacy_table_name
|
||||
|
||||
# Mock: legacy table has 50 records
|
||||
async def mock_query(sql, params):
|
||||
if "COUNT(*)" in sql:
|
||||
# First call for legacy count, then for verification
|
||||
if storage.legacy_table_name in sql:
|
||||
return {"count": 50}
|
||||
else:
|
||||
return {"count": 50}
|
||||
return {}
|
||||
|
||||
# Mock fetch for migration
|
||||
mock_rows = [
|
||||
{"id": f"legacy_id_{i}", "content": f"legacy_content_{i}", "workspace": "legacy_workspace"}
|
||||
for i in range(50)
|
||||
]
|
||||
|
||||
async def mock_fetch(sql, params):
|
||||
offset = params[0] if params else 0
|
||||
limit = params[1] if len(params) > 1 else 500
|
||||
start = offset
|
||||
end = min(offset + limit, len(mock_rows))
|
||||
return mock_rows[start:end]
|
||||
|
||||
mock_pg_db.query = AsyncMock(side_effect=mock_query)
|
||||
mock_pg_db.fetch = AsyncMock(side_effect=mock_fetch)
|
||||
|
||||
with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \
|
||||
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create:
|
||||
|
||||
await storage.initialize()
|
||||
|
||||
# Verify table name contains ada-002
|
||||
assert "text_embedding_ada_002_1536d" in storage.table_name
|
||||
|
||||
# Verify migration was executed
|
||||
assert mock_pg_db.execute.call_count >= 50 # At least one execute per row
|
||||
mock_create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_3_multi_model_coexistence(mock_client_manager, mock_pg_db, mock_embedding_func):
|
||||
"""
|
||||
Scenario 3: Multiple embedding models coexist
|
||||
|
||||
Expected behavior:
|
||||
- Different embedding models create separate tables
|
||||
- Tables are isolated by model suffix
|
||||
- No interference between different models
|
||||
"""
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
# Workspace A: uses bge-small (768d)
|
||||
embedding_func_a = EmbeddingFunc(
|
||||
embedding_dim=768,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="bge-small"
|
||||
)
|
||||
|
||||
storage_a = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func_a,
|
||||
workspace="workspace_a"
|
||||
)
|
||||
|
||||
# Workspace B: uses bge-large (1024d)
|
||||
async def embed_func_b(texts, **kwargs):
|
||||
return np.array([[0.1] * 1024 for _ in texts])
|
||||
|
||||
embedding_func_b = EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
func=embed_func_b,
|
||||
model_name="bge-large"
|
||||
)
|
||||
|
||||
storage_b = PGVectorStorage(
|
||||
namespace=NameSpace.VECTOR_STORE_CHUNKS,
|
||||
global_config=config,
|
||||
embedding_func=embedding_func_b,
|
||||
workspace="workspace_b"
|
||||
)
|
||||
|
||||
# Verify different table names
|
||||
assert storage_a.table_name != storage_b.table_name
|
||||
assert "bge_small_768d" in storage_a.table_name
|
||||
assert "bge_large_1024d" in storage_b.table_name
|
||||
|
||||
# Mock: both tables don't exist yet
|
||||
async def mock_table_exists(db, table_name):
|
||||
return False
|
||||
|
||||
with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \
|
||||
patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create:
|
||||
|
||||
# Initialize both storages
|
||||
await storage_a.initialize()
|
||||
await storage_b.initialize()
|
||||
|
||||
# Verify two separate tables were created
|
||||
assert mock_create.call_count == 2
|
||||
|
||||
# Verify table names are different
|
||||
call_args_list = mock_create.call_args_list
|
||||
table_names = [call[0][1] for call in call_args_list] # Second positional arg
|
||||
assert len(set(table_names)) == 2 # Two unique table names
|
||||
assert storage_a.table_name in table_names
|
||||
assert storage_b.table_name in table_names
|
||||
|
|
@ -161,3 +161,208 @@ async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_fun
|
|||
# In Qdrant implementation, Case 2 calls get_collection
|
||||
mock_qdrant_client.get_collection.assert_called_with(storage.final_namespace)
|
||||
mock_qdrant_client.scroll.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for scenarios described in design document (Lines 606-649)
|
||||
# ============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_1_new_workspace_creation(mock_qdrant_client, mock_embedding_func):
|
||||
"""
|
||||
场景1:新建workspace
|
||||
预期:直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d
|
||||
"""
|
||||
# Use a large embedding model
|
||||
large_model_func = EmbeddingFunc(
|
||||
embedding_dim=3072,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="text-embedding-3-large"
|
||||
)
|
||||
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=large_model_func,
|
||||
workspace="test_new"
|
||||
)
|
||||
|
||||
# Case 3: Neither legacy nor new collection exists
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
# Initialize storage
|
||||
await storage.initialize()
|
||||
|
||||
# Verify: Should create new collection with model suffix
|
||||
expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d"
|
||||
assert storage.final_namespace == expected_collection
|
||||
|
||||
# Verify create_collection was called with correct name
|
||||
create_calls = [call for call in mock_qdrant_client.create_collection.call_args_list]
|
||||
assert len(create_calls) > 0
|
||||
assert create_calls[0][0][0] == expected_collection or create_calls[0].kwargs.get('collection_name') == expected_collection
|
||||
|
||||
# Verify no migration was attempted
|
||||
mock_qdrant_client.scroll.assert_not_called()
|
||||
|
||||
print(f"✅ Scenario 1: New workspace created with collection '{expected_collection}'")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_2_legacy_upgrade_migration(mock_qdrant_client, mock_embedding_func):
|
||||
"""
|
||||
场景2:从旧版本升级
|
||||
已存在lightrag_vdb_chunks(无后缀)
|
||||
预期:自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d
|
||||
"""
|
||||
# Use ada-002 model
|
||||
ada_func = EmbeddingFunc(
|
||||
embedding_dim=1536,
|
||||
func=mock_embedding_func.func,
|
||||
model_name="text-embedding-ada-002"
|
||||
)
|
||||
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
storage = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=ada_func,
|
||||
workspace="test_legacy"
|
||||
)
|
||||
|
||||
legacy_collection = storage.legacy_namespace
|
||||
new_collection = storage.final_namespace
|
||||
|
||||
# Case 4: Only legacy collection exists
|
||||
mock_qdrant_client.collection_exists.side_effect = lambda name: name == legacy_collection
|
||||
|
||||
# Mock legacy data
|
||||
mock_qdrant_client.count.return_value.count = 150
|
||||
|
||||
# Mock scroll results (simulate migration in batches)
|
||||
from qdrant_client import models
|
||||
mock_points = []
|
||||
for i in range(10):
|
||||
point = MagicMock()
|
||||
point.id = f"legacy-{i}"
|
||||
point.vector = [0.1] * 1536
|
||||
point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"}
|
||||
mock_points.append(point)
|
||||
|
||||
# First batch returns points, second batch returns empty
|
||||
mock_qdrant_client.scroll.side_effect = [
|
||||
(mock_points, "offset1"),
|
||||
([], None)
|
||||
]
|
||||
|
||||
# Initialize (triggers migration)
|
||||
await storage.initialize()
|
||||
|
||||
# Verify: New collection should be created
|
||||
expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d"
|
||||
assert storage.final_namespace == expected_new_collection
|
||||
|
||||
# Verify migration steps
|
||||
# 1. Check legacy count
|
||||
mock_qdrant_client.count.assert_any_call(
|
||||
collection_name=legacy_collection,
|
||||
exact=True
|
||||
)
|
||||
|
||||
# 2. Create new collection
|
||||
mock_qdrant_client.create_collection.assert_called()
|
||||
|
||||
# 3. Scroll legacy data
|
||||
scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list]
|
||||
assert len(scroll_calls) >= 1
|
||||
assert scroll_calls[0].kwargs['collection_name'] == legacy_collection
|
||||
|
||||
# 4. Upsert to new collection
|
||||
upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list]
|
||||
assert len(upsert_calls) >= 1
|
||||
assert upsert_calls[0].kwargs['collection_name'] == new_collection
|
||||
|
||||
print(f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}'")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_3_multi_model_coexistence(mock_qdrant_client):
|
||||
"""
|
||||
场景3:多模型并存
|
||||
预期:两个独立的collection,互不干扰
|
||||
"""
|
||||
# Model A: bge-small with 768d
|
||||
async def embed_func_a(texts, **kwargs):
|
||||
return np.array([[0.1] * 768 for _ in texts])
|
||||
|
||||
model_a_func = EmbeddingFunc(
|
||||
embedding_dim=768,
|
||||
func=embed_func_a,
|
||||
model_name="bge-small"
|
||||
)
|
||||
|
||||
# Model B: bge-large with 1024d
|
||||
async def embed_func_b(texts, **kwargs):
|
||||
return np.array([[0.2] * 1024 for _ in texts])
|
||||
|
||||
model_b_func = EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
func=embed_func_b,
|
||||
model_name="bge-large"
|
||||
)
|
||||
|
||||
config = {
|
||||
"embedding_batch_num": 10,
|
||||
"vector_db_storage_cls_kwargs": {
|
||||
"cosine_better_than_threshold": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
# Create storage for workspace A with model A
|
||||
storage_a = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=model_a_func,
|
||||
workspace="workspace_a"
|
||||
)
|
||||
|
||||
# Create storage for workspace B with model B
|
||||
storage_b = QdrantVectorDBStorage(
|
||||
namespace="chunks",
|
||||
global_config=config,
|
||||
embedding_func=model_b_func,
|
||||
workspace="workspace_b"
|
||||
)
|
||||
|
||||
# Verify: Collection names are different
|
||||
assert storage_a.final_namespace != storage_b.final_namespace
|
||||
|
||||
# Verify: Model A collection
|
||||
expected_collection_a = "lightrag_vdb_chunks_bge_small_768d"
|
||||
assert storage_a.final_namespace == expected_collection_a
|
||||
|
||||
# Verify: Model B collection
|
||||
expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d"
|
||||
assert storage_b.final_namespace == expected_collection_b
|
||||
|
||||
# Verify: Different embedding dimensions are preserved
|
||||
assert storage_a.embedding_func.embedding_dim == 768
|
||||
assert storage_b.embedding_func.embedding_dim == 1024
|
||||
|
||||
print(f"✅ Scenario 3: Multi-model coexistence verified")
|
||||
print(f" - Workspace A: {expected_collection_a} (768d)")
|
||||
print(f" - Workspace B: {expected_collection_b} (1024d)")
|
||||
print(f" - Collections are independent")
|
||||
|
|
|
|||
6
uv.lock
generated
6
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.10"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14' and python_full_version < '4' and platform_machine == 'x86_64' and sys_platform == 'darwin'",
|
||||
|
|
@ -2735,7 +2735,6 @@ requires-dist = [
|
|||
{ name = "json-repair", marker = "extra == 'api'" },
|
||||
{ name = "langfuse", marker = "extra == 'observability'", specifier = ">=3.8.1" },
|
||||
{ name = "lightrag-hku", extras = ["api", "offline-llm", "offline-storage"], marker = "extra == 'offline'" },
|
||||
{ name = "lightrag-hku", extras = ["pytest"], marker = "extra == 'evaluation'" },
|
||||
{ name = "llama-index", marker = "extra == 'offline-llm'", specifier = ">=0.9.0,<1.0.0" },
|
||||
{ name = "nano-vectordb" },
|
||||
{ name = "nano-vectordb", marker = "extra == 'api'" },
|
||||
|
|
@ -2753,6 +2752,7 @@ requires-dist = [
|
|||
{ name = "passlib", extras = ["bcrypt"], marker = "extra == 'api'" },
|
||||
{ name = "pipmaster" },
|
||||
{ name = "pipmaster", marker = "extra == 'api'" },
|
||||
{ name = "pre-commit", marker = "extra == 'evaluation'" },
|
||||
{ name = "pre-commit", marker = "extra == 'pytest'" },
|
||||
{ name = "psutil", marker = "extra == 'api'" },
|
||||
{ name = "pycryptodome", marker = "extra == 'api'", specifier = ">=3.0.0,<4.0.0" },
|
||||
|
|
@ -2764,7 +2764,9 @@ requires-dist = [
|
|||
{ name = "pypdf", marker = "extra == 'api'", specifier = ">=6.1.0" },
|
||||
{ name = "pypinyin" },
|
||||
{ name = "pypinyin", marker = "extra == 'api'" },
|
||||
{ name = "pytest", marker = "extra == 'evaluation'", specifier = ">=8.4.2" },
|
||||
{ name = "pytest", marker = "extra == 'pytest'", specifier = ">=8.4.2" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'evaluation'", specifier = ">=1.2.0" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'pytest'", specifier = ">=1.2.0" },
|
||||
{ name = "python-docx", marker = "extra == 'api'", specifier = ">=0.8.11,<2.0.0" },
|
||||
{ name = "python-dotenv" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue