cherry-pick 23cbb9c9
This commit is contained in:
parent
93778770ab
commit
4ec5073aaa
1 changed files with 21 additions and 277 deletions
|
|
@ -1,8 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import weakref
|
||||
|
||||
import sys
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import csv
|
||||
|
|
@ -42,35 +40,6 @@ from lightrag.constants import (
|
|||
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
||||
)
|
||||
|
||||
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||
|
||||
|
||||
class SafeStreamHandler(logging.StreamHandler):
|
||||
"""StreamHandler that gracefully handles closed streams during shutdown.
|
||||
|
||||
This handler prevents "ValueError: I/O operation on closed file" errors
|
||||
that can occur when pytest or other test frameworks close stdout/stderr
|
||||
before Python's logging cleanup runs.
|
||||
"""
|
||||
|
||||
def flush(self):
|
||||
"""Flush the stream, ignoring errors if the stream is closed."""
|
||||
try:
|
||||
super().flush()
|
||||
except (ValueError, OSError):
|
||||
# Stream is closed or otherwise unavailable, silently ignore
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""Close the handler, ignoring errors if the stream is already closed."""
|
||||
try:
|
||||
super().close()
|
||||
except (ValueError, OSError):
|
||||
# Stream is closed or otherwise unavailable, silently ignore
|
||||
pass
|
||||
|
||||
|
||||
# Initialize logger with basic configuration
|
||||
logger = logging.getLogger("lightrag")
|
||||
logger.propagate = False # prevent log message send to root logger
|
||||
|
|
@ -78,7 +47,7 @@ logger.setLevel(logging.INFO)
|
|||
|
||||
# Add console handler if no handlers exist
|
||||
if not logger.handlers:
|
||||
console_handler = SafeStreamHandler()
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
|
|
@ -87,33 +56,6 @@ if not logger.handlers:
|
|||
# Set httpx logging level to WARNING
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def _patch_ascii_colors_console_handler() -> None:
|
||||
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
|
||||
|
||||
try:
|
||||
from ascii_colors import ConsoleHandler
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
if getattr(ConsoleHandler, "_lightrag_patched", False):
|
||||
return
|
||||
|
||||
original_handle_error = ConsoleHandler.handle_error
|
||||
|
||||
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type in (ValueError, OSError) and "close" in message.lower():
|
||||
return
|
||||
original_handle_error(self, message)
|
||||
|
||||
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
|
||||
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_patch_ascii_colors_console_handler()
|
||||
|
||||
|
||||
# Global import for pypinyin with startup-time logging
|
||||
try:
|
||||
import pypinyin
|
||||
|
|
@ -341,8 +283,8 @@ def setup_logger(
|
|||
logger_instance.handlers = [] # Clear existing handlers
|
||||
logger_instance.propagate = False
|
||||
|
||||
# Add console handler with safe stream handling
|
||||
console_handler = SafeStreamHandler()
|
||||
# Add console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(simple_formatter)
|
||||
console_handler.setLevel(level)
|
||||
logger_instance.addHandler(console_handler)
|
||||
|
|
@ -408,20 +350,9 @@ class TaskState:
|
|||
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
"""Embedding function wrapper with dimension validation
|
||||
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
||||
This class should not be wrapped multiple times.
|
||||
|
||||
Args:
|
||||
embedding_dim: Expected dimension of the embeddings
|
||||
func: The actual embedding function to wrap
|
||||
max_token_size: Optional token limit for the embedding model
|
||||
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
||||
"""
|
||||
|
||||
embedding_dim: int
|
||||
func: callable
|
||||
max_token_size: int | None = None # Token limit for the embedding model
|
||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
||||
send_dimensions: bool = (
|
||||
False # Control whether to send embedding_dim to the function
|
||||
)
|
||||
|
|
@ -445,32 +376,7 @@ class EmbeddingFunc:
|
|||
# Inject embedding_dim from decorator
|
||||
kwargs["embedding_dim"] = self.embedding_dim
|
||||
|
||||
# Call the actual embedding function
|
||||
result = await self.func(*args, **kwargs)
|
||||
|
||||
# Validate embedding dimensions using total element count
|
||||
total_elements = result.size # Total number of elements in the numpy array
|
||||
expected_dim = self.embedding_dim
|
||||
|
||||
# Check if total elements can be evenly divided by embedding_dim
|
||||
if total_elements % expected_dim != 0:
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch detected: "
|
||||
f"total elements ({total_elements}) cannot be evenly divided by "
|
||||
f"expected dimension ({expected_dim}). "
|
||||
)
|
||||
|
||||
# Optional: Verify vector count matches input text count
|
||||
actual_vectors = total_elements // expected_dim
|
||||
if args and isinstance(args[0], (list, tuple)):
|
||||
expected_vectors = len(args[0])
|
||||
if actual_vectors != expected_vectors:
|
||||
raise ValueError(
|
||||
f"Vector count mismatch: "
|
||||
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
||||
)
|
||||
|
||||
return result
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def compute_args_hash(*args: Any) -> str:
|
||||
|
|
@ -1005,76 +911,7 @@ def priority_limit_async_func_call(
|
|||
|
||||
|
||||
def wrap_embedding_func_with_attrs(**kwargs):
|
||||
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
|
||||
|
||||
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
|
||||
that automatically handles dimension parameter injection and attribute management.
|
||||
|
||||
WARNING: DO NOT apply this decorator to wrapper functions that call other
|
||||
decorated embedding functions. This will cause double decoration and parameter
|
||||
injection conflicts.
|
||||
|
||||
Correct usage patterns:
|
||||
|
||||
1. Direct implementation (decorated):
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
async def my_embed(texts, embedding_dim=None):
|
||||
# Direct implementation
|
||||
return embeddings
|
||||
```
|
||||
|
||||
2. Wrapper calling decorated function (DO NOT decorate wrapper):
|
||||
```python
|
||||
# my_embed is already decorated above
|
||||
|
||||
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
|
||||
# Must call .func to access unwrapped implementation
|
||||
return await my_embed.func(texts, **kwargs)
|
||||
```
|
||||
|
||||
3. Wrapper calling decorated function (properly decorated):
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
|
||||
# Calling .func avoids double decoration
|
||||
return await my_embed.func(texts, **kwargs)
|
||||
```
|
||||
|
||||
The decorated function becomes an EmbeddingFunc instance with:
|
||||
- embedding_dim: The embedding dimension
|
||||
- max_token_size: Maximum token limit (optional)
|
||||
- func: The original unwrapped function (access via .func)
|
||||
- __call__: Wrapper that injects embedding_dim parameter
|
||||
|
||||
Double decoration causes:
|
||||
- Double injection of embedding_dim parameter
|
||||
- Incorrect parameter passing to the underlying implementation
|
||||
- Runtime errors due to parameter conflicts
|
||||
|
||||
Args:
|
||||
embedding_dim: The dimension of embedding vectors
|
||||
max_token_size: Maximum number of tokens (optional)
|
||||
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
|
||||
|
||||
Returns:
|
||||
A decorator that wraps the function as an EmbeddingFunc instance
|
||||
|
||||
Example of correct wrapper implementation:
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(...)
|
||||
async def openai_embed(texts, ...):
|
||||
# Base implementation
|
||||
pass
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
|
||||
async def azure_openai_embed(texts, ...):
|
||||
# CRITICAL: Call .func to access unwrapped function
|
||||
return await openai_embed.func(texts, ...) # ✅ Correct
|
||||
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
|
||||
```
|
||||
"""
|
||||
"""Wrap a function with attributes"""
|
||||
|
||||
def final_decro(func) -> EmbeddingFunc:
|
||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||
|
|
@ -1090,123 +927,30 @@ def load_json(file_name):
|
|||
return json.load(f)
|
||||
|
||||
|
||||
def _sanitize_string_for_json(text: str) -> str:
|
||||
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
||||
|
||||
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
||||
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
||||
def _sanitize_json_data(data: Any) -> Any:
|
||||
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
|
||||
|
||||
Args:
|
||||
text: String to sanitize
|
||||
data: Data to sanitize (dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Original string if clean (zero-copy), sanitized string if dirty
|
||||
Sanitized data with all strings cleaned of problematic characters
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# Fast path: Check if sanitization is needed using C-level regex search
|
||||
if not _SURROGATE_PATTERN.search(text):
|
||||
return text # Zero-copy for clean strings - most common case
|
||||
|
||||
# Slow path: Remove problematic characters using C-level regex substitution
|
||||
return _SURROGATE_PATTERN.sub("", text)
|
||||
|
||||
|
||||
class SanitizingJSONEncoder(json.JSONEncoder):
|
||||
"""
|
||||
Custom JSON encoder that sanitizes data during serialization.
|
||||
|
||||
This encoder cleans strings during the encoding process without creating
|
||||
a full copy of the data structure, making it memory-efficient for large datasets.
|
||||
"""
|
||||
|
||||
def encode(self, o):
|
||||
"""Override encode method to handle simple string cases"""
|
||||
if isinstance(o, str):
|
||||
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
|
||||
return super().encode(o)
|
||||
|
||||
def iterencode(self, o, _one_shot=False):
|
||||
"""
|
||||
Override iterencode to sanitize strings during serialization.
|
||||
This is the core method that handles complex nested structures.
|
||||
"""
|
||||
# Preprocess: sanitize all strings in the object
|
||||
sanitized = self._sanitize_for_encoding(o)
|
||||
|
||||
# Call parent's iterencode with sanitized data
|
||||
for chunk in super().iterencode(sanitized, _one_shot):
|
||||
yield chunk
|
||||
|
||||
def _sanitize_for_encoding(self, obj):
|
||||
"""
|
||||
Recursively sanitize strings in an object.
|
||||
Creates new objects only when necessary to avoid deep copies.
|
||||
|
||||
Args:
|
||||
obj: Object to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized object with cleaned strings
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return _sanitize_string_for_json(obj)
|
||||
|
||||
elif isinstance(obj, dict):
|
||||
# Create new dict with sanitized keys and values
|
||||
new_dict = {}
|
||||
for k, v in obj.items():
|
||||
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
|
||||
clean_v = self._sanitize_for_encoding(v)
|
||||
new_dict[clean_k] = clean_v
|
||||
return new_dict
|
||||
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
# Sanitize list/tuple elements
|
||||
cleaned = [self._sanitize_for_encoding(item) for item in obj]
|
||||
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
|
||||
|
||||
else:
|
||||
# Numbers, booleans, None, etc. remain unchanged
|
||||
return obj
|
||||
if isinstance(data, dict):
|
||||
return {k: _sanitize_json_data(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [_sanitize_json_data(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
return sanitize_text_for_encoding(data, replacement_char="")
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def write_json(json_obj, file_name):
|
||||
"""
|
||||
Write JSON data to file with optimized sanitization strategy.
|
||||
|
||||
This function uses a two-stage approach:
|
||||
1. Fast path: Try direct serialization (works for clean data ~99% of time)
|
||||
2. Slow path: Use custom encoder that sanitizes during serialization
|
||||
|
||||
The custom encoder approach avoids creating a deep copy of the data,
|
||||
making it memory-efficient. When sanitization occurs, the caller should
|
||||
reload the cleaned data from the file to update shared memory.
|
||||
|
||||
Args:
|
||||
json_obj: Object to serialize (may be a shallow copy from shared memory)
|
||||
file_name: Output file path
|
||||
|
||||
Returns:
|
||||
bool: True if sanitization was applied (caller should reload data),
|
||||
False if direct write succeeded (no reload needed)
|
||||
"""
|
||||
try:
|
||||
# Strategy 1: Fast path - try direct serialization
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||
return False # No sanitization needed, no reload required
|
||||
|
||||
except (UnicodeEncodeError, UnicodeDecodeError) as e:
|
||||
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
|
||||
|
||||
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
|
||||
# Sanitize data before writing to prevent UTF-8 encoding errors
|
||||
sanitized_obj = _sanitize_json_data(json_obj)
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
|
||||
|
||||
logger.info(f"JSON sanitization applied during write: {file_name}")
|
||||
return True # Sanitization applied, reload recommended
|
||||
json.dump(sanitized_obj, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
class TokenizerInterface(Protocol):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue