diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index af700393..e2f9a3d7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -85,7 +85,7 @@ from .utils import ( lazy_external_import, priority_limit_async_func_call, get_content_summary, - clean_text, + sanitize_text_for_encoding, check_storage_env_vars, generate_track_id, logger, @@ -908,8 +908,8 @@ class LightRAG: update_storage = False try: # Clean input texts - full_text = clean_text(full_text) - text_chunks = [clean_text(chunk) for chunk in text_chunks] + full_text = sanitize_text_for_encoding(full_text) + text_chunks = [sanitize_text_for_encoding(chunk) for chunk in text_chunks] file_path = "" # Process cleaned texts @@ -1020,7 +1020,7 @@ class LightRAG: # Generate contents dict and remove duplicates in one pass unique_contents = {} for id_, doc, path in zip(ids, input, file_paths): - cleaned_content = clean_text(doc) + cleaned_content = sanitize_text_for_encoding(doc) if cleaned_content not in unique_contents: unique_contents[cleaned_content] = (id_, path) @@ -1033,7 +1033,7 @@ class LightRAG: # Clean input text and remove duplicates in one pass unique_content_with_paths = {} for doc, path in zip(input, file_paths): - cleaned_content = clean_text(doc) + cleaned_content = sanitize_text_for_encoding(doc) if cleaned_content not in unique_content_with_paths: unique_content_with_paths[cleaned_content] = path @@ -1817,7 +1817,7 @@ class LightRAG: all_chunks_data: dict[str, dict[str, str]] = {} chunk_to_source_map: dict[str, str] = {} for chunk_data in custom_kg.get("chunks", []): - chunk_content = clean_text(chunk_data["content"]) + chunk_content = sanitize_text_for_encoding(chunk_data["content"]) source_id = chunk_data["source_id"] file_path = chunk_data.get("file_path", "custom_kg") tokens = len(self.tokenizer.encode(chunk_content)) diff --git a/lightrag/utils.py b/lightrag/utils.py index bbafd9f1..979517b5 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -272,19 +272,26 @@ def compute_args_hash(*args: Any) -> str: Returns: str: Hash string """ - import hashlib - # Convert all arguments to strings and join them args_str = "".join([str(arg) for arg in args]) # Use 'replace' error handling to safely encode problematic Unicode characters # This replaces invalid characters with Unicode replacement character (U+FFFD) try: - return hashlib.md5(args_str.encode("utf-8")).hexdigest() + return md5(args_str.encode("utf-8")).hexdigest() except UnicodeEncodeError: # Handle surrogate characters and other encoding issues safe_bytes = args_str.encode("utf-8", errors="replace") - return hashlib.md5(safe_bytes).hexdigest() + return md5(safe_bytes).hexdigest() + + +def compute_mdhash_id(content: str, prefix: str = "") -> str: + """ + Compute a unique ID for a given content string. + + The ID is a combination of the given prefix and the MD5 hash of the content string. + """ + return prefix + compute_args_hash(content) def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str: @@ -316,15 +323,6 @@ def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None: return None -def compute_mdhash_id(content: str, prefix: str = "") -> str: - """ - Compute a unique ID for a given content string. - - The ID is a combination of the given prefix and the MD5 hash of the content string. - """ - return prefix + md5(content.encode()).hexdigest() - - # Custom exception class class QueueFullError(Exception): """Raised when the queue is full and the wait times out""" @@ -1402,11 +1400,13 @@ async def use_llm_func_with_cache( chunk_id: str | None = None, cache_keys_collector: list = None, ) -> str: - """Call LLM function with cache support + """Call LLM function with cache support and text sanitization If cache is available and enabled (determined by handle_cache based on mode), retrieve result from cache; otherwise call LLM function and save result to cache. + This function applies text sanitization to prevent UTF-8 encoding errors for all LLM providers. + Args: input_text: Input text to send to LLM use_llm_func: LLM function with higher priority @@ -1421,12 +1421,25 @@ async def use_llm_func_with_cache( Returns: LLM response text """ + # Sanitize input text to prevent UTF-8 encoding errors for all LLM providers + safe_input_text = sanitize_text_for_encoding(input_text) + + # Sanitize history messages if provided + safe_history_messages = None + if history_messages: + safe_history_messages = [] + for i, msg in enumerate(history_messages): + safe_msg = msg.copy() + if "content" in safe_msg: + safe_msg["content"] = sanitize_text_for_encoding(safe_msg["content"]) + safe_history_messages.append(safe_msg) + if llm_response_cache: - if history_messages: - history = json.dumps(history_messages, ensure_ascii=False) - _prompt = history + "\n" + input_text + if safe_history_messages: + history = json.dumps(safe_history_messages, ensure_ascii=False) + _prompt = history + "\n" + safe_input_text else: - _prompt = input_text + _prompt = safe_input_text arg_hash = compute_args_hash(_prompt) # Generate cache key for this LLM call @@ -1450,14 +1463,14 @@ async def use_llm_func_with_cache( return cached_return statistic_data["llm_call"] += 1 - # Call LLM + # Call LLM with sanitized input kwargs = {} - if history_messages: - kwargs["history_messages"] = history_messages + if safe_history_messages: + kwargs["history_messages"] = safe_history_messages if max_tokens is not None: kwargs["max_tokens"] = max_tokens - res: str = await use_llm_func(input_text, **kwargs) + res: str = await use_llm_func(safe_input_text, **kwargs) res = remove_think_tags(res) if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"): @@ -1478,15 +1491,15 @@ async def use_llm_func_with_cache( return res - # When cache is disabled, directly call LLM + # When cache is disabled, directly call LLM with sanitized input kwargs = {} - if history_messages: - kwargs["history_messages"] = history_messages + if safe_history_messages: + kwargs["history_messages"] = safe_history_messages if max_tokens is not None: kwargs["max_tokens"] = max_tokens - logger.info(f"Call LLM function with query text length: {len(input_text)}") - res = await use_llm_func(input_text, **kwargs) + logger.info(f"Call LLM function with query text length: {len(safe_input_text)}") + res = await use_llm_func(safe_input_text, **kwargs) return remove_think_tags(res) @@ -1560,16 +1573,97 @@ def normalize_extracted_info(name: str, is_entity=False) -> str: return name -def clean_text(text: str) -> str: - """Clean text by removing null bytes (0x00) and whitespace +def sanitize_text_for_encoding(text: str, replacement_char: str = "") -> str: + """Sanitize text to ensure safe UTF-8 encoding by removing or replacing problematic characters. + + This function handles: + - Surrogate characters (the main cause of the encoding error) + - Other invalid Unicode sequences + - Control characters that might cause issues + - Whitespace trimming Args: - text: Input text to clean + text: Input text to sanitize + replacement_char: Character to use for replacing invalid sequences Returns: - Cleaned text + Sanitized text that can be safely encoded as UTF-8 """ - return text.strip().replace("\x00", "") + if not isinstance(text, str): + return str(text) + + if not text: + return text + + try: + # First, strip whitespace + text = text.strip() + + # Early return if text is empty after basic cleaning + if not text: + return text + + # Try to encode/decode to catch any encoding issues early + text.encode("utf-8") + + # Remove or replace surrogate characters (U+D800 to U+DFFF) + # These are the main cause of the encoding error + sanitized = "" + for char in text: + code_point = ord(char) + # Check for surrogate characters + if 0xD800 <= code_point <= 0xDFFF: + # Replace surrogate with replacement character + sanitized += replacement_char + continue + # Check for other problematic characters + elif code_point == 0xFFFE or code_point == 0xFFFF: + # These are non-characters in Unicode + sanitized += replacement_char + continue + else: + sanitized += char + + # Additional cleanup: remove null bytes and other control characters that might cause issues + # (but preserve common whitespace like \t, \n, \r) + sanitized = re.sub( + r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", replacement_char, sanitized + ) + + # Test final encoding to ensure it's safe + sanitized.encode("utf-8") + + return sanitized + + except UnicodeEncodeError as e: + logger.warning( + f"Text sanitization: UnicodeEncodeError encountered, applying aggressive cleaning: {str(e)[:100]}" + ) + + # Aggressive fallback: encode with error handling + try: + # Use 'replace' error handling to substitute problematic characters + safe_bytes = text.encode("utf-8", errors="replace") + sanitized = safe_bytes.decode("utf-8") + + # Additional cleanup + sanitized = re.sub( + r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", replacement_char, sanitized + ) + + return sanitized + + except Exception as fallback_error: + logger.error( + f"Text sanitization: Aggressive fallback failed: {str(fallback_error)}" + ) + # Last resort: return a safe placeholder + return f"[TEXT_ENCODING_ERROR: {len(text)} characters]" + + except Exception as e: + logger.error(f"Text sanitization: Unexpected error: {str(e)}") + # Return original text if no encoding issues detected + return text def check_storage_env_vars(storage_name: str) -> None: