This commit is contained in:
Raphaël MANSUY 2025-12-04 19:15:02 +08:00
parent 93778770ab
commit 4ec5073aaa

View file

@ -1,8 +1,6 @@
from __future__ import annotations
import weakref
import sys
import asyncio
import html
import csv
@ -42,35 +40,6 @@ from lightrag.constants import (
SOURCE_IDS_LIMIT_METHOD_FIFO,
)
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
class SafeStreamHandler(logging.StreamHandler):
"""StreamHandler that gracefully handles closed streams during shutdown.
This handler prevents "ValueError: I/O operation on closed file" errors
that can occur when pytest or other test frameworks close stdout/stderr
before Python's logging cleanup runs.
"""
def flush(self):
"""Flush the stream, ignoring errors if the stream is closed."""
try:
super().flush()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
def close(self):
"""Close the handler, ignoring errors if the stream is already closed."""
try:
super().close()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
# Initialize logger with basic configuration
logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root logger
@ -78,7 +47,7 @@ logger.setLevel(logging.INFO)
# Add console handler if no handlers exist
if not logger.handlers:
console_handler = SafeStreamHandler()
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)s: %(message)s")
console_handler.setFormatter(formatter)
@ -87,33 +56,6 @@ if not logger.handlers:
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
def _patch_ascii_colors_console_handler() -> None:
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
try:
from ascii_colors import ConsoleHandler
except ImportError:
return
if getattr(ConsoleHandler, "_lightrag_patched", False):
return
original_handle_error = ConsoleHandler.handle_error
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
exc_type, _, _ = sys.exc_info()
if exc_type in (ValueError, OSError) and "close" in message.lower():
return
original_handle_error(self, message)
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
_patch_ascii_colors_console_handler()
# Global import for pypinyin with startup-time logging
try:
import pypinyin
@ -341,8 +283,8 @@ def setup_logger(
logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False
# Add console handler with safe stream handling
console_handler = SafeStreamHandler()
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
@ -408,20 +350,9 @@ class TaskState:
@dataclass
class EmbeddingFunc:
"""Embedding function wrapper with dimension validation
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
This class should not be wrapped multiple times.
Args:
embedding_dim: Expected dimension of the embeddings
func: The actual embedding function to wrap
max_token_size: Optional token limit for the embedding model
send_dimensions: Whether to inject embedding_dim as a keyword argument
"""
embedding_dim: int
func: callable
max_token_size: int | None = None # Token limit for the embedding model
max_token_size: int | None = None # deprecated keep it for compatible only
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
@ -445,32 +376,7 @@ class EmbeddingFunc:
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
# Call the actual embedding function
result = await self.func(*args, **kwargs)
# Validate embedding dimensions using total element count
total_elements = result.size # Total number of elements in the numpy array
expected_dim = self.embedding_dim
# Check if total elements can be evenly divided by embedding_dim
if total_elements % expected_dim != 0:
raise ValueError(
f"Embedding dimension mismatch detected: "
f"total elements ({total_elements}) cannot be evenly divided by "
f"expected dimension ({expected_dim}). "
)
# Optional: Verify vector count matches input text count
actual_vectors = total_elements // expected_dim
if args and isinstance(args[0], (list, tuple)):
expected_vectors = len(args[0])
if actual_vectors != expected_vectors:
raise ValueError(
f"Vector count mismatch: "
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
)
return result
return await self.func(*args, **kwargs)
def compute_args_hash(*args: Any) -> str:
@ -1005,76 +911,7 @@ def priority_limit_async_func_call(
def wrap_embedding_func_with_attrs(**kwargs):
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
that automatically handles dimension parameter injection and attribute management.
WARNING: DO NOT apply this decorator to wrapper functions that call other
decorated embedding functions. This will cause double decoration and parameter
injection conflicts.
Correct usage patterns:
1. Direct implementation (decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_embed(texts, embedding_dim=None):
# Direct implementation
return embeddings
```
2. Wrapper calling decorated function (DO NOT decorate wrapper):
```python
# my_embed is already decorated above
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
# Must call .func to access unwrapped implementation
return await my_embed.func(texts, **kwargs)
```
3. Wrapper calling decorated function (properly decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
# Calling .func avoids double decoration
return await my_embed.func(texts, **kwargs)
```
The decorated function becomes an EmbeddingFunc instance with:
- embedding_dim: The embedding dimension
- max_token_size: Maximum token limit (optional)
- func: The original unwrapped function (access via .func)
- __call__: Wrapper that injects embedding_dim parameter
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
Args:
embedding_dim: The dimension of embedding vectors
max_token_size: Maximum number of tokens (optional)
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
Returns:
A decorator that wraps the function as an EmbeddingFunc instance
Example of correct wrapper implementation:
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(...)
async def openai_embed(texts, ...):
# Base implementation
pass
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
async def azure_openai_embed(texts, ...):
# CRITICAL: Call .func to access unwrapped function
return await openai_embed.func(texts, ...) # ✅ Correct
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
```
"""
"""Wrap a function with attributes"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
@ -1090,123 +927,30 @@ def load_json(file_name):
return json.load(f)
def _sanitize_string_for_json(text: str) -> str:
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
Uses regex for optimal performance with zero-copy optimization for clean strings.
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
def _sanitize_json_data(data: Any) -> Any:
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
Args:
text: String to sanitize
data: Data to sanitize (dict, list, str, or other types)
Returns:
Original string if clean (zero-copy), sanitized string if dirty
Sanitized data with all strings cleaned of problematic characters
"""
if not text:
return text
# Fast path: Check if sanitization is needed using C-level regex search
if not _SURROGATE_PATTERN.search(text):
return text # Zero-copy for clean strings - most common case
# Slow path: Remove problematic characters using C-level regex substitution
return _SURROGATE_PATTERN.sub("", text)
class SanitizingJSONEncoder(json.JSONEncoder):
"""
Custom JSON encoder that sanitizes data during serialization.
This encoder cleans strings during the encoding process without creating
a full copy of the data structure, making it memory-efficient for large datasets.
"""
def encode(self, o):
"""Override encode method to handle simple string cases"""
if isinstance(o, str):
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
return super().encode(o)
def iterencode(self, o, _one_shot=False):
"""
Override iterencode to sanitize strings during serialization.
This is the core method that handles complex nested structures.
"""
# Preprocess: sanitize all strings in the object
sanitized = self._sanitize_for_encoding(o)
# Call parent's iterencode with sanitized data
for chunk in super().iterencode(sanitized, _one_shot):
yield chunk
def _sanitize_for_encoding(self, obj):
"""
Recursively sanitize strings in an object.
Creates new objects only when necessary to avoid deep copies.
Args:
obj: Object to sanitize
Returns:
Sanitized object with cleaned strings
"""
if isinstance(obj, str):
return _sanitize_string_for_json(obj)
elif isinstance(obj, dict):
# Create new dict with sanitized keys and values
new_dict = {}
for k, v in obj.items():
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
clean_v = self._sanitize_for_encoding(v)
new_dict[clean_k] = clean_v
return new_dict
elif isinstance(obj, (list, tuple)):
# Sanitize list/tuple elements
cleaned = [self._sanitize_for_encoding(item) for item in obj]
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
else:
# Numbers, booleans, None, etc. remain unchanged
return obj
if isinstance(data, dict):
return {k: _sanitize_json_data(v) for k, v in data.items()}
elif isinstance(data, list):
return [_sanitize_json_data(item) for item in data]
elif isinstance(data, str):
return sanitize_text_for_encoding(data, replacement_char="")
else:
return data
def write_json(json_obj, file_name):
"""
Write JSON data to file with optimized sanitization strategy.
This function uses a two-stage approach:
1. Fast path: Try direct serialization (works for clean data ~99% of time)
2. Slow path: Use custom encoder that sanitizes during serialization
The custom encoder approach avoids creating a deep copy of the data,
making it memory-efficient. When sanitization occurs, the caller should
reload the cleaned data from the file to update shared memory.
Args:
json_obj: Object to serialize (may be a shallow copy from shared memory)
file_name: Output file path
Returns:
bool: True if sanitization was applied (caller should reload data),
False if direct write succeeded (no reload needed)
"""
try:
# Strategy 1: Fast path - try direct serialization
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
return False # No sanitization needed, no reload required
except (UnicodeEncodeError, UnicodeDecodeError) as e:
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
# Sanitize data before writing to prevent UTF-8 encoding errors
sanitized_obj = _sanitize_json_data(json_obj)
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
logger.info(f"JSON sanitization applied during write: {file_name}")
return True # Sanitization applied, reload recommended
json.dump(sanitized_obj, f, indent=2, ensure_ascii=False)
class TokenizerInterface(Protocol):