This commit is contained in:
Raphaël MANSUY 2025-12-04 19:15:02 +08:00
parent 93778770ab
commit 4ec5073aaa

View file

@ -1,8 +1,6 @@
from __future__ import annotations from __future__ import annotations
import weakref import weakref
import sys
import asyncio import asyncio
import html import html
import csv import csv
@ -42,35 +40,6 @@ from lightrag.constants import (
SOURCE_IDS_LIMIT_METHOD_FIFO, SOURCE_IDS_LIMIT_METHOD_FIFO,
) )
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
class SafeStreamHandler(logging.StreamHandler):
"""StreamHandler that gracefully handles closed streams during shutdown.
This handler prevents "ValueError: I/O operation on closed file" errors
that can occur when pytest or other test frameworks close stdout/stderr
before Python's logging cleanup runs.
"""
def flush(self):
"""Flush the stream, ignoring errors if the stream is closed."""
try:
super().flush()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
def close(self):
"""Close the handler, ignoring errors if the stream is already closed."""
try:
super().close()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
# Initialize logger with basic configuration # Initialize logger with basic configuration
logger = logging.getLogger("lightrag") logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root logger logger.propagate = False # prevent log message send to root logger
@ -78,7 +47,7 @@ logger.setLevel(logging.INFO)
# Add console handler if no handlers exist # Add console handler if no handlers exist
if not logger.handlers: if not logger.handlers:
console_handler = SafeStreamHandler() console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)s: %(message)s") formatter = logging.Formatter("%(levelname)s: %(message)s")
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
@ -87,33 +56,6 @@ if not logger.handlers:
# Set httpx logging level to WARNING # Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
def _patch_ascii_colors_console_handler() -> None:
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
try:
from ascii_colors import ConsoleHandler
except ImportError:
return
if getattr(ConsoleHandler, "_lightrag_patched", False):
return
original_handle_error = ConsoleHandler.handle_error
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
exc_type, _, _ = sys.exc_info()
if exc_type in (ValueError, OSError) and "close" in message.lower():
return
original_handle_error(self, message)
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
_patch_ascii_colors_console_handler()
# Global import for pypinyin with startup-time logging # Global import for pypinyin with startup-time logging
try: try:
import pypinyin import pypinyin
@ -341,8 +283,8 @@ def setup_logger(
logger_instance.handlers = [] # Clear existing handlers logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False logger_instance.propagate = False
# Add console handler with safe stream handling # Add console handler
console_handler = SafeStreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(simple_formatter) console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level) console_handler.setLevel(level)
logger_instance.addHandler(console_handler) logger_instance.addHandler(console_handler)
@ -408,20 +350,9 @@ class TaskState:
@dataclass @dataclass
class EmbeddingFunc: class EmbeddingFunc:
"""Embedding function wrapper with dimension validation
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
This class should not be wrapped multiple times.
Args:
embedding_dim: Expected dimension of the embeddings
func: The actual embedding function to wrap
max_token_size: Optional token limit for the embedding model
send_dimensions: Whether to inject embedding_dim as a keyword argument
"""
embedding_dim: int embedding_dim: int
func: callable func: callable
max_token_size: int | None = None # Token limit for the embedding model max_token_size: int | None = None # deprecated keep it for compatible only
send_dimensions: bool = ( send_dimensions: bool = (
False # Control whether to send embedding_dim to the function False # Control whether to send embedding_dim to the function
) )
@ -445,32 +376,7 @@ class EmbeddingFunc:
# Inject embedding_dim from decorator # Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim kwargs["embedding_dim"] = self.embedding_dim
# Call the actual embedding function return await self.func(*args, **kwargs)
result = await self.func(*args, **kwargs)
# Validate embedding dimensions using total element count
total_elements = result.size # Total number of elements in the numpy array
expected_dim = self.embedding_dim
# Check if total elements can be evenly divided by embedding_dim
if total_elements % expected_dim != 0:
raise ValueError(
f"Embedding dimension mismatch detected: "
f"total elements ({total_elements}) cannot be evenly divided by "
f"expected dimension ({expected_dim}). "
)
# Optional: Verify vector count matches input text count
actual_vectors = total_elements // expected_dim
if args and isinstance(args[0], (list, tuple)):
expected_vectors = len(args[0])
if actual_vectors != expected_vectors:
raise ValueError(
f"Vector count mismatch: "
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
)
return result
def compute_args_hash(*args: Any) -> str: def compute_args_hash(*args: Any) -> str:
@ -1005,76 +911,7 @@ def priority_limit_async_func_call(
def wrap_embedding_func_with_attrs(**kwargs): def wrap_embedding_func_with_attrs(**kwargs):
"""Decorator to add embedding dimension and token limit attributes to embedding functions. """Wrap a function with attributes"""
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
that automatically handles dimension parameter injection and attribute management.
WARNING: DO NOT apply this decorator to wrapper functions that call other
decorated embedding functions. This will cause double decoration and parameter
injection conflicts.
Correct usage patterns:
1. Direct implementation (decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_embed(texts, embedding_dim=None):
# Direct implementation
return embeddings
```
2. Wrapper calling decorated function (DO NOT decorate wrapper):
```python
# my_embed is already decorated above
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
# Must call .func to access unwrapped implementation
return await my_embed.func(texts, **kwargs)
```
3. Wrapper calling decorated function (properly decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
# Calling .func avoids double decoration
return await my_embed.func(texts, **kwargs)
```
The decorated function becomes an EmbeddingFunc instance with:
- embedding_dim: The embedding dimension
- max_token_size: Maximum token limit (optional)
- func: The original unwrapped function (access via .func)
- __call__: Wrapper that injects embedding_dim parameter
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
Args:
embedding_dim: The dimension of embedding vectors
max_token_size: Maximum number of tokens (optional)
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
Returns:
A decorator that wraps the function as an EmbeddingFunc instance
Example of correct wrapper implementation:
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(...)
async def openai_embed(texts, ...):
# Base implementation
pass
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
async def azure_openai_embed(texts, ...):
# CRITICAL: Call .func to access unwrapped function
return await openai_embed.func(texts, ...) # ✅ Correct
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
```
"""
def final_decro(func) -> EmbeddingFunc: def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func) new_func = EmbeddingFunc(**kwargs, func=func)
@ -1090,123 +927,30 @@ def load_json(file_name):
return json.load(f) return json.load(f)
def _sanitize_string_for_json(text: str) -> str: def _sanitize_json_data(data: Any) -> Any:
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization. """Recursively sanitize all string values in data structure for safe UTF-8 encoding
Uses regex for optimal performance with zero-copy optimization for clean strings.
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
Args: Args:
text: String to sanitize data: Data to sanitize (dict, list, str, or other types)
Returns: Returns:
Original string if clean (zero-copy), sanitized string if dirty Sanitized data with all strings cleaned of problematic characters
""" """
if not text: if isinstance(data, dict):
return text return {k: _sanitize_json_data(v) for k, v in data.items()}
elif isinstance(data, list):
# Fast path: Check if sanitization is needed using C-level regex search return [_sanitize_json_data(item) for item in data]
if not _SURROGATE_PATTERN.search(text): elif isinstance(data, str):
return text # Zero-copy for clean strings - most common case return sanitize_text_for_encoding(data, replacement_char="")
else:
# Slow path: Remove problematic characters using C-level regex substitution return data
return _SURROGATE_PATTERN.sub("", text)
class SanitizingJSONEncoder(json.JSONEncoder):
"""
Custom JSON encoder that sanitizes data during serialization.
This encoder cleans strings during the encoding process without creating
a full copy of the data structure, making it memory-efficient for large datasets.
"""
def encode(self, o):
"""Override encode method to handle simple string cases"""
if isinstance(o, str):
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
return super().encode(o)
def iterencode(self, o, _one_shot=False):
"""
Override iterencode to sanitize strings during serialization.
This is the core method that handles complex nested structures.
"""
# Preprocess: sanitize all strings in the object
sanitized = self._sanitize_for_encoding(o)
# Call parent's iterencode with sanitized data
for chunk in super().iterencode(sanitized, _one_shot):
yield chunk
def _sanitize_for_encoding(self, obj):
"""
Recursively sanitize strings in an object.
Creates new objects only when necessary to avoid deep copies.
Args:
obj: Object to sanitize
Returns:
Sanitized object with cleaned strings
"""
if isinstance(obj, str):
return _sanitize_string_for_json(obj)
elif isinstance(obj, dict):
# Create new dict with sanitized keys and values
new_dict = {}
for k, v in obj.items():
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
clean_v = self._sanitize_for_encoding(v)
new_dict[clean_k] = clean_v
return new_dict
elif isinstance(obj, (list, tuple)):
# Sanitize list/tuple elements
cleaned = [self._sanitize_for_encoding(item) for item in obj]
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
else:
# Numbers, booleans, None, etc. remain unchanged
return obj
def write_json(json_obj, file_name): def write_json(json_obj, file_name):
""" # Sanitize data before writing to prevent UTF-8 encoding errors
Write JSON data to file with optimized sanitization strategy. sanitized_obj = _sanitize_json_data(json_obj)
This function uses a two-stage approach:
1. Fast path: Try direct serialization (works for clean data ~99% of time)
2. Slow path: Use custom encoder that sanitizes during serialization
The custom encoder approach avoids creating a deep copy of the data,
making it memory-efficient. When sanitization occurs, the caller should
reload the cleaned data from the file to update shared memory.
Args:
json_obj: Object to serialize (may be a shallow copy from shared memory)
file_name: Output file path
Returns:
bool: True if sanitization was applied (caller should reload data),
False if direct write succeeded (no reload needed)
"""
try:
# Strategy 1: Fast path - try direct serialization
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
return False # No sanitization needed, no reload required
except (UnicodeEncodeError, UnicodeDecodeError) as e:
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
with open(file_name, "w", encoding="utf-8") as f: with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder) json.dump(sanitized_obj, f, indent=2, ensure_ascii=False)
logger.info(f"JSON sanitization applied during write: {file_name}")
return True # Sanitization applied, reload recommended
class TokenizerInterface(Protocol): class TokenizerInterface(Protocol):