This commit is contained in:
Raphaël MANSUY 2025-12-04 19:19:02 +08:00
parent fd739ee133
commit 5cbf856686

View file

@ -6,7 +6,9 @@ This tool migrates LLM response cache (default:extract:* and default:summary:*)
between different KV storage implementations while preserving workspace isolation.
Usage:
python tools/migrate_llm_cache.py
python -m lightrag.tools.migrate_llm_cache
# or
python lightrag/tools/migrate_llm_cache.py
Supported KV Storage Types:
- JsonKVStorage
@ -23,14 +25,19 @@ from typing import Any, Dict, List
from dataclasses import dataclass, field
from dotenv import load_dotenv
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Add project root to path for imports
sys.path.insert(
0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from lightrag.kg import STORAGE_ENV_REQUIREMENTS
from lightrag.namespace import NameSpace
from lightrag.utils import setup_logger
# Load environment variables
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
# Setup logger
@ -55,6 +62,14 @@ WORKSPACE_ENV_MAP = {
DEFAULT_BATCH_SIZE = 1000
# Default count batch size for efficient counting
DEFAULT_COUNT_BATCH_SIZE = 1000
# ANSI color codes for terminal output
BOLD_CYAN = "\033[1;36m"
RESET = "\033[0m"
@dataclass
class MigrationStats:
"""Migration statistics and error tracking"""
@ -413,6 +428,364 @@ class MigrationTool:
else:
raise ValueError(f"Unsupported storage type: {storage_name}")
async def count_default_caches_json(self, storage) -> int:
"""Count default caches in JsonKVStorage - O(N) but very fast in-memory
Args:
storage: JsonKVStorage instance
Returns:
Total count of cache records
"""
async with storage._storage_lock:
return sum(
1
for key in storage._data.keys()
if key.startswith("default:extract:")
or key.startswith("default:summary:")
)
async def count_default_caches_redis(self, storage) -> int:
"""Count default caches in RedisKVStorage using SCAN with progress display
Args:
storage: RedisKVStorage instance
Returns:
Total count of cache records
"""
count = 0
print("Scanning Redis keys...", end="", flush=True)
async with storage._get_redis_connection() as redis:
for pattern in ["default:extract:*", "default:summary:*"]:
prefixed_pattern = f"{storage.final_namespace}:{pattern}"
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=prefixed_pattern, count=DEFAULT_COUNT_BATCH_SIZE
)
count += len(keys)
# Show progress
print(
f"\rScanning Redis keys... found {count:,} records",
end="",
flush=True,
)
if cursor == 0:
break
print() # New line after progress
return count
async def count_default_caches_pg(self, storage) -> int:
"""Count default caches in PostgreSQL using COUNT(*) with progress indicator
Args:
storage: PGKVStorage instance
Returns:
Total count of cache records
"""
from lightrag.kg.postgres_impl import namespace_to_table_name
table_name = namespace_to_table_name(storage.namespace)
query = f"""
SELECT COUNT(*) as count
FROM {table_name}
WHERE workspace = $1
AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%')
"""
print("Counting PostgreSQL records...", end="", flush=True)
start_time = time.time()
result = await storage.db.query(query, [storage.workspace])
elapsed = time.time() - start_time
if elapsed > 1:
print(f" (took {elapsed:.1f}s)", end="")
print() # New line
return result["count"] if result else 0
async def count_default_caches_mongo(self, storage) -> int:
"""Count default caches in MongoDB using count_documents with progress indicator
Args:
storage: MongoKVStorage instance
Returns:
Total count of cache records
"""
query = {"_id": {"$regex": "^default:(extract|summary):"}}
print("Counting MongoDB documents...", end="", flush=True)
start_time = time.time()
count = await storage._data.count_documents(query)
elapsed = time.time() - start_time
if elapsed > 1:
print(f" (took {elapsed:.1f}s)", end="")
print() # New line
return count
async def count_default_caches(self, storage, storage_name: str) -> int:
"""Count default caches from any storage type efficiently
Args:
storage: Storage instance
storage_name: Storage type name
Returns:
Total count of cache records
"""
if storage_name == "JsonKVStorage":
return await self.count_default_caches_json(storage)
elif storage_name == "RedisKVStorage":
return await self.count_default_caches_redis(storage)
elif storage_name == "PGKVStorage":
return await self.count_default_caches_pg(storage)
elif storage_name == "MongoKVStorage":
return await self.count_default_caches_mongo(storage)
else:
raise ValueError(f"Unsupported storage type: {storage_name}")
async def stream_default_caches_json(self, storage, batch_size: int):
"""Stream default caches from JsonKVStorage - yields batches
Args:
storage: JsonKVStorage instance
batch_size: Size of each batch to yield
Yields:
Dictionary batches of cache entries
Note:
This method creates a snapshot of matching items while holding the lock,
then releases the lock before yielding batches. This prevents deadlock
when the target storage (also JsonKVStorage) tries to acquire the same
lock during upsert operations.
"""
# Create a snapshot of matching items while holding the lock
async with storage._storage_lock:
matching_items = [
(key, value)
for key, value in storage._data.items()
if key.startswith("default:extract:")
or key.startswith("default:summary:")
]
# Now iterate over snapshot without holding lock
batch = {}
for key, value in matching_items:
batch[key] = value
if len(batch) >= batch_size:
yield batch
batch = {}
# Yield remaining items
if batch:
yield batch
async def stream_default_caches_redis(self, storage, batch_size: int):
"""Stream default caches from RedisKVStorage - yields batches
Args:
storage: RedisKVStorage instance
batch_size: Size of each batch to yield
Yields:
Dictionary batches of cache entries
"""
import json
async with storage._get_redis_connection() as redis:
for pattern in ["default:extract:*", "default:summary:*"]:
prefixed_pattern = f"{storage.final_namespace}:{pattern}"
cursor = 0
while True:
cursor, keys = await redis.scan(
cursor, match=prefixed_pattern, count=batch_size
)
if keys:
try:
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
batch = {}
for key, value in zip(keys, values):
if value:
key_str = (
key.decode() if isinstance(key, bytes) else key
)
original_key = key_str.replace(
f"{storage.final_namespace}:", "", 1
)
batch[original_key] = json.loads(value)
if batch:
yield batch
except Exception as e:
print(f"⚠️ Pipeline execution failed for batch: {e}")
# Fall back to individual gets
batch = {}
for key in keys:
try:
value = await redis.get(key)
if value:
key_str = (
key.decode()
if isinstance(key, bytes)
else key
)
original_key = key_str.replace(
f"{storage.final_namespace}:", "", 1
)
batch[original_key] = json.loads(value)
except Exception as individual_error:
print(
f"⚠️ Failed to get individual key {key}: {individual_error}"
)
continue
if batch:
yield batch
if cursor == 0:
break
await asyncio.sleep(0)
async def stream_default_caches_pg(self, storage, batch_size: int):
"""Stream default caches from PostgreSQL - yields batches
Args:
storage: PGKVStorage instance
batch_size: Size of each batch to yield
Yields:
Dictionary batches of cache entries
"""
from lightrag.kg.postgres_impl import namespace_to_table_name
table_name = namespace_to_table_name(storage.namespace)
offset = 0
while True:
query = f"""
SELECT id as key, original_prompt, return_value, chunk_id, cache_type, queryparam,
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
FROM {table_name}
WHERE workspace = $1
AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%')
ORDER BY id
LIMIT $2 OFFSET $3
"""
results = await storage.db.query(
query, [storage.workspace, batch_size, offset], multirows=True
)
if not results:
break
batch = {}
for row in results:
cache_entry = {
"return": row.get("return_value", ""),
"cache_type": row.get("cache_type"),
"original_prompt": row.get("original_prompt", ""),
"chunk_id": row.get("chunk_id"),
"queryparam": row.get("queryparam"),
"create_time": row.get("create_time", 0),
"update_time": row.get("update_time", 0),
}
batch[row["key"]] = cache_entry
if batch:
yield batch
if len(results) < batch_size:
break
offset += batch_size
await asyncio.sleep(0)
async def stream_default_caches_mongo(self, storage, batch_size: int):
"""Stream default caches from MongoDB - yields batches
Args:
storage: MongoKVStorage instance
batch_size: Size of each batch to yield
Yields:
Dictionary batches of cache entries
"""
query = {"_id": {"$regex": "^default:(extract|summary):"}}
cursor = storage._data.find(query).batch_size(batch_size)
batch = {}
async for doc in cursor:
doc_copy = doc.copy()
key = doc_copy.pop("_id")
# Filter MongoDB/database-specific fields
for field_name in ["namespace", "workspace", "_id", "content"]:
doc_copy.pop(field_name, None)
batch[key] = doc_copy
if len(batch) >= batch_size:
yield batch
batch = {}
# Yield remaining items
if batch:
yield batch
async def stream_default_caches(
self, storage, storage_name: str, batch_size: int = None
):
"""Stream default caches from any storage type - unified interface
Args:
storage: Storage instance
storage_name: Storage type name
batch_size: Size of each batch to yield (defaults to self.batch_size)
Yields:
Dictionary batches of cache entries
"""
if batch_size is None:
batch_size = self.batch_size
if storage_name == "JsonKVStorage":
async for batch in self.stream_default_caches_json(storage, batch_size):
yield batch
elif storage_name == "RedisKVStorage":
async for batch in self.stream_default_caches_redis(storage, batch_size):
yield batch
elif storage_name == "PGKVStorage":
async for batch in self.stream_default_caches_pg(storage, batch_size):
yield batch
elif storage_name == "MongoKVStorage":
async for batch in self.stream_default_caches_mongo(storage, batch_size):
yield batch
else:
raise ValueError(f"Unsupported storage type: {storage_name}")
async def count_cache_types(self, cache_data: Dict[str, Any]) -> Dict[str, int]:
"""Count cache entries by type
@ -447,43 +820,107 @@ class MigrationTool:
for key, value in STORAGE_TYPES.items():
print(f"[{key}] {value}")
def get_user_choice(self, prompt: str, valid_choices: list) -> str:
"""Get user choice with validation
def format_workspace(self, workspace: str) -> str:
"""Format workspace name with highlighting
Args:
prompt: Prompt message
valid_choices: List of valid choices
workspace: Workspace name (may be empty)
Returns:
User's choice
Formatted workspace string with ANSI color codes
"""
while True:
choice = input(f"\n{prompt}: ").strip()
if choice in valid_choices:
return choice
print(f"✗ Invalid choice, please enter one of: {', '.join(valid_choices)}")
if workspace:
return f"{BOLD_CYAN}{workspace}{RESET}"
else:
return f"{BOLD_CYAN}(default){RESET}"
async def setup_storage(self, storage_type: str) -> tuple:
def format_storage_name(self, storage_name: str) -> str:
"""Format storage type name with highlighting
Args:
storage_name: Storage type name
Returns:
Formatted storage name string with ANSI color codes
"""
return f"{BOLD_CYAN}{storage_name}{RESET}"
async def setup_storage(
self,
storage_type: str,
use_streaming: bool = False,
exclude_storage_name: str = None,
) -> tuple:
"""Setup and initialize storage
Args:
storage_type: Type label (source/target)
use_streaming: If True, only count records without loading. If False, load all data (legacy mode)
exclude_storage_name: Storage type to exclude from selection (e.g., to prevent selecting same as source)
Returns:
Tuple of (storage_instance, storage_name, workspace, cache_data)
Tuple of (storage_instance, storage_name, workspace, total_count)
Returns (None, None, None, 0) if user chooses to exit
"""
print(f"\n=== {storage_type} Storage Setup ===")
# Get storage type choice
choice = self.get_user_choice(
f"Select {storage_type} storage type (1-4)", list(STORAGE_TYPES.keys())
)
storage_name = STORAGE_TYPES[choice]
# Filter and remap available storage types if exclusion is specified
if exclude_storage_name:
# Get available storage types (excluding source)
available_list = [
(k, v) for k, v in STORAGE_TYPES.items() if v != exclude_storage_name
]
# Remap to sequential numbering (1, 2, 3...)
remapped_types = {
str(i + 1): name for i, (_, name) in enumerate(available_list)
}
# Print available types with new sequential numbers
print(
f"\nAvailable Storage Types for Target (source: {exclude_storage_name} excluded):"
)
for key, value in remapped_types.items():
print(f"[{key}] {value}")
available_types = remapped_types
else:
# For source storage, use original numbering
available_types = STORAGE_TYPES.copy()
self.print_storage_types()
# Generate dynamic prompt based on number of options
num_options = len(available_types)
if num_options == 1:
prompt_range = "1"
else:
prompt_range = f"1-{num_options}"
# Custom input handling with exit support
while True:
choice = input(
f"\nSelect {storage_type} storage type ({prompt_range}) (Press Enter or 0 to exit): "
).strip()
# Check for exit
if choice == "" or choice == "0":
print("\n✓ Migration cancelled by user")
return None, None, None, 0
# Check if choice is valid
if choice in available_types:
break
print(
f"✗ Invalid choice. Please enter one of: {', '.join(available_types.keys())}"
)
storage_name = available_types[choice]
# Check environment variables
print("\nChecking environment variables...")
if not self.check_env_vars(storage_name):
return None, None, None, None
return None, None, None, 0
# Get workspace
workspace = self.get_workspace_for_storage(storage_name)
@ -497,27 +934,34 @@ class MigrationTool:
print("- Connection Status: ✓ Success")
except Exception as e:
print(f"✗ Initialization failed: {e}")
return None, None, None, None
return None, None, None, 0
# Get cache data
print("\nCounting cache records...")
# Count cache records efficiently
print(f"\n{'Counting' if use_streaming else 'Loading'} cache records...")
try:
cache_data = await self.get_default_caches(storage, storage_name)
counts = await self.count_cache_types(cache_data)
if use_streaming:
# Use efficient counting without loading data
total_count = await self.count_default_caches(storage, storage_name)
print(f"- Total: {total_count:,} records")
else:
# Legacy mode: load all data
cache_data = await self.get_default_caches(storage, storage_name)
counts = await self.count_cache_types(cache_data)
total_count = len(cache_data)
print(f"- default:extract: {counts['extract']:,} records")
print(f"- default:summary: {counts['summary']:,} records")
print(f"- Total: {len(cache_data):,} records")
print(f"- default:extract: {counts['extract']:,} records")
print(f"- default:summary: {counts['summary']:,} records")
print(f"- Total: {total_count:,} records")
except Exception as e:
print(f"✗ Counting failed: {e}")
return None, None, None, None
print(f"{'Counting' if use_streaming else 'Loading'} failed: {e}")
return None, None, None, 0
return storage, storage_name, workspace, cache_data
return storage, storage_name, workspace, total_count
async def migrate_caches(
self, source_data: Dict[str, Any], target_storage, target_storage_name: str
) -> MigrationStats:
"""Migrate caches in batches with error tracking
"""Migrate caches in batches with error tracking (Legacy mode - loads all data)
Args:
source_data: Source cache data
@ -592,6 +1036,98 @@ class MigrationTool:
return stats
async def migrate_caches_streaming(
self,
source_storage,
source_storage_name: str,
target_storage,
target_storage_name: str,
total_records: int,
) -> MigrationStats:
"""Migrate caches using streaming approach - minimal memory footprint
Args:
source_storage: Source storage instance
source_storage_name: Source storage type name
target_storage: Target storage instance
target_storage_name: Target storage type name
total_records: Total number of records to migrate
Returns:
MigrationStats object with migration results and errors
"""
stats = MigrationStats()
stats.total_source_records = total_records
if stats.total_source_records == 0:
print("\nNo records to migrate")
return stats
# Calculate total batches
stats.total_batches = (total_records + self.batch_size - 1) // self.batch_size
print("\n=== Starting Streaming Migration ===")
print(
f"💡 Memory-optimized mode: Processing {self.batch_size:,} records at a time\n"
)
batch_idx = 0
# Stream batches from source and write to target immediately
async for batch in self.stream_default_caches(
source_storage, source_storage_name
):
batch_idx += 1
# Determine current cache type for display
if batch:
first_key = next(iter(batch.keys()))
cache_type = "extract" if "extract" in first_key else "summary"
else:
cache_type = "unknown"
try:
# Write batch to target storage
await target_storage.upsert(batch)
# Success - update stats
stats.successful_batches += 1
stats.successful_records += len(batch)
# Calculate progress with known total
progress = (stats.successful_records / total_records) * 100
bar_length = 20
filled_length = int(
bar_length * stats.successful_records // total_records
)
bar = "" * filled_length + "" * (bar_length - filled_length)
print(
f"Batch {batch_idx}/{stats.total_batches}: {bar} "
f"{stats.successful_records:,}/{total_records:,} ({progress:.1f}%) - "
f"default:{cache_type}"
)
except Exception as e:
# Error - record and continue
stats.add_error(batch_idx, e, len(batch))
print(
f"Batch {batch_idx}/{stats.total_batches}: ✗ FAILED - "
f"{type(e).__name__}: {str(e)}"
)
# Final persist
print("\nPersisting data to disk...")
try:
await target_storage.index_done_callback()
print("✓ Data persisted successfully")
except Exception as e:
print(f"✗ Persist failed: {e}")
stats.add_error(0, e, 0) # batch 0 = persist error
return stats
def print_migration_report(self, stats: MigrationStats):
"""Print comprehensive migration report
@ -655,37 +1191,44 @@ class MigrationTool:
print("=" * 60)
async def run(self):
"""Run the migration tool"""
"""Run the migration tool with streaming approach"""
try:
# Initialize shared storage (REQUIRED for storage classes to work)
from lightrag.kg.shared_storage import initialize_share_data
initialize_share_data(workers=1)
# Print header
self.print_header()
self.print_storage_types()
# Setup source storage
# Setup source storage with streaming (only count, don't load all data)
(
self.source_storage,
source_storage_name,
self.source_workspace,
source_data,
) = await self.setup_storage("Source")
source_count,
) = await self.setup_storage("Source", use_streaming=True)
if not self.source_storage:
print("\n✗ Source storage setup failed")
# Check if user cancelled (setup_storage returns None for all fields)
if self.source_storage is None:
return
if not source_data:
if source_count == 0:
print("\n⚠ Source storage has no cache records to migrate")
# Cleanup
await self.source_storage.finalize()
return
# Setup target storage
# Setup target storage with streaming (only count, don't load all data)
# Exclude source storage type from target selection
(
self.target_storage,
target_storage_name,
self.target_workspace,
target_data,
) = await self.setup_storage("Target")
target_count,
) = await self.setup_storage(
"Target", use_streaming=True, exclude_storage_name=source_storage_name
)
if not self.target_storage:
print("\n✗ Target storage setup failed")
@ -698,16 +1241,17 @@ class MigrationTool:
print("Migration Confirmation")
print("=" * 50)
print(
f"Source: {source_storage_name} (workspace: {self.source_workspace if self.source_workspace else '(default)'}) - {len(source_data):,} records"
f"Source: {self.format_storage_name(source_storage_name)} (workspace: {self.format_workspace(self.source_workspace)}) - {source_count:,} records"
)
print(
f"Target: {target_storage_name} (workspace: {self.target_workspace if self.target_workspace else '(default)'}) - {len(target_data):,} records"
f"Target: {self.format_storage_name(target_storage_name)} (workspace: {self.format_workspace(self.target_workspace)}) - {target_count:,} records"
)
print(f"Batch Size: {self.batch_size:,} records/batch")
print("Memory Mode: Streaming (memory-optimized)")
if target_data:
if target_count > 0:
print(
f"\n Warning: Target storage already has {len(target_data):,} records"
f"\n Warning: Target storage already has {target_count:,} records"
)
print("Migration will overwrite records with the same keys")
@ -720,9 +1264,13 @@ class MigrationTool:
await self.target_storage.finalize()
return
# Perform migration with error tracking
stats = await self.migrate_caches(
source_data, self.target_storage, target_storage_name
# Perform streaming migration with error tracking
stats = await self.migrate_caches_streaming(
self.source_storage,
source_storage_name,
self.target_storage,
target_storage_name,
source_count,
)
# Print comprehensive migration report
@ -752,6 +1300,14 @@ class MigrationTool:
except Exception:
pass
# Finalize shared storage
try:
from lightrag.kg.shared_storage import finalize_share_data
finalize_share_data()
except Exception:
pass
async def main():
"""Main entry point"""