feat(core): implement safety check for embedding dimension mismatch

This commit is contained in:
captainmirk 2025-12-03 00:20:40 +00:00
parent 6476021619
commit 896e203574

View file

@ -5,6 +5,7 @@ import asyncio
import configparser
import inspect
import os
import json
import time
import warnings
from dataclasses import asdict, dataclass, field
@ -524,8 +525,8 @@ class LightRAG:
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
# Init Embedding
# Step 1: Capture max_token_size and embedding_dim before applying decorator
embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size
@ -534,6 +535,12 @@ class LightRAG:
)
self.embedding_token_limit = embedding_max_token_size
# --- CAPTURE EMBEDDING DIMENSION (NEW) ---
self.embedding_dim = None
if self.embedding_func and hasattr(self.embedding_func, "embedding_dim"):
self.embedding_dim = self.embedding_func.embedding_dim
# -----------------------------------------
# Step 2: Apply priority wrapper decorator
self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async,
@ -658,8 +665,53 @@ class LightRAG:
self._storages_status = StoragesStatus.CREATED
def _check_embedding_config(self):
"""
Validates that the current embedding dimension matches the existing data.
Prevents data corruption when switching models without clearing storage.
"""
if self.embedding_dim is None:
# If we couldn't capture dimensions, skip the check to avoid blocking valid custom models
return
meta_file = os.path.join(self.working_dir, "lightrag_meta.json")
if os.path.exists(meta_file):
with open(meta_file, "r", encoding="utf-8") as f:
try:
meta_data = json.load(f)
saved_dim = meta_data.get("embedding_dim")
saved_model = meta_data.get("embedding_model_func", "unknown")
if saved_dim and saved_dim != self.embedding_dim:
raise ValueError(
f"Embedding dimension mismatch! "
f"Existing data uses dimension {saved_dim} (Model: {saved_model}), "
f"but current configuration uses {self.embedding_dim}. "
f"Please clear the '{self.working_dir}' directory or switch back to the original model."
)
except json.JSONDecodeError:
logger.warning(f"Could not parse {meta_file}. Skipping dimension check.")
else:
# First run: Save the metadata
meta_data = {
"embedding_dim": self.embedding_dim,
"embedding_model_func": self.embedding_func.__class__.__name__ if self.embedding_func else "unknown",
"created_at": str(os.path.abspath(self.working_dir))
}
# Ensure directory exists
if not os.path.exists(self.working_dir):
os.makedirs(self.working_dir)
with open(meta_file, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=4)
async def initialize_storages(self):
"""Storage initialization must be called one by one to prevent deadlock"""
# --- NEW SAFETY CHECK CALL ---
self._check_embedding_config()
# -----------------------------
if self._storages_status == StoragesStatus.CREATED:
# Set the first initialized workspace will set the default workspace
# Allows namespace operation without specifying workspace for backward compatibility