From cacea8ab560899ddb77e11a57ab4c2670109db6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:26 +0800 Subject: [PATCH] cherry-pick 33a1482f --- env.example | 7 + lightrag/api/lightrag_server.py | 342 ++++++-------------------------- lightrag/llm/jina.py | 17 +- lightrag/llm/openai.py | 84 ++++++-- lightrag/utils.py | 310 ++--------------------------- 5 files changed, 149 insertions(+), 611 deletions(-) diff --git a/env.example b/env.example index 43bc759b..8fca2b5e 100644 --- a/env.example +++ b/env.example @@ -233,6 +233,13 @@ OLLAMA_LLM_NUM_CTX=32768 ### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service ####################################################################################### # EMBEDDING_TIMEOUT=30 + +### Control whether to send embedding_dim parameter to embedding API +### Set to 'true' to enable dynamic dimension adjustment (only works for OpenAI and Jina) +### Set to 'false' (default) to disable sending dimension parameter +### Note: This is automatically ignored for backends that don't support dimension parameter +# EMBEDDING_SEND_DIM=false + EMBEDDING_BINDING=ollama EMBEDDING_MODEL=bge-m3:latest EMBEDDING_DIM=1024 diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 7396261a..a66d5d3c 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -56,8 +56,6 @@ from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( get_namespace_data, - get_default_workspace, - # set_default_workspace, initialize_pipeline_status, cleanup_keyed_lock, finalize_share_data, @@ -91,7 +89,6 @@ class LLMConfigCache: # Initialize configurations based on binding conditions self.openai_llm_options = None self.gemini_llm_options = None - self.gemini_embedding_options = None self.ollama_llm_options = None self.ollama_embedding_options = None @@ -138,23 +135,6 @@ class LLMConfigCache: ) self.ollama_embedding_options = {} - # Only initialize and log Gemini Embedding options when using Gemini Embedding binding - if args.embedding_binding == "gemini": - try: - from lightrag.llm.binding_options import GeminiEmbeddingOptions - - self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict( - args - ) - logger.info( - f"Gemini Embedding Options: {self.gemini_embedding_options}" - ) - except ImportError: - logger.warning( - "GeminiEmbeddingOptions not available, using default configuration" - ) - self.gemini_embedding_options = {} - def check_frontend_build(): """Check if frontend is built and optionally check if source is up-to-date @@ -316,7 +296,6 @@ def create_app(args): "azure_openai", "aws_bedrock", "jina", - "gemini", ]: raise Exception("embedding binding not supported") @@ -352,9 +331,8 @@ def create_app(args): try: # Initialize database connections - # set_default_workspace(rag.workspace) # comment this line to test auto default workspace setting in initialize_storages await rag.initialize_storages() - await initialize_pipeline_status() # with default workspace + await initialize_pipeline_status() # Data migration regardless of storage implementation await rag.check_and_migrate_data() @@ -455,29 +433,6 @@ def create_app(args): # Create combined auth dependency for all endpoints combined_auth = get_combined_auth_dependency(api_key) - def get_workspace_from_request(request: Request) -> str: - """ - Extract workspace from HTTP request header or use default. - - This enables multi-workspace API support by checking the custom - 'LIGHTRAG-WORKSPACE' header. If not present, falls back to the - server's default workspace configuration. - - Args: - request: FastAPI Request object - - Returns: - Workspace identifier (may be empty string for global namespace) - """ - # Check custom header first - workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() - - # Fall back to server default if header not provided - if not workspace: - workspace = args.workspace - - return workspace - # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -556,9 +511,7 @@ def create_app(args): return optimized_azure_openai_model_complete - def create_optimized_gemini_llm_func( - config_cache: LLMConfigCache, args, llm_timeout: int - ): + def create_optimized_gemini_llm_func(config_cache: LLMConfigCache, args): """Create optimized Gemini LLM function with cached configuration""" async def optimized_gemini_model_complete( @@ -573,8 +526,6 @@ def create_app(args): if history_messages is None: history_messages = [] - # Use pre-processed configuration to avoid repeated parsing - kwargs["timeout"] = llm_timeout if ( config_cache.gemini_llm_options is not None and "generation_config" not in kwargs @@ -616,7 +567,7 @@ def create_app(args): config_cache, args, llm_timeout ) elif binding == "gemini": - return create_optimized_gemini_llm_func(config_cache, args, llm_timeout) + return create_optimized_gemini_llm_func(config_cache, args) else: # openai and compatible # Use optimized function with pre-processed configuration return create_optimized_openai_llm_func(config_cache, args, llm_timeout) @@ -644,108 +595,33 @@ def create_app(args): def create_optimized_embedding_function( config_cache: LLMConfigCache, binding, model, host, api_key, args - ) -> EmbeddingFunc: + ): """ - Create optimized embedding function and return an EmbeddingFunc instance - with proper max_token_size inheritance from provider defaults. - - This function: - 1. Imports the provider embedding function - 2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc - 3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping) - 4. Returns a properly configured EmbeddingFunc instance + Create optimized embedding function with pre-processed configuration for applicable bindings. + Uses lazy imports for all bindings and avoids repeated configuration parsing. """ - # Step 1: Import provider function and extract default attributes - provider_func = None - provider_max_token_size = None - provider_embedding_dim = None - - try: - if binding == "openai": - from lightrag.llm.openai import openai_embed - - provider_func = openai_embed - elif binding == "ollama": - from lightrag.llm.ollama import ollama_embed - - provider_func = ollama_embed - elif binding == "gemini": - from lightrag.llm.gemini import gemini_embed - - provider_func = gemini_embed - elif binding == "jina": - from lightrag.llm.jina import jina_embed - - provider_func = jina_embed - elif binding == "azure_openai": - from lightrag.llm.azure_openai import azure_openai_embed - - provider_func = azure_openai_embed - elif binding == "aws_bedrock": - from lightrag.llm.bedrock import bedrock_embed - - provider_func = bedrock_embed - elif binding == "lollms": - from lightrag.llm.lollms import lollms_embed - - provider_func = lollms_embed - - # Extract attributes if provider is an EmbeddingFunc - if provider_func and isinstance(provider_func, EmbeddingFunc): - provider_max_token_size = provider_func.max_token_size - provider_embedding_dim = provider_func.embedding_dim - logger.debug( - f"Extracted from {binding} provider: " - f"max_token_size={provider_max_token_size}, " - f"embedding_dim={provider_embedding_dim}" - ) - except ImportError as e: - logger.warning(f"Could not import provider function for {binding}: {e}") - - # Step 2: Apply priority (user config > provider default) - # For max_token_size: explicit env var > provider default > None - final_max_token_size = args.embedding_token_limit or provider_max_token_size - # For embedding_dim: user config (always has value) takes priority - # Only use provider default if user config is explicitly None (which shouldn't happen) - final_embedding_dim = ( - args.embedding_dim if args.embedding_dim else provider_embedding_dim - ) - - # Step 3: Create optimized embedding function (calls underlying function directly) - async def optimized_embedding_function(texts, embedding_dim=None): + async def optimized_embedding_function(texts): try: if binding == "lollms": from lightrag.llm.lollms import lollms_embed - # Get real function, skip EmbeddingFunc wrapper if present - actual_func = ( - lollms_embed.func - if isinstance(lollms_embed, EmbeddingFunc) - else lollms_embed - ) - return await actual_func( + return await lollms_embed( texts, embed_model=model, host=host, api_key=api_key ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed - # Get real function, skip EmbeddingFunc wrapper if present - actual_func = ( - ollama_embed.func - if isinstance(ollama_embed, EmbeddingFunc) - else ollama_embed - ) - - # Use pre-processed configuration if available + # Use pre-processed configuration if available, otherwise fallback to dynamic parsing if config_cache.ollama_embedding_options is not None: ollama_options = config_cache.ollama_embedding_options else: + # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import OllamaEmbeddingOptions ollama_options = OllamaEmbeddingOptions.options_dict(args) - return await actual_func( + return await ollama_embed( texts, embed_model=model, host=host, @@ -755,93 +631,27 @@ def create_app(args): elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - actual_func = ( - azure_openai_embed.func - if isinstance(azure_openai_embed, EmbeddingFunc) - else azure_openai_embed - ) - return await actual_func(texts, model=model, api_key=api_key) + return await azure_openai_embed(texts, model=model, api_key=api_key) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - actual_func = ( - bedrock_embed.func - if isinstance(bedrock_embed, EmbeddingFunc) - else bedrock_embed - ) - return await actual_func(texts, model=model) + return await bedrock_embed(texts, model=model) elif binding == "jina": from lightrag.llm.jina import jina_embed - actual_func = ( - jina_embed.func - if isinstance(jina_embed, EmbeddingFunc) - else jina_embed - ) - return await actual_func( - texts, - embedding_dim=embedding_dim, - base_url=host, - api_key=api_key, - ) - elif binding == "gemini": - from lightrag.llm.gemini import gemini_embed - - actual_func = ( - gemini_embed.func - if isinstance(gemini_embed, EmbeddingFunc) - else gemini_embed - ) - - # Use pre-processed configuration if available - if config_cache.gemini_embedding_options is not None: - gemini_options = config_cache.gemini_embedding_options - else: - from lightrag.llm.binding_options import GeminiEmbeddingOptions - - gemini_options = GeminiEmbeddingOptions.options_dict(args) - - return await actual_func( - texts, - model=model, - base_url=host, - api_key=api_key, - embedding_dim=embedding_dim, - task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"), + return await jina_embed( + texts, base_url=host, api_key=api_key ) else: # openai and compatible from lightrag.llm.openai import openai_embed - actual_func = ( - openai_embed.func - if isinstance(openai_embed, EmbeddingFunc) - else openai_embed - ) - return await actual_func( - texts, - model=model, - base_url=host, - api_key=api_key, - embedding_dim=embedding_dim, + return await openai_embed( + texts, model=model, base_url=host, api_key=api_key ) except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") - # Step 4: Wrap in EmbeddingFunc and return - embedding_func_instance = EmbeddingFunc( - embedding_dim=final_embedding_dim, - func=optimized_embedding_function, - max_token_size=final_max_token_size, - send_dimensions=False, # Will be set later based on binding requirements - ) - - # Log final embedding configuration - logger.info( - f"Embedding config: binding={binding} model={model} " - f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}" - ) - - return embedding_func_instance + return optimized_embedding_function llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) embedding_timeout = get_env_value( @@ -875,62 +685,45 @@ def create_app(args): **kwargs, ) - # Create embedding function with optimized configuration and max_token_size inheritance + # Create embedding function with optimized configuration import inspect - - # Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size) - embedding_func = create_optimized_embedding_function( + + # Create the optimized embedding function + optimized_embedding_func = create_optimized_embedding_function( config_cache=config_cache, binding=args.embedding_binding, model=args.embedding_model, host=args.embedding_binding_host, api_key=args.embedding_binding_api_key, - args=args, + args=args, # Pass args object for fallback option generation ) - - # Get embedding_send_dim from centralized configuration - embedding_send_dim = args.embedding_send_dim - - # Check if the underlying function signature has embedding_dim parameter - sig = inspect.signature(embedding_func.func) - has_embedding_dim_param = "embedding_dim" in sig.parameters - - # Determine send_dimensions value based on binding type - # Jina and Gemini REQUIRE dimension parameter (forced to True) - # OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable - if args.embedding_binding in ["jina", "gemini"]: - # Jina and Gemini APIs require dimension parameter - always send it - send_dimensions = has_embedding_dim_param - dimension_control = f"forced by {args.embedding_binding.title()} API" - else: - # For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting - send_dimensions = embedding_send_dim and has_embedding_dim_param - if send_dimensions or not embedding_send_dim: - dimension_control = "by env var" - else: - dimension_control = "by not hasparam" - - # Set send_dimensions on the EmbeddingFunc instance - embedding_func.send_dimensions = send_dimensions - + + # Check environment variable for sending dimensions + embedding_send_dim = os.getenv("EMBEDDING_SEND_DIM", "false").lower() == "true" + + # Check if the function signature has embedding_dim parameter + # Note: Since optimized_embedding_func is an async function, inspect its signature + sig = inspect.signature(optimized_embedding_func) + has_embedding_dim_param = 'embedding_dim' in sig.parameters + + # Determine send_dimensions value + # Only send dimensions if both conditions are met: + # 1. EMBEDDING_SEND_DIM environment variable is true + # 2. The function has embedding_dim parameter + send_dimensions = embedding_send_dim and has_embedding_dim_param + logger.info( - f"Send embedding dimension: {send_dimensions} {dimension_control} " - f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, " + f"Embedding configuration: send_dimensions={send_dimensions} " + f"(env_var={embedding_send_dim}, has_param={has_embedding_dim_param}, " f"binding={args.embedding_binding})" ) - - # Log max_token_size source - if embedding_func.max_token_size: - source = ( - "env variable" - if args.embedding_token_limit - else f"{args.embedding_binding} provider default" - ) - logger.info( - f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" - ) - else: - logger.info("Embedding max_token_size: not set (90% token warning disabled)") + + # Create EmbeddingFunc with send_dimensions attribute + embedding_func = EmbeddingFunc( + embedding_dim=args.embedding_dim, + func=optimized_embedding_func, + send_dimensions=send_dimensions, + ) # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None @@ -970,27 +763,15 @@ def create_app(args): query: str, documents: list, top_n: int = None, extra_body: dict = None ): """Server rerank function with configuration from environment variables""" - # Prepare kwargs for rerank function - kwargs = { - "query": query, - "documents": documents, - "top_n": top_n, - "api_key": args.rerank_binding_api_key, - "model": args.rerank_model, - "base_url": args.rerank_binding_host, - } - - # Add Cohere-specific parameters if using cohere binding - if args.rerank_binding == "cohere": - # Enable chunking if configured (useful for models with token limits like ColBERT) - kwargs["enable_chunking"] = ( - os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true" - ) - kwargs["max_tokens_per_doc"] = int( - os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096") - ) - - return await selected_rerank_func(**kwargs, extra_body=extra_body) + return await selected_rerank_func( + query=query, + documents=documents, + top_n=top_n, + api_key=args.rerank_binding_api_key, + model=args.rerank_model, + base_url=args.rerank_binding_host, + extra_body=extra_body, + ) rerank_model_func = server_rerank_func logger.info( @@ -1151,10 +932,9 @@ def create_app(args): } @app.get("/health", dependencies=[Depends(combined_auth)]) - async def get_status(request: Request): + async def get_status(): """Get current system status""" try: - default_workspace = get_default_workspace() pipeline_status = await get_namespace_data("pipeline_status") if not auth_configured: @@ -1186,7 +966,7 @@ def create_app(args): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": default_workspace, + "workspace": args.workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None, @@ -1375,12 +1155,6 @@ def check_and_install_dependencies(): def main(): - # Explicitly initialize configuration for clarity - # (The proxy will auto-initialize anyway, but this makes intent clear) - from .config import initialize_config - - initialize_config() - # Check if running under Gunicorn if "GUNICORN_CMD_ARGS" in os.environ: # If started with Gunicorn, return directly as Gunicorn will call get_application diff --git a/lightrag/llm/jina.py b/lightrag/llm/jina.py index f3c89228..70de5995 100644 --- a/lightrag/llm/jina.py +++ b/lightrag/llm/jina.py @@ -1,6 +1,4 @@ import os -from typing import Final - import pipmaster as pm # Pipmaster for dynamic library install # install specific modules @@ -21,9 +19,6 @@ from tenacity import ( from lightrag.utils import wrap_embedding_func_with_attrs, logger -DEFAULT_JINA_EMBED_DIM: Final[int] = 2048 - - async def fetch_data(url, headers, data): async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=data) as response: @@ -63,7 +58,7 @@ async def fetch_data(url, headers, data): return data_list -@wrap_embedding_func_with_attrs(embedding_dim=DEFAULT_JINA_EMBED_DIM) +@wrap_embedding_func_with_attrs(embedding_dim=2048) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -74,7 +69,7 @@ async def fetch_data(url, headers, data): ) async def jina_embed( texts: list[str], - embedding_dim: int | None = DEFAULT_JINA_EMBED_DIM, + embedding_dim: int = 2048, late_chunking: bool = False, base_url: str = None, api_key: str = None, @@ -100,10 +95,6 @@ async def jina_embed( aiohttp.ClientError: If there is a connection error with the Jina API. aiohttp.ClientResponseError: If the Jina API returns an error response. """ - resolved_embedding_dim = ( - embedding_dim if embedding_dim is not None else DEFAULT_JINA_EMBED_DIM - ) - if api_key: os.environ["JINA_API_KEY"] = api_key @@ -118,7 +109,7 @@ async def jina_embed( data = { "model": "jina-embeddings-v4", "task": "text-matching", - "dimensions": resolved_embedding_dim, + "dimensions": embedding_dim, "embedding_type": "base64", "input": texts, } @@ -128,7 +119,7 @@ async def jina_embed( data["late_chunking"] = late_chunking logger.debug( - f"Jina embedding request: {len(texts)} texts, dimensions: {resolved_embedding_dim}" + f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}" ) try: diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index f7b759ad..fce33cac 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -11,7 +11,6 @@ if not pm.is_installed("openai"): pm.install("openai") from openai import ( - AsyncOpenAI, APIConnectionError, RateLimitError, APITimeoutError, @@ -27,6 +26,7 @@ from lightrag.utils import ( safe_unicode_decode, logger, ) + from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ @@ -36,6 +36,32 @@ from typing import Any, Union from dotenv import load_dotenv +# Try to import Langfuse for LLM observability (optional) +# Falls back to standard OpenAI client if not available +# Langfuse requires proper configuration to work correctly +LANGFUSE_ENABLED = False +try: + # Check if required Langfuse environment variables are set + langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") + langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY") + + # Only enable Langfuse if both keys are configured + if langfuse_public_key and langfuse_secret_key: + from langfuse.openai import AsyncOpenAI + + LANGFUSE_ENABLED = True + logger.info("Langfuse observability enabled for OpenAI client") + else: + from openai import AsyncOpenAI + + logger.debug( + "Langfuse environment variables not configured, using standard OpenAI client" + ) +except ImportError: + from openai import AsyncOpenAI + + logger.debug("Langfuse not available, using standard OpenAI client") + # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file @@ -370,18 +396,23 @@ async def openai_complete_if_cache( ) # Ensure resources are released even if no exception occurs - if ( - iteration_started - and hasattr(response, "aclose") - and callable(getattr(response, "aclose", None)) - ): - try: - await response.aclose() - logger.debug("Successfully closed stream response") - except Exception as close_error: - logger.warning( - f"Failed to close stream response in finally block: {close_error}" - ) + # Note: Some wrapped clients (e.g., Langfuse) may not implement aclose() properly + if iteration_started and hasattr(response, "aclose"): + aclose_method = getattr(response, "aclose", None) + if callable(aclose_method): + try: + await response.aclose() + logger.debug("Successfully closed stream response") + except (AttributeError, TypeError) as close_error: + # Some wrapper objects may report hasattr(aclose) but fail when called + # This is expected behavior for certain client wrappers + logger.debug( + f"Stream response cleanup not supported by client wrapper: {close_error}" + ) + except Exception as close_error: + logger.warning( + f"Unexpected error during stream response cleanup: {close_error}" + ) # This prevents resource leaks since the caller doesn't handle closing try: @@ -578,6 +609,7 @@ async def openai_embed( model: str = "text-embedding-3-small", base_url: str | None = None, api_key: str | None = None, + embedding_dim: int | None = None, client_configs: dict[str, Any] | None = None, token_tracker: Any | None = None, ) -> np.ndarray: @@ -588,6 +620,12 @@ async def openai_embed( model: The OpenAI embedding model to use. base_url: Optional base URL for the OpenAI API. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + embedding_dim: Optional embedding dimension for dynamic dimension reduction. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. + Do NOT manually pass this parameter when calling the function directly. + The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. + Manually passing a different value will trigger a warning and be ignored. + When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). @@ -607,17 +645,27 @@ async def openai_embed( ) async with openai_async_client: - response = await openai_async_client.embeddings.create( - model=model, input=texts, encoding_format="base64" - ) - + # Prepare API call parameters + api_params = { + "model": model, + "input": texts, + "encoding_format": "base64", + } + + # Add dimensions parameter only if embedding_dim is provided + if embedding_dim is not None: + api_params["dimensions"] = embedding_dim + + # Make API call + response = await openai_async_client.embeddings.create(**api_params) + if token_tracker and hasattr(response, "usage"): token_counts = { "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), "total_tokens": getattr(response.usage, "total_tokens", 0), } token_tracker.add_usage(token_counts) - + return np.array( [ np.array(dp.embedding, dtype=np.float32) diff --git a/lightrag/utils.py b/lightrag/utils.py index 65c1e4bc..9ac82b1c 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,8 +1,6 @@ from __future__ import annotations import weakref -import sys - import asyncio import html import csv @@ -42,35 +40,6 @@ from lightrag.constants import ( SOURCE_IDS_LIMIT_METHOD_FIFO, ) -# Precompile regex pattern for JSON sanitization (module-level, compiled once) -_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") - - -class SafeStreamHandler(logging.StreamHandler): - """StreamHandler that gracefully handles closed streams during shutdown. - - This handler prevents "ValueError: I/O operation on closed file" errors - that can occur when pytest or other test frameworks close stdout/stderr - before Python's logging cleanup runs. - """ - - def flush(self): - """Flush the stream, ignoring errors if the stream is closed.""" - try: - super().flush() - except (ValueError, OSError): - # Stream is closed or otherwise unavailable, silently ignore - pass - - def close(self): - """Close the handler, ignoring errors if the stream is already closed.""" - try: - super().close() - except (ValueError, OSError): - # Stream is closed or otherwise unavailable, silently ignore - pass - - # Initialize logger with basic configuration logger = logging.getLogger("lightrag") logger.propagate = False # prevent log message send to root logger @@ -78,7 +47,7 @@ logger.setLevel(logging.INFO) # Add console handler if no handlers exist if not logger.handlers: - console_handler = SafeStreamHandler() + console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(levelname)s: %(message)s") console_handler.setFormatter(formatter) @@ -87,33 +56,6 @@ if not logger.handlers: # Set httpx logging level to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) - -def _patch_ascii_colors_console_handler() -> None: - """Prevent ascii_colors from printing flush errors during interpreter exit.""" - - try: - from ascii_colors import ConsoleHandler - except ImportError: - return - - if getattr(ConsoleHandler, "_lightrag_patched", False): - return - - original_handle_error = ConsoleHandler.handle_error - - def _safe_handle_error(self, message: str) -> None: # type: ignore[override] - exc_type, _, _ = sys.exc_info() - if exc_type in (ValueError, OSError) and "close" in message.lower(): - return - original_handle_error(self, message) - - ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment] - ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined] - - -_patch_ascii_colors_console_handler() - - # Global import for pypinyin with startup-time logging try: import pypinyin @@ -341,8 +283,8 @@ def setup_logger( logger_instance.handlers = [] # Clear existing handlers logger_instance.propagate = False - # Add console handler with safe stream handling - console_handler = SafeStreamHandler() + # Add console handler + console_handler = logging.StreamHandler() console_handler.setFormatter(simple_formatter) console_handler.setLevel(level) logger_instance.addHandler(console_handler) @@ -408,69 +350,28 @@ class TaskState: @dataclass class EmbeddingFunc: - """Embedding function wrapper with dimension validation - This class wraps an embedding function to ensure that the output embeddings have the correct dimension. - This class should not be wrapped multiple times. - - Args: - embedding_dim: Expected dimension of the embeddings - func: The actual embedding function to wrap - max_token_size: Optional token limit for the embedding model - send_dimensions: Whether to inject embedding_dim as a keyword argument - """ - embedding_dim: int func: callable - max_token_size: int | None = None # Token limit for the embedding model - send_dimensions: bool = ( - False # Control whether to send embedding_dim to the function - ) + max_token_size: int | None = None # deprecated keep it for compatible only + send_dimensions: bool = False # Control whether to send embedding_dim to the function async def __call__(self, *args, **kwargs) -> np.ndarray: # Only inject embedding_dim when send_dimensions is True if self.send_dimensions: # Check if user provided embedding_dim parameter - if "embedding_dim" in kwargs: - user_provided_dim = kwargs["embedding_dim"] + if 'embedding_dim' in kwargs: + user_provided_dim = kwargs['embedding_dim'] # If user's value differs from class attribute, output warning - if ( - user_provided_dim is not None - and user_provided_dim != self.embedding_dim - ): + if user_provided_dim is not None and user_provided_dim != self.embedding_dim: logger.warning( f"Ignoring user-provided embedding_dim={user_provided_dim}, " f"using declared embedding_dim={self.embedding_dim} from decorator" ) - + # Inject embedding_dim from decorator - kwargs["embedding_dim"] = self.embedding_dim - - # Call the actual embedding function - result = await self.func(*args, **kwargs) - - # Validate embedding dimensions using total element count - total_elements = result.size # Total number of elements in the numpy array - expected_dim = self.embedding_dim - - # Check if total elements can be evenly divided by embedding_dim - if total_elements % expected_dim != 0: - raise ValueError( - f"Embedding dimension mismatch detected: " - f"total elements ({total_elements}) cannot be evenly divided by " - f"expected dimension ({expected_dim}). " - ) - - # Optional: Verify vector count matches input text count - actual_vectors = total_elements // expected_dim - if args and isinstance(args[0], (list, tuple)): - expected_vectors = len(args[0]) - if actual_vectors != expected_vectors: - raise ValueError( - f"Vector count mismatch: " - f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." - ) - - return result + kwargs['embedding_dim'] = self.embedding_dim + + return await self.func(*args, **kwargs) def compute_args_hash(*args: Any) -> str: @@ -1005,76 +906,7 @@ def priority_limit_async_func_call( def wrap_embedding_func_with_attrs(**kwargs): - """Decorator to add embedding dimension and token limit attributes to embedding functions. - - This decorator wraps an async embedding function and returns an EmbeddingFunc instance - that automatically handles dimension parameter injection and attribute management. - - WARNING: DO NOT apply this decorator to wrapper functions that call other - decorated embedding functions. This will cause double decoration and parameter - injection conflicts. - - Correct usage patterns: - - 1. Direct implementation (decorated): - ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536) - async def my_embed(texts, embedding_dim=None): - # Direct implementation - return embeddings - ``` - - 2. Wrapper calling decorated function (DO NOT decorate wrapper): - ```python - # my_embed is already decorated above - - async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this! - # Must call .func to access unwrapped implementation - return await my_embed.func(texts, **kwargs) - ``` - - 3. Wrapper calling decorated function (properly decorated): - ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536) - async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func - # Calling .func avoids double decoration - return await my_embed.func(texts, **kwargs) - ``` - - The decorated function becomes an EmbeddingFunc instance with: - - embedding_dim: The embedding dimension - - max_token_size: Maximum token limit (optional) - - func: The original unwrapped function (access via .func) - - __call__: Wrapper that injects embedding_dim parameter - - Double decoration causes: - - Double injection of embedding_dim parameter - - Incorrect parameter passing to the underlying implementation - - Runtime errors due to parameter conflicts - - Args: - embedding_dim: The dimension of embedding vectors - max_token_size: Maximum number of tokens (optional) - send_dimensions: Whether to inject embedding_dim as a keyword argument (optional) - - Returns: - A decorator that wraps the function as an EmbeddingFunc instance - - Example of correct wrapper implementation: - ```python - @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) - @retry(...) - async def openai_embed(texts, ...): - # Base implementation - pass - - @wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here! - async def azure_openai_embed(texts, ...): - # CRITICAL: Call .func to access unwrapped function - return await openai_embed.func(texts, ...) # ✅ Correct - # return await openai_embed(texts, ...) # ❌ Wrong - double decoration! - ``` - """ + """Wrap a function with attributes""" def final_decro(func) -> EmbeddingFunc: new_func = EmbeddingFunc(**kwargs, func=func) @@ -1090,123 +922,9 @@ def load_json(file_name): return json.load(f) -def _sanitize_string_for_json(text: str) -> str: - """Remove characters that cannot be encoded in UTF-8 for JSON serialization. - - Uses regex for optimal performance with zero-copy optimization for clean strings. - Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings. - - Args: - text: String to sanitize - - Returns: - Original string if clean (zero-copy), sanitized string if dirty - """ - if not text: - return text - - # Fast path: Check if sanitization is needed using C-level regex search - if not _SURROGATE_PATTERN.search(text): - return text # Zero-copy for clean strings - most common case - - # Slow path: Remove problematic characters using C-level regex substitution - return _SURROGATE_PATTERN.sub("", text) - - -class SanitizingJSONEncoder(json.JSONEncoder): - """ - Custom JSON encoder that sanitizes data during serialization. - - This encoder cleans strings during the encoding process without creating - a full copy of the data structure, making it memory-efficient for large datasets. - """ - - def encode(self, o): - """Override encode method to handle simple string cases""" - if isinstance(o, str): - return json.encoder.encode_basestring(_sanitize_string_for_json(o)) - return super().encode(o) - - def iterencode(self, o, _one_shot=False): - """ - Override iterencode to sanitize strings during serialization. - This is the core method that handles complex nested structures. - """ - # Preprocess: sanitize all strings in the object - sanitized = self._sanitize_for_encoding(o) - - # Call parent's iterencode with sanitized data - for chunk in super().iterencode(sanitized, _one_shot): - yield chunk - - def _sanitize_for_encoding(self, obj): - """ - Recursively sanitize strings in an object. - Creates new objects only when necessary to avoid deep copies. - - Args: - obj: Object to sanitize - - Returns: - Sanitized object with cleaned strings - """ - if isinstance(obj, str): - return _sanitize_string_for_json(obj) - - elif isinstance(obj, dict): - # Create new dict with sanitized keys and values - new_dict = {} - for k, v in obj.items(): - clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k - clean_v = self._sanitize_for_encoding(v) - new_dict[clean_k] = clean_v - return new_dict - - elif isinstance(obj, (list, tuple)): - # Sanitize list/tuple elements - cleaned = [self._sanitize_for_encoding(item) for item in obj] - return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned - - else: - # Numbers, booleans, None, etc. remain unchanged - return obj - - def write_json(json_obj, file_name): - """ - Write JSON data to file with optimized sanitization strategy. - - This function uses a two-stage approach: - 1. Fast path: Try direct serialization (works for clean data ~99% of time) - 2. Slow path: Use custom encoder that sanitizes during serialization - - The custom encoder approach avoids creating a deep copy of the data, - making it memory-efficient. When sanitization occurs, the caller should - reload the cleaned data from the file to update shared memory. - - Args: - json_obj: Object to serialize (may be a shallow copy from shared memory) - file_name: Output file path - - Returns: - bool: True if sanitization was applied (caller should reload data), - False if direct write succeeded (no reload needed) - """ - try: - # Strategy 1: Fast path - try direct serialization - with open(file_name, "w", encoding="utf-8") as f: - json.dump(json_obj, f, indent=2, ensure_ascii=False) - return False # No sanitization needed, no reload required - - except (UnicodeEncodeError, UnicodeDecodeError) as e: - logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}") - - # Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy) with open(file_name, "w", encoding="utf-8") as f: - json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder) - - logger.info(f"JSON sanitization applied during write: {file_name}") - return True # Sanitization applied, reload recommended + json.dump(json_obj, f, indent=2, ensure_ascii=False) class TokenizerInterface(Protocol):