Add comprehensive test suites for prompt evaluation: - test_prompt_accuracy.py: 365 lines testing prompt extraction accuracy - test_prompt_quality_deep.py: 672 lines for deep quality analysis - Refactor prompt.py to consolidate optimized variants (removed prompt_optimized.py) - Apply ruff formatting and type hints across 30 files - Update pyrightconfig.json for static type checking - Modernize reproduce scripts and examples with improved type annotations - Sync uv.lock dependencies
1364 lines
49 KiB
Python
1364 lines
49 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
LLM Cache Migration Tool for LightRAG
|
|
|
|
This tool migrates LLM response cache (default:extract:* and default:summary:*)
|
|
between different KV storage implementations while preserving workspace isolation.
|
|
|
|
Usage:
|
|
python -m lightrag.tools.migrate_llm_cache
|
|
# or
|
|
python lightrag/tools/migrate_llm_cache.py
|
|
|
|
Supported KV Storage Types:
|
|
- JsonKVStorage
|
|
- RedisKVStorage
|
|
- PGKVStorage
|
|
- MongoKVStorage
|
|
"""
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import os
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
# Add project root to path for imports
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
from lightrag.kg import STORAGE_ENV_REQUIREMENTS
|
|
from lightrag.namespace import NameSpace
|
|
from lightrag.utils import setup_logger
|
|
|
|
# Load environment variables
|
|
# use the .env that is inside the current folder
|
|
# allows to use different .env file for each lightrag instance
|
|
# the OS environment variables take precedence over the .env file
|
|
load_dotenv(dotenv_path='.env', override=False)
|
|
|
|
# Setup logger
|
|
setup_logger('lightrag', level='INFO')
|
|
|
|
# Storage type configurations
|
|
STORAGE_TYPES = {
|
|
'1': 'JsonKVStorage',
|
|
'2': 'RedisKVStorage',
|
|
'3': 'PGKVStorage',
|
|
'4': 'MongoKVStorage',
|
|
}
|
|
|
|
# Workspace environment variable mapping
|
|
WORKSPACE_ENV_MAP = {
|
|
'PGKVStorage': 'POSTGRES_WORKSPACE',
|
|
'MongoKVStorage': 'MONGODB_WORKSPACE',
|
|
'RedisKVStorage': 'REDIS_WORKSPACE',
|
|
}
|
|
|
|
# Default batch size for migration
|
|
DEFAULT_BATCH_SIZE = 1000
|
|
|
|
|
|
# Default count batch size for efficient counting
|
|
DEFAULT_COUNT_BATCH_SIZE = 1000
|
|
|
|
# ANSI color codes for terminal output
|
|
BOLD_CYAN = '\033[1;36m'
|
|
RESET = '\033[0m'
|
|
|
|
|
|
@dataclass
|
|
class MigrationStats:
|
|
"""Migration statistics and error tracking"""
|
|
|
|
total_source_records: int = 0
|
|
total_batches: int = 0
|
|
successful_batches: int = 0
|
|
failed_batches: int = 0
|
|
successful_records: int = 0
|
|
failed_records: int = 0
|
|
errors: list[dict[str, Any]] = field(default_factory=list)
|
|
|
|
def add_error(self, batch_idx: int, error: Exception, batch_size: int):
|
|
"""Record batch error"""
|
|
self.errors.append(
|
|
{
|
|
'batch': batch_idx,
|
|
'error_type': type(error).__name__,
|
|
'error_msg': str(error),
|
|
'records_lost': batch_size,
|
|
'timestamp': time.time(),
|
|
}
|
|
)
|
|
self.failed_batches += 1
|
|
self.failed_records += batch_size
|
|
|
|
|
|
class MigrationTool:
|
|
"""LLM Cache Migration Tool"""
|
|
|
|
def __init__(self):
|
|
self.source_storage = None
|
|
self.target_storage = None
|
|
self.source_workspace = ''
|
|
self.target_workspace = ''
|
|
self.batch_size = DEFAULT_BATCH_SIZE
|
|
|
|
def get_workspace_for_storage(self, storage_name: str) -> str:
|
|
"""Get workspace for a specific storage type
|
|
|
|
Priority: Storage-specific env var > WORKSPACE env var > empty string
|
|
|
|
Args:
|
|
storage_name: Storage implementation name
|
|
|
|
Returns:
|
|
Workspace name
|
|
"""
|
|
# Check storage-specific workspace
|
|
if storage_name in WORKSPACE_ENV_MAP:
|
|
specific_workspace = os.getenv(WORKSPACE_ENV_MAP[storage_name])
|
|
if specific_workspace:
|
|
return specific_workspace
|
|
|
|
# Check generic WORKSPACE
|
|
workspace = os.getenv('WORKSPACE', '')
|
|
return workspace
|
|
|
|
def check_config_ini_for_storage(self, storage_name: str) -> bool:
|
|
"""Check if config.ini has configuration for the storage type
|
|
|
|
Args:
|
|
storage_name: Storage implementation name
|
|
|
|
Returns:
|
|
True if config.ini has the necessary configuration
|
|
"""
|
|
try:
|
|
import configparser
|
|
|
|
config = configparser.ConfigParser()
|
|
config.read('config.ini', encoding='utf-8')
|
|
|
|
if storage_name == 'RedisKVStorage':
|
|
return config.has_option('redis', 'uri')
|
|
elif storage_name == 'PGKVStorage':
|
|
return (
|
|
config.has_option('postgres', 'user')
|
|
and config.has_option('postgres', 'password')
|
|
and config.has_option('postgres', 'database')
|
|
)
|
|
elif storage_name == 'MongoKVStorage':
|
|
return config.has_option('mongodb', 'uri') and config.has_option('mongodb', 'database')
|
|
|
|
return False
|
|
except Exception:
|
|
return False
|
|
|
|
def check_env_vars(self, storage_name: str) -> bool:
|
|
"""Check environment variables, show warnings if missing but don't fail
|
|
|
|
Args:
|
|
storage_name: Storage implementation name
|
|
|
|
Returns:
|
|
Always returns True (warnings only, no hard failure)
|
|
"""
|
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
|
|
if not required_vars:
|
|
print('✓ No environment variables required')
|
|
return True
|
|
|
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
|
|
if missing_vars:
|
|
print(f'⚠️ Warning: Missing environment variables: {", ".join(missing_vars)}')
|
|
|
|
# Check if config.ini has configuration
|
|
has_config = self.check_config_ini_for_storage(storage_name)
|
|
if has_config:
|
|
print(' ✓ Found configuration in config.ini')
|
|
else:
|
|
print(f' Will attempt to use defaults for {storage_name}')
|
|
|
|
return True
|
|
|
|
print('✓ All required environment variables are set')
|
|
return True
|
|
|
|
def count_available_storage_types(self) -> int:
|
|
"""Count available storage types (with env vars, config.ini, or defaults)
|
|
|
|
Returns:
|
|
Number of available storage types
|
|
"""
|
|
available_count = 0
|
|
|
|
for storage_name in STORAGE_TYPES.values():
|
|
# Check if storage requires configuration
|
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
|
|
if not required_vars:
|
|
# JsonKVStorage, MongoKVStorage etc. - no config needed
|
|
available_count += 1
|
|
else:
|
|
# Check if has environment variables
|
|
has_env = all(var in os.environ for var in required_vars)
|
|
if has_env:
|
|
available_count += 1
|
|
else:
|
|
# Check if has config.ini configuration
|
|
has_config = self.check_config_ini_for_storage(storage_name)
|
|
if has_config:
|
|
available_count += 1
|
|
|
|
return available_count
|
|
|
|
def get_storage_class(self, storage_name: str):
|
|
"""Dynamically import and return storage class
|
|
|
|
Args:
|
|
storage_name: Storage implementation name
|
|
|
|
Returns:
|
|
Storage class
|
|
"""
|
|
if storage_name == 'JsonKVStorage':
|
|
from lightrag.kg.json_kv_impl import JsonKVStorage
|
|
|
|
return JsonKVStorage
|
|
elif storage_name == 'RedisKVStorage':
|
|
from lightrag.kg.redis_impl import RedisKVStorage
|
|
|
|
return RedisKVStorage
|
|
elif storage_name == 'PGKVStorage':
|
|
from lightrag.kg.postgres_impl import PGKVStorage
|
|
|
|
return PGKVStorage
|
|
elif storage_name == 'MongoKVStorage':
|
|
from lightrag.kg.mongo_impl import MongoKVStorage
|
|
|
|
return MongoKVStorage
|
|
else:
|
|
raise ValueError(f'Unsupported storage type: {storage_name}')
|
|
|
|
async def initialize_storage(self, storage_name: str, workspace: str):
|
|
"""Initialize storage instance with fallback to config.ini and defaults
|
|
|
|
Args:
|
|
storage_name: Storage implementation name
|
|
workspace: Workspace name
|
|
|
|
Returns:
|
|
Initialized storage instance
|
|
|
|
Raises:
|
|
Exception: If initialization fails
|
|
"""
|
|
storage_class = self.get_storage_class(storage_name)
|
|
|
|
# Create global config
|
|
global_config = {
|
|
'working_dir': os.getenv('WORKING_DIR', './rag_storage'),
|
|
'embedding_batch_num': 10,
|
|
}
|
|
|
|
# Initialize storage
|
|
storage = storage_class(
|
|
namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
|
|
workspace=workspace,
|
|
global_config=global_config,
|
|
embedding_func=None,
|
|
)
|
|
|
|
# Initialize the storage (may raise exception if connection fails)
|
|
await storage.initialize()
|
|
|
|
return storage
|
|
|
|
async def get_default_caches_json(self, storage) -> dict[str, Any]:
|
|
"""Get default caches from JsonKVStorage
|
|
|
|
Args:
|
|
storage: JsonKVStorage instance
|
|
|
|
Returns:
|
|
Dictionary of cache entries with default:extract:* or default:summary:* keys
|
|
"""
|
|
# Access _data directly - it's a dict from shared_storage
|
|
async with storage._storage_lock:
|
|
filtered = {}
|
|
for key, value in storage._data.items():
|
|
if key.startswith('default:extract:') or key.startswith('default:summary:'):
|
|
filtered[key] = value.copy()
|
|
return filtered
|
|
|
|
async def get_default_caches_redis(self, storage, batch_size: int = 1000) -> dict[str, Any]:
|
|
"""Get default caches from RedisKVStorage with pagination
|
|
|
|
Args:
|
|
storage: RedisKVStorage instance
|
|
batch_size: Number of keys to process per batch
|
|
|
|
Returns:
|
|
Dictionary of cache entries with default:extract:* or default:summary:* keys
|
|
"""
|
|
import json
|
|
|
|
cache_data = {}
|
|
|
|
# Use _get_redis_connection() context manager
|
|
async with storage._get_redis_connection() as redis:
|
|
for pattern in ['default:extract:*', 'default:summary:*']:
|
|
# Add namespace prefix to pattern
|
|
prefixed_pattern = f'{storage.final_namespace}:{pattern}'
|
|
cursor = 0
|
|
|
|
while True:
|
|
# SCAN already implements cursor-based pagination
|
|
cursor, keys = await redis.scan(cursor, match=prefixed_pattern, count=batch_size)
|
|
|
|
if keys:
|
|
# Process this batch using pipeline with error handling
|
|
try:
|
|
pipe = redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
for key, value in zip(keys, values, strict=False):
|
|
if value:
|
|
key_str = key.decode() if isinstance(key, bytes) else key
|
|
# Remove namespace prefix to get original key
|
|
original_key = key_str.replace(f'{storage.final_namespace}:', '', 1)
|
|
cache_data[original_key] = json.loads(value)
|
|
|
|
except Exception as e:
|
|
# Pipeline execution failed, fall back to individual gets
|
|
print(f'⚠️ Pipeline execution failed for batch, using individual gets: {e}')
|
|
for key in keys:
|
|
try:
|
|
value = await redis.get(key)
|
|
if value:
|
|
key_str = key.decode() if isinstance(key, bytes) else key
|
|
original_key = key_str.replace(f'{storage.final_namespace}:', '', 1)
|
|
cache_data[original_key] = json.loads(value)
|
|
except Exception as individual_error:
|
|
print(f'⚠️ Failed to get individual key {key}: {individual_error}')
|
|
continue
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
# Yield control periodically to avoid blocking
|
|
await asyncio.sleep(0)
|
|
|
|
return cache_data
|
|
|
|
async def get_default_caches_pg(self, storage, batch_size: int = 1000) -> dict[str, Any]:
|
|
"""Get default caches from PGKVStorage with pagination
|
|
|
|
Args:
|
|
storage: PGKVStorage instance
|
|
batch_size: Number of records to fetch per batch
|
|
|
|
Returns:
|
|
Dictionary of cache entries with default:extract:* or default:summary:* keys
|
|
"""
|
|
from lightrag.kg.postgres_impl import namespace_to_table_name
|
|
|
|
cache_data = {}
|
|
table_name = namespace_to_table_name(storage.namespace)
|
|
offset = 0
|
|
|
|
while True:
|
|
# Use LIMIT and OFFSET for pagination
|
|
query = f"""
|
|
SELECT id as key, original_prompt, return_value, chunk_id, cache_type, queryparam,
|
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
|
FROM {table_name}
|
|
WHERE workspace = $1
|
|
AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%')
|
|
ORDER BY id
|
|
LIMIT $2 OFFSET $3
|
|
"""
|
|
|
|
results = await storage.db.query(query, [storage.workspace, batch_size, offset], multirows=True)
|
|
|
|
if not results:
|
|
break
|
|
|
|
for row in results:
|
|
# Map PostgreSQL fields to cache format
|
|
cache_entry = {
|
|
'return': row.get('return_value', ''),
|
|
'cache_type': row.get('cache_type'),
|
|
'original_prompt': row.get('original_prompt', ''),
|
|
'chunk_id': row.get('chunk_id'),
|
|
'queryparam': row.get('queryparam'),
|
|
'create_time': row.get('create_time', 0),
|
|
'update_time': row.get('update_time', 0),
|
|
}
|
|
cache_data[row['key']] = cache_entry
|
|
|
|
# If we got fewer results than batch_size, we're done
|
|
if len(results) < batch_size:
|
|
break
|
|
|
|
offset += batch_size
|
|
|
|
# Yield control periodically
|
|
await asyncio.sleep(0)
|
|
|
|
return cache_data
|
|
|
|
async def get_default_caches_mongo(self, storage, batch_size: int = 1000) -> dict[str, Any]:
|
|
"""Get default caches from MongoKVStorage with cursor-based pagination
|
|
|
|
Args:
|
|
storage: MongoKVStorage instance
|
|
batch_size: Number of documents to process per batch
|
|
|
|
Returns:
|
|
Dictionary of cache entries with default:extract:* or default:summary:* keys
|
|
"""
|
|
cache_data = {}
|
|
|
|
# MongoDB query with regex - use _data not collection
|
|
query = {'_id': {'$regex': '^default:(extract|summary):'}}
|
|
|
|
# Use cursor without to_list() - process in batches
|
|
cursor = storage._data.find(query).batch_size(batch_size)
|
|
|
|
async for doc in cursor:
|
|
# Process each document as it comes
|
|
doc_copy = doc.copy()
|
|
key = doc_copy.pop('_id')
|
|
|
|
# Filter ALL MongoDB/database-specific fields
|
|
# Following .clinerules: "Always filter deprecated/incompatible fields during deserialization"
|
|
for field_name in ['namespace', 'workspace', '_id', 'content']:
|
|
doc_copy.pop(field_name, None)
|
|
|
|
cache_data[key] = doc_copy.copy()
|
|
|
|
# Periodically yield control (every batch_size documents)
|
|
if len(cache_data) % batch_size == 0:
|
|
await asyncio.sleep(0)
|
|
|
|
return cache_data
|
|
|
|
async def get_default_caches(self, storage, storage_name: str) -> dict[str, Any]:
|
|
"""Get default caches from any storage type
|
|
|
|
Args:
|
|
storage: Storage instance
|
|
storage_name: Storage type name
|
|
|
|
Returns:
|
|
Dictionary of cache entries
|
|
"""
|
|
if storage_name == 'JsonKVStorage':
|
|
return await self.get_default_caches_json(storage)
|
|
elif storage_name == 'RedisKVStorage':
|
|
return await self.get_default_caches_redis(storage)
|
|
elif storage_name == 'PGKVStorage':
|
|
return await self.get_default_caches_pg(storage)
|
|
elif storage_name == 'MongoKVStorage':
|
|
return await self.get_default_caches_mongo(storage)
|
|
else:
|
|
raise ValueError(f'Unsupported storage type: {storage_name}')
|
|
|
|
async def count_default_caches_json(self, storage) -> int:
|
|
"""Count default caches in JsonKVStorage - O(N) but very fast in-memory
|
|
|
|
Args:
|
|
storage: JsonKVStorage instance
|
|
|
|
Returns:
|
|
Total count of cache records
|
|
"""
|
|
async with storage._storage_lock:
|
|
return sum(
|
|
1 for key in storage._data if key.startswith('default:extract:') or key.startswith('default:summary:')
|
|
)
|
|
|
|
async def count_default_caches_redis(self, storage) -> int:
|
|
"""Count default caches in RedisKVStorage using SCAN with progress display
|
|
|
|
Args:
|
|
storage: RedisKVStorage instance
|
|
|
|
Returns:
|
|
Total count of cache records
|
|
"""
|
|
count = 0
|
|
print('Scanning Redis keys...', end='', flush=True)
|
|
|
|
async with storage._get_redis_connection() as redis:
|
|
for pattern in ['default:extract:*', 'default:summary:*']:
|
|
prefixed_pattern = f'{storage.final_namespace}:{pattern}'
|
|
cursor = 0
|
|
while True:
|
|
cursor, keys = await redis.scan(cursor, match=prefixed_pattern, count=DEFAULT_COUNT_BATCH_SIZE)
|
|
count += len(keys)
|
|
|
|
# Show progress
|
|
print(
|
|
f'\rScanning Redis keys... found {count:,} records',
|
|
end='',
|
|
flush=True,
|
|
)
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
print() # New line after progress
|
|
return count
|
|
|
|
async def count_default_caches_pg(self, storage) -> int:
|
|
"""Count default caches in PostgreSQL using COUNT(*) with progress indicator
|
|
|
|
Args:
|
|
storage: PGKVStorage instance
|
|
|
|
Returns:
|
|
Total count of cache records
|
|
"""
|
|
from lightrag.kg.postgres_impl import namespace_to_table_name
|
|
|
|
table_name = namespace_to_table_name(storage.namespace)
|
|
|
|
query = f"""
|
|
SELECT COUNT(*) as count
|
|
FROM {table_name}
|
|
WHERE workspace = $1
|
|
AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%')
|
|
"""
|
|
|
|
print('Counting PostgreSQL records...', end='', flush=True)
|
|
start_time = time.time()
|
|
|
|
result = await storage.db.query(query, [storage.workspace])
|
|
|
|
elapsed = time.time() - start_time
|
|
if elapsed > 1:
|
|
print(f' (took {elapsed:.1f}s)', end='')
|
|
print() # New line
|
|
|
|
return result['count'] if result else 0
|
|
|
|
async def count_default_caches_mongo(self, storage) -> int:
|
|
"""Count default caches in MongoDB using count_documents with progress indicator
|
|
|
|
Args:
|
|
storage: MongoKVStorage instance
|
|
|
|
Returns:
|
|
Total count of cache records
|
|
"""
|
|
query = {'_id': {'$regex': '^default:(extract|summary):'}}
|
|
|
|
print('Counting MongoDB documents...', end='', flush=True)
|
|
start_time = time.time()
|
|
|
|
count = await storage._data.count_documents(query)
|
|
|
|
elapsed = time.time() - start_time
|
|
if elapsed > 1:
|
|
print(f' (took {elapsed:.1f}s)', end='')
|
|
print() # New line
|
|
|
|
return count
|
|
|
|
async def count_default_caches(self, storage, storage_name: str) -> int:
|
|
"""Count default caches from any storage type efficiently
|
|
|
|
Args:
|
|
storage: Storage instance
|
|
storage_name: Storage type name
|
|
|
|
Returns:
|
|
Total count of cache records
|
|
"""
|
|
if storage_name == 'JsonKVStorage':
|
|
return await self.count_default_caches_json(storage)
|
|
elif storage_name == 'RedisKVStorage':
|
|
return await self.count_default_caches_redis(storage)
|
|
elif storage_name == 'PGKVStorage':
|
|
return await self.count_default_caches_pg(storage)
|
|
elif storage_name == 'MongoKVStorage':
|
|
return await self.count_default_caches_mongo(storage)
|
|
else:
|
|
raise ValueError(f'Unsupported storage type: {storage_name}')
|
|
|
|
async def stream_default_caches_json(self, storage, batch_size: int):
|
|
"""Stream default caches from JsonKVStorage - yields batches
|
|
|
|
Args:
|
|
storage: JsonKVStorage instance
|
|
batch_size: Size of each batch to yield
|
|
|
|
Yields:
|
|
Dictionary batches of cache entries
|
|
|
|
Note:
|
|
This method creates a snapshot of matching items while holding the lock,
|
|
then releases the lock before yielding batches. This prevents deadlock
|
|
when the target storage (also JsonKVStorage) tries to acquire the same
|
|
lock during upsert operations.
|
|
"""
|
|
# Create a snapshot of matching items while holding the lock
|
|
async with storage._storage_lock:
|
|
matching_items = [
|
|
(key, value)
|
|
for key, value in storage._data.items()
|
|
if key.startswith('default:extract:') or key.startswith('default:summary:')
|
|
]
|
|
|
|
# Now iterate over snapshot without holding lock
|
|
batch = {}
|
|
for key, value in matching_items:
|
|
batch[key] = value.copy()
|
|
if len(batch) >= batch_size:
|
|
yield batch
|
|
batch = {}
|
|
|
|
# Yield remaining items
|
|
if batch:
|
|
yield batch
|
|
|
|
async def stream_default_caches_redis(self, storage, batch_size: int):
|
|
"""Stream default caches from RedisKVStorage - yields batches
|
|
|
|
Args:
|
|
storage: RedisKVStorage instance
|
|
batch_size: Size of each batch to yield
|
|
|
|
Yields:
|
|
Dictionary batches of cache entries
|
|
"""
|
|
import json
|
|
|
|
async with storage._get_redis_connection() as redis:
|
|
for pattern in ['default:extract:*', 'default:summary:*']:
|
|
prefixed_pattern = f'{storage.final_namespace}:{pattern}'
|
|
cursor = 0
|
|
|
|
while True:
|
|
cursor, keys = await redis.scan(cursor, match=prefixed_pattern, count=batch_size)
|
|
|
|
if keys:
|
|
try:
|
|
pipe = redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
batch = {}
|
|
for key, value in zip(keys, values, strict=False):
|
|
if value:
|
|
key_str = key.decode() if isinstance(key, bytes) else key
|
|
original_key = key_str.replace(f'{storage.final_namespace}:', '', 1)
|
|
batch[original_key] = json.loads(value)
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
except Exception as e:
|
|
print(f'⚠️ Pipeline execution failed for batch: {e}')
|
|
# Fall back to individual gets
|
|
batch = {}
|
|
for key in keys:
|
|
try:
|
|
value = await redis.get(key)
|
|
if value:
|
|
key_str = key.decode() if isinstance(key, bytes) else key
|
|
original_key = key_str.replace(f'{storage.final_namespace}:', '', 1)
|
|
batch[original_key] = json.loads(value)
|
|
except Exception as individual_error:
|
|
print(f'⚠️ Failed to get individual key {key}: {individual_error}')
|
|
continue
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
async def stream_default_caches_pg(self, storage, batch_size: int):
|
|
"""Stream default caches from PostgreSQL - yields batches
|
|
|
|
Args:
|
|
storage: PGKVStorage instance
|
|
batch_size: Size of each batch to yield
|
|
|
|
Yields:
|
|
Dictionary batches of cache entries
|
|
"""
|
|
from lightrag.kg.postgres_impl import namespace_to_table_name
|
|
|
|
table_name = namespace_to_table_name(storage.namespace)
|
|
offset = 0
|
|
|
|
while True:
|
|
query = f"""
|
|
SELECT id as key, original_prompt, return_value, chunk_id, cache_type, queryparam,
|
|
EXTRACT(EPOCH FROM create_time)::BIGINT as create_time,
|
|
EXTRACT(EPOCH FROM update_time)::BIGINT as update_time
|
|
FROM {table_name}
|
|
WHERE workspace = $1
|
|
AND (id LIKE 'default:extract:%' OR id LIKE 'default:summary:%')
|
|
ORDER BY id
|
|
LIMIT $2 OFFSET $3
|
|
"""
|
|
|
|
results = await storage.db.query(query, [storage.workspace, batch_size, offset], multirows=True)
|
|
|
|
if not results:
|
|
break
|
|
|
|
batch = {}
|
|
for row in results:
|
|
cache_entry = {
|
|
'return': row.get('return_value', ''),
|
|
'cache_type': row.get('cache_type'),
|
|
'original_prompt': row.get('original_prompt', ''),
|
|
'chunk_id': row.get('chunk_id'),
|
|
'queryparam': row.get('queryparam'),
|
|
'create_time': row.get('create_time', 0),
|
|
'update_time': row.get('update_time', 0),
|
|
}
|
|
batch[row['key']] = cache_entry
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
if len(results) < batch_size:
|
|
break
|
|
|
|
offset += batch_size
|
|
await asyncio.sleep(0)
|
|
|
|
async def stream_default_caches_mongo(self, storage, batch_size: int):
|
|
"""Stream default caches from MongoDB - yields batches
|
|
|
|
Args:
|
|
storage: MongoKVStorage instance
|
|
batch_size: Size of each batch to yield
|
|
|
|
Yields:
|
|
Dictionary batches of cache entries
|
|
"""
|
|
query = {'_id': {'$regex': '^default:(extract|summary):'}}
|
|
cursor = storage._data.find(query).batch_size(batch_size)
|
|
|
|
batch = {}
|
|
async for doc in cursor:
|
|
doc_copy = doc.copy()
|
|
key = doc_copy.pop('_id')
|
|
|
|
# Filter MongoDB/database-specific fields
|
|
for field_name in ['namespace', 'workspace', '_id', 'content']:
|
|
doc_copy.pop(field_name, None)
|
|
|
|
batch[key] = doc_copy.copy()
|
|
|
|
if len(batch) >= batch_size:
|
|
yield batch
|
|
batch = {}
|
|
|
|
# Yield remaining items
|
|
if batch:
|
|
yield batch
|
|
|
|
async def stream_default_caches(self, storage, storage_name: str, batch_size: int | None = None):
|
|
"""Stream default caches from any storage type - unified interface
|
|
|
|
Args:
|
|
storage: Storage instance
|
|
storage_name: Storage type name
|
|
batch_size: Size of each batch to yield (defaults to self.batch_size)
|
|
|
|
Yields:
|
|
Dictionary batches of cache entries
|
|
"""
|
|
if batch_size is None:
|
|
batch_size = self.batch_size
|
|
|
|
if storage_name == 'JsonKVStorage':
|
|
async for batch in self.stream_default_caches_json(storage, batch_size):
|
|
yield batch
|
|
elif storage_name == 'RedisKVStorage':
|
|
async for batch in self.stream_default_caches_redis(storage, batch_size):
|
|
yield batch
|
|
elif storage_name == 'PGKVStorage':
|
|
async for batch in self.stream_default_caches_pg(storage, batch_size):
|
|
yield batch
|
|
elif storage_name == 'MongoKVStorage':
|
|
async for batch in self.stream_default_caches_mongo(storage, batch_size):
|
|
yield batch
|
|
else:
|
|
raise ValueError(f'Unsupported storage type: {storage_name}')
|
|
|
|
async def count_cache_types(self, cache_data: dict[str, Any]) -> dict[str, int]:
|
|
"""Count cache entries by type
|
|
|
|
Args:
|
|
cache_data: Dictionary of cache entries
|
|
|
|
Returns:
|
|
Dictionary with counts for each cache type
|
|
"""
|
|
counts = {
|
|
'extract': 0,
|
|
'summary': 0,
|
|
}
|
|
|
|
for key in cache_data:
|
|
if key.startswith('default:extract:'):
|
|
counts['extract'] += 1
|
|
elif key.startswith('default:summary:'):
|
|
counts['summary'] += 1
|
|
|
|
return counts
|
|
|
|
def print_header(self):
|
|
"""Print tool header"""
|
|
print('\n' + '=' * 50)
|
|
print('LLM Cache Migration Tool - LightRAG')
|
|
print('=' * 50)
|
|
|
|
def print_storage_types(self):
|
|
"""Print available storage types"""
|
|
print('\nSupported KV Storage Types:')
|
|
for key, value in STORAGE_TYPES.items():
|
|
print(f'[{key}] {value}')
|
|
|
|
def format_workspace(self, workspace: str) -> str:
|
|
"""Format workspace name with highlighting
|
|
|
|
Args:
|
|
workspace: Workspace name (may be empty)
|
|
|
|
Returns:
|
|
Formatted workspace string with ANSI color codes
|
|
"""
|
|
if workspace:
|
|
return f'{BOLD_CYAN}{workspace}{RESET}'
|
|
else:
|
|
return f'{BOLD_CYAN}(default){RESET}'
|
|
|
|
def format_storage_name(self, storage_name: str) -> str:
|
|
"""Format storage type name with highlighting
|
|
|
|
Args:
|
|
storage_name: Storage type name
|
|
|
|
Returns:
|
|
Formatted storage name string with ANSI color codes
|
|
"""
|
|
return f'{BOLD_CYAN}{storage_name}{RESET}'
|
|
|
|
async def setup_storage(
|
|
self,
|
|
storage_type: str,
|
|
use_streaming: bool = False,
|
|
exclude_storage_name: str | None = None,
|
|
) -> tuple:
|
|
"""Setup and initialize storage with config.ini fallback support
|
|
|
|
Args:
|
|
storage_type: Type label (source/target)
|
|
use_streaming: If True, only count records without loading. If False, load all data (legacy mode)
|
|
exclude_storage_name: Storage type to exclude from selection (e.g., to prevent selecting same as source)
|
|
|
|
Returns:
|
|
Tuple of (storage_instance, storage_name, workspace, total_count)
|
|
Returns (None, None, None, 0) if user chooses to exit
|
|
"""
|
|
print(f'\n=== {storage_type} Storage Setup ===')
|
|
|
|
# Filter and remap available storage types if exclusion is specified
|
|
if exclude_storage_name:
|
|
# Get available storage types (excluding source)
|
|
available_list = [(k, v) for k, v in STORAGE_TYPES.items() if v != exclude_storage_name]
|
|
|
|
# Remap to sequential numbering (1, 2, 3...)
|
|
remapped_types = {str(i + 1): name for i, (_, name) in enumerate(available_list)}
|
|
|
|
# Print available types with new sequential numbers
|
|
print(f'\nAvailable Storage Types for Target (source: {exclude_storage_name} excluded):')
|
|
for key, value in remapped_types.items():
|
|
print(f'[{key}] {value}')
|
|
|
|
available_types = remapped_types
|
|
else:
|
|
# For source storage, use original numbering
|
|
available_types = STORAGE_TYPES.copy()
|
|
self.print_storage_types()
|
|
|
|
# Generate dynamic prompt based on number of options
|
|
num_options = len(available_types)
|
|
prompt_range = '1' if num_options == 1 else f'1-{num_options}'
|
|
|
|
# Custom input handling with exit support
|
|
while True:
|
|
choice = input(f'\nSelect {storage_type} storage type ({prompt_range}) (Press Enter to exit): ').strip()
|
|
|
|
# Check for exit
|
|
if choice == '' or choice == '0':
|
|
print('\n✓ Migration cancelled by user')
|
|
return None, None, None, 0
|
|
|
|
# Check if choice is valid
|
|
if choice in available_types:
|
|
break
|
|
|
|
print(f'✗ Invalid choice. Please enter one of: {", ".join(available_types.keys())}')
|
|
|
|
storage_name = available_types[choice]
|
|
|
|
# Check configuration (warnings only, doesn't block)
|
|
print('\nChecking configuration...')
|
|
self.check_env_vars(storage_name)
|
|
|
|
# Get workspace
|
|
workspace = self.get_workspace_for_storage(storage_name)
|
|
|
|
# Initialize storage (real validation point)
|
|
print(f'\nInitializing {storage_type} storage...')
|
|
try:
|
|
storage = await self.initialize_storage(storage_name, workspace)
|
|
print(f'- Storage Type: {storage_name}')
|
|
print(f'- Workspace: {workspace if workspace else "(default)"}')
|
|
print('- Connection Status: ✓ Success')
|
|
|
|
# Show configuration source for transparency
|
|
if storage_name == 'RedisKVStorage':
|
|
config_source = 'environment variable' if 'REDIS_URI' in os.environ else 'config.ini or default'
|
|
print(f'- Configuration Source: {config_source}')
|
|
elif storage_name == 'PGKVStorage' or storage_name == 'MongoKVStorage':
|
|
config_source = (
|
|
'environment variables'
|
|
if all(var in os.environ for var in STORAGE_ENV_REQUIREMENTS[storage_name])
|
|
else 'config.ini or defaults'
|
|
)
|
|
print(f'- Configuration Source: {config_source}')
|
|
|
|
except Exception as e:
|
|
print(f'✗ Initialization failed: {e}')
|
|
print(f'\nFor {storage_name}, you can configure using:')
|
|
print(' 1. Environment variables (highest priority)')
|
|
|
|
# Show specific environment variable requirements
|
|
if storage_name in STORAGE_ENV_REQUIREMENTS:
|
|
for var in STORAGE_ENV_REQUIREMENTS[storage_name]:
|
|
print(f' - {var}')
|
|
|
|
print(' 2. config.ini file (medium priority)')
|
|
if storage_name == 'RedisKVStorage':
|
|
print(' [redis]')
|
|
print(' uri = redis://localhost:6379')
|
|
elif storage_name == 'PGKVStorage':
|
|
print(' [postgres]')
|
|
print(' host = localhost')
|
|
print(' port = 5432')
|
|
print(' user = postgres')
|
|
print(' password = yourpassword')
|
|
print(' database = lightrag')
|
|
elif storage_name == 'MongoKVStorage':
|
|
print(' [mongodb]')
|
|
print(' uri = mongodb://root:root@localhost:27017/')
|
|
print(' database = LightRAG')
|
|
|
|
return None, None, None, 0
|
|
|
|
# Count cache records efficiently
|
|
print(f'\n{"Counting" if use_streaming else "Loading"} cache records...')
|
|
try:
|
|
if use_streaming:
|
|
# Use efficient counting without loading data
|
|
total_count = await self.count_default_caches(storage, storage_name)
|
|
print(f'- Total: {total_count:,} records')
|
|
else:
|
|
# Legacy mode: load all data
|
|
cache_data = await self.get_default_caches(storage, storage_name)
|
|
counts = await self.count_cache_types(cache_data)
|
|
total_count = len(cache_data)
|
|
|
|
print(f'- default:extract: {counts["extract"]:,} records')
|
|
print(f'- default:summary: {counts["summary"]:,} records')
|
|
print(f'- Total: {total_count:,} records')
|
|
except Exception as e:
|
|
print(f'✗ {"Counting" if use_streaming else "Loading"} failed: {e}')
|
|
return None, None, None, 0
|
|
|
|
return storage, storage_name, workspace, total_count
|
|
|
|
async def migrate_caches(
|
|
self, source_data: dict[str, Any], target_storage, target_storage_name: str
|
|
) -> MigrationStats:
|
|
"""Migrate caches in batches with error tracking (Legacy mode - loads all data)
|
|
|
|
Args:
|
|
source_data: Source cache data
|
|
target_storage: Target storage instance
|
|
target_storage_name: Target storage type name
|
|
|
|
Returns:
|
|
MigrationStats object with migration results and errors
|
|
"""
|
|
stats = MigrationStats()
|
|
stats.total_source_records = len(source_data)
|
|
|
|
if stats.total_source_records == 0:
|
|
print('\nNo records to migrate')
|
|
return stats
|
|
|
|
# Convert to list for batching
|
|
items = list(source_data.items())
|
|
stats.total_batches = (stats.total_source_records + self.batch_size - 1) // self.batch_size
|
|
|
|
print('\n=== Starting Migration ===')
|
|
|
|
for batch_idx in range(stats.total_batches):
|
|
start_idx = batch_idx * self.batch_size
|
|
end_idx = min((batch_idx + 1) * self.batch_size, stats.total_source_records)
|
|
batch_items = items[start_idx:end_idx]
|
|
batch_data = dict(batch_items)
|
|
|
|
# Determine current cache type for display
|
|
current_key = batch_items[0][0]
|
|
cache_type = 'extract' if 'extract' in current_key else 'summary'
|
|
|
|
try:
|
|
# Attempt to write batch
|
|
await target_storage.upsert(batch_data)
|
|
|
|
# Success - update stats
|
|
stats.successful_batches += 1
|
|
stats.successful_records += len(batch_data)
|
|
|
|
# Calculate progress
|
|
progress = (end_idx / stats.total_source_records) * 100
|
|
bar_length = 20
|
|
filled_length = int(bar_length * end_idx // stats.total_source_records)
|
|
bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
|
|
|
print(
|
|
f'Batch {batch_idx + 1}/{stats.total_batches}: {bar} '
|
|
f'{end_idx:,}/{stats.total_source_records:,} ({progress:.0f}%) - '
|
|
f'default:{cache_type} ✓'
|
|
)
|
|
|
|
except Exception as e:
|
|
# Error - record and continue
|
|
stats.add_error(batch_idx + 1, e, len(batch_data))
|
|
|
|
print(f'Batch {batch_idx + 1}/{stats.total_batches}: ✗ FAILED - {type(e).__name__}: {e!s}')
|
|
|
|
# Final persist
|
|
print('\nPersisting data to disk...')
|
|
try:
|
|
await target_storage.index_done_callback()
|
|
print('✓ Data persisted successfully')
|
|
except Exception as e:
|
|
print(f'✗ Persist failed: {e}')
|
|
stats.add_error(0, e, 0) # batch 0 = persist error
|
|
|
|
return stats
|
|
|
|
async def migrate_caches_streaming(
|
|
self,
|
|
source_storage,
|
|
source_storage_name: str,
|
|
target_storage,
|
|
target_storage_name: str,
|
|
total_records: int,
|
|
) -> MigrationStats:
|
|
"""Migrate caches using streaming approach - minimal memory footprint
|
|
|
|
Args:
|
|
source_storage: Source storage instance
|
|
source_storage_name: Source storage type name
|
|
target_storage: Target storage instance
|
|
target_storage_name: Target storage type name
|
|
total_records: Total number of records to migrate
|
|
|
|
Returns:
|
|
MigrationStats object with migration results and errors
|
|
"""
|
|
stats = MigrationStats()
|
|
stats.total_source_records = total_records
|
|
|
|
if stats.total_source_records == 0:
|
|
print('\nNo records to migrate')
|
|
return stats
|
|
|
|
# Calculate total batches
|
|
stats.total_batches = (total_records + self.batch_size - 1) // self.batch_size
|
|
|
|
print('\n=== Starting Streaming Migration ===')
|
|
print(f'💡 Memory-optimized mode: Processing {self.batch_size:,} records at a time\n')
|
|
|
|
batch_idx = 0
|
|
|
|
# Stream batches from source and write to target immediately
|
|
async for batch in self.stream_default_caches(source_storage, source_storage_name):
|
|
batch_idx += 1
|
|
|
|
# Determine current cache type for display
|
|
if batch:
|
|
first_key = next(iter(batch.keys()))
|
|
cache_type = 'extract' if 'extract' in first_key else 'summary'
|
|
else:
|
|
cache_type = 'unknown'
|
|
|
|
try:
|
|
# Write batch to target storage
|
|
await target_storage.upsert(batch)
|
|
|
|
# Success - update stats
|
|
stats.successful_batches += 1
|
|
stats.successful_records += len(batch)
|
|
|
|
# Calculate progress with known total
|
|
progress = (stats.successful_records / total_records) * 100
|
|
bar_length = 20
|
|
filled_length = int(bar_length * stats.successful_records // total_records)
|
|
bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
|
|
|
print(
|
|
f'Batch {batch_idx}/{stats.total_batches}: {bar} '
|
|
f'{stats.successful_records:,}/{total_records:,} ({progress:.1f}%) - '
|
|
f'default:{cache_type} ✓'
|
|
)
|
|
|
|
except Exception as e:
|
|
# Error - record and continue
|
|
stats.add_error(batch_idx, e, len(batch))
|
|
|
|
print(f'Batch {batch_idx}/{stats.total_batches}: ✗ FAILED - {type(e).__name__}: {e!s}')
|
|
|
|
# Final persist
|
|
print('\nPersisting data to disk...')
|
|
try:
|
|
await target_storage.index_done_callback()
|
|
print('✓ Data persisted successfully')
|
|
except Exception as e:
|
|
print(f'✗ Persist failed: {e}')
|
|
stats.add_error(0, e, 0) # batch 0 = persist error
|
|
|
|
return stats
|
|
|
|
def print_migration_report(self, stats: MigrationStats):
|
|
"""Print comprehensive migration report
|
|
|
|
Args:
|
|
stats: MigrationStats object with migration results
|
|
"""
|
|
print('\n' + '=' * 60)
|
|
print('Migration Complete - Final Report')
|
|
print('=' * 60)
|
|
|
|
# Overall statistics
|
|
print('\n📊 Statistics:')
|
|
print(f' Total source records: {stats.total_source_records:,}')
|
|
print(f' Total batches: {stats.total_batches:,}')
|
|
print(f' Successful batches: {stats.successful_batches:,}')
|
|
print(f' Failed batches: {stats.failed_batches:,}')
|
|
print(f' Successfully migrated: {stats.successful_records:,}')
|
|
print(f' Failed to migrate: {stats.failed_records:,}')
|
|
|
|
# Success rate
|
|
success_rate = (
|
|
(stats.successful_records / stats.total_source_records * 100) if stats.total_source_records > 0 else 0
|
|
)
|
|
print(f' Success rate: {success_rate:.2f}%')
|
|
|
|
# Error details
|
|
if stats.errors:
|
|
print(f'\n⚠️ Errors encountered: {len(stats.errors)}')
|
|
print('\nError Details:')
|
|
print('-' * 60)
|
|
|
|
# Group errors by type
|
|
error_types = {}
|
|
for error in stats.errors:
|
|
err_type = error['error_type']
|
|
error_types[err_type] = error_types.get(err_type, 0) + 1
|
|
|
|
print('\nError Summary:')
|
|
for err_type, count in sorted(error_types.items(), key=lambda x: -x[1]):
|
|
print(f' - {err_type}: {count} occurrence(s)')
|
|
|
|
print('\nFirst 5 errors:')
|
|
for i, error in enumerate(stats.errors[:5], 1):
|
|
print(f'\n {i}. Batch {error["batch"]}')
|
|
print(f' Type: {error["error_type"]}')
|
|
print(f' Message: {error["error_msg"]}')
|
|
print(f' Records lost: {error["records_lost"]:,}')
|
|
|
|
if len(stats.errors) > 5:
|
|
print(f'\n ... and {len(stats.errors) - 5} more errors')
|
|
|
|
print('\n' + '=' * 60)
|
|
print('⚠️ WARNING: Migration completed with errors!')
|
|
print(' Please review the error details above.')
|
|
print('=' * 60)
|
|
else:
|
|
print('\n' + '=' * 60)
|
|
print('✓ SUCCESS: All records migrated successfully!')
|
|
print('=' * 60)
|
|
|
|
async def run(self):
|
|
"""Run the migration tool with streaming approach and early validation"""
|
|
try:
|
|
# Initialize shared storage (REQUIRED for storage classes to work)
|
|
from lightrag.kg.shared_storage import initialize_share_data
|
|
|
|
initialize_share_data(workers=1)
|
|
|
|
# Print header
|
|
self.print_header()
|
|
|
|
# Setup source storage with streaming (only count, don't load all data)
|
|
(
|
|
self.source_storage,
|
|
source_storage_name,
|
|
self.source_workspace,
|
|
source_count,
|
|
) = await self.setup_storage('Source', use_streaming=True)
|
|
|
|
# Check if user cancelled (setup_storage returns None for all fields)
|
|
if self.source_storage is None:
|
|
return
|
|
|
|
# Check if there are at least 2 storage types available
|
|
available_count = self.count_available_storage_types()
|
|
if available_count <= 1:
|
|
print('\n' + '=' * 60)
|
|
print('⚠️ Warning: Migration Not Possible')
|
|
print('=' * 60)
|
|
print(f'Only {available_count} storage type(s) available.')
|
|
print('Migration requires at least 2 different storage types.')
|
|
print('\nTo enable migration, configure additional storage:')
|
|
print(' 1. Set environment variables, OR')
|
|
print(' 2. Update config.ini file')
|
|
print('\nSupported storage types:')
|
|
for name in STORAGE_TYPES.values():
|
|
if name != source_storage_name:
|
|
print(f' - {name}')
|
|
if name in STORAGE_ENV_REQUIREMENTS:
|
|
for var in STORAGE_ENV_REQUIREMENTS[name]:
|
|
print(f' Required: {var}')
|
|
print('=' * 60)
|
|
|
|
# Cleanup
|
|
await self.source_storage.finalize()
|
|
return
|
|
|
|
if source_count == 0:
|
|
print('\n⚠️ Source storage has no cache records to migrate')
|
|
# Cleanup
|
|
await self.source_storage.finalize()
|
|
return
|
|
|
|
# Setup target storage with streaming (only count, don't load all data)
|
|
# Exclude source storage type from target selection
|
|
(
|
|
self.target_storage,
|
|
target_storage_name,
|
|
self.target_workspace,
|
|
target_count,
|
|
) = await self.setup_storage('Target', use_streaming=True, exclude_storage_name=source_storage_name)
|
|
|
|
if not self.target_storage:
|
|
print('\n✗ Target storage setup failed')
|
|
# Cleanup source
|
|
await self.source_storage.finalize()
|
|
return
|
|
|
|
# Show migration summary
|
|
print('\n' + '=' * 50)
|
|
print('Migration Confirmation')
|
|
print('=' * 50)
|
|
print(
|
|
f'Source: {self.format_storage_name(source_storage_name)} (workspace: {self.format_workspace(self.source_workspace)}) - {source_count:,} records'
|
|
)
|
|
print(
|
|
f'Target: {self.format_storage_name(target_storage_name)} (workspace: {self.format_workspace(self.target_workspace)}) - {target_count:,} records'
|
|
)
|
|
print(f'Batch Size: {self.batch_size:,} records/batch')
|
|
print('Memory Mode: Streaming (memory-optimized)')
|
|
|
|
if target_count > 0:
|
|
print(f'\n⚠️ Warning: Target storage already has {target_count:,} records')
|
|
print('Migration will overwrite records with the same keys')
|
|
|
|
# Confirm migration
|
|
confirm = input('\nContinue? (y/n): ').strip().lower()
|
|
if confirm != 'y':
|
|
print('\n✗ Migration cancelled')
|
|
# Cleanup
|
|
await self.source_storage.finalize()
|
|
await self.target_storage.finalize()
|
|
return
|
|
|
|
# Perform streaming migration with error tracking
|
|
stats = await self.migrate_caches_streaming(
|
|
self.source_storage,
|
|
source_storage_name,
|
|
self.target_storage,
|
|
target_storage_name,
|
|
source_count,
|
|
)
|
|
|
|
# Print comprehensive migration report
|
|
self.print_migration_report(stats)
|
|
|
|
# Cleanup
|
|
await self.source_storage.finalize()
|
|
await self.target_storage.finalize()
|
|
|
|
except KeyboardInterrupt:
|
|
print('\n\n✗ Migration interrupted by user')
|
|
except Exception as e:
|
|
print(f'\n✗ Migration failed: {e}')
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
finally:
|
|
# Ensure cleanup
|
|
if self.source_storage:
|
|
with contextlib.suppress(Exception):
|
|
await self.source_storage.finalize()
|
|
if self.target_storage:
|
|
with contextlib.suppress(Exception):
|
|
await self.target_storage.finalize()
|
|
|
|
# Finalize shared storage
|
|
try:
|
|
from lightrag.kg.shared_storage import finalize_share_data
|
|
|
|
finalize_share_data()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def main():
|
|
"""Main entry point"""
|
|
tool = MigrationTool()
|
|
await tool.run()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|