diff --git a/lightrag/utils.py b/lightrag/utils.py index 65c1e4bc..064e4804 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,8 +1,6 @@ from __future__ import annotations import weakref -import sys - import asyncio import html import csv @@ -42,35 +40,6 @@ from lightrag.constants import ( SOURCE_IDS_LIMIT_METHOD_FIFO, ) -# Precompile regex pattern for JSON sanitization (module-level, compiled once) -_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") - - -class SafeStreamHandler(logging.StreamHandler): - """StreamHandler that gracefully handles closed streams during shutdown. - - This handler prevents "ValueError: I/O operation on closed file" errors - that can occur when pytest or other test frameworks close stdout/stderr - before Python's logging cleanup runs. - """ - - def flush(self): - """Flush the stream, ignoring errors if the stream is closed.""" - try: - super().flush() - except (ValueError, OSError): - # Stream is closed or otherwise unavailable, silently ignore - pass - - def close(self): - """Close the handler, ignoring errors if the stream is already closed.""" - try: - super().close() - except (ValueError, OSError): - # Stream is closed or otherwise unavailable, silently ignore - pass - - # Initialize logger with basic configuration logger = logging.getLogger("lightrag") logger.propagate = False # prevent log message send to root logger @@ -78,7 +47,7 @@ logger.setLevel(logging.INFO) # Add console handler if no handlers exist if not logger.handlers: - console_handler = SafeStreamHandler() + console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(levelname)s: %(message)s") console_handler.setFormatter(formatter) @@ -87,33 +56,6 @@ if not logger.handlers: # Set httpx logging level to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) - -def _patch_ascii_colors_console_handler() -> None: - """Prevent ascii_colors from printing flush errors during interpreter exit.""" - - try: - from ascii_colors import ConsoleHandler - except ImportError: - return - - if getattr(ConsoleHandler, "_lightrag_patched", False): - return - - original_handle_error = ConsoleHandler.handle_error - - def _safe_handle_error(self, message: str) -> None: # type: ignore[override] - exc_type, _, _ = sys.exc_info() - if exc_type in (ValueError, OSError) and "close" in message.lower(): - return - original_handle_error(self, message) - - ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment] - ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined] - - -_patch_ascii_colors_console_handler() - - # Global import for pypinyin with startup-time logging try: import pypinyin @@ -341,8 +283,8 @@ def setup_logger( logger_instance.handlers = [] # Clear existing handlers logger_instance.propagate = False - # Add console handler with safe stream handling - console_handler = SafeStreamHandler() + # Add console handler + console_handler = logging.StreamHandler() console_handler.setFormatter(simple_formatter) console_handler.setLevel(level) logger_instance.addHandler(console_handler) @@ -408,20 +350,9 @@ class TaskState: @dataclass class EmbeddingFunc: - """Embedding function wrapper with dimension validation - This class wraps an embedding function to ensure that the output embeddings have the correct dimension. - This class should not be wrapped multiple times. - - Args: - embedding_dim: Expected dimension of the embeddings - func: The actual embedding function to wrap - max_token_size: Optional token limit for the embedding model - send_dimensions: Whether to inject embedding_dim as a keyword argument - """ - embedding_dim: int func: callable - max_token_size: int | None = None # Token limit for the embedding model + max_token_size: int | None = None # deprecated keep it for compatible only send_dimensions: bool = ( False # Control whether to send embedding_dim to the function ) @@ -445,32 +376,7 @@ class EmbeddingFunc: # Inject embedding_dim from decorator kwargs["embedding_dim"] = self.embedding_dim - # Call the actual embedding function - result = await self.func(*args, **kwargs) - - # Validate embedding dimensions using total element count - total_elements = result.size # Total number of elements in the numpy array - expected_dim = self.embedding_dim - - # Check if total elements can be evenly divided by embedding_dim - if total_elements % expected_dim != 0: - raise ValueError( - f"Embedding dimension mismatch detected: " - f"total elements ({total_elements}) cannot be evenly divided by " - f"expected dimension ({expected_dim}). " - ) - - # Optional: Verify vector count matches input text count - actual_vectors = total_elements // expected_dim - if args and isinstance(args[0], (list, tuple)): - expected_vectors = len(args[0]) - if actual_vectors != expected_vectors: - raise ValueError( - f"Vector count mismatch: " - f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." - ) - - return result + return await self.func(*args, **kwargs) def compute_args_hash(*args: Any) -> str: @@ -1005,76 +911,7 @@ def priority_limit_async_func_call( def wrap_embedding_func_with_attrs(**kwargs): - """Decorator to add embedding dimension and token limit attributes to embedding functions. - - This decorator wraps an async embedding function and returns an EmbeddingFunc instance - that automatically handles dimension parameter injection and attribute management. - - WARNING: DO NOT apply this decorator to wrapper functions that call other - decorated embedding functions. This will cause double decoration and parameter - injection conflicts. - - Correct usage patterns: - - 1. Direct implementation (decorated): - ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536) - async def my_embed(texts, embedding_dim=None): - # Direct implementation - return embeddings - ``` - - 2. Wrapper calling decorated function (DO NOT decorate wrapper): - ```python - # my_embed is already decorated above - - async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this! - # Must call .func to access unwrapped implementation - return await my_embed.func(texts, **kwargs) - ``` - - 3. Wrapper calling decorated function (properly decorated): - ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536) - async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func - # Calling .func avoids double decoration - return await my_embed.func(texts, **kwargs) - ``` - - The decorated function becomes an EmbeddingFunc instance with: - - embedding_dim: The embedding dimension - - max_token_size: Maximum token limit (optional) - - func: The original unwrapped function (access via .func) - - __call__: Wrapper that injects embedding_dim parameter - - Double decoration causes: - - Double injection of embedding_dim parameter - - Incorrect parameter passing to the underlying implementation - - Runtime errors due to parameter conflicts - - Args: - embedding_dim: The dimension of embedding vectors - max_token_size: Maximum number of tokens (optional) - send_dimensions: Whether to inject embedding_dim as a keyword argument (optional) - - Returns: - A decorator that wraps the function as an EmbeddingFunc instance - - Example of correct wrapper implementation: - ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) - @retry(...) - async def openai_embed(texts, ...): - # Base implementation - pass - - @wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here! - async def azure_openai_embed(texts, ...): - # CRITICAL: Call .func to access unwrapped function - return await openai_embed.func(texts, ...) # ✅ Correct - # return await openai_embed(texts, ...) # ❌ Wrong - double decoration! - ``` - """ + """Wrap a function with attributes""" def final_decro(func) -> EmbeddingFunc: new_func = EmbeddingFunc(**kwargs, func=func) @@ -1090,123 +927,30 @@ def load_json(file_name): return json.load(f) -def _sanitize_string_for_json(text: str) -> str: - """Remove characters that cannot be encoded in UTF-8 for JSON serialization. - - Uses regex for optimal performance with zero-copy optimization for clean strings. - Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings. +def _sanitize_json_data(data: Any) -> Any: + """Recursively sanitize all string values in data structure for safe UTF-8 encoding Args: - text: String to sanitize + data: Data to sanitize (dict, list, str, or other types) Returns: - Original string if clean (zero-copy), sanitized string if dirty + Sanitized data with all strings cleaned of problematic characters """ - if not text: - return text - - # Fast path: Check if sanitization is needed using C-level regex search - if not _SURROGATE_PATTERN.search(text): - return text # Zero-copy for clean strings - most common case - - # Slow path: Remove problematic characters using C-level regex substitution - return _SURROGATE_PATTERN.sub("", text) - - -class SanitizingJSONEncoder(json.JSONEncoder): - """ - Custom JSON encoder that sanitizes data during serialization. - - This encoder cleans strings during the encoding process without creating - a full copy of the data structure, making it memory-efficient for large datasets. - """ - - def encode(self, o): - """Override encode method to handle simple string cases""" - if isinstance(o, str): - return json.encoder.encode_basestring(_sanitize_string_for_json(o)) - return super().encode(o) - - def iterencode(self, o, _one_shot=False): - """ - Override iterencode to sanitize strings during serialization. - This is the core method that handles complex nested structures. - """ - # Preprocess: sanitize all strings in the object - sanitized = self._sanitize_for_encoding(o) - - # Call parent's iterencode with sanitized data - for chunk in super().iterencode(sanitized, _one_shot): - yield chunk - - def _sanitize_for_encoding(self, obj): - """ - Recursively sanitize strings in an object. - Creates new objects only when necessary to avoid deep copies. - - Args: - obj: Object to sanitize - - Returns: - Sanitized object with cleaned strings - """ - if isinstance(obj, str): - return _sanitize_string_for_json(obj) - - elif isinstance(obj, dict): - # Create new dict with sanitized keys and values - new_dict = {} - for k, v in obj.items(): - clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k - clean_v = self._sanitize_for_encoding(v) - new_dict[clean_k] = clean_v - return new_dict - - elif isinstance(obj, (list, tuple)): - # Sanitize list/tuple elements - cleaned = [self._sanitize_for_encoding(item) for item in obj] - return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned - - else: - # Numbers, booleans, None, etc. remain unchanged - return obj + if isinstance(data, dict): + return {k: _sanitize_json_data(v) for k, v in data.items()} + elif isinstance(data, list): + return [_sanitize_json_data(item) for item in data] + elif isinstance(data, str): + return sanitize_text_for_encoding(data, replacement_char="") + else: + return data def write_json(json_obj, file_name): - """ - Write JSON data to file with optimized sanitization strategy. - - This function uses a two-stage approach: - 1. Fast path: Try direct serialization (works for clean data ~99% of time) - 2. Slow path: Use custom encoder that sanitizes during serialization - - The custom encoder approach avoids creating a deep copy of the data, - making it memory-efficient. When sanitization occurs, the caller should - reload the cleaned data from the file to update shared memory. - - Args: - json_obj: Object to serialize (may be a shallow copy from shared memory) - file_name: Output file path - - Returns: - bool: True if sanitization was applied (caller should reload data), - False if direct write succeeded (no reload needed) - """ - try: - # Strategy 1: Fast path - try direct serialization - with open(file_name, "w", encoding="utf-8") as f: - json.dump(json_obj, f, indent=2, ensure_ascii=False) - return False # No sanitization needed, no reload required - - except (UnicodeEncodeError, UnicodeDecodeError) as e: - logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}") - - # Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy) + # Sanitize data before writing to prevent UTF-8 encoding errors + sanitized_obj = _sanitize_json_data(json_obj) with open(file_name, "w", encoding="utf-8") as f: - json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder) - - logger.info(f"JSON sanitization applied during write: {file_name}") - return True # Sanitization applied, reload recommended + json.dump(sanitized_obj, f, indent=2, ensure_ascii=False) class TokenizerInterface(Protocol):