Fix linting

This commit is contained in:
yangdx 2025-11-08 18:16:03 +08:00
parent 55274dde59
commit 0f2c0de8df
2 changed files with 194 additions and 152 deletions

View file

@ -58,6 +58,7 @@ DEFAULT_BATCH_SIZE = 1000
@dataclass @dataclass
class MigrationStats: class MigrationStats:
"""Migration statistics and error tracking""" """Migration statistics and error tracking"""
total_source_records: int = 0 total_source_records: int = 0
total_batches: int = 0 total_batches: int = 0
successful_batches: int = 0 successful_batches: int = 0
@ -68,13 +69,15 @@ class MigrationStats:
def add_error(self, batch_idx: int, error: Exception, batch_size: int): def add_error(self, batch_idx: int, error: Exception, batch_size: int):
"""Record batch error""" """Record batch error"""
self.errors.append({ self.errors.append(
'batch': batch_idx, {
'error_type': type(error).__name__, "batch": batch_idx,
'error_msg': str(error), "error_type": type(error).__name__,
'records_lost': batch_size, "error_msg": str(error),
'timestamp': time.time() "records_lost": batch_size,
}) "timestamp": time.time(),
}
)
self.failed_batches += 1 self.failed_batches += 1
self.failed_records += batch_size self.failed_records += batch_size
@ -123,7 +126,9 @@ class MigrationTool:
missing_vars = [var for var in required_vars if var not in os.environ] missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars: if missing_vars:
print(f"✗ Missing required environment variables: {', '.join(missing_vars)}") print(
f"✗ Missing required environment variables: {', '.join(missing_vars)}"
)
return False return False
print("✓ All required environment variables are set") print("✓ All required environment variables are set")
@ -140,15 +145,19 @@ class MigrationTool:
""" """
if storage_name == "JsonKVStorage": if storage_name == "JsonKVStorage":
from lightrag.kg.json_kv_impl import JsonKVStorage from lightrag.kg.json_kv_impl import JsonKVStorage
return JsonKVStorage return JsonKVStorage
elif storage_name == "RedisKVStorage": elif storage_name == "RedisKVStorage":
from lightrag.kg.redis_impl import RedisKVStorage from lightrag.kg.redis_impl import RedisKVStorage
return RedisKVStorage return RedisKVStorage
elif storage_name == "PGKVStorage": elif storage_name == "PGKVStorage":
from lightrag.kg.postgres_impl import PGKVStorage from lightrag.kg.postgres_impl import PGKVStorage
return PGKVStorage return PGKVStorage
elif storage_name == "MongoKVStorage": elif storage_name == "MongoKVStorage":
from lightrag.kg.mongo_impl import MongoKVStorage from lightrag.kg.mongo_impl import MongoKVStorage
return MongoKVStorage return MongoKVStorage
else: else:
raise ValueError(f"Unsupported storage type: {storage_name}") raise ValueError(f"Unsupported storage type: {storage_name}")
@ -197,11 +206,15 @@ class MigrationTool:
async with storage._storage_lock: async with storage._storage_lock:
filtered = {} filtered = {}
for key, value in storage._data.items(): for key, value in storage._data.items():
if key.startswith("default:extract:") or key.startswith("default:summary:"): if key.startswith("default:extract:") or key.startswith(
"default:summary:"
):
filtered[key] = value filtered[key] = value
return filtered return filtered
async def get_default_caches_redis(self, storage, batch_size: int = 1000) -> Dict[str, Any]: async def get_default_caches_redis(
self, storage, batch_size: int = 1000
) -> Dict[str, Any]:
"""Get default caches from RedisKVStorage with pagination """Get default caches from RedisKVStorage with pagination
Args: Args:
@ -225,9 +238,7 @@ class MigrationTool:
while True: while True:
# SCAN already implements cursor-based pagination # SCAN already implements cursor-based pagination
cursor, keys = await redis.scan( cursor, keys = await redis.scan(
cursor, cursor, match=prefixed_pattern, count=batch_size
match=prefixed_pattern,
count=batch_size
) )
if keys: if keys:
@ -240,23 +251,37 @@ class MigrationTool:
for key, value in zip(keys, values): for key, value in zip(keys, values):
if value: if value:
key_str = key.decode() if isinstance(key, bytes) else key key_str = (
key.decode() if isinstance(key, bytes) else key
)
# Remove namespace prefix to get original key # Remove namespace prefix to get original key
original_key = key_str.replace(f"{storage.final_namespace}:", "", 1) original_key = key_str.replace(
f"{storage.final_namespace}:", "", 1
)
cache_data[original_key] = json.loads(value) cache_data[original_key] = json.loads(value)
except Exception as e: except Exception as e:
# Pipeline execution failed, fall back to individual gets # Pipeline execution failed, fall back to individual gets
print(f"⚠️ Pipeline execution failed for batch, using individual gets: {e}") print(
f"⚠️ Pipeline execution failed for batch, using individual gets: {e}"
)
for key in keys: for key in keys:
try: try:
value = await redis.get(key) value = await redis.get(key)
if value: if value:
key_str = key.decode() if isinstance(key, bytes) else key key_str = (
original_key = key_str.replace(f"{storage.final_namespace}:", "", 1) key.decode()
if isinstance(key, bytes)
else key
)
original_key = key_str.replace(
f"{storage.final_namespace}:", "", 1
)
cache_data[original_key] = json.loads(value) cache_data[original_key] = json.loads(value)
except Exception as individual_error: except Exception as individual_error:
print(f"⚠️ Failed to get individual key {key}: {individual_error}") print(
f"⚠️ Failed to get individual key {key}: {individual_error}"
)
continue continue
if cursor == 0: if cursor == 0:
@ -267,7 +292,9 @@ class MigrationTool:
return cache_data return cache_data
async def get_default_caches_pg(self, storage, batch_size: int = 1000) -> Dict[str, Any]: async def get_default_caches_pg(
self, storage, batch_size: int = 1000
) -> Dict[str, Any]:
"""Get default caches from PGKVStorage with pagination """Get default caches from PGKVStorage with pagination
Args: Args:
@ -297,9 +324,7 @@ class MigrationTool:
""" """
results = await storage.db.query( results = await storage.db.query(
query, query, [storage.workspace, batch_size, offset], multirows=True
[storage.workspace, batch_size, offset],
multirows=True
) )
if not results: if not results:
@ -329,7 +354,9 @@ class MigrationTool:
return cache_data return cache_data
async def get_default_caches_mongo(self, storage, batch_size: int = 1000) -> Dict[str, Any]: async def get_default_caches_mongo(
self, storage, batch_size: int = 1000
) -> Dict[str, Any]:
"""Get default caches from MongoKVStorage with cursor-based pagination """Get default caches from MongoKVStorage with cursor-based pagination
Args: Args:
@ -449,8 +476,7 @@ class MigrationTool:
# Get storage type choice # Get storage type choice
choice = self.get_user_choice( choice = self.get_user_choice(
f"Select {storage_type} storage type (1-4)", f"Select {storage_type} storage type (1-4)", list(STORAGE_TYPES.keys())
list(STORAGE_TYPES.keys())
) )
storage_name = STORAGE_TYPES[choice] storage_name = STORAGE_TYPES[choice]
@ -489,10 +515,7 @@ class MigrationTool:
return storage, storage_name, workspace, cache_data return storage, storage_name, workspace, cache_data
async def migrate_caches( async def migrate_caches(
self, self, source_data: Dict[str, Any], target_storage, target_storage_name: str
source_data: Dict[str, Any],
target_storage,
target_storage_name: str
) -> MigrationStats: ) -> MigrationStats:
"""Migrate caches in batches with error tracking """Migrate caches in batches with error tracking
@ -513,7 +536,9 @@ class MigrationTool:
# Convert to list for batching # Convert to list for batching
items = list(source_data.items()) items = list(source_data.items())
stats.total_batches = (stats.total_source_records + self.batch_size - 1) // self.batch_size stats.total_batches = (
stats.total_source_records + self.batch_size - 1
) // self.batch_size
print("\n=== Starting Migration ===") print("\n=== Starting Migration ===")
@ -541,16 +566,20 @@ class MigrationTool:
filled_length = int(bar_length * end_idx // stats.total_source_records) filled_length = int(bar_length * end_idx // stats.total_source_records)
bar = "" * filled_length + "" * (bar_length - filled_length) bar = "" * filled_length + "" * (bar_length - filled_length)
print(f"Batch {batch_idx + 1}/{stats.total_batches}: {bar} " print(
f"{end_idx:,}/{stats.total_source_records:,} ({progress:.0f}%) - " f"Batch {batch_idx + 1}/{stats.total_batches}: {bar} "
f"default:{cache_type}") f"{end_idx:,}/{stats.total_source_records:,} ({progress:.0f}%) - "
f"default:{cache_type}"
)
except Exception as e: except Exception as e:
# Error - record and continue # Error - record and continue
stats.add_error(batch_idx + 1, e, len(batch_data)) stats.add_error(batch_idx + 1, e, len(batch_data))
print(f"Batch {batch_idx + 1}/{stats.total_batches}: ✗ FAILED - " print(
f"{type(e).__name__}: {str(e)}") f"Batch {batch_idx + 1}/{stats.total_batches}: ✗ FAILED - "
f"{type(e).__name__}: {str(e)}"
)
# Final persist # Final persist
print("\nPersisting data to disk...") print("\nPersisting data to disk...")
@ -583,7 +612,11 @@ class MigrationTool:
print(f" Failed to migrate: {stats.failed_records:,}") print(f" Failed to migrate: {stats.failed_records:,}")
# Success rate # Success rate
success_rate = (stats.successful_records / stats.total_source_records * 100) if stats.total_source_records > 0 else 0 success_rate = (
(stats.successful_records / stats.total_source_records * 100)
if stats.total_source_records > 0
else 0
)
print(f" Success rate: {success_rate:.2f}%") print(f" Success rate: {success_rate:.2f}%")
# Error details # Error details
@ -595,7 +628,7 @@ class MigrationTool:
# Group errors by type # Group errors by type
error_types = {} error_types = {}
for error in stats.errors: for error in stats.errors:
err_type = error['error_type'] err_type = error["error_type"]
error_types[err_type] = error_types.get(err_type, 0) + 1 error_types[err_type] = error_types.get(err_type, 0) + 1
print("\nError Summary:") print("\nError Summary:")
@ -633,7 +666,7 @@ class MigrationTool:
self.source_storage, self.source_storage,
source_storage_name, source_storage_name,
self.source_workspace, self.source_workspace,
source_data source_data,
) = await self.setup_storage("Source") ) = await self.setup_storage("Source")
if not self.source_storage: if not self.source_storage:
@ -651,7 +684,7 @@ class MigrationTool:
self.target_storage, self.target_storage,
target_storage_name, target_storage_name,
self.target_workspace, self.target_workspace,
target_data target_data,
) = await self.setup_storage("Target") ) = await self.setup_storage("Target")
if not self.target_storage: if not self.target_storage:
@ -664,17 +697,23 @@ class MigrationTool:
print("\n" + "=" * 50) print("\n" + "=" * 50)
print("Migration Confirmation") print("Migration Confirmation")
print("=" * 50) print("=" * 50)
print(f"Source: {source_storage_name} (workspace: {self.source_workspace if self.source_workspace else '(default)'}) - {len(source_data):,} records") print(
print(f"Target: {target_storage_name} (workspace: {self.target_workspace if self.target_workspace else '(default)'}) - {len(target_data):,} records") f"Source: {source_storage_name} (workspace: {self.source_workspace if self.source_workspace else '(default)'}) - {len(source_data):,} records"
)
print(
f"Target: {target_storage_name} (workspace: {self.target_workspace if self.target_workspace else '(default)'}) - {len(target_data):,} records"
)
print(f"Batch Size: {self.batch_size:,} records/batch") print(f"Batch Size: {self.batch_size:,} records/batch")
if target_data: if target_data:
print(f"\n⚠ Warning: Target storage already has {len(target_data):,} records") print(
f"\n⚠ Warning: Target storage already has {len(target_data):,} records"
)
print("Migration will overwrite records with the same keys") print("Migration will overwrite records with the same keys")
# Confirm migration # Confirm migration
confirm = input("\nContinue? (y/n): ").strip().lower() confirm = input("\nContinue? (y/n): ").strip().lower()
if confirm != 'y': if confirm != "y":
print("\n✗ Migration cancelled") print("\n✗ Migration cancelled")
# Cleanup # Cleanup
await self.source_storage.finalize() await self.source_storage.finalize()
@ -682,7 +721,9 @@ class MigrationTool:
return return
# Perform migration with error tracking # Perform migration with error tracking
stats = await self.migrate_caches(source_data, self.target_storage, target_storage_name) stats = await self.migrate_caches(
source_data, self.target_storage, target_storage_name
)
# Print comprehensive migration report # Print comprehensive migration report
self.print_migration_report(stats) self.print_migration_report(stats)
@ -696,6 +737,7 @@ class MigrationTool:
except Exception as e: except Exception as e:
print(f"\n✗ Migration failed: {e}") print(f"\n✗ Migration failed: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally: finally:
# Ensure cleanup # Ensure cleanup