This commit is contained in:
Raphaël MANSUY 2025-12-04 19:14:26 +08:00
parent e11e30be0e
commit cacea8ab56
5 changed files with 149 additions and 611 deletions

View file

@ -233,6 +233,13 @@ OLLAMA_LLM_NUM_CTX=32768
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
#######################################################################################
# EMBEDDING_TIMEOUT=30
### Control whether to send embedding_dim parameter to embedding API
### Set to 'true' to enable dynamic dimension adjustment (only works for OpenAI and Jina)
### Set to 'false' (default) to disable sending dimension parameter
### Note: This is automatically ignored for backends that don't support dimension parameter
# EMBEDDING_SEND_DIM=false
EMBEDDING_BINDING=ollama
EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024

View file

@ -56,8 +56,6 @@ from lightrag.api.routers.ollama_api import OllamaAPI
from lightrag.utils import logger, set_verbose_debug
from lightrag.kg.shared_storage import (
get_namespace_data,
get_default_workspace,
# set_default_workspace,
initialize_pipeline_status,
cleanup_keyed_lock,
finalize_share_data,
@ -91,7 +89,6 @@ class LLMConfigCache:
# Initialize configurations based on binding conditions
self.openai_llm_options = None
self.gemini_llm_options = None
self.gemini_embedding_options = None
self.ollama_llm_options = None
self.ollama_embedding_options = None
@ -138,23 +135,6 @@ class LLMConfigCache:
)
self.ollama_embedding_options = {}
# Only initialize and log Gemini Embedding options when using Gemini Embedding binding
if args.embedding_binding == "gemini":
try:
from lightrag.llm.binding_options import GeminiEmbeddingOptions
self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict(
args
)
logger.info(
f"Gemini Embedding Options: {self.gemini_embedding_options}"
)
except ImportError:
logger.warning(
"GeminiEmbeddingOptions not available, using default configuration"
)
self.gemini_embedding_options = {}
def check_frontend_build():
"""Check if frontend is built and optionally check if source is up-to-date
@ -316,7 +296,6 @@ def create_app(args):
"azure_openai",
"aws_bedrock",
"jina",
"gemini",
]:
raise Exception("embedding binding not supported")
@ -352,9 +331,8 @@ def create_app(args):
try:
# Initialize database connections
# set_default_workspace(rag.workspace) # comment this line to test auto default workspace setting in initialize_storages
await rag.initialize_storages()
await initialize_pipeline_status() # with default workspace
await initialize_pipeline_status()
# Data migration regardless of storage implementation
await rag.check_and_migrate_data()
@ -455,29 +433,6 @@ def create_app(args):
# Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key)
def get_workspace_from_request(request: Request) -> str:
"""
Extract workspace from HTTP request header or use default.
This enables multi-workspace API support by checking the custom
'LIGHTRAG-WORKSPACE' header. If not present, falls back to the
server's default workspace configuration.
Args:
request: FastAPI Request object
Returns:
Workspace identifier (may be empty string for global namespace)
"""
# Check custom header first
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
# Fall back to server default if header not provided
if not workspace:
workspace = args.workspace
return workspace
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
@ -556,9 +511,7 @@ def create_app(args):
return optimized_azure_openai_model_complete
def create_optimized_gemini_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int
):
def create_optimized_gemini_llm_func(config_cache: LLMConfigCache, args):
"""Create optimized Gemini LLM function with cached configuration"""
async def optimized_gemini_model_complete(
@ -573,8 +526,6 @@ def create_app(args):
if history_messages is None:
history_messages = []
# Use pre-processed configuration to avoid repeated parsing
kwargs["timeout"] = llm_timeout
if (
config_cache.gemini_llm_options is not None
and "generation_config" not in kwargs
@ -616,7 +567,7 @@ def create_app(args):
config_cache, args, llm_timeout
)
elif binding == "gemini":
return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
return create_optimized_gemini_llm_func(config_cache, args)
else: # openai and compatible
# Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
@ -644,108 +595,33 @@ def create_app(args):
def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, args
) -> EmbeddingFunc:
):
"""
Create optimized embedding function and return an EmbeddingFunc instance
with proper max_token_size inheritance from provider defaults.
This function:
1. Imports the provider embedding function
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
4. Returns a properly configured EmbeddingFunc instance
Create optimized embedding function with pre-processed configuration for applicable bindings.
Uses lazy imports for all bindings and avoids repeated configuration parsing.
"""
# Step 1: Import provider function and extract default attributes
provider_func = None
provider_max_token_size = None
provider_embedding_dim = None
try:
if binding == "openai":
from lightrag.llm.openai import openai_embed
provider_func = openai_embed
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
provider_func = ollama_embed
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
provider_func = gemini_embed
elif binding == "jina":
from lightrag.llm.jina import jina_embed
provider_func = jina_embed
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
provider_func = azure_openai_embed
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
provider_func = bedrock_embed
elif binding == "lollms":
from lightrag.llm.lollms import lollms_embed
provider_func = lollms_embed
# Extract attributes if provider is an EmbeddingFunc
if provider_func and isinstance(provider_func, EmbeddingFunc):
provider_max_token_size = provider_func.max_token_size
provider_embedding_dim = provider_func.embedding_dim
logger.debug(
f"Extracted from {binding} provider: "
f"max_token_size={provider_max_token_size}, "
f"embedding_dim={provider_embedding_dim}"
)
except ImportError as e:
logger.warning(f"Could not import provider function for {binding}: {e}")
# Step 2: Apply priority (user config > provider default)
# For max_token_size: explicit env var > provider default > None
final_max_token_size = args.embedding_token_limit or provider_max_token_size
# For embedding_dim: user config (always has value) takes priority
# Only use provider default if user config is explicitly None (which shouldn't happen)
final_embedding_dim = (
args.embedding_dim if args.embedding_dim else provider_embedding_dim
)
# Step 3: Create optimized embedding function (calls underlying function directly)
async def optimized_embedding_function(texts, embedding_dim=None):
async def optimized_embedding_function(texts):
try:
if binding == "lollms":
from lightrag.llm.lollms import lollms_embed
# Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
lollms_embed.func
if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed
)
return await actual_func(
return await lollms_embed(
texts, embed_model=model, host=host, api_key=api_key
)
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
# Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
ollama_embed.func
if isinstance(ollama_embed, EmbeddingFunc)
else ollama_embed
)
# Use pre-processed configuration if available
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
if config_cache.ollama_embedding_options is not None:
ollama_options = config_cache.ollama_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import OllamaEmbeddingOptions
ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await actual_func(
return await ollama_embed(
texts,
embed_model=model,
host=host,
@ -755,93 +631,27 @@ def create_app(args):
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
actual_func = (
azure_openai_embed.func
if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed
)
return await actual_func(texts, model=model, api_key=api_key)
return await azure_openai_embed(texts, model=model, api_key=api_key)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
actual_func = (
bedrock_embed.func
if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed
)
return await actual_func(texts, model=model)
return await bedrock_embed(texts, model=model)
elif binding == "jina":
from lightrag.llm.jina import jina_embed
actual_func = (
jina_embed.func
if isinstance(jina_embed, EmbeddingFunc)
else jina_embed
)
return await actual_func(
texts,
embedding_dim=embedding_dim,
base_url=host,
api_key=api_key,
)
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
actual_func = (
gemini_embed.func
if isinstance(gemini_embed, EmbeddingFunc)
else gemini_embed
)
# Use pre-processed configuration if available
if config_cache.gemini_embedding_options is not None:
gemini_options = config_cache.gemini_embedding_options
else:
from lightrag.llm.binding_options import GeminiEmbeddingOptions
gemini_options = GeminiEmbeddingOptions.options_dict(args)
return await actual_func(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
return await jina_embed(
texts, base_url=host, api_key=api_key
)
else: # openai and compatible
from lightrag.llm.openai import openai_embed
actual_func = (
openai_embed.func
if isinstance(openai_embed, EmbeddingFunc)
else openai_embed
)
return await actual_func(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
return await openai_embed(
texts, model=model, base_url=host, api_key=api_key
)
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")
# Step 4: Wrap in EmbeddingFunc and return
embedding_func_instance = EmbeddingFunc(
embedding_dim=final_embedding_dim,
func=optimized_embedding_function,
max_token_size=final_max_token_size,
send_dimensions=False, # Will be set later based on binding requirements
)
# Log final embedding configuration
logger.info(
f"Embedding config: binding={binding} model={model} "
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
)
return embedding_func_instance
return optimized_embedding_function
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value(
@ -875,62 +685,45 @@ def create_app(args):
**kwargs,
)
# Create embedding function with optimized configuration and max_token_size inheritance
# Create embedding function with optimized configuration
import inspect
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
embedding_func = create_optimized_embedding_function(
# Create the optimized embedding function
optimized_embedding_func = create_optimized_embedding_function(
config_cache=config_cache,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
args=args,
args=args, # Pass args object for fallback option generation
)
# Get embedding_send_dim from centralized configuration
embedding_send_dim = args.embedding_send_dim
# Check if the underlying function signature has embedding_dim parameter
sig = inspect.signature(embedding_func.func)
has_embedding_dim_param = "embedding_dim" in sig.parameters
# Determine send_dimensions value based on binding type
# Jina and Gemini REQUIRE dimension parameter (forced to True)
# OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable
if args.embedding_binding in ["jina", "gemini"]:
# Jina and Gemini APIs require dimension parameter - always send it
send_dimensions = has_embedding_dim_param
dimension_control = f"forced by {args.embedding_binding.title()} API"
else:
# For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting
send_dimensions = embedding_send_dim and has_embedding_dim_param
if send_dimensions or not embedding_send_dim:
dimension_control = "by env var"
else:
dimension_control = "by not hasparam"
# Set send_dimensions on the EmbeddingFunc instance
embedding_func.send_dimensions = send_dimensions
# Check environment variable for sending dimensions
embedding_send_dim = os.getenv("EMBEDDING_SEND_DIM", "false").lower() == "true"
# Check if the function signature has embedding_dim parameter
# Note: Since optimized_embedding_func is an async function, inspect its signature
sig = inspect.signature(optimized_embedding_func)
has_embedding_dim_param = 'embedding_dim' in sig.parameters
# Determine send_dimensions value
# Only send dimensions if both conditions are met:
# 1. EMBEDDING_SEND_DIM environment variable is true
# 2. The function has embedding_dim parameter
send_dimensions = embedding_send_dim and has_embedding_dim_param
logger.info(
f"Send embedding dimension: {send_dimensions} {dimension_control} "
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
f"Embedding configuration: send_dimensions={send_dimensions} "
f"(env_var={embedding_send_dim}, has_param={has_embedding_dim_param}, "
f"binding={args.embedding_binding})"
)
# Log max_token_size source
if embedding_func.max_token_size:
source = (
"env variable"
if args.embedding_token_limit
else f"{args.embedding_binding} provider default"
)
logger.info(
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
)
else:
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
# Create EmbeddingFunc with send_dimensions attribute
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
func=optimized_embedding_func,
send_dimensions=send_dimensions,
)
# Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None
@ -970,27 +763,15 @@ def create_app(args):
query: str, documents: list, top_n: int = None, extra_body: dict = None
):
"""Server rerank function with configuration from environment variables"""
# Prepare kwargs for rerank function
kwargs = {
"query": query,
"documents": documents,
"top_n": top_n,
"api_key": args.rerank_binding_api_key,
"model": args.rerank_model,
"base_url": args.rerank_binding_host,
}
# Add Cohere-specific parameters if using cohere binding
if args.rerank_binding == "cohere":
# Enable chunking if configured (useful for models with token limits like ColBERT)
kwargs["enable_chunking"] = (
os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true"
)
kwargs["max_tokens_per_doc"] = int(
os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")
)
return await selected_rerank_func(**kwargs, extra_body=extra_body)
return await selected_rerank_func(
query=query,
documents=documents,
top_n=top_n,
api_key=args.rerank_binding_api_key,
model=args.rerank_model,
base_url=args.rerank_binding_host,
extra_body=extra_body,
)
rerank_model_func = server_rerank_func
logger.info(
@ -1151,10 +932,9 @@ def create_app(args):
}
@app.get("/health", dependencies=[Depends(combined_auth)])
async def get_status(request: Request):
async def get_status():
"""Get current system status"""
try:
default_workspace = get_default_workspace()
pipeline_status = await get_namespace_data("pipeline_status")
if not auth_configured:
@ -1186,7 +966,7 @@ def create_app(args):
"vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
"enable_llm_cache": args.enable_llm_cache,
"workspace": default_workspace,
"workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration
"enable_rerank": rerank_model_func is not None,
@ -1375,12 +1155,6 @@ def check_and_install_dependencies():
def main():
# Explicitly initialize configuration for clarity
# (The proxy will auto-initialize anyway, but this makes intent clear)
from .config import initialize_config
initialize_config()
# Check if running under Gunicorn
if "GUNICORN_CMD_ARGS" in os.environ:
# If started with Gunicorn, return directly as Gunicorn will call get_application

View file

@ -1,6 +1,4 @@
import os
from typing import Final
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
@ -21,9 +19,6 @@ from tenacity import (
from lightrag.utils import wrap_embedding_func_with_attrs, logger
DEFAULT_JINA_EMBED_DIM: Final[int] = 2048
async def fetch_data(url, headers, data):
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
@ -63,7 +58,7 @@ async def fetch_data(url, headers, data):
return data_list
@wrap_embedding_func_with_attrs(embedding_dim=DEFAULT_JINA_EMBED_DIM)
@wrap_embedding_func_with_attrs(embedding_dim=2048)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
@ -74,7 +69,7 @@ async def fetch_data(url, headers, data):
)
async def jina_embed(
texts: list[str],
embedding_dim: int | None = DEFAULT_JINA_EMBED_DIM,
embedding_dim: int = 2048,
late_chunking: bool = False,
base_url: str = None,
api_key: str = None,
@ -100,10 +95,6 @@ async def jina_embed(
aiohttp.ClientError: If there is a connection error with the Jina API.
aiohttp.ClientResponseError: If the Jina API returns an error response.
"""
resolved_embedding_dim = (
embedding_dim if embedding_dim is not None else DEFAULT_JINA_EMBED_DIM
)
if api_key:
os.environ["JINA_API_KEY"] = api_key
@ -118,7 +109,7 @@ async def jina_embed(
data = {
"model": "jina-embeddings-v4",
"task": "text-matching",
"dimensions": resolved_embedding_dim,
"dimensions": embedding_dim,
"embedding_type": "base64",
"input": texts,
}
@ -128,7 +119,7 @@ async def jina_embed(
data["late_chunking"] = late_chunking
logger.debug(
f"Jina embedding request: {len(texts)} texts, dimensions: {resolved_embedding_dim}"
f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
)
try:

View file

@ -11,7 +11,6 @@ if not pm.is_installed("openai"):
pm.install("openai")
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
@ -27,6 +26,7 @@ from lightrag.utils import (
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
@ -36,6 +36,32 @@ from typing import Any, Union
from dotenv import load_dotenv
# Try to import Langfuse for LLM observability (optional)
# Falls back to standard OpenAI client if not available
# Langfuse requires proper configuration to work correctly
LANGFUSE_ENABLED = False
try:
# Check if required Langfuse environment variables are set
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
# Only enable Langfuse if both keys are configured
if langfuse_public_key and langfuse_secret_key:
from langfuse.openai import AsyncOpenAI
LANGFUSE_ENABLED = True
logger.info("Langfuse observability enabled for OpenAI client")
else:
from openai import AsyncOpenAI
logger.debug(
"Langfuse environment variables not configured, using standard OpenAI client"
)
except ImportError:
from openai import AsyncOpenAI
logger.debug("Langfuse not available, using standard OpenAI client")
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
@ -370,18 +396,23 @@ async def openai_complete_if_cache(
)
# Ensure resources are released even if no exception occurs
if (
iteration_started
and hasattr(response, "aclose")
and callable(getattr(response, "aclose", None))
):
try:
await response.aclose()
logger.debug("Successfully closed stream response")
except Exception as close_error:
logger.warning(
f"Failed to close stream response in finally block: {close_error}"
)
# Note: Some wrapped clients (e.g., Langfuse) may not implement aclose() properly
if iteration_started and hasattr(response, "aclose"):
aclose_method = getattr(response, "aclose", None)
if callable(aclose_method):
try:
await response.aclose()
logger.debug("Successfully closed stream response")
except (AttributeError, TypeError) as close_error:
# Some wrapper objects may report hasattr(aclose) but fail when called
# This is expected behavior for certain client wrappers
logger.debug(
f"Stream response cleanup not supported by client wrapper: {close_error}"
)
except Exception as close_error:
logger.warning(
f"Unexpected error during stream response cleanup: {close_error}"
)
# This prevents resource leaks since the caller doesn't handle closing
try:
@ -578,6 +609,7 @@ async def openai_embed(
model: str = "text-embedding-3-small",
base_url: str | None = None,
api_key: str | None = None,
embedding_dim: int | None = None,
client_configs: dict[str, Any] | None = None,
token_tracker: Any | None = None,
) -> np.ndarray:
@ -588,6 +620,12 @@ async def openai_embed(
model: The OpenAI embedding model to use.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
Do NOT manually pass this parameter when calling the function directly.
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
Manually passing a different value will trigger a warning and be ignored.
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
client_configs: Additional configuration options for the AsyncOpenAI client.
These will override any default configurations but will be overridden by
explicit parameters (api_key, base_url).
@ -607,17 +645,27 @@ async def openai_embed(
)
async with openai_async_client:
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="base64"
)
# Prepare API call parameters
api_params = {
"model": model,
"input": texts,
"encoding_format": "base64",
}
# Add dimensions parameter only if embedding_dim is provided
if embedding_dim is not None:
api_params["dimensions"] = embedding_dim
# Make API call
response = await openai_async_client.embeddings.create(**api_params)
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
return np.array(
[
np.array(dp.embedding, dtype=np.float32)

View file

@ -1,8 +1,6 @@
from __future__ import annotations
import weakref
import sys
import asyncio
import html
import csv
@ -42,35 +40,6 @@ from lightrag.constants import (
SOURCE_IDS_LIMIT_METHOD_FIFO,
)
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
class SafeStreamHandler(logging.StreamHandler):
"""StreamHandler that gracefully handles closed streams during shutdown.
This handler prevents "ValueError: I/O operation on closed file" errors
that can occur when pytest or other test frameworks close stdout/stderr
before Python's logging cleanup runs.
"""
def flush(self):
"""Flush the stream, ignoring errors if the stream is closed."""
try:
super().flush()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
def close(self):
"""Close the handler, ignoring errors if the stream is already closed."""
try:
super().close()
except (ValueError, OSError):
# Stream is closed or otherwise unavailable, silently ignore
pass
# Initialize logger with basic configuration
logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root logger
@ -78,7 +47,7 @@ logger.setLevel(logging.INFO)
# Add console handler if no handlers exist
if not logger.handlers:
console_handler = SafeStreamHandler()
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)s: %(message)s")
console_handler.setFormatter(formatter)
@ -87,33 +56,6 @@ if not logger.handlers:
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
def _patch_ascii_colors_console_handler() -> None:
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
try:
from ascii_colors import ConsoleHandler
except ImportError:
return
if getattr(ConsoleHandler, "_lightrag_patched", False):
return
original_handle_error = ConsoleHandler.handle_error
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
exc_type, _, _ = sys.exc_info()
if exc_type in (ValueError, OSError) and "close" in message.lower():
return
original_handle_error(self, message)
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
_patch_ascii_colors_console_handler()
# Global import for pypinyin with startup-time logging
try:
import pypinyin
@ -341,8 +283,8 @@ def setup_logger(
logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False
# Add console handler with safe stream handling
console_handler = SafeStreamHandler()
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
@ -408,69 +350,28 @@ class TaskState:
@dataclass
class EmbeddingFunc:
"""Embedding function wrapper with dimension validation
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
This class should not be wrapped multiple times.
Args:
embedding_dim: Expected dimension of the embeddings
func: The actual embedding function to wrap
max_token_size: Optional token limit for the embedding model
send_dimensions: Whether to inject embedding_dim as a keyword argument
"""
embedding_dim: int
func: callable
max_token_size: int | None = None # Token limit for the embedding model
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
max_token_size: int | None = None # deprecated keep it for compatible only
send_dimensions: bool = False # Control whether to send embedding_dim to the function
async def __call__(self, *args, **kwargs) -> np.ndarray:
# Only inject embedding_dim when send_dimensions is True
if self.send_dimensions:
# Check if user provided embedding_dim parameter
if "embedding_dim" in kwargs:
user_provided_dim = kwargs["embedding_dim"]
if 'embedding_dim' in kwargs:
user_provided_dim = kwargs['embedding_dim']
# If user's value differs from class attribute, output warning
if (
user_provided_dim is not None
and user_provided_dim != self.embedding_dim
):
if user_provided_dim is not None and user_provided_dim != self.embedding_dim:
logger.warning(
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
f"using declared embedding_dim={self.embedding_dim} from decorator"
)
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
# Call the actual embedding function
result = await self.func(*args, **kwargs)
# Validate embedding dimensions using total element count
total_elements = result.size # Total number of elements in the numpy array
expected_dim = self.embedding_dim
# Check if total elements can be evenly divided by embedding_dim
if total_elements % expected_dim != 0:
raise ValueError(
f"Embedding dimension mismatch detected: "
f"total elements ({total_elements}) cannot be evenly divided by "
f"expected dimension ({expected_dim}). "
)
# Optional: Verify vector count matches input text count
actual_vectors = total_elements // expected_dim
if args and isinstance(args[0], (list, tuple)):
expected_vectors = len(args[0])
if actual_vectors != expected_vectors:
raise ValueError(
f"Vector count mismatch: "
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
)
return result
kwargs['embedding_dim'] = self.embedding_dim
return await self.func(*args, **kwargs)
def compute_args_hash(*args: Any) -> str:
@ -1005,76 +906,7 @@ def priority_limit_async_func_call(
def wrap_embedding_func_with_attrs(**kwargs):
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
that automatically handles dimension parameter injection and attribute management.
WARNING: DO NOT apply this decorator to wrapper functions that call other
decorated embedding functions. This will cause double decoration and parameter
injection conflicts.
Correct usage patterns:
1. Direct implementation (decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_embed(texts, embedding_dim=None):
# Direct implementation
return embeddings
```
2. Wrapper calling decorated function (DO NOT decorate wrapper):
```python
# my_embed is already decorated above
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
# Must call .func to access unwrapped implementation
return await my_embed.func(texts, **kwargs)
```
3. Wrapper calling decorated function (properly decorated):
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536)
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
# Calling .func avoids double decoration
return await my_embed.func(texts, **kwargs)
```
The decorated function becomes an EmbeddingFunc instance with:
- embedding_dim: The embedding dimension
- max_token_size: Maximum token limit (optional)
- func: The original unwrapped function (access via .func)
- __call__: Wrapper that injects embedding_dim parameter
Double decoration causes:
- Double injection of embedding_dim parameter
- Incorrect parameter passing to the underlying implementation
- Runtime errors due to parameter conflicts
Args:
embedding_dim: The dimension of embedding vectors
max_token_size: Maximum number of tokens (optional)
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
Returns:
A decorator that wraps the function as an EmbeddingFunc instance
Example of correct wrapper implementation:
```python
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(...)
async def openai_embed(texts, ...):
# Base implementation
pass
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
async def azure_openai_embed(texts, ...):
# CRITICAL: Call .func to access unwrapped function
return await openai_embed.func(texts, ...) # ✅ Correct
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
```
"""
"""Wrap a function with attributes"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
@ -1090,123 +922,9 @@ def load_json(file_name):
return json.load(f)
def _sanitize_string_for_json(text: str) -> str:
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
Uses regex for optimal performance with zero-copy optimization for clean strings.
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
Args:
text: String to sanitize
Returns:
Original string if clean (zero-copy), sanitized string if dirty
"""
if not text:
return text
# Fast path: Check if sanitization is needed using C-level regex search
if not _SURROGATE_PATTERN.search(text):
return text # Zero-copy for clean strings - most common case
# Slow path: Remove problematic characters using C-level regex substitution
return _SURROGATE_PATTERN.sub("", text)
class SanitizingJSONEncoder(json.JSONEncoder):
"""
Custom JSON encoder that sanitizes data during serialization.
This encoder cleans strings during the encoding process without creating
a full copy of the data structure, making it memory-efficient for large datasets.
"""
def encode(self, o):
"""Override encode method to handle simple string cases"""
if isinstance(o, str):
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
return super().encode(o)
def iterencode(self, o, _one_shot=False):
"""
Override iterencode to sanitize strings during serialization.
This is the core method that handles complex nested structures.
"""
# Preprocess: sanitize all strings in the object
sanitized = self._sanitize_for_encoding(o)
# Call parent's iterencode with sanitized data
for chunk in super().iterencode(sanitized, _one_shot):
yield chunk
def _sanitize_for_encoding(self, obj):
"""
Recursively sanitize strings in an object.
Creates new objects only when necessary to avoid deep copies.
Args:
obj: Object to sanitize
Returns:
Sanitized object with cleaned strings
"""
if isinstance(obj, str):
return _sanitize_string_for_json(obj)
elif isinstance(obj, dict):
# Create new dict with sanitized keys and values
new_dict = {}
for k, v in obj.items():
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
clean_v = self._sanitize_for_encoding(v)
new_dict[clean_k] = clean_v
return new_dict
elif isinstance(obj, (list, tuple)):
# Sanitize list/tuple elements
cleaned = [self._sanitize_for_encoding(item) for item in obj]
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
else:
# Numbers, booleans, None, etc. remain unchanged
return obj
def write_json(json_obj, file_name):
"""
Write JSON data to file with optimized sanitization strategy.
This function uses a two-stage approach:
1. Fast path: Try direct serialization (works for clean data ~99% of time)
2. Slow path: Use custom encoder that sanitizes during serialization
The custom encoder approach avoids creating a deep copy of the data,
making it memory-efficient. When sanitization occurs, the caller should
reload the cleaned data from the file to update shared memory.
Args:
json_obj: Object to serialize (may be a shallow copy from shared memory)
file_name: Output file path
Returns:
bool: True if sanitization was applied (caller should reload data),
False if direct write succeeded (no reload needed)
"""
try:
# Strategy 1: Fast path - try direct serialization
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
return False # No sanitization needed, no reload required
except (UnicodeEncodeError, UnicodeDecodeError) as e:
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
logger.info(f"JSON sanitization applied during write: {file_name}")
return True # Sanitization applied, reload recommended
json.dump(json_obj, f, indent=2, ensure_ascii=False)
class TokenizerInterface(Protocol):