diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 1447a79e..7ad4ed1f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -2175,6 +2175,38 @@ class PGKVStorage(BaseKVStorage): return {"status": "error", "message": str(e)} +async def _pg_table_exists(db: PostgreSQLDB, table_name: str) -> bool: + """Check if a table exists in PostgreSQL database""" + query = """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + ) + """ + result = await db.query(query, [table_name.lower()]) + return result.get("exists", False) if result else False + + +async def _pg_create_table( + db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int +) -> None: + """Create a new vector table by replacing the table name in DDL template""" + if base_table not in TABLES: + raise ValueError(f"No DDL template found for table: {base_table}") + + ddl_template = TABLES[base_table]["ddl"] + + # Replace embedding dimension placeholder if exists + ddl = ddl_template.replace( + f"VECTOR({os.environ.get('EMBEDDING_DIM', 1024)})", f"VECTOR({embedding_dim})" + ) + + # Replace table name + ddl = ddl.replace(base_table, table_name) + + await db.execute(ddl) + + @final @dataclass class PGVectorStorage(BaseVectorStorage): @@ -2190,6 +2222,163 @@ class PGVectorStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + # Generate model suffix for table isolation + self.model_suffix = self._generate_collection_suffix() + + # Get base table name + base_table = namespace_to_table_name(self.namespace) + if not base_table: + raise ValueError(f"Unknown namespace: {self.namespace}") + + # New table name (with suffix) + self.table_name = f"{base_table}_{self.model_suffix}" + + # Legacy table name (without suffix, for migration) + self.legacy_table_name = base_table + + logger.debug( + f"PostgreSQL table naming: " + f"new='{self.table_name}', " + f"legacy='{self.legacy_table_name}', " + f"model_suffix='{self.model_suffix}'" + ) + + @staticmethod + async def setup_table( + db: PostgreSQLDB, + table_name: str, + legacy_table_name: str = None, + base_table: str = None, + embedding_dim: int = None, + ): + """ + Setup PostgreSQL table with migration support from legacy tables. + + This method mirrors Qdrant's setup_collection approach to maintain consistency. + + Args: + db: PostgreSQLDB instance + table_name: Name of the new table + legacy_table_name: Name of the legacy table (if exists) + base_table: Base table name for DDL template lookup + embedding_dim: Embedding dimension for vector column + """ + new_table_exists = await _pg_table_exists(db, table_name) + legacy_exists = legacy_table_name and await _pg_table_exists( + db, legacy_table_name + ) + + # Case 1: Both new and legacy tables exist - Warning only (no migration) + if new_table_exists and legacy_exists: + logger.warning( + f"PostgreSQL: Legacy table '{legacy_table_name}' still exists. " + f"Remove it if migration is complete." + ) + return + + # Case 2: Only new table exists - Already migrated or newly created + if new_table_exists: + logger.debug(f"PostgreSQL: Table '{table_name}' already exists") + return + + # Case 3: Neither exists - Create new table + if not legacy_exists: + logger.info(f"PostgreSQL: Creating new table '{table_name}'") + await _pg_create_table(db, table_name, base_table, embedding_dim) + logger.info(f"PostgreSQL: Table '{table_name}' created successfully") + return + + # Case 4: Only legacy exists - Migrate data + logger.info( + f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}'" + ) + + try: + # Get legacy table count + count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" + count_result = await db.query(count_query, []) + legacy_count = count_result.get("count", 0) if count_result else 0 + logger.info(f"PostgreSQL: Found {legacy_count} records in legacy table") + + if legacy_count == 0: + logger.info("PostgreSQL: Legacy table is empty, skipping migration") + await _pg_create_table(db, table_name, base_table, embedding_dim) + return + + # Create new table first + logger.info(f"PostgreSQL: Creating new table '{table_name}'") + await _pg_create_table(db, table_name, base_table, embedding_dim) + + # Batch migration (500 records per batch, same as Qdrant) + migrated_count = 0 + offset = 0 + batch_size = 500 # Mirror Qdrant batch size + + while True: + # Fetch a batch of rows + select_query = ( + f"SELECT * FROM {legacy_table_name} OFFSET $1 LIMIT $2" + ) + rows = await db.fetch(select_query, [offset, batch_size]) + + if not rows: + break + + # Insert batch into new table + for row in rows: + # Get column names and values + columns = list(row.keys()) + values = list(row.values()) + + # Build insert query + placeholders = ", ".join([f"${i+1}" for i in range(len(columns))]) + columns_str = ", ".join(columns) + insert_query = f""" + INSERT INTO {table_name} ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT DO NOTHING + """ + + await db.execute(insert_query, values) + + migrated_count += len(rows) + logger.info( + f"PostgreSQL: {migrated_count}/{legacy_count} records migrated" + ) + + offset += batch_size + + # Verify migration by comparing counts + logger.info("Verifying migration...") + new_count_query = f"SELECT COUNT(*) as count FROM {table_name}" + new_count_result = await db.query(new_count_query, []) + new_count = new_count_result.get("count", 0) if new_count_result else 0 + + if new_count != legacy_count: + error_msg = ( + f"PostgreSQL: Migration verification failed, " + f"expected {legacy_count} records, got {new_count} in new table" + ) + logger.error(error_msg) + raise PostgreSQLMigrationError(error_msg) + + logger.info( + f"PostgreSQL: Migration completed successfully: {migrated_count} records migrated" + ) + logger.info( + f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully" + ) + + except PostgreSQLMigrationError: + # Re-raise migration errors without wrapping + raise + except Exception as e: + error_msg = f"PostgreSQL: Migration failed with error: {e}" + logger.error(error_msg) + # Mirror Qdrant behavior: no automatic rollback + # Reason: partial data can be continued by re-running migration + raise PostgreSQLMigrationError(error_msg) from e + async def initialize(self): async with get_data_init_lock(): if self.db is None: @@ -2206,6 +2395,15 @@ class PGVectorStorage(BaseVectorStorage): # Use "default" for compatibility (lowest priority) self.workspace = "default" + # Setup table (create if not exists and handle migration) + await PGVectorStorage.setup_table( + self.db, + self.table_name, + legacy_table_name=self.legacy_table_name, + base_table=self.legacy_table_name, # base_table for DDL template lookup + embedding_dim=self.embedding_func.embedding_dim, + ) + async def finalize(self): if self.db is not None: await ClientManager.release_client(self.db) @@ -2215,7 +2413,9 @@ class PGVectorStorage(BaseVectorStorage): self, item: dict[str, Any], current_time: datetime.datetime ) -> tuple[str, dict[str, Any]]: try: - upsert_sql = SQL_TEMPLATES["upsert_chunk"] + upsert_sql = SQL_TEMPLATES["upsert_chunk"].format( + table_name=self.table_name + ) data: dict[str, Any] = { "workspace": self.workspace, "id": item["__id__"], @@ -2239,7 +2439,7 @@ class PGVectorStorage(BaseVectorStorage): def _upsert_entities( self, item: dict[str, Any], current_time: datetime.datetime ) -> tuple[str, dict[str, Any]]: - upsert_sql = SQL_TEMPLATES["upsert_entity"] + upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name) source_id = item["source_id"] if isinstance(source_id, str) and "" in source_id: chunk_ids = source_id.split("") @@ -2262,7 +2462,9 @@ class PGVectorStorage(BaseVectorStorage): def _upsert_relationships( self, item: dict[str, Any], current_time: datetime.datetime ) -> tuple[str, dict[str, Any]]: - upsert_sql = SQL_TEMPLATES["upsert_relationship"] + upsert_sql = SQL_TEMPLATES["upsert_relationship"].format( + table_name=self.table_name + ) source_id = item["source_id"] if isinstance(source_id, str) and "" in source_id: chunk_ids = source_id.split("") @@ -2335,7 +2537,9 @@ class PGVectorStorage(BaseVectorStorage): embedding_string = ",".join(map(str, embedding)) - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + sql = SQL_TEMPLATES[self.namespace].format( + embedding_string=embedding_string, table_name=self.table_name + ) params = { "workspace": self.workspace, "closer_than_threshold": 1 - self.cosine_better_than_threshold, @@ -2357,14 +2561,7 @@ class PGVectorStorage(BaseVectorStorage): if not ids: return - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}" - ) - return - - delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" + delete_sql = f"DELETE FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)" try: await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids}) @@ -2383,8 +2580,8 @@ class PGVectorStorage(BaseVectorStorage): entity_name: The name of the entity to delete """ try: - # Construct SQL to delete the entity - delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + # Construct SQL to delete the entity using dynamic table name + delete_sql = f"""DELETE FROM {self.table_name} WHERE workspace=$1 AND entity_name=$2""" await self.db.execute( @@ -2404,7 +2601,7 @@ class PGVectorStorage(BaseVectorStorage): """ try: # Delete relations where the entity is either the source or target - delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + delete_sql = f"""DELETE FROM {self.table_name} WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" await self.db.execute( @@ -3188,6 +3385,11 @@ class PGDocStatusStorage(DocStatusStorage): return {"status": "error", "message": str(e)} +class PostgreSQLMigrationError(Exception): + """Exception for PostgreSQL table migration errors.""" + pass + + class PGGraphQueryException(Exception): """Exception for the AGE queries.""" @@ -5047,7 +5249,7 @@ SQL_TEMPLATES = { update_time = EXCLUDED.update_time """, # SQL for VectorStorage - "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, + "upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -5060,7 +5262,7 @@ SQL_TEMPLATES = { file_path=EXCLUDED.file_path, update_time = EXCLUDED.update_time """, - "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, + "upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content, content_vector, chunk_ids, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9) ON CONFLICT (workspace,id) DO UPDATE @@ -5071,7 +5273,7 @@ SQL_TEMPLATES = { file_path=EXCLUDED.file_path, update_time=EXCLUDED.update_time """, - "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, + "upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id, target_id, content, content_vector, chunk_ids, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10) ON CONFLICT (workspace,id) DO UPDATE @@ -5087,7 +5289,7 @@ SQL_TEMPLATES = { SELECT r.source_id AS src_id, r.target_id AS tgt_id, EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_RELATION r + FROM {table_name} r WHERE r.workspace = $1 AND r.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY r.content_vector <=> '[{embedding_string}]'::vector @@ -5096,7 +5298,7 @@ SQL_TEMPLATES = { "entities": """ SELECT e.entity_name, EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_ENTITY e + FROM {table_name} e WHERE e.workspace = $1 AND e.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY e.content_vector <=> '[{embedding_string}]'::vector @@ -5107,7 +5309,7 @@ SQL_TEMPLATES = { c.content, c.file_path, EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_CHUNKS c + FROM {table_name} c WHERE c.workspace = $1 AND c.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY c.content_vector <=> '[{embedding_string}]'::vector diff --git a/tests/test_postgres_migration.py b/tests/test_postgres_migration.py new file mode 100644 index 00000000..2ca6c770 --- /dev/null +++ b/tests/test_postgres_migration.py @@ -0,0 +1,366 @@ +import os +import pytest +from unittest.mock import MagicMock, patch, AsyncMock, call +import numpy as np +from lightrag.utils import EmbeddingFunc +from lightrag.kg.postgres_impl import ( + PGVectorStorage, + _pg_table_exists, + _pg_create_table, + PostgreSQLMigrationError, +) +from lightrag.namespace import NameSpace + + +# Mock PostgreSQLDB +@pytest.fixture +def mock_pg_db(): + """Mock PostgreSQL database connection""" + db = AsyncMock() + db.workspace = "test_workspace" + + # Mock query responses + db.query = AsyncMock(return_value={"exists": False, "count": 0}) + db.execute = AsyncMock() + db.fetch = AsyncMock(return_value=[]) + + return db + + +# Mock get_data_init_lock to avoid async lock issues in tests +@pytest.fixture(autouse=True) +def mock_data_init_lock(): + with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock: + mock_lock_ctx = AsyncMock() + mock_lock.return_value = mock_lock_ctx + yield mock_lock + + +# Mock ClientManager +@pytest.fixture +def mock_client_manager(mock_pg_db): + with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager: + mock_manager.get_client = AsyncMock(return_value=mock_pg_db) + mock_manager.release_client = AsyncMock() + yield mock_manager + + +# Mock Embedding function +@pytest.fixture +def mock_embedding_func(): + async def embed_func(texts, **kwargs): + return np.array([[0.1] * 768 for _ in texts]) + + func = EmbeddingFunc( + embedding_dim=768, + func=embed_func, + model_name="test_model" + ) + return func + + +@pytest.mark.asyncio +async def test_postgres_table_naming(mock_client_manager, mock_pg_db, mock_embedding_func): + """Test if table name is correctly generated with model suffix""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws" + ) + + # Verify table name contains model suffix + expected_suffix = "test_model_768d" + assert expected_suffix in storage.table_name + assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}" + + # Verify legacy table name + assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS" + + +@pytest.mark.asyncio +async def test_postgres_migration_trigger(mock_client_manager, mock_pg_db, mock_embedding_func): + """Test if migration logic is triggered correctly""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws" + ) + + # Setup mocks for migration scenario + # 1. New table does not exist, legacy table exists + async def mock_table_exists(db, table_name): + return table_name == storage.legacy_table_name + + # 2. Legacy table has 100 records + async def mock_query(sql, params): + if "COUNT(*)" in sql: + return {"count": 100} + return {} + + # 3. Mock fetch for batch migration + mock_rows = [ + {"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"} + for i in range(100) + ] + + async def mock_fetch(sql, params): + offset = params[0] if params else 0 + limit = params[1] if len(params) > 1 else 500 + start = offset + end = min(offset + limit, len(mock_rows)) + return mock_rows[start:end] + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + mock_pg_db.fetch = AsyncMock(side_effect=mock_fetch) + + with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \ + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()): + + # Initialize storage (should trigger migration) + await storage.initialize() + + # Verify migration was executed + # Check that execute was called for inserting rows + assert mock_pg_db.execute.call_count > 0 + + +@pytest.mark.asyncio +async def test_postgres_no_migration_needed(mock_client_manager, mock_pg_db, mock_embedding_func): + """Test scenario where new table already exists (no migration needed)""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws" + ) + + # Mock: new table already exists + async def mock_table_exists(db, table_name): + return table_name == storage.table_name + + with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \ + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create: + + await storage.initialize() + + # Verify no table creation was attempted + mock_create.assert_not_called() + + +@pytest.mark.asyncio +async def test_scenario_1_new_workspace_creation(mock_client_manager, mock_pg_db, mock_embedding_func): + """ + Scenario 1: New workspace creation + + Expected behavior: + - No legacy table exists + - Directly create new table with model suffix + - No migration needed + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + func=mock_embedding_func.func, + model_name="text-embedding-3-large" + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="new_workspace" + ) + + # Mock: neither table exists + async def mock_table_exists(db, table_name): + return False + + with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \ + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create: + + await storage.initialize() + + # Verify table name format + assert "text_embedding_3_large_3072d" in storage.table_name + + # Verify new table creation was called + mock_create.assert_called_once() + call_args = mock_create.call_args + assert call_args[0][1] == storage.table_name # table_name is second positional arg + + +@pytest.mark.asyncio +async def test_scenario_2_legacy_upgrade_migration(mock_client_manager, mock_pg_db, mock_embedding_func): + """ + Scenario 2: Upgrade from legacy version + + Expected behavior: + - Legacy table exists (without model suffix) + - New table doesn't exist + - Automatically migrate data to new table with suffix + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="text-embedding-ada-002" + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="legacy_workspace" + ) + + # Mock: only legacy table exists + async def mock_table_exists(db, table_name): + return table_name == storage.legacy_table_name + + # Mock: legacy table has 50 records + async def mock_query(sql, params): + if "COUNT(*)" in sql: + # First call for legacy count, then for verification + if storage.legacy_table_name in sql: + return {"count": 50} + else: + return {"count": 50} + return {} + + # Mock fetch for migration + mock_rows = [ + {"id": f"legacy_id_{i}", "content": f"legacy_content_{i}", "workspace": "legacy_workspace"} + for i in range(50) + ] + + async def mock_fetch(sql, params): + offset = params[0] if params else 0 + limit = params[1] if len(params) > 1 else 500 + start = offset + end = min(offset + limit, len(mock_rows)) + return mock_rows[start:end] + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + mock_pg_db.fetch = AsyncMock(side_effect=mock_fetch) + + with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \ + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create: + + await storage.initialize() + + # Verify table name contains ada-002 + assert "text_embedding_ada_002_1536d" in storage.table_name + + # Verify migration was executed + assert mock_pg_db.execute.call_count >= 50 # At least one execute per row + mock_create.assert_called_once() + + +@pytest.mark.asyncio +async def test_scenario_3_multi_model_coexistence(mock_client_manager, mock_pg_db, mock_embedding_func): + """ + Scenario 3: Multiple embedding models coexist + + Expected behavior: + - Different embedding models create separate tables + - Tables are isolated by model suffix + - No interference between different models + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + # Workspace A: uses bge-small (768d) + embedding_func_a = EmbeddingFunc( + embedding_dim=768, + func=mock_embedding_func.func, + model_name="bge-small" + ) + + storage_a = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func_a, + workspace="workspace_a" + ) + + # Workspace B: uses bge-large (1024d) + async def embed_func_b(texts, **kwargs): + return np.array([[0.1] * 1024 for _ in texts]) + + embedding_func_b = EmbeddingFunc( + embedding_dim=1024, + func=embed_func_b, + model_name="bge-large" + ) + + storage_b = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func_b, + workspace="workspace_b" + ) + + # Verify different table names + assert storage_a.table_name != storage_b.table_name + assert "bge_small_768d" in storage_a.table_name + assert "bge_large_1024d" in storage_b.table_name + + # Mock: both tables don't exist yet + async def mock_table_exists(db, table_name): + return False + + with patch("lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists), \ + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create: + + # Initialize both storages + await storage_a.initialize() + await storage_b.initialize() + + # Verify two separate tables were created + assert mock_create.call_count == 2 + + # Verify table names are different + call_args_list = mock_create.call_args_list + table_names = [call[0][1] for call in call_args_list] # Second positional arg + assert len(set(table_names)) == 2 # Two unique table names + assert storage_a.table_name in table_names + assert storage_b.table_name in table_names diff --git a/tests/test_qdrant_migration.py b/tests/test_qdrant_migration.py index 0b49163a..2a343012 100644 --- a/tests/test_qdrant_migration.py +++ b/tests/test_qdrant_migration.py @@ -161,3 +161,208 @@ async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_fun # In Qdrant implementation, Case 2 calls get_collection mock_qdrant_client.get_collection.assert_called_with(storage.final_namespace) mock_qdrant_client.scroll.assert_not_called() + + +# ============================================================================ +# Tests for scenarios described in design document (Lines 606-649) +# ============================================================================ + +@pytest.mark.asyncio +async def test_scenario_1_new_workspace_creation(mock_qdrant_client, mock_embedding_func): + """ + 场景1:新建workspace + 预期:直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d + """ + # Use a large embedding model + large_model_func = EmbeddingFunc( + embedding_dim=3072, + func=mock_embedding_func.func, + model_name="text-embedding-3-large" + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=large_model_func, + workspace="test_new" + ) + + # Case 3: Neither legacy nor new collection exists + mock_qdrant_client.collection_exists.return_value = False + + # Initialize storage + await storage.initialize() + + # Verify: Should create new collection with model suffix + expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d" + assert storage.final_namespace == expected_collection + + # Verify create_collection was called with correct name + create_calls = [call for call in mock_qdrant_client.create_collection.call_args_list] + assert len(create_calls) > 0 + assert create_calls[0][0][0] == expected_collection or create_calls[0].kwargs.get('collection_name') == expected_collection + + # Verify no migration was attempted + mock_qdrant_client.scroll.assert_not_called() + + print(f"✅ Scenario 1: New workspace created with collection '{expected_collection}'") + + +@pytest.mark.asyncio +async def test_scenario_2_legacy_upgrade_migration(mock_qdrant_client, mock_embedding_func): + """ + 场景2:从旧版本升级 + 已存在lightrag_vdb_chunks(无后缀) + 预期:自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d + """ + # Use ada-002 model + ada_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="text-embedding-ada-002" + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=ada_func, + workspace="test_legacy" + ) + + legacy_collection = storage.legacy_namespace + new_collection = storage.final_namespace + + # Case 4: Only legacy collection exists + mock_qdrant_client.collection_exists.side_effect = lambda name: name == legacy_collection + + # Mock legacy data + mock_qdrant_client.count.return_value.count = 150 + + # Mock scroll results (simulate migration in batches) + from qdrant_client import models + mock_points = [] + for i in range(10): + point = MagicMock() + point.id = f"legacy-{i}" + point.vector = [0.1] * 1536 + point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"} + mock_points.append(point) + + # First batch returns points, second batch returns empty + mock_qdrant_client.scroll.side_effect = [ + (mock_points, "offset1"), + ([], None) + ] + + # Initialize (triggers migration) + await storage.initialize() + + # Verify: New collection should be created + expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d" + assert storage.final_namespace == expected_new_collection + + # Verify migration steps + # 1. Check legacy count + mock_qdrant_client.count.assert_any_call( + collection_name=legacy_collection, + exact=True + ) + + # 2. Create new collection + mock_qdrant_client.create_collection.assert_called() + + # 3. Scroll legacy data + scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list] + assert len(scroll_calls) >= 1 + assert scroll_calls[0].kwargs['collection_name'] == legacy_collection + + # 4. Upsert to new collection + upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list] + assert len(upsert_calls) >= 1 + assert upsert_calls[0].kwargs['collection_name'] == new_collection + + print(f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}'") + + +@pytest.mark.asyncio +async def test_scenario_3_multi_model_coexistence(mock_qdrant_client): + """ + 场景3:多模型并存 + 预期:两个独立的collection,互不干扰 + """ + # Model A: bge-small with 768d + async def embed_func_a(texts, **kwargs): + return np.array([[0.1] * 768 for _ in texts]) + + model_a_func = EmbeddingFunc( + embedding_dim=768, + func=embed_func_a, + model_name="bge-small" + ) + + # Model B: bge-large with 1024d + async def embed_func_b(texts, **kwargs): + return np.array([[0.2] * 1024 for _ in texts]) + + model_b_func = EmbeddingFunc( + embedding_dim=1024, + func=embed_func_b, + model_name="bge-large" + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": { + "cosine_better_than_threshold": 0.8 + } + } + + # Create storage for workspace A with model A + storage_a = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=model_a_func, + workspace="workspace_a" + ) + + # Create storage for workspace B with model B + storage_b = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=model_b_func, + workspace="workspace_b" + ) + + # Verify: Collection names are different + assert storage_a.final_namespace != storage_b.final_namespace + + # Verify: Model A collection + expected_collection_a = "lightrag_vdb_chunks_bge_small_768d" + assert storage_a.final_namespace == expected_collection_a + + # Verify: Model B collection + expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d" + assert storage_b.final_namespace == expected_collection_b + + # Verify: Different embedding dimensions are preserved + assert storage_a.embedding_func.embedding_dim == 768 + assert storage_b.embedding_func.embedding_dim == 1024 + + print(f"✅ Scenario 3: Multi-model coexistence verified") + print(f" - Workspace A: {expected_collection_a} (768d)") + print(f" - Workspace B: {expected_collection_b} (1024d)") + print(f" - Collections are independent") diff --git a/uv.lock b/uv.lock index 97703af0..019f7539 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and python_full_version < '4' and platform_machine == 'x86_64' and sys_platform == 'darwin'", @@ -2735,7 +2735,6 @@ requires-dist = [ { name = "json-repair", marker = "extra == 'api'" }, { name = "langfuse", marker = "extra == 'observability'", specifier = ">=3.8.1" }, { name = "lightrag-hku", extras = ["api", "offline-llm", "offline-storage"], marker = "extra == 'offline'" }, - { name = "lightrag-hku", extras = ["pytest"], marker = "extra == 'evaluation'" }, { name = "llama-index", marker = "extra == 'offline-llm'", specifier = ">=0.9.0,<1.0.0" }, { name = "nano-vectordb" }, { name = "nano-vectordb", marker = "extra == 'api'" }, @@ -2753,6 +2752,7 @@ requires-dist = [ { name = "passlib", extras = ["bcrypt"], marker = "extra == 'api'" }, { name = "pipmaster" }, { name = "pipmaster", marker = "extra == 'api'" }, + { name = "pre-commit", marker = "extra == 'evaluation'" }, { name = "pre-commit", marker = "extra == 'pytest'" }, { name = "psutil", marker = "extra == 'api'" }, { name = "pycryptodome", marker = "extra == 'api'", specifier = ">=3.0.0,<4.0.0" }, @@ -2764,7 +2764,9 @@ requires-dist = [ { name = "pypdf", marker = "extra == 'api'", specifier = ">=6.1.0" }, { name = "pypinyin" }, { name = "pypinyin", marker = "extra == 'api'" }, + { name = "pytest", marker = "extra == 'evaluation'", specifier = ">=8.4.2" }, { name = "pytest", marker = "extra == 'pytest'", specifier = ">=8.4.2" }, + { name = "pytest-asyncio", marker = "extra == 'evaluation'", specifier = ">=1.2.0" }, { name = "pytest-asyncio", marker = "extra == 'pytest'", specifier = ">=1.2.0" }, { name = "python-docx", marker = "extra == 'api'", specifier = ">=0.8.11,<2.0.0" }, { name = "python-dotenv" },