cherry-pick 23cbb9c9
This commit is contained in:
parent
93778770ab
commit
4ec5073aaa
1 changed files with 21 additions and 277 deletions
|
|
@ -1,8 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import csv
|
import csv
|
||||||
|
|
@ -42,35 +40,6 @@ from lightrag.constants import (
|
||||||
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
|
||||||
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
|
||||||
|
|
||||||
|
|
||||||
class SafeStreamHandler(logging.StreamHandler):
|
|
||||||
"""StreamHandler that gracefully handles closed streams during shutdown.
|
|
||||||
|
|
||||||
This handler prevents "ValueError: I/O operation on closed file" errors
|
|
||||||
that can occur when pytest or other test frameworks close stdout/stderr
|
|
||||||
before Python's logging cleanup runs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
"""Flush the stream, ignoring errors if the stream is closed."""
|
|
||||||
try:
|
|
||||||
super().flush()
|
|
||||||
except (ValueError, OSError):
|
|
||||||
# Stream is closed or otherwise unavailable, silently ignore
|
|
||||||
pass
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Close the handler, ignoring errors if the stream is already closed."""
|
|
||||||
try:
|
|
||||||
super().close()
|
|
||||||
except (ValueError, OSError):
|
|
||||||
# Stream is closed or otherwise unavailable, silently ignore
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize logger with basic configuration
|
# Initialize logger with basic configuration
|
||||||
logger = logging.getLogger("lightrag")
|
logger = logging.getLogger("lightrag")
|
||||||
logger.propagate = False # prevent log message send to root logger
|
logger.propagate = False # prevent log message send to root logger
|
||||||
|
|
@ -78,7 +47,7 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
# Add console handler if no handlers exist
|
# Add console handler if no handlers exist
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
console_handler = SafeStreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setLevel(logging.INFO)
|
console_handler.setLevel(logging.INFO)
|
||||||
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
|
|
@ -87,33 +56,6 @@ if not logger.handlers:
|
||||||
# Set httpx logging level to WARNING
|
# Set httpx logging level to WARNING
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
def _patch_ascii_colors_console_handler() -> None:
|
|
||||||
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from ascii_colors import ConsoleHandler
|
|
||||||
except ImportError:
|
|
||||||
return
|
|
||||||
|
|
||||||
if getattr(ConsoleHandler, "_lightrag_patched", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
original_handle_error = ConsoleHandler.handle_error
|
|
||||||
|
|
||||||
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
|
|
||||||
exc_type, _, _ = sys.exc_info()
|
|
||||||
if exc_type in (ValueError, OSError) and "close" in message.lower():
|
|
||||||
return
|
|
||||||
original_handle_error(self, message)
|
|
||||||
|
|
||||||
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
|
|
||||||
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
|
|
||||||
_patch_ascii_colors_console_handler()
|
|
||||||
|
|
||||||
|
|
||||||
# Global import for pypinyin with startup-time logging
|
# Global import for pypinyin with startup-time logging
|
||||||
try:
|
try:
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
|
@ -341,8 +283,8 @@ def setup_logger(
|
||||||
logger_instance.handlers = [] # Clear existing handlers
|
logger_instance.handlers = [] # Clear existing handlers
|
||||||
logger_instance.propagate = False
|
logger_instance.propagate = False
|
||||||
|
|
||||||
# Add console handler with safe stream handling
|
# Add console handler
|
||||||
console_handler = SafeStreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(simple_formatter)
|
console_handler.setFormatter(simple_formatter)
|
||||||
console_handler.setLevel(level)
|
console_handler.setLevel(level)
|
||||||
logger_instance.addHandler(console_handler)
|
logger_instance.addHandler(console_handler)
|
||||||
|
|
@ -408,20 +350,9 @@ class TaskState:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
"""Embedding function wrapper with dimension validation
|
|
||||||
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
|
||||||
This class should not be wrapped multiple times.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding_dim: Expected dimension of the embeddings
|
|
||||||
func: The actual embedding function to wrap
|
|
||||||
max_token_size: Optional token limit for the embedding model
|
|
||||||
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
|
||||||
"""
|
|
||||||
|
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
func: callable
|
func: callable
|
||||||
max_token_size: int | None = None # Token limit for the embedding model
|
max_token_size: int | None = None # deprecated keep it for compatible only
|
||||||
send_dimensions: bool = (
|
send_dimensions: bool = (
|
||||||
False # Control whether to send embedding_dim to the function
|
False # Control whether to send embedding_dim to the function
|
||||||
)
|
)
|
||||||
|
|
@ -445,32 +376,7 @@ class EmbeddingFunc:
|
||||||
# Inject embedding_dim from decorator
|
# Inject embedding_dim from decorator
|
||||||
kwargs["embedding_dim"] = self.embedding_dim
|
kwargs["embedding_dim"] = self.embedding_dim
|
||||||
|
|
||||||
# Call the actual embedding function
|
return await self.func(*args, **kwargs)
|
||||||
result = await self.func(*args, **kwargs)
|
|
||||||
|
|
||||||
# Validate embedding dimensions using total element count
|
|
||||||
total_elements = result.size # Total number of elements in the numpy array
|
|
||||||
expected_dim = self.embedding_dim
|
|
||||||
|
|
||||||
# Check if total elements can be evenly divided by embedding_dim
|
|
||||||
if total_elements % expected_dim != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Embedding dimension mismatch detected: "
|
|
||||||
f"total elements ({total_elements}) cannot be evenly divided by "
|
|
||||||
f"expected dimension ({expected_dim}). "
|
|
||||||
)
|
|
||||||
|
|
||||||
# Optional: Verify vector count matches input text count
|
|
||||||
actual_vectors = total_elements // expected_dim
|
|
||||||
if args and isinstance(args[0], (list, tuple)):
|
|
||||||
expected_vectors = len(args[0])
|
|
||||||
if actual_vectors != expected_vectors:
|
|
||||||
raise ValueError(
|
|
||||||
f"Vector count mismatch: "
|
|
||||||
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args: Any) -> str:
|
def compute_args_hash(*args: Any) -> str:
|
||||||
|
|
@ -1005,76 +911,7 @@ def priority_limit_async_func_call(
|
||||||
|
|
||||||
|
|
||||||
def wrap_embedding_func_with_attrs(**kwargs):
|
def wrap_embedding_func_with_attrs(**kwargs):
|
||||||
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
|
"""Wrap a function with attributes"""
|
||||||
|
|
||||||
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
|
|
||||||
that automatically handles dimension parameter injection and attribute management.
|
|
||||||
|
|
||||||
WARNING: DO NOT apply this decorator to wrapper functions that call other
|
|
||||||
decorated embedding functions. This will cause double decoration and parameter
|
|
||||||
injection conflicts.
|
|
||||||
|
|
||||||
Correct usage patterns:
|
|
||||||
|
|
||||||
1. Direct implementation (decorated):
|
|
||||||
```python
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
|
||||||
async def my_embed(texts, embedding_dim=None):
|
|
||||||
# Direct implementation
|
|
||||||
return embeddings
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Wrapper calling decorated function (DO NOT decorate wrapper):
|
|
||||||
```python
|
|
||||||
# my_embed is already decorated above
|
|
||||||
|
|
||||||
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
|
|
||||||
# Must call .func to access unwrapped implementation
|
|
||||||
return await my_embed.func(texts, **kwargs)
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Wrapper calling decorated function (properly decorated):
|
|
||||||
```python
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
|
||||||
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
|
|
||||||
# Calling .func avoids double decoration
|
|
||||||
return await my_embed.func(texts, **kwargs)
|
|
||||||
```
|
|
||||||
|
|
||||||
The decorated function becomes an EmbeddingFunc instance with:
|
|
||||||
- embedding_dim: The embedding dimension
|
|
||||||
- max_token_size: Maximum token limit (optional)
|
|
||||||
- func: The original unwrapped function (access via .func)
|
|
||||||
- __call__: Wrapper that injects embedding_dim parameter
|
|
||||||
|
|
||||||
Double decoration causes:
|
|
||||||
- Double injection of embedding_dim parameter
|
|
||||||
- Incorrect parameter passing to the underlying implementation
|
|
||||||
- Runtime errors due to parameter conflicts
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding_dim: The dimension of embedding vectors
|
|
||||||
max_token_size: Maximum number of tokens (optional)
|
|
||||||
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A decorator that wraps the function as an EmbeddingFunc instance
|
|
||||||
|
|
||||||
Example of correct wrapper implementation:
|
|
||||||
```python
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
|
||||||
@retry(...)
|
|
||||||
async def openai_embed(texts, ...):
|
|
||||||
# Base implementation
|
|
||||||
pass
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
|
|
||||||
async def azure_openai_embed(texts, ...):
|
|
||||||
# CRITICAL: Call .func to access unwrapped function
|
|
||||||
return await openai_embed.func(texts, ...) # ✅ Correct
|
|
||||||
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def final_decro(func) -> EmbeddingFunc:
|
def final_decro(func) -> EmbeddingFunc:
|
||||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||||
|
|
@ -1090,123 +927,30 @@ def load_json(file_name):
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_string_for_json(text: str) -> str:
|
def _sanitize_json_data(data: Any) -> Any:
|
||||||
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
|
||||||
|
|
||||||
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
|
||||||
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: String to sanitize
|
data: Data to sanitize (dict, list, str, or other types)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Original string if clean (zero-copy), sanitized string if dirty
|
Sanitized data with all strings cleaned of problematic characters
|
||||||
"""
|
"""
|
||||||
if not text:
|
if isinstance(data, dict):
|
||||||
return text
|
return {k: _sanitize_json_data(v) for k, v in data.items()}
|
||||||
|
elif isinstance(data, list):
|
||||||
# Fast path: Check if sanitization is needed using C-level regex search
|
return [_sanitize_json_data(item) for item in data]
|
||||||
if not _SURROGATE_PATTERN.search(text):
|
elif isinstance(data, str):
|
||||||
return text # Zero-copy for clean strings - most common case
|
return sanitize_text_for_encoding(data, replacement_char="")
|
||||||
|
else:
|
||||||
# Slow path: Remove problematic characters using C-level regex substitution
|
return data
|
||||||
return _SURROGATE_PATTERN.sub("", text)
|
|
||||||
|
|
||||||
|
|
||||||
class SanitizingJSONEncoder(json.JSONEncoder):
|
|
||||||
"""
|
|
||||||
Custom JSON encoder that sanitizes data during serialization.
|
|
||||||
|
|
||||||
This encoder cleans strings during the encoding process without creating
|
|
||||||
a full copy of the data structure, making it memory-efficient for large datasets.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def encode(self, o):
|
|
||||||
"""Override encode method to handle simple string cases"""
|
|
||||||
if isinstance(o, str):
|
|
||||||
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
|
|
||||||
return super().encode(o)
|
|
||||||
|
|
||||||
def iterencode(self, o, _one_shot=False):
|
|
||||||
"""
|
|
||||||
Override iterencode to sanitize strings during serialization.
|
|
||||||
This is the core method that handles complex nested structures.
|
|
||||||
"""
|
|
||||||
# Preprocess: sanitize all strings in the object
|
|
||||||
sanitized = self._sanitize_for_encoding(o)
|
|
||||||
|
|
||||||
# Call parent's iterencode with sanitized data
|
|
||||||
for chunk in super().iterencode(sanitized, _one_shot):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
def _sanitize_for_encoding(self, obj):
|
|
||||||
"""
|
|
||||||
Recursively sanitize strings in an object.
|
|
||||||
Creates new objects only when necessary to avoid deep copies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: Object to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized object with cleaned strings
|
|
||||||
"""
|
|
||||||
if isinstance(obj, str):
|
|
||||||
return _sanitize_string_for_json(obj)
|
|
||||||
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
# Create new dict with sanitized keys and values
|
|
||||||
new_dict = {}
|
|
||||||
for k, v in obj.items():
|
|
||||||
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
|
|
||||||
clean_v = self._sanitize_for_encoding(v)
|
|
||||||
new_dict[clean_k] = clean_v
|
|
||||||
return new_dict
|
|
||||||
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
# Sanitize list/tuple elements
|
|
||||||
cleaned = [self._sanitize_for_encoding(item) for item in obj]
|
|
||||||
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Numbers, booleans, None, etc. remain unchanged
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def write_json(json_obj, file_name):
|
def write_json(json_obj, file_name):
|
||||||
"""
|
# Sanitize data before writing to prevent UTF-8 encoding errors
|
||||||
Write JSON data to file with optimized sanitization strategy.
|
sanitized_obj = _sanitize_json_data(json_obj)
|
||||||
|
|
||||||
This function uses a two-stage approach:
|
|
||||||
1. Fast path: Try direct serialization (works for clean data ~99% of time)
|
|
||||||
2. Slow path: Use custom encoder that sanitizes during serialization
|
|
||||||
|
|
||||||
The custom encoder approach avoids creating a deep copy of the data,
|
|
||||||
making it memory-efficient. When sanitization occurs, the caller should
|
|
||||||
reload the cleaned data from the file to update shared memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_obj: Object to serialize (may be a shallow copy from shared memory)
|
|
||||||
file_name: Output file path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if sanitization was applied (caller should reload data),
|
|
||||||
False if direct write succeeded (no reload needed)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Strategy 1: Fast path - try direct serialization
|
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
|
||||||
return False # No sanitization needed, no reload required
|
|
||||||
|
|
||||||
except (UnicodeEncodeError, UnicodeDecodeError) as e:
|
|
||||||
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
|
|
||||||
|
|
||||||
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
|
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
|
json.dump(sanitized_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
logger.info(f"JSON sanitization applied during write: {file_name}")
|
|
||||||
return True # Sanitization applied, reload recommended
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerInterface(Protocol):
|
class TokenizerInterface(Protocol):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue