From 5cbf8566862526958667a51b01e5a14742ef1abd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:19:02 +0800 Subject: [PATCH] cherry-pick 1864b282 --- lightrag/tools/migrate_llm_cache.py | 658 +++++++++++++++++++++++++--- 1 file changed, 607 insertions(+), 51 deletions(-) diff --git a/lightrag/tools/migrate_llm_cache.py b/lightrag/tools/migrate_llm_cache.py index 26b7c81c..db6933b2 100644 --- a/lightrag/tools/migrate_llm_cache.py +++ b/lightrag/tools/migrate_llm_cache.py @@ -6,7 +6,9 @@ This tool migrates LLM response cache (default:extract:* and default:summary:*) between different KV storage implementations while preserving workspace isolation. Usage: - python tools/migrate_llm_cache.py + python -m lightrag.tools.migrate_llm_cache + # or + python lightrag/tools/migrate_llm_cache.py Supported KV Storage Types: - JsonKVStorage @@ -23,14 +25,19 @@ from typing import Any, Dict, List from dataclasses import dataclass, field from dotenv import load_dotenv -# Add parent directory to path for imports -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# Add project root to path for imports +sys.path.insert( + 0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) from lightrag.kg import STORAGE_ENV_REQUIREMENTS from lightrag.namespace import NameSpace from lightrag.utils import setup_logger # Load environment variables +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) # Setup logger @@ -55,6 +62,14 @@ WORKSPACE_ENV_MAP = { DEFAULT_BATCH_SIZE = 1000 +# Default count batch size for efficient counting +DEFAULT_COUNT_BATCH_SIZE = 1000 + +# ANSI color codes for terminal output +BOLD_CYAN = "\033[1;36m" +RESET = "\033[0m" + + @dataclass class MigrationStats: """Migration statistics and error tracking""" @@ -413,6 +428,364 @@ class MigrationTool: else: raise ValueError(f"Unsupported storage type: {storage_name}") + async def count_default_caches_json(self, storage) -> int: + """Count default caches in JsonKVStorage - O(N) but very fast in-memory + + Args: + storage: JsonKVStorage instance + + Returns: + Total count of cache records + """ + async with storage._storage_lock: + return sum( + 1 + for key in storage._data.keys() + if key.startswith("default:extract:") + or key.startswith("default:summary:") + ) + + async def count_default_caches_redis(self, storage) -> int: + """Count default caches in RedisKVStorage using SCAN with progress display + + Args: + storage: RedisKVStorage instance + + Returns: + Total count of cache records + """ + count = 0 + print("Scanning Redis keys...", end="", flush=True) + + async with storage._get_redis_connection() as redis: + for pattern in ["default:extract:*", "default:summary:*"]: + prefixed_pattern = f"{storage.final_namespace}:{pattern}" + cursor = 0 + while True: + cursor, keys = await redis.scan( + cursor, match=prefixed_pattern, count=DEFAULT_COUNT_BATCH_SIZE + ) + count += len(keys) + + # Show progress + print( + f"\rScanning Redis keys... found {count:,} records", + end="", + flush=True, + ) + + if cursor == 0: + break + + print() # New line after progress + return count + + async def count_default_caches_pg(self, storage) -> int: + """Count default caches in PostgreSQL using COUNT(*) with progress indicator + + Args: + storage: PGKVStorage instance + + Returns: + Total count of cache records + """ + from lightrag.kg.postgres_impl import namespace_to_table_name + + table_name = namespace_to_table_name(storage.namespace) + + query = f""" + SELECT COUNT(*) as count + FROM {table_name} + WHERE workspace = $1 + AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%') + """ + + print("Counting PostgreSQL records...", end="", flush=True) + start_time = time.time() + + result = await storage.db.query(query, [storage.workspace]) + + elapsed = time.time() - start_time + if elapsed > 1: + print(f" (took {elapsed:.1f}s)", end="") + print() # New line + + return result["count"] if result else 0 + + async def count_default_caches_mongo(self, storage) -> int: + """Count default caches in MongoDB using count_documents with progress indicator + + Args: + storage: MongoKVStorage instance + + Returns: + Total count of cache records + """ + query = {"_id": {"$regex": "^default:(extract|summary):"}} + + print("Counting MongoDB documents...", end="", flush=True) + start_time = time.time() + + count = await storage._data.count_documents(query) + + elapsed = time.time() - start_time + if elapsed > 1: + print(f" (took {elapsed:.1f}s)", end="") + print() # New line + + return count + + async def count_default_caches(self, storage, storage_name: str) -> int: + """Count default caches from any storage type efficiently + + Args: + storage: Storage instance + storage_name: Storage type name + + Returns: + Total count of cache records + """ + if storage_name == "JsonKVStorage": + return await self.count_default_caches_json(storage) + elif storage_name == "RedisKVStorage": + return await self.count_default_caches_redis(storage) + elif storage_name == "PGKVStorage": + return await self.count_default_caches_pg(storage) + elif storage_name == "MongoKVStorage": + return await self.count_default_caches_mongo(storage) + else: + raise ValueError(f"Unsupported storage type: {storage_name}") + + async def stream_default_caches_json(self, storage, batch_size: int): + """Stream default caches from JsonKVStorage - yields batches + + Args: + storage: JsonKVStorage instance + batch_size: Size of each batch to yield + + Yields: + Dictionary batches of cache entries + + Note: + This method creates a snapshot of matching items while holding the lock, + then releases the lock before yielding batches. This prevents deadlock + when the target storage (also JsonKVStorage) tries to acquire the same + lock during upsert operations. + """ + # Create a snapshot of matching items while holding the lock + async with storage._storage_lock: + matching_items = [ + (key, value) + for key, value in storage._data.items() + if key.startswith("default:extract:") + or key.startswith("default:summary:") + ] + + # Now iterate over snapshot without holding lock + batch = {} + for key, value in matching_items: + batch[key] = value + if len(batch) >= batch_size: + yield batch + batch = {} + + # Yield remaining items + if batch: + yield batch + + async def stream_default_caches_redis(self, storage, batch_size: int): + """Stream default caches from RedisKVStorage - yields batches + + Args: + storage: RedisKVStorage instance + batch_size: Size of each batch to yield + + Yields: + Dictionary batches of cache entries + """ + import json + + async with storage._get_redis_connection() as redis: + for pattern in ["default:extract:*", "default:summary:*"]: + prefixed_pattern = f"{storage.final_namespace}:{pattern}" + cursor = 0 + + while True: + cursor, keys = await redis.scan( + cursor, match=prefixed_pattern, count=batch_size + ) + + if keys: + try: + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + batch = {} + for key, value in zip(keys, values): + if value: + key_str = ( + key.decode() if isinstance(key, bytes) else key + ) + original_key = key_str.replace( + f"{storage.final_namespace}:", "", 1 + ) + batch[original_key] = json.loads(value) + + if batch: + yield batch + + except Exception as e: + print(f"⚠️ Pipeline execution failed for batch: {e}") + # Fall back to individual gets + batch = {} + for key in keys: + try: + value = await redis.get(key) + if value: + key_str = ( + key.decode() + if isinstance(key, bytes) + else key + ) + original_key = key_str.replace( + f"{storage.final_namespace}:", "", 1 + ) + batch[original_key] = json.loads(value) + except Exception as individual_error: + print( + f"⚠️ Failed to get individual key {key}: {individual_error}" + ) + continue + + if batch: + yield batch + + if cursor == 0: + break + + await asyncio.sleep(0) + + async def stream_default_caches_pg(self, storage, batch_size: int): + """Stream default caches from PostgreSQL - yields batches + + Args: + storage: PGKVStorage instance + batch_size: Size of each batch to yield + + Yields: + Dictionary batches of cache entries + """ + from lightrag.kg.postgres_impl import namespace_to_table_name + + table_name = namespace_to_table_name(storage.namespace) + offset = 0 + + while True: + query = f""" + SELECT id as key, original_prompt, return_value, chunk_id, cache_type, queryparam, + EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, + EXTRACT(EPOCH FROM update_time)::BIGINT as update_time + FROM {table_name} + WHERE workspace = $1 + AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%') + ORDER BY id + LIMIT $2 OFFSET $3 + """ + + results = await storage.db.query( + query, [storage.workspace, batch_size, offset], multirows=True + ) + + if not results: + break + + batch = {} + for row in results: + cache_entry = { + "return": row.get("return_value", ""), + "cache_type": row.get("cache_type"), + "original_prompt": row.get("original_prompt", ""), + "chunk_id": row.get("chunk_id"), + "queryparam": row.get("queryparam"), + "create_time": row.get("create_time", 0), + "update_time": row.get("update_time", 0), + } + batch[row["key"]] = cache_entry + + if batch: + yield batch + + if len(results) < batch_size: + break + + offset += batch_size + await asyncio.sleep(0) + + async def stream_default_caches_mongo(self, storage, batch_size: int): + """Stream default caches from MongoDB - yields batches + + Args: + storage: MongoKVStorage instance + batch_size: Size of each batch to yield + + Yields: + Dictionary batches of cache entries + """ + query = {"_id": {"$regex": "^default:(extract|summary):"}} + cursor = storage._data.find(query).batch_size(batch_size) + + batch = {} + async for doc in cursor: + doc_copy = doc.copy() + key = doc_copy.pop("_id") + + # Filter MongoDB/database-specific fields + for field_name in ["namespace", "workspace", "_id", "content"]: + doc_copy.pop(field_name, None) + + batch[key] = doc_copy + + if len(batch) >= batch_size: + yield batch + batch = {} + + # Yield remaining items + if batch: + yield batch + + async def stream_default_caches( + self, storage, storage_name: str, batch_size: int = None + ): + """Stream default caches from any storage type - unified interface + + Args: + storage: Storage instance + storage_name: Storage type name + batch_size: Size of each batch to yield (defaults to self.batch_size) + + Yields: + Dictionary batches of cache entries + """ + if batch_size is None: + batch_size = self.batch_size + + if storage_name == "JsonKVStorage": + async for batch in self.stream_default_caches_json(storage, batch_size): + yield batch + elif storage_name == "RedisKVStorage": + async for batch in self.stream_default_caches_redis(storage, batch_size): + yield batch + elif storage_name == "PGKVStorage": + async for batch in self.stream_default_caches_pg(storage, batch_size): + yield batch + elif storage_name == "MongoKVStorage": + async for batch in self.stream_default_caches_mongo(storage, batch_size): + yield batch + else: + raise ValueError(f"Unsupported storage type: {storage_name}") + async def count_cache_types(self, cache_data: Dict[str, Any]) -> Dict[str, int]: """Count cache entries by type @@ -447,43 +820,107 @@ class MigrationTool: for key, value in STORAGE_TYPES.items(): print(f"[{key}] {value}") - def get_user_choice(self, prompt: str, valid_choices: list) -> str: - """Get user choice with validation + def format_workspace(self, workspace: str) -> str: + """Format workspace name with highlighting Args: - prompt: Prompt message - valid_choices: List of valid choices + workspace: Workspace name (may be empty) Returns: - User's choice + Formatted workspace string with ANSI color codes """ - while True: - choice = input(f"\n{prompt}: ").strip() - if choice in valid_choices: - return choice - print(f"✗ Invalid choice, please enter one of: {', '.join(valid_choices)}") + if workspace: + return f"{BOLD_CYAN}{workspace}{RESET}" + else: + return f"{BOLD_CYAN}(default){RESET}" - async def setup_storage(self, storage_type: str) -> tuple: + def format_storage_name(self, storage_name: str) -> str: + """Format storage type name with highlighting + + Args: + storage_name: Storage type name + + Returns: + Formatted storage name string with ANSI color codes + """ + return f"{BOLD_CYAN}{storage_name}{RESET}" + + async def setup_storage( + self, + storage_type: str, + use_streaming: bool = False, + exclude_storage_name: str = None, + ) -> tuple: """Setup and initialize storage Args: storage_type: Type label (source/target) + use_streaming: If True, only count records without loading. If False, load all data (legacy mode) + exclude_storage_name: Storage type to exclude from selection (e.g., to prevent selecting same as source) Returns: - Tuple of (storage_instance, storage_name, workspace, cache_data) + Tuple of (storage_instance, storage_name, workspace, total_count) + Returns (None, None, None, 0) if user chooses to exit """ print(f"\n=== {storage_type} Storage Setup ===") - # Get storage type choice - choice = self.get_user_choice( - f"Select {storage_type} storage type (1-4)", list(STORAGE_TYPES.keys()) - ) - storage_name = STORAGE_TYPES[choice] + # Filter and remap available storage types if exclusion is specified + if exclude_storage_name: + # Get available storage types (excluding source) + available_list = [ + (k, v) for k, v in STORAGE_TYPES.items() if v != exclude_storage_name + ] + + # Remap to sequential numbering (1, 2, 3...) + remapped_types = { + str(i + 1): name for i, (_, name) in enumerate(available_list) + } + + # Print available types with new sequential numbers + print( + f"\nAvailable Storage Types for Target (source: {exclude_storage_name} excluded):" + ) + for key, value in remapped_types.items(): + print(f"[{key}] {value}") + + available_types = remapped_types + else: + # For source storage, use original numbering + available_types = STORAGE_TYPES.copy() + self.print_storage_types() + + # Generate dynamic prompt based on number of options + num_options = len(available_types) + if num_options == 1: + prompt_range = "1" + else: + prompt_range = f"1-{num_options}" + + # Custom input handling with exit support + while True: + choice = input( + f"\nSelect {storage_type} storage type ({prompt_range}) (Press Enter or 0 to exit): " + ).strip() + + # Check for exit + if choice == "" or choice == "0": + print("\n✓ Migration cancelled by user") + return None, None, None, 0 + + # Check if choice is valid + if choice in available_types: + break + + print( + f"✗ Invalid choice. Please enter one of: {', '.join(available_types.keys())}" + ) + + storage_name = available_types[choice] # Check environment variables print("\nChecking environment variables...") if not self.check_env_vars(storage_name): - return None, None, None, None + return None, None, None, 0 # Get workspace workspace = self.get_workspace_for_storage(storage_name) @@ -497,27 +934,34 @@ class MigrationTool: print("- Connection Status: ✓ Success") except Exception as e: print(f"✗ Initialization failed: {e}") - return None, None, None, None + return None, None, None, 0 - # Get cache data - print("\nCounting cache records...") + # Count cache records efficiently + print(f"\n{'Counting' if use_streaming else 'Loading'} cache records...") try: - cache_data = await self.get_default_caches(storage, storage_name) - counts = await self.count_cache_types(cache_data) + if use_streaming: + # Use efficient counting without loading data + total_count = await self.count_default_caches(storage, storage_name) + print(f"- Total: {total_count:,} records") + else: + # Legacy mode: load all data + cache_data = await self.get_default_caches(storage, storage_name) + counts = await self.count_cache_types(cache_data) + total_count = len(cache_data) - print(f"- default:extract: {counts['extract']:,} records") - print(f"- default:summary: {counts['summary']:,} records") - print(f"- Total: {len(cache_data):,} records") + print(f"- default:extract: {counts['extract']:,} records") + print(f"- default:summary: {counts['summary']:,} records") + print(f"- Total: {total_count:,} records") except Exception as e: - print(f"✗ Counting failed: {e}") - return None, None, None, None + print(f"✗ {'Counting' if use_streaming else 'Loading'} failed: {e}") + return None, None, None, 0 - return storage, storage_name, workspace, cache_data + return storage, storage_name, workspace, total_count async def migrate_caches( self, source_data: Dict[str, Any], target_storage, target_storage_name: str ) -> MigrationStats: - """Migrate caches in batches with error tracking + """Migrate caches in batches with error tracking (Legacy mode - loads all data) Args: source_data: Source cache data @@ -592,6 +1036,98 @@ class MigrationTool: return stats + async def migrate_caches_streaming( + self, + source_storage, + source_storage_name: str, + target_storage, + target_storage_name: str, + total_records: int, + ) -> MigrationStats: + """Migrate caches using streaming approach - minimal memory footprint + + Args: + source_storage: Source storage instance + source_storage_name: Source storage type name + target_storage: Target storage instance + target_storage_name: Target storage type name + total_records: Total number of records to migrate + + Returns: + MigrationStats object with migration results and errors + """ + stats = MigrationStats() + stats.total_source_records = total_records + + if stats.total_source_records == 0: + print("\nNo records to migrate") + return stats + + # Calculate total batches + stats.total_batches = (total_records + self.batch_size - 1) // self.batch_size + + print("\n=== Starting Streaming Migration ===") + print( + f"💡 Memory-optimized mode: Processing {self.batch_size:,} records at a time\n" + ) + + batch_idx = 0 + + # Stream batches from source and write to target immediately + async for batch in self.stream_default_caches( + source_storage, source_storage_name + ): + batch_idx += 1 + + # Determine current cache type for display + if batch: + first_key = next(iter(batch.keys())) + cache_type = "extract" if "extract" in first_key else "summary" + else: + cache_type = "unknown" + + try: + # Write batch to target storage + await target_storage.upsert(batch) + + # Success - update stats + stats.successful_batches += 1 + stats.successful_records += len(batch) + + # Calculate progress with known total + progress = (stats.successful_records / total_records) * 100 + bar_length = 20 + filled_length = int( + bar_length * stats.successful_records // total_records + ) + bar = "█" * filled_length + "░" * (bar_length - filled_length) + + print( + f"Batch {batch_idx}/{stats.total_batches}: {bar} " + f"{stats.successful_records:,}/{total_records:,} ({progress:.1f}%) - " + f"default:{cache_type} ✓" + ) + + except Exception as e: + # Error - record and continue + stats.add_error(batch_idx, e, len(batch)) + + print( + f"Batch {batch_idx}/{stats.total_batches}: ✗ FAILED - " + f"{type(e).__name__}: {str(e)}" + ) + + # Final persist + print("\nPersisting data to disk...") + try: + await target_storage.index_done_callback() + print("✓ Data persisted successfully") + except Exception as e: + print(f"✗ Persist failed: {e}") + stats.add_error(0, e, 0) # batch 0 = persist error + + return stats + def print_migration_report(self, stats: MigrationStats): """Print comprehensive migration report @@ -655,37 +1191,44 @@ class MigrationTool: print("=" * 60) async def run(self): - """Run the migration tool""" + """Run the migration tool with streaming approach""" try: + # Initialize shared storage (REQUIRED for storage classes to work) + from lightrag.kg.shared_storage import initialize_share_data + + initialize_share_data(workers=1) + # Print header self.print_header() - self.print_storage_types() - # Setup source storage + # Setup source storage with streaming (only count, don't load all data) ( self.source_storage, source_storage_name, self.source_workspace, - source_data, - ) = await self.setup_storage("Source") + source_count, + ) = await self.setup_storage("Source", use_streaming=True) - if not self.source_storage: - print("\n✗ Source storage setup failed") + # Check if user cancelled (setup_storage returns None for all fields) + if self.source_storage is None: return - if not source_data: + if source_count == 0: print("\n⚠ Source storage has no cache records to migrate") # Cleanup await self.source_storage.finalize() return - # Setup target storage + # Setup target storage with streaming (only count, don't load all data) + # Exclude source storage type from target selection ( self.target_storage, target_storage_name, self.target_workspace, - target_data, - ) = await self.setup_storage("Target") + target_count, + ) = await self.setup_storage( + "Target", use_streaming=True, exclude_storage_name=source_storage_name + ) if not self.target_storage: print("\n✗ Target storage setup failed") @@ -698,16 +1241,17 @@ class MigrationTool: print("Migration Confirmation") print("=" * 50) print( - f"Source: {source_storage_name} (workspace: {self.source_workspace if self.source_workspace else '(default)'}) - {len(source_data):,} records" + f"Source: {self.format_storage_name(source_storage_name)} (workspace: {self.format_workspace(self.source_workspace)}) - {source_count:,} records" ) print( - f"Target: {target_storage_name} (workspace: {self.target_workspace if self.target_workspace else '(default)'}) - {len(target_data):,} records" + f"Target: {self.format_storage_name(target_storage_name)} (workspace: {self.format_workspace(self.target_workspace)}) - {target_count:,} records" ) print(f"Batch Size: {self.batch_size:,} records/batch") + print("Memory Mode: Streaming (memory-optimized)") - if target_data: + if target_count > 0: print( - f"\n⚠ Warning: Target storage already has {len(target_data):,} records" + f"\n⚠️ Warning: Target storage already has {target_count:,} records" ) print("Migration will overwrite records with the same keys") @@ -720,9 +1264,13 @@ class MigrationTool: await self.target_storage.finalize() return - # Perform migration with error tracking - stats = await self.migrate_caches( - source_data, self.target_storage, target_storage_name + # Perform streaming migration with error tracking + stats = await self.migrate_caches_streaming( + self.source_storage, + source_storage_name, + self.target_storage, + target_storage_name, + source_count, ) # Print comprehensive migration report @@ -752,6 +1300,14 @@ class MigrationTool: except Exception: pass + # Finalize shared storage + try: + from lightrag.kg.shared_storage import finalize_share_data + + finalize_share_data() + except Exception: + pass + async def main(): """Main entry point"""