cherry-pick 05852e1a
This commit is contained in:
parent
da7683a001
commit
3558adae47
11 changed files with 205 additions and 91 deletions
|
|
@ -16,6 +16,7 @@ from tenacity import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||||
|
|
||||||
if sys.version_info < (3, 9):
|
if sys.version_info < (3, 9):
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
|
@ -253,7 +254,7 @@ async def bedrock_complete(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024)
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
# @retry(
|
# @retry(
|
||||||
# stop=stop_after_attempt(3),
|
# stop=stop_after_attempt(3),
|
||||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
|
|
||||||
|
|
@ -429,7 +429,7 @@ async def gemini_model_complete(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from lightrag.exceptions import (
|
||||||
)
|
)
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
@ -141,6 +142,7 @@ async def hf_model_complete(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||||
# Detect the appropriate device
|
# Detect the appropriate device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ async def fetch_data(url, headers, data):
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ async def llama_index_complete(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,10 @@ from lightrag.exceptions import (
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from lightrag.utils import (
|
||||||
|
wrap_embedding_func_with_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
|
|
@ -134,6 +138,7 @@ async def lollms_model_complete(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
async def lollms_embed(
|
async def lollms_embed(
|
||||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ from lightrag.utils import (
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,8 @@
|
||||||
import sys
|
from collections.abc import AsyncIterator
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
if sys.version_info < (3, 9):
|
import pipmaster as pm
|
||||||
from typing import AsyncIterator
|
|
||||||
else:
|
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
|
|
||||||
import pipmaster as pm # Pipmaster for dynamic library install
|
|
||||||
|
|
||||||
# install specific modules
|
# install specific modules
|
||||||
if not pm.is_installed("ollama"):
|
if not pm.is_installed("ollama"):
|
||||||
|
|
@ -27,8 +24,31 @@ from lightrag.exceptions import (
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import (
|
||||||
|
wrap_embedding_func_with_attrs,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_OLLAMA_CLOUD_HOST = "https://ollama.com"
|
||||||
|
_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
|
||||||
|
if host:
|
||||||
|
return host
|
||||||
|
try:
|
||||||
|
model_name_str = str(model) if model is not None else ""
|
||||||
|
except (TypeError, ValueError, AttributeError) as e:
|
||||||
|
logger.warning(f"Failed to convert model to string: {e}, using empty string")
|
||||||
|
model_name_str = ""
|
||||||
|
if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
|
||||||
|
logger.debug(
|
||||||
|
f"Detected cloud model '{model_name_str}', using Ollama Cloud host"
|
||||||
|
)
|
||||||
|
return _OLLAMA_CLOUD_HOST
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
|
|
@ -58,6 +78,9 @@ async def _ollama_model_if_cache(
|
||||||
timeout = None
|
timeout = None
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
api_key = kwargs.pop("api_key", None)
|
api_key = kwargs.pop("api_key", None)
|
||||||
|
# fallback to environment variable when not provided explicitly
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("OLLAMA_API_KEY")
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"User-Agent": f"LightRAG/{__api_version__}",
|
"User-Agent": f"LightRAG/{__api_version__}",
|
||||||
|
|
@ -65,6 +88,8 @@ async def _ollama_model_if_cache(
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
host = _coerce_host_for_cloud_model(host, model)
|
||||||
|
|
||||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -147,17 +172,11 @@ async def ollama_model_complete(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
"""
|
|
||||||
Generate embeddings using Ollama API.
|
|
||||||
|
|
||||||
Uses httpx directly instead of ollama.AsyncClient to work around a bug in ollama SDK v0.6.1
|
|
||||||
where the host parameter is not properly used for the embed endpoint.
|
|
||||||
"""
|
|
||||||
import httpx
|
|
||||||
import json
|
|
||||||
|
|
||||||
api_key = kwargs.pop("api_key", None)
|
api_key = kwargs.pop("api_key", None)
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("OLLAMA_API_KEY")
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"User-Agent": f"LightRAG/{__api_version__}",
|
"User-Agent": f"LightRAG/{__api_version__}",
|
||||||
|
|
@ -168,63 +187,28 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
host = kwargs.pop("host", None)
|
host = kwargs.pop("host", None)
|
||||||
timeout = kwargs.pop("timeout", None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
# Ensure host has proper format
|
host = _coerce_host_for_cloud_model(host, embed_model)
|
||||||
if host and not host.startswith("http"):
|
|
||||||
host = f"http://{host}"
|
|
||||||
if not host:
|
|
||||||
host = "http://localhost:11434"
|
|
||||||
|
|
||||||
# Validate host format to catch any corruption
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||||
if not isinstance(host, str) or not host.startswith("http"):
|
try:
|
||||||
logger.error(f"Invalid host format for Ollama embed: {host} (type: {type(host).__name__})")
|
options = kwargs.pop("options", {})
|
||||||
raise ValueError(f"Invalid host format for Ollama: {host}")
|
data = await ollama_client.embed(
|
||||||
|
model=embed_model, input=texts, options=options
|
||||||
logger.info(f"Ollama embed called with host: {host}, model: {embed_model}")
|
)
|
||||||
|
return np.array(data["embeddings"])
|
||||||
# Use httpx directly to avoid ollama SDK bug with embed endpoint
|
except Exception as e:
|
||||||
async with httpx.AsyncClient(timeout=timeout if timeout else 120.0) as client:
|
logger.error(f"Error in ollama_embed: {str(e)}")
|
||||||
try:
|
try:
|
||||||
options = kwargs.pop("options", {})
|
await ollama_client._client.aclose()
|
||||||
|
logger.debug("Successfully closed Ollama client after exception in embed")
|
||||||
# Construct the embed API endpoint
|
except Exception as close_error:
|
||||||
embed_url = f"{host}/api/embed"
|
logger.warning(
|
||||||
|
f"Failed to close Ollama client after exception in embed: {close_error}"
|
||||||
# Prepare request payload
|
|
||||||
payload = {
|
|
||||||
"model": embed_model,
|
|
||||||
"input": texts,
|
|
||||||
}
|
|
||||||
if options:
|
|
||||||
payload["options"] = options
|
|
||||||
|
|
||||||
logger.debug(f"Sending embed request to {embed_url}")
|
|
||||||
|
|
||||||
# Make the request
|
|
||||||
response = await client.post(
|
|
||||||
embed_url,
|
|
||||||
json=payload,
|
|
||||||
headers=headers
|
|
||||||
)
|
)
|
||||||
|
raise e
|
||||||
# Check for errors
|
finally:
|
||||||
response.raise_for_status()
|
try:
|
||||||
|
await ollama_client._client.aclose()
|
||||||
# Parse response
|
logger.debug("Successfully closed Ollama client after embed")
|
||||||
data = response.json()
|
except Exception as close_error:
|
||||||
|
logger.warning(f"Failed to close Ollama client after embed: {close_error}")
|
||||||
if "embeddings" not in data:
|
|
||||||
raise ValueError(f"Invalid response from Ollama: {data}")
|
|
||||||
|
|
||||||
return np.array(data["embeddings"])
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_msg = f"HTTP error from Ollama: {e.response.status_code} - {e.response.text}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg) from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
error_msg = f"Connection error to Ollama at {host}: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise Exception(error_msg) from e
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in ollama_embed: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ try:
|
||||||
|
|
||||||
# Only enable Langfuse if both keys are configured
|
# Only enable Langfuse if both keys are configured
|
||||||
if langfuse_public_key and langfuse_secret_key:
|
if langfuse_public_key and langfuse_secret_key:
|
||||||
from langfuse.openai import AsyncOpenAI
|
from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped]
|
||||||
|
|
||||||
LANGFUSE_ENABLED = True
|
LANGFUSE_ENABLED = True
|
||||||
logger.info("Langfuse observability enabled for OpenAI client")
|
logger.info("Langfuse observability enabled for OpenAI client")
|
||||||
|
|
@ -594,7 +594,7 @@ async def nvidia_openai_complete(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,9 @@ if not logger.handlers:
|
||||||
# Set httpx logging level to WARNING
|
# Set httpx logging level to WARNING
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||||
|
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||||
|
|
||||||
# Global import for pypinyin with startup-time logging
|
# Global import for pypinyin with startup-time logging
|
||||||
try:
|
try:
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
|
@ -352,24 +355,29 @@ class TaskState:
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
func: callable
|
func: callable
|
||||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
max_token_size: int | None = None # Token limit for the embedding model
|
||||||
send_dimensions: bool = False # Control whether to send embedding_dim to the function
|
send_dimensions: bool = (
|
||||||
|
False # Control whether to send embedding_dim to the function
|
||||||
|
)
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
# Only inject embedding_dim when send_dimensions is True
|
# Only inject embedding_dim when send_dimensions is True
|
||||||
if self.send_dimensions:
|
if self.send_dimensions:
|
||||||
# Check if user provided embedding_dim parameter
|
# Check if user provided embedding_dim parameter
|
||||||
if 'embedding_dim' in kwargs:
|
if "embedding_dim" in kwargs:
|
||||||
user_provided_dim = kwargs['embedding_dim']
|
user_provided_dim = kwargs["embedding_dim"]
|
||||||
# If user's value differs from class attribute, output warning
|
# If user's value differs from class attribute, output warning
|
||||||
if user_provided_dim is not None and user_provided_dim != self.embedding_dim:
|
if (
|
||||||
|
user_provided_dim is not None
|
||||||
|
and user_provided_dim != self.embedding_dim
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
|
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
|
||||||
f"using declared embedding_dim={self.embedding_dim} from decorator"
|
f"using declared embedding_dim={self.embedding_dim} from decorator"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inject embedding_dim from decorator
|
# Inject embedding_dim from decorator
|
||||||
kwargs['embedding_dim'] = self.embedding_dim
|
kwargs["embedding_dim"] = self.embedding_dim
|
||||||
|
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -922,9 +930,123 @@ def load_json(file_name):
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_string_for_json(text: str) -> str:
|
||||||
|
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
||||||
|
|
||||||
|
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
||||||
|
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Original string if clean (zero-copy), sanitized string if dirty
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Fast path: Check if sanitization is needed using C-level regex search
|
||||||
|
if not _SURROGATE_PATTERN.search(text):
|
||||||
|
return text # Zero-copy for clean strings - most common case
|
||||||
|
|
||||||
|
# Slow path: Remove problematic characters using C-level regex substitution
|
||||||
|
return _SURROGATE_PATTERN.sub("", text)
|
||||||
|
|
||||||
|
|
||||||
|
class SanitizingJSONEncoder(json.JSONEncoder):
|
||||||
|
"""
|
||||||
|
Custom JSON encoder that sanitizes data during serialization.
|
||||||
|
|
||||||
|
This encoder cleans strings during the encoding process without creating
|
||||||
|
a full copy of the data structure, making it memory-efficient for large datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def encode(self, o):
|
||||||
|
"""Override encode method to handle simple string cases"""
|
||||||
|
if isinstance(o, str):
|
||||||
|
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
|
||||||
|
return super().encode(o)
|
||||||
|
|
||||||
|
def iterencode(self, o, _one_shot=False):
|
||||||
|
"""
|
||||||
|
Override iterencode to sanitize strings during serialization.
|
||||||
|
This is the core method that handles complex nested structures.
|
||||||
|
"""
|
||||||
|
# Preprocess: sanitize all strings in the object
|
||||||
|
sanitized = self._sanitize_for_encoding(o)
|
||||||
|
|
||||||
|
# Call parent's iterencode with sanitized data
|
||||||
|
for chunk in super().iterencode(sanitized, _one_shot):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _sanitize_for_encoding(self, obj):
|
||||||
|
"""
|
||||||
|
Recursively sanitize strings in an object.
|
||||||
|
Creates new objects only when necessary to avoid deep copies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized object with cleaned strings
|
||||||
|
"""
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return _sanitize_string_for_json(obj)
|
||||||
|
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
# Create new dict with sanitized keys and values
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in obj.items():
|
||||||
|
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
|
||||||
|
clean_v = self._sanitize_for_encoding(v)
|
||||||
|
new_dict[clean_k] = clean_v
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
# Sanitize list/tuple elements
|
||||||
|
cleaned = [self._sanitize_for_encoding(item) for item in obj]
|
||||||
|
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Numbers, booleans, None, etc. remain unchanged
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def write_json(json_obj, file_name):
|
def write_json(json_obj, file_name):
|
||||||
|
"""
|
||||||
|
Write JSON data to file with optimized sanitization strategy.
|
||||||
|
|
||||||
|
This function uses a two-stage approach:
|
||||||
|
1. Fast path: Try direct serialization (works for clean data ~99% of time)
|
||||||
|
2. Slow path: Use custom encoder that sanitizes during serialization
|
||||||
|
|
||||||
|
The custom encoder approach avoids creating a deep copy of the data,
|
||||||
|
making it memory-efficient. When sanitization occurs, the caller should
|
||||||
|
reload the cleaned data from the file to update shared memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_obj: Object to serialize (may be a shallow copy from shared memory)
|
||||||
|
file_name: Output file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if sanitization was applied (caller should reload data),
|
||||||
|
False if direct write succeeded (no reload needed)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Strategy 1: Fast path - try direct serialization
|
||||||
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
return False # No sanitization needed, no reload required
|
||||||
|
|
||||||
|
except (UnicodeEncodeError, UnicodeDecodeError) as e:
|
||||||
|
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
|
||||||
|
|
||||||
|
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
|
||||||
|
|
||||||
|
logger.info(f"JSON sanitization applied during write: {file_name}")
|
||||||
|
return True # Sanitization applied, reload recommended
|
||||||
|
|
||||||
|
|
||||||
class TokenizerInterface(Protocol):
|
class TokenizerInterface(Protocol):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue