cherry-pick 90f52acf
This commit is contained in:
parent
5f36666ac1
commit
0c46370940
1 changed files with 156 additions and 3 deletions
|
|
@ -56,6 +56,9 @@ if not logger.handlers:
|
||||||
# Set httpx logging level to WARNING
|
# Set httpx logging level to WARNING
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||||
|
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||||
|
|
||||||
# Global import for pypinyin with startup-time logging
|
# Global import for pypinyin with startup-time logging
|
||||||
try:
|
try:
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
|
@ -350,9 +353,20 @@ class TaskState:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
|
"""Embedding function wrapper with dimension validation
|
||||||
|
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
||||||
|
This class should not be wrapped multiple times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_dim: Expected dimension of the embeddings
|
||||||
|
func: The actual embedding function to wrap
|
||||||
|
max_token_size: Optional token limit for the embedding model
|
||||||
|
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
||||||
|
"""
|
||||||
|
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
func: callable
|
func: callable
|
||||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
max_token_size: int | None = None # Token limit for the embedding model
|
||||||
send_dimensions: bool = (
|
send_dimensions: bool = (
|
||||||
False # Control whether to send embedding_dim to the function
|
False # Control whether to send embedding_dim to the function
|
||||||
)
|
)
|
||||||
|
|
@ -376,7 +390,32 @@ class EmbeddingFunc:
|
||||||
# Inject embedding_dim from decorator
|
# Inject embedding_dim from decorator
|
||||||
kwargs["embedding_dim"] = self.embedding_dim
|
kwargs["embedding_dim"] = self.embedding_dim
|
||||||
|
|
||||||
return await self.func(*args, **kwargs)
|
# Call the actual embedding function
|
||||||
|
result = await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Validate embedding dimensions using total element count
|
||||||
|
total_elements = result.size # Total number of elements in the numpy array
|
||||||
|
expected_dim = self.embedding_dim
|
||||||
|
|
||||||
|
# Check if total elements can be evenly divided by embedding_dim
|
||||||
|
if total_elements % expected_dim != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding dimension mismatch detected: "
|
||||||
|
f"total elements ({total_elements}) cannot be evenly divided by "
|
||||||
|
f"expected dimension ({expected_dim}). "
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional: Verify vector count matches input text count
|
||||||
|
actual_vectors = total_elements // expected_dim
|
||||||
|
if args and isinstance(args[0], (list, tuple)):
|
||||||
|
expected_vectors = len(args[0])
|
||||||
|
if actual_vectors != expected_vectors:
|
||||||
|
raise ValueError(
|
||||||
|
f"Vector count mismatch: "
|
||||||
|
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args: Any) -> str:
|
def compute_args_hash(*args: Any) -> str:
|
||||||
|
|
@ -927,9 +966,123 @@ def load_json(file_name):
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_string_for_json(text: str) -> str:
|
||||||
|
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
||||||
|
|
||||||
|
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
||||||
|
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Original string if clean (zero-copy), sanitized string if dirty
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Fast path: Check if sanitization is needed using C-level regex search
|
||||||
|
if not _SURROGATE_PATTERN.search(text):
|
||||||
|
return text # Zero-copy for clean strings - most common case
|
||||||
|
|
||||||
|
# Slow path: Remove problematic characters using C-level regex substitution
|
||||||
|
return _SURROGATE_PATTERN.sub("", text)
|
||||||
|
|
||||||
|
|
||||||
|
class SanitizingJSONEncoder(json.JSONEncoder):
|
||||||
|
"""
|
||||||
|
Custom JSON encoder that sanitizes data during serialization.
|
||||||
|
|
||||||
|
This encoder cleans strings during the encoding process without creating
|
||||||
|
a full copy of the data structure, making it memory-efficient for large datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def encode(self, o):
|
||||||
|
"""Override encode method to handle simple string cases"""
|
||||||
|
if isinstance(o, str):
|
||||||
|
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
|
||||||
|
return super().encode(o)
|
||||||
|
|
||||||
|
def iterencode(self, o, _one_shot=False):
|
||||||
|
"""
|
||||||
|
Override iterencode to sanitize strings during serialization.
|
||||||
|
This is the core method that handles complex nested structures.
|
||||||
|
"""
|
||||||
|
# Preprocess: sanitize all strings in the object
|
||||||
|
sanitized = self._sanitize_for_encoding(o)
|
||||||
|
|
||||||
|
# Call parent's iterencode with sanitized data
|
||||||
|
for chunk in super().iterencode(sanitized, _one_shot):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _sanitize_for_encoding(self, obj):
|
||||||
|
"""
|
||||||
|
Recursively sanitize strings in an object.
|
||||||
|
Creates new objects only when necessary to avoid deep copies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized object with cleaned strings
|
||||||
|
"""
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return _sanitize_string_for_json(obj)
|
||||||
|
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
# Create new dict with sanitized keys and values
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in obj.items():
|
||||||
|
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
|
||||||
|
clean_v = self._sanitize_for_encoding(v)
|
||||||
|
new_dict[clean_k] = clean_v
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
# Sanitize list/tuple elements
|
||||||
|
cleaned = [self._sanitize_for_encoding(item) for item in obj]
|
||||||
|
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Numbers, booleans, None, etc. remain unchanged
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def write_json(json_obj, file_name):
|
def write_json(json_obj, file_name):
|
||||||
|
"""
|
||||||
|
Write JSON data to file with optimized sanitization strategy.
|
||||||
|
|
||||||
|
This function uses a two-stage approach:
|
||||||
|
1. Fast path: Try direct serialization (works for clean data ~99% of time)
|
||||||
|
2. Slow path: Use custom encoder that sanitizes during serialization
|
||||||
|
|
||||||
|
The custom encoder approach avoids creating a deep copy of the data,
|
||||||
|
making it memory-efficient. When sanitization occurs, the caller should
|
||||||
|
reload the cleaned data from the file to update shared memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_obj: Object to serialize (may be a shallow copy from shared memory)
|
||||||
|
file_name: Output file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if sanitization was applied (caller should reload data),
|
||||||
|
False if direct write succeeded (no reload needed)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Strategy 1: Fast path - try direct serialization
|
||||||
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
return False # No sanitization needed, no reload required
|
||||||
|
|
||||||
|
except (UnicodeEncodeError, UnicodeDecodeError) as e:
|
||||||
|
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
|
||||||
|
|
||||||
|
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
|
||||||
|
|
||||||
|
logger.info(f"JSON sanitization applied during write: {file_name}")
|
||||||
|
return True # Sanitization applied, reload recommended
|
||||||
|
|
||||||
|
|
||||||
class TokenizerInterface(Protocol):
|
class TokenizerInterface(Protocol):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue