From 0c463709407d06eaac9bea704b2f63deef98f7b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:19:04 +0800 Subject: [PATCH] cherry-pick 90f52acf --- lightrag/utils.py | 159 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 3 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 460ede3c..8c9b7776 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -56,6 +56,9 @@ if not logger.handlers: # Set httpx logging level to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) +# Precompile regex pattern for JSON sanitization (module-level, compiled once) +_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") + # Global import for pypinyin with startup-time logging try: import pypinyin @@ -350,9 +353,20 @@ class TaskState: @dataclass class EmbeddingFunc: + """Embedding function wrapper with dimension validation + This class wraps an embedding function to ensure that the output embeddings have the correct dimension. + This class should not be wrapped multiple times. + + Args: + embedding_dim: Expected dimension of the embeddings + func: The actual embedding function to wrap + max_token_size: Optional token limit for the embedding model + send_dimensions: Whether to inject embedding_dim as a keyword argument + """ + embedding_dim: int func: callable - max_token_size: int | None = None # deprecated keep it for compatible only + max_token_size: int | None = None # Token limit for the embedding model send_dimensions: bool = ( False # Control whether to send embedding_dim to the function ) @@ -376,7 +390,32 @@ class EmbeddingFunc: # Inject embedding_dim from decorator kwargs["embedding_dim"] = self.embedding_dim - return await self.func(*args, **kwargs) + # Call the actual embedding function + result = await self.func(*args, **kwargs) + + # Validate embedding dimensions using total element count + total_elements = result.size # Total number of elements in the numpy array + expected_dim = self.embedding_dim + + # Check if total elements can be evenly divided by embedding_dim + if total_elements % expected_dim != 0: + raise ValueError( + f"Embedding dimension mismatch detected: " + f"total elements ({total_elements}) cannot be evenly divided by " + f"expected dimension ({expected_dim}). " + ) + + # Optional: Verify vector count matches input text count + actual_vectors = total_elements // expected_dim + if args and isinstance(args[0], (list, tuple)): + expected_vectors = len(args[0]) + if actual_vectors != expected_vectors: + raise ValueError( + f"Vector count mismatch: " + f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." + ) + + return result def compute_args_hash(*args: Any) -> str: @@ -927,9 +966,123 @@ def load_json(file_name): return json.load(f) +def _sanitize_string_for_json(text: str) -> str: + """Remove characters that cannot be encoded in UTF-8 for JSON serialization. + + Uses regex for optimal performance with zero-copy optimization for clean strings. + Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings. + + Args: + text: String to sanitize + + Returns: + Original string if clean (zero-copy), sanitized string if dirty + """ + if not text: + return text + + # Fast path: Check if sanitization is needed using C-level regex search + if not _SURROGATE_PATTERN.search(text): + return text # Zero-copy for clean strings - most common case + + # Slow path: Remove problematic characters using C-level regex substitution + return _SURROGATE_PATTERN.sub("", text) + + +class SanitizingJSONEncoder(json.JSONEncoder): + """ + Custom JSON encoder that sanitizes data during serialization. + + This encoder cleans strings during the encoding process without creating + a full copy of the data structure, making it memory-efficient for large datasets. + """ + + def encode(self, o): + """Override encode method to handle simple string cases""" + if isinstance(o, str): + return json.encoder.encode_basestring(_sanitize_string_for_json(o)) + return super().encode(o) + + def iterencode(self, o, _one_shot=False): + """ + Override iterencode to sanitize strings during serialization. + This is the core method that handles complex nested structures. + """ + # Preprocess: sanitize all strings in the object + sanitized = self._sanitize_for_encoding(o) + + # Call parent's iterencode with sanitized data + for chunk in super().iterencode(sanitized, _one_shot): + yield chunk + + def _sanitize_for_encoding(self, obj): + """ + Recursively sanitize strings in an object. + Creates new objects only when necessary to avoid deep copies. + + Args: + obj: Object to sanitize + + Returns: + Sanitized object with cleaned strings + """ + if isinstance(obj, str): + return _sanitize_string_for_json(obj) + + elif isinstance(obj, dict): + # Create new dict with sanitized keys and values + new_dict = {} + for k, v in obj.items(): + clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k + clean_v = self._sanitize_for_encoding(v) + new_dict[clean_k] = clean_v + return new_dict + + elif isinstance(obj, (list, tuple)): + # Sanitize list/tuple elements + cleaned = [self._sanitize_for_encoding(item) for item in obj] + return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned + + else: + # Numbers, booleans, None, etc. remain unchanged + return obj + + def write_json(json_obj, file_name): + """ + Write JSON data to file with optimized sanitization strategy. + + This function uses a two-stage approach: + 1. Fast path: Try direct serialization (works for clean data ~99% of time) + 2. Slow path: Use custom encoder that sanitizes during serialization + + The custom encoder approach avoids creating a deep copy of the data, + making it memory-efficient. When sanitization occurs, the caller should + reload the cleaned data from the file to update shared memory. + + Args: + json_obj: Object to serialize (may be a shallow copy from shared memory) + file_name: Output file path + + Returns: + bool: True if sanitization was applied (caller should reload data), + False if direct write succeeded (no reload needed) + """ + try: + # Strategy 1: Fast path - try direct serialization + with open(file_name, "w", encoding="utf-8") as f: + json.dump(json_obj, f, indent=2, ensure_ascii=False) + return False # No sanitization needed, no reload required + + except (UnicodeEncodeError, UnicodeDecodeError) as e: + logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}") + + # Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy) with open(file_name, "w", encoding="utf-8") as f: - json.dump(json_obj, f, indent=2, ensure_ascii=False) + json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder) + + logger.info(f"JSON sanitization applied during write: {file_name}") + return True # Sanitization applied, reload recommended class TokenizerInterface(Protocol):