This commit is contained in:
Safi 2025-12-12 10:15:33 +08:00 committed by GitHub
commit cc70c1d9fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -5,6 +5,7 @@ import asyncio
import configparser import configparser
import inspect import inspect
import os import os
import json
import time import time
import warnings import warnings
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
@ -117,6 +118,7 @@ from dotenv import load_dotenv
# use the .env that is inside the current folder # use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance # allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False) load_dotenv(dotenv_path=".env", override=False)
@ -525,7 +527,7 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding # Init Embedding
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes) # Step 1: Capture max_token_size and embedding_dim before applying decorator
embedding_max_token_size = None embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"): if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size embedding_max_token_size = self.embedding_func.max_token_size
@ -534,6 +536,17 @@ class LightRAG:
) )
self.embedding_token_limit = embedding_max_token_size self.embedding_token_limit = embedding_max_token_size
# Capture embedding model name before decoration so we don't lose it to wrappers
self.embedding_model_name = (
self.embedding_func.__class__.__name__ if self.embedding_func else "unknown"
)
# --- CAPTURE EMBEDDING DIMENSION (NEW) ---
self.embedding_dim = None
if self.embedding_func and hasattr(self.embedding_func, "embedding_dim"):
self.embedding_dim = self.embedding_func.embedding_dim
# -----------------------------------------
# Step 2: Apply priority wrapper decorator # Step 2: Apply priority wrapper decorator
self.embedding_func = priority_limit_async_func_call( self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async, self.embedding_func_max_async,
@ -658,8 +671,56 @@ class LightRAG:
self._storages_status = StoragesStatus.CREATED self._storages_status = StoragesStatus.CREATED
def _check_embedding_config(self):
"""
Validates that the current embedding dimension matches the existing data.
Prevents data corruption when switching models without clearing storage.
"""
if self.embedding_dim is None:
# If we couldn't capture dimensions, skip the check to avoid blocking valid custom models
return
meta_file = os.path.join(self.working_dir, "lightrag_meta.json")
if os.path.exists(meta_file):
with open(meta_file, "r", encoding="utf-8") as f:
try:
meta_data = json.load(f)
saved_dim = meta_data.get("embedding_dim")
saved_model = meta_data.get("embedding_model_func", "unknown")
if saved_dim and saved_dim != self.embedding_dim:
raise ValueError(
f"Embedding dimension mismatch! "
f"Existing data uses dimension {saved_dim} (Model: {saved_model}), "
f"but current configuration uses {self.embedding_dim}. "
f"Please clear the '{self.working_dir}' directory or switch back to the original model."
)
except json.JSONDecodeError:
logger.warning(
f"Could not parse {meta_file}. Skipping dimension check."
)
else:
# First run: Save the metadata
meta_data = {
"embedding_dim": self.embedding_dim,
"embedding_model_func": self.embedding_model_name,
"created_at": datetime.now(timezone.utc).isoformat(),
}
# Ensure directory exists
if not os.path.exists(self.working_dir):
os.makedirs(self.working_dir)
with open(meta_file, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=4)
async def initialize_storages(self): async def initialize_storages(self):
"""Storage initialization must be called one by one to prevent deadlock""" """Storage initialization must be called one by one to prevent deadlock"""
# --- NEW SAFETY CHECK CALL ---
self._check_embedding_config()
# -----------------------------
if self._storages_status == StoragesStatus.CREATED: if self._storages_status == StoragesStatus.CREATED:
# Set the first initialized workspace will set the default workspace # Set the first initialized workspace will set the default workspace
# Allows namespace operation without specifying workspace for backward compatibility # Allows namespace operation without specifying workspace for backward compatibility