diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 16737341..ccfbb4f7 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -16,6 +16,7 @@ from tenacity import ( ) import sys +from lightrag.utils import wrap_embedding_func_with_attrs if sys.version_info < (3, 9): from typing import AsyncIterator @@ -253,7 +254,7 @@ async def bedrock_complete( return result -# @wrap_embedding_func_with_attrs(embedding_dim=1024) +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @retry( # stop=stop_after_attempt(3), # wait=wait_exponential(multiplier=1, min=4, max=10), diff --git a/lightrag/llm/siliconcloud.py b/lightrag/llm/deprecated/siliconcloud.py similarity index 100% rename from lightrag/llm/siliconcloud.py rename to lightrag/llm/deprecated/siliconcloud.py diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 3954e814..5372307e 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -429,7 +429,7 @@ async def gemini_model_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1536) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py index c33b1c7f..447f95c3 100644 --- a/lightrag/llm/hf.py +++ b/lightrag/llm/hf.py @@ -26,6 +26,7 @@ from lightrag.exceptions import ( ) import torch import numpy as np +from lightrag.utils import wrap_embedding_func_with_attrs os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -141,6 +142,7 @@ async def hf_model_complete( return result +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: # Detect the appropriate device if torch.cuda.is_available(): diff --git a/lightrag/llm/jina.py b/lightrag/llm/jina.py index 70de5995..f61faadd 100644 --- a/lightrag/llm/jina.py +++ b/lightrag/llm/jina.py @@ -58,7 +58,7 @@ async def fetch_data(url, headers, data): return data_list -@wrap_embedding_func_with_attrs(embedding_dim=2048) +@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/llama_index_impl.py b/lightrag/llm/llama_index_impl.py index 38ec7cd1..c44e6c7a 100644 --- a/lightrag/llm/llama_index_impl.py +++ b/lightrag/llm/llama_index_impl.py @@ -174,7 +174,7 @@ async def llama_index_complete( return result -@wrap_embedding_func_with_attrs(embedding_dim=1536) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py index 9274dbfc..2f2a1dbf 100644 --- a/lightrag/llm/lollms.py +++ b/lightrag/llm/lollms.py @@ -26,6 +26,10 @@ from lightrag.exceptions import ( from typing import Union, List import numpy as np +from lightrag.utils import ( + wrap_embedding_func_with_attrs, +) + @retry( stop=stop_after_attempt(3), @@ -134,6 +138,7 @@ async def lollms_model_complete( ) +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) async def lollms_embed( texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs ) -> np.ndarray: diff --git a/lightrag/llm/nvidia_openai.py b/lightrag/llm/nvidia_openai.py index 1cbab380..1ebaf3a6 100644 --- a/lightrag/llm/nvidia_openai.py +++ b/lightrag/llm/nvidia_openai.py @@ -33,7 +33,7 @@ from lightrag.utils import ( import numpy as np -@wrap_embedding_func_with_attrs(embedding_dim=2048) +@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index ffe0f133..e35dc293 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -1,11 +1,8 @@ -import sys +from collections.abc import AsyncIterator +import os +import re -if sys.version_info < (3, 9): - from typing import AsyncIterator -else: - from collections.abc import AsyncIterator - -import pipmaster as pm # Pipmaster for dynamic library install +import pipmaster as pm # install specific modules if not pm.is_installed("ollama"): @@ -27,8 +24,31 @@ from lightrag.exceptions import ( from lightrag.api import __api_version__ import numpy as np -from typing import Union -from lightrag.utils import logger +from typing import Optional, Union +from lightrag.utils import ( + wrap_embedding_func_with_attrs, + logger, +) + + +_OLLAMA_CLOUD_HOST = "https://ollama.com" +_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$") + + +def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]: + if host: + return host + try: + model_name_str = str(model) if model is not None else "" + except (TypeError, ValueError, AttributeError) as e: + logger.warning(f"Failed to convert model to string: {e}, using empty string") + model_name_str = "" + if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str): + logger.debug( + f"Detected cloud model '{model_name_str}', using Ollama Cloud host" + ) + return _OLLAMA_CLOUD_HOST + return host @retry( @@ -58,6 +78,9 @@ async def _ollama_model_if_cache( timeout = None kwargs.pop("hashing_kv", None) api_key = kwargs.pop("api_key", None) + # fallback to environment variable when not provided explicitly + if not api_key: + api_key = os.getenv("OLLAMA_API_KEY") headers = { "Content-Type": "application/json", "User-Agent": f"LightRAG/{__api_version__}", @@ -65,6 +88,8 @@ async def _ollama_model_if_cache( if api_key: headers["Authorization"] = f"Bearer {api_key}" + host = _coerce_host_for_cloud_model(host, model) + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) try: @@ -147,17 +172,11 @@ async def ollama_model_complete( ) +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: - """ - Generate embeddings using Ollama API. - - Uses httpx directly instead of ollama.AsyncClient to work around a bug in ollama SDK v0.6.1 - where the host parameter is not properly used for the embed endpoint. - """ - import httpx - import json - api_key = kwargs.pop("api_key", None) + if not api_key: + api_key = os.getenv("OLLAMA_API_KEY") headers = { "Content-Type": "application/json", "User-Agent": f"LightRAG/{__api_version__}", @@ -167,64 +186,29 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) - - # Ensure host has proper format - if host and not host.startswith("http"): - host = f"http://{host}" - if not host: - host = "http://localhost:11434" - - # Validate host format to catch any corruption - if not isinstance(host, str) or not host.startswith("http"): - logger.error(f"Invalid host format for Ollama embed: {host} (type: {type(host).__name__})") - raise ValueError(f"Invalid host format for Ollama: {host}") - logger.info(f"Ollama embed called with host: {host}, model: {embed_model}") + host = _coerce_host_for_cloud_model(host, embed_model) - # Use httpx directly to avoid ollama SDK bug with embed endpoint - async with httpx.AsyncClient(timeout=timeout if timeout else 120.0) as client: + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) + try: + options = kwargs.pop("options", {}) + data = await ollama_client.embed( + model=embed_model, input=texts, options=options + ) + return np.array(data["embeddings"]) + except Exception as e: + logger.error(f"Error in ollama_embed: {str(e)}") try: - options = kwargs.pop("options", {}) - - # Construct the embed API endpoint - embed_url = f"{host}/api/embed" - - # Prepare request payload - payload = { - "model": embed_model, - "input": texts, - } - if options: - payload["options"] = options - - logger.debug(f"Sending embed request to {embed_url}") - - # Make the request - response = await client.post( - embed_url, - json=payload, - headers=headers + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client after exception in embed") + except Exception as close_error: + logger.warning( + f"Failed to close Ollama client after exception in embed: {close_error}" ) - - # Check for errors - response.raise_for_status() - - # Parse response - data = response.json() - - if "embeddings" not in data: - raise ValueError(f"Invalid response from Ollama: {data}") - - return np.array(data["embeddings"]) - - except httpx.HTTPStatusError as e: - error_msg = f"HTTP error from Ollama: {e.response.status_code} - {e.response.text}" - logger.error(error_msg) - raise Exception(error_msg) from e - except httpx.RequestError as e: - error_msg = f"Connection error to Ollama at {host}: {str(e)}" - logger.error(error_msg) - raise Exception(error_msg) from e - except Exception as e: - logger.error(f"Error in ollama_embed: {str(e)}") - raise + raise e + finally: + try: + await ollama_client._client.aclose() + logger.debug("Successfully closed Ollama client after embed") + except Exception as close_error: + logger.warning(f"Failed to close Ollama client after embed: {close_error}") diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 155fd3e9..61db965f 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -47,7 +47,7 @@ try: # Only enable Langfuse if both keys are configured if langfuse_public_key and langfuse_secret_key: - from langfuse.openai import AsyncOpenAI + from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped] LANGFUSE_ENABLED = True logger.info("Langfuse observability enabled for OpenAI client") @@ -594,7 +594,7 @@ async def nvidia_openai_complete( return result -@wrap_embedding_func_with_attrs(embedding_dim=1536) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/utils.py b/lightrag/utils.py index 9ac82b1c..d653c1e3 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -56,6 +56,9 @@ if not logger.handlers: # Set httpx logging level to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) +# Precompile regex pattern for JSON sanitization (module-level, compiled once) +_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") + # Global import for pypinyin with startup-time logging try: import pypinyin @@ -352,25 +355,30 @@ class TaskState: class EmbeddingFunc: embedding_dim: int func: callable - max_token_size: int | None = None # deprecated keep it for compatible only - send_dimensions: bool = False # Control whether to send embedding_dim to the function + max_token_size: int | None = None # Token limit for the embedding model + send_dimensions: bool = ( + False # Control whether to send embedding_dim to the function + ) async def __call__(self, *args, **kwargs) -> np.ndarray: # Only inject embedding_dim when send_dimensions is True if self.send_dimensions: # Check if user provided embedding_dim parameter - if 'embedding_dim' in kwargs: - user_provided_dim = kwargs['embedding_dim'] + if "embedding_dim" in kwargs: + user_provided_dim = kwargs["embedding_dim"] # If user's value differs from class attribute, output warning - if user_provided_dim is not None and user_provided_dim != self.embedding_dim: + if ( + user_provided_dim is not None + and user_provided_dim != self.embedding_dim + ): logger.warning( f"Ignoring user-provided embedding_dim={user_provided_dim}, " f"using declared embedding_dim={self.embedding_dim} from decorator" ) - + # Inject embedding_dim from decorator - kwargs['embedding_dim'] = self.embedding_dim - + kwargs["embedding_dim"] = self.embedding_dim + return await self.func(*args, **kwargs) @@ -922,9 +930,123 @@ def load_json(file_name): return json.load(f) +def _sanitize_string_for_json(text: str) -> str: + """Remove characters that cannot be encoded in UTF-8 for JSON serialization. + + Uses regex for optimal performance with zero-copy optimization for clean strings. + Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings. + + Args: + text: String to sanitize + + Returns: + Original string if clean (zero-copy), sanitized string if dirty + """ + if not text: + return text + + # Fast path: Check if sanitization is needed using C-level regex search + if not _SURROGATE_PATTERN.search(text): + return text # Zero-copy for clean strings - most common case + + # Slow path: Remove problematic characters using C-level regex substitution + return _SURROGATE_PATTERN.sub("", text) + + +class SanitizingJSONEncoder(json.JSONEncoder): + """ + Custom JSON encoder that sanitizes data during serialization. + + This encoder cleans strings during the encoding process without creating + a full copy of the data structure, making it memory-efficient for large datasets. + """ + + def encode(self, o): + """Override encode method to handle simple string cases""" + if isinstance(o, str): + return json.encoder.encode_basestring(_sanitize_string_for_json(o)) + return super().encode(o) + + def iterencode(self, o, _one_shot=False): + """ + Override iterencode to sanitize strings during serialization. + This is the core method that handles complex nested structures. + """ + # Preprocess: sanitize all strings in the object + sanitized = self._sanitize_for_encoding(o) + + # Call parent's iterencode with sanitized data + for chunk in super().iterencode(sanitized, _one_shot): + yield chunk + + def _sanitize_for_encoding(self, obj): + """ + Recursively sanitize strings in an object. + Creates new objects only when necessary to avoid deep copies. + + Args: + obj: Object to sanitize + + Returns: + Sanitized object with cleaned strings + """ + if isinstance(obj, str): + return _sanitize_string_for_json(obj) + + elif isinstance(obj, dict): + # Create new dict with sanitized keys and values + new_dict = {} + for k, v in obj.items(): + clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k + clean_v = self._sanitize_for_encoding(v) + new_dict[clean_k] = clean_v + return new_dict + + elif isinstance(obj, (list, tuple)): + # Sanitize list/tuple elements + cleaned = [self._sanitize_for_encoding(item) for item in obj] + return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned + + else: + # Numbers, booleans, None, etc. remain unchanged + return obj + + def write_json(json_obj, file_name): + """ + Write JSON data to file with optimized sanitization strategy. + + This function uses a two-stage approach: + 1. Fast path: Try direct serialization (works for clean data ~99% of time) + 2. Slow path: Use custom encoder that sanitizes during serialization + + The custom encoder approach avoids creating a deep copy of the data, + making it memory-efficient. When sanitization occurs, the caller should + reload the cleaned data from the file to update shared memory. + + Args: + json_obj: Object to serialize (may be a shallow copy from shared memory) + file_name: Output file path + + Returns: + bool: True if sanitization was applied (caller should reload data), + False if direct write succeeded (no reload needed) + """ + try: + # Strategy 1: Fast path - try direct serialization + with open(file_name, "w", encoding="utf-8") as f: + json.dump(json_obj, f, indent=2, ensure_ascii=False) + return False # No sanitization needed, no reload required + + except (UnicodeEncodeError, UnicodeDecodeError) as e: + logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}") + + # Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy) with open(file_name, "w", encoding="utf-8") as f: - json.dump(json_obj, f, indent=2, ensure_ascii=False) + json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder) + + logger.info(f"JSON sanitization applied during write: {file_name}") + return True # Sanitization applied, reload recommended class TokenizerInterface(Protocol):