Merge branch 'HKUDS:main' into main

This commit is contained in:
Hầu Phi Dao 2025-11-17 12:05:19 +07:00 committed by GitHub
commit 43a9b307bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 3698 additions and 817 deletions

View file

@ -23,10 +23,11 @@ LightRAG uses dynamic package installation (`pipmaster`) for optional features b
LightRAG dynamically installs packages for:
- **Document Processing**: `docling`, `pypdf2`, `python-docx`, `python-pptx`, `openpyxl`
- **Storage Backends**: `redis`, `neo4j`, `pymilvus`, `pymongo`, `asyncpg`, `qdrant-client`
- **LLM Providers**: `openai`, `anthropic`, `ollama`, `zhipuai`, `aioboto3`, `voyageai`, `llama-index`, `lmdeploy`, `transformers`, `torch`
- Tiktoken Models**: BPE encoding models downloaded from OpenAI CDN
- **Tiktoken Models**: BPE encoding models downloaded from OpenAI CDN
**Note**: Document processing dependencies (`pypdf`, `python-docx`, `python-pptx`, `openpyxl`) are now pre-installed with the `api` extras group and no longer require dynamic installation.
## Quick Start
@ -75,32 +76,31 @@ LightRAG provides flexible dependency groups for different use cases:
| Group | Description | Use Case |
|-------|-------------|----------|
| `offline-docs` | Document processing | PDF, DOCX, PPTX, XLSX files |
| `api` | API server + document processing | FastAPI server with PDF, DOCX, PPTX, XLSX support |
| `offline-storage` | Storage backends | Redis, Neo4j, MongoDB, PostgreSQL, etc. |
| `offline-llm` | LLM providers | OpenAI, Anthropic, Ollama, etc. |
| `offline` | All of the above | Complete offline deployment |
| `offline` | Complete offline package | API + Storage + LLM (all features) |
**Note**: Document processing (PDF, DOCX, PPTX, XLSX) is included in the `api` extras group. The previous `offline-docs` group has been merged into `api` for better integration.
> Software packages requiring `transformers`, `torch`, or `cuda` will not be included in the offline dependency group.
### Installation Examples
```bash
# Install only document processing dependencies
pip install lightrag-hku[offline-docs]
# Install API with document processing
pip install lightrag-hku[api]
# Install document processing and storage backends
pip install lightrag-hku[offline-docs,offline-storage]
# Install API and storage backends
pip install lightrag-hku[api,offline-storage]
# Install all offline dependencies
# Install all offline dependencies (recommended for offline deployment)
pip install lightrag-hku[offline]
```
### Using Individual Requirements Files
```bash
# Document processing only
pip install -r requirements-offline-docs.txt
# Storage backends only
pip install -r requirements-offline-storage.txt
@ -244,8 +244,8 @@ ls -la ~/.tiktoken_cache/
**Solution**:
```bash
# Pre-install the specific package you need
# For document processing:
pip install lightrag-hku[offline-docs]
# For API with document processing:
pip install lightrag-hku[api]
# For storage backends:
pip install lightrag-hku[offline-storage]
@ -297,9 +297,9 @@ mkdir -p ~/my_tiktoken_cache
5. **Minimal Installation**: Only install what you need:
```bash
# If you only process PDFs with OpenAI
pip install lightrag-hku[offline-docs]
# Then manually add: pip install openai
# If you only need API with document processing
pip install lightrag-hku[api]
# Then manually add specific LLM: pip install openai
```
## Additional Resources

View file

@ -29,7 +29,7 @@ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
# OLLAMA_EMULATING_MODEL_NAME=lightrag
OLLAMA_EMULATING_MODEL_TAG=latest
### Max nodes return from graph retrieval in webui
### Max nodes for graph retrieval (Ensure WebUI local settings are also updated, which is limited to this value)
# MAX_GRAPH_NODES=1000
### Logging level
@ -255,21 +255,23 @@ OLLAMA_LLM_NUM_CTX=32768
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
# EMBEDDING_SEND_DIM=false
EMBEDDING_BINDING=ollama
EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024
EMBEDDING_BINDING_API_KEY=your_api_key
# If LightRAG deployed in Docker uses host.docker.internal instead of localhost
EMBEDDING_BINDING_HOST=http://localhost:11434
### OpenAI compatible (VoyageAI embedding openai compatible)
# EMBEDDING_BINDING=openai
# EMBEDDING_MODEL=text-embedding-3-large
# EMBEDDING_DIM=3072
# EMBEDDING_BINDING_HOST=https://api.openai.com/v1
# Ollama embedding
# EMBEDDING_BINDING=ollama
# EMBEDDING_MODEL=bge-m3:latest
# EMBEDDING_DIM=1024
# EMBEDDING_BINDING_API_KEY=your_api_key
### If LightRAG deployed in Docker uses host.docker.internal instead of localhost
# EMBEDDING_BINDING_HOST=http://localhost:11434
### OpenAI compatible embedding
EMBEDDING_BINDING=openai
EMBEDDING_MODEL=text-embedding-3-large
EMBEDDING_DIM=3072
EMBEDDING_SEND_DIM=false
EMBEDDING_TOKEN_LIMIT=8192
EMBEDDING_BINDING_HOST=https://api.openai.com/v1
EMBEDDING_BINDING_API_KEY=your_api_key
### Optional for Azure
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
@ -277,6 +279,16 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
# AZURE_EMBEDDING_API_KEY=your_api_key
### Gemini embedding
# EMBEDDING_BINDING=gemini
# EMBEDDING_MODEL=gemini-embedding-001
# EMBEDDING_DIM=1536
# EMBEDDING_TOKEN_LIMIT=2048
# EMBEDDING_BINDING_HOST=https://generativelanguage.googleapis.com
# EMBEDDING_BINDING_API_KEY=your_api_key
### Gemini embedding requires sending dimension to server
# EMBEDDING_SEND_DIM=true
### Jina AI Embedding
# EMBEDDING_BINDING=jina
# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings

View file

@ -53,8 +53,8 @@ def xml_to_json(xml_file):
"description": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d7']", namespace).text
if edge.find("./data[@key='d7']", namespace) is not None
"keywords": edge.find("./data[@key='d9']", namespace).text
if edge.find("./data[@key='d9']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d8']", namespace).text
if edge.find("./data[@key='d8']", namespace) is not None

View file

@ -1 +1 @@
__api_version__ = "0253"
__api_version__ = "0254"

View file

@ -258,6 +258,14 @@ def parse_args() -> argparse.Namespace:
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
)
# Document loading engine configuration
parser.add_argument(
"--docling",
action="store_true",
default=False,
help="Enable DOCLING document loading engine (default: from env or DEFAULT)",
)
# Conditionally add binding options defined in binding_options module
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
# and corresponding environment variables (e.g., OLLAMA_EMBEDDING_NUM_CTX)
@ -371,8 +379,13 @@ def parse_args() -> argparse.Namespace:
)
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
# Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
# Set document_loading_engine from --docling flag
if args.docling:
args.document_loading_engine = "DOCLING"
else:
args.document_loading_engine = get_env_value(
"DOCUMENT_LOADING_ENGINE", "DEFAULT"
)
# PDF decryption password
args.pdf_decrypt_password = get_env_value("PDF_DECRYPT_PASSWORD", None)
@ -432,6 +445,11 @@ def parse_args() -> argparse.Namespace:
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
)
# Embedding token limit configuration
args.embedding_token_limit = get_env_value(
"EMBEDDING_TOKEN_LIMIT", None, int, special_none=True
)
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
@ -449,4 +467,83 @@ def update_uvicorn_mode_config():
)
global_args = parse_args()
# Global configuration with lazy initialization
_global_args = None
_initialized = False
def initialize_config(args=None, force=False):
"""Initialize global configuration
This function allows explicit initialization of the configuration,
which is useful for programmatic usage, testing, or embedding LightRAG
in other applications.
Args:
args: Pre-parsed argparse.Namespace or None to parse from sys.argv
force: Force re-initialization even if already initialized
Returns:
argparse.Namespace: The configured arguments
Example:
# Use parsed command line arguments (default)
initialize_config()
# Use custom configuration programmatically
custom_args = argparse.Namespace(
host='localhost',
port=8080,
working_dir='./custom_rag',
# ... other config
)
initialize_config(custom_args)
"""
global _global_args, _initialized
if _initialized and not force:
return _global_args
_global_args = args if args is not None else parse_args()
_initialized = True
return _global_args
def get_config():
"""Get global configuration, auto-initializing if needed
Returns:
argparse.Namespace: The configured arguments
"""
if not _initialized:
initialize_config()
return _global_args
class _GlobalArgsProxy:
"""Proxy object that auto-initializes configuration on first access
This maintains backward compatibility with existing code while
allowing programmatic control over initialization timing.
"""
def __getattr__(self, name):
if not _initialized:
initialize_config()
return getattr(_global_args, name)
def __setattr__(self, name, value):
if not _initialized:
initialize_config()
setattr(_global_args, name, value)
def __repr__(self):
if not _initialized:
return "<GlobalArgsProxy: Not initialized>"
return repr(_global_args)
# Create proxy instance for backward compatibility
# Existing code like `from config import global_args` continues to work
# The proxy will auto-initialize on first attribute access
global_args = _GlobalArgsProxy()

View file

@ -618,33 +618,108 @@ def create_app(args):
def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, args
):
) -> EmbeddingFunc:
"""
Create optimized embedding function with pre-processed configuration for applicable bindings.
Uses lazy imports for all bindings and avoids repeated configuration parsing.
Create optimized embedding function and return an EmbeddingFunc instance
with proper max_token_size inheritance from provider defaults.
This function:
1. Imports the provider embedding function
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
4. Returns a properly configured EmbeddingFunc instance
"""
# Step 1: Import provider function and extract default attributes
provider_func = None
provider_max_token_size = None
provider_embedding_dim = None
try:
if binding == "openai":
from lightrag.llm.openai import openai_embed
provider_func = openai_embed
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
provider_func = ollama_embed
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
provider_func = gemini_embed
elif binding == "jina":
from lightrag.llm.jina import jina_embed
provider_func = jina_embed
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
provider_func = azure_openai_embed
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
provider_func = bedrock_embed
elif binding == "lollms":
from lightrag.llm.lollms import lollms_embed
provider_func = lollms_embed
# Extract attributes if provider is an EmbeddingFunc
if provider_func and isinstance(provider_func, EmbeddingFunc):
provider_max_token_size = provider_func.max_token_size
provider_embedding_dim = provider_func.embedding_dim
logger.debug(
f"Extracted from {binding} provider: "
f"max_token_size={provider_max_token_size}, "
f"embedding_dim={provider_embedding_dim}"
)
except ImportError as e:
logger.warning(f"Could not import provider function for {binding}: {e}")
# Step 2: Apply priority (user config > provider default)
# For max_token_size: explicit env var > provider default > None
final_max_token_size = args.embedding_token_limit or provider_max_token_size
# For embedding_dim: user config (always has value) takes priority
# Only use provider default if user config is explicitly None (which shouldn't happen)
final_embedding_dim = (
args.embedding_dim if args.embedding_dim else provider_embedding_dim
)
# Step 3: Create optimized embedding function (calls underlying function directly)
async def optimized_embedding_function(texts, embedding_dim=None):
try:
if binding == "lollms":
from lightrag.llm.lollms import lollms_embed
return await lollms_embed(
# Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
lollms_embed.func
if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed
)
return await actual_func(
texts, embed_model=model, host=host, api_key=api_key
)
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
# Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
ollama_embed.func
if isinstance(ollama_embed, EmbeddingFunc)
else ollama_embed
)
# Use pre-processed configuration if available
if config_cache.ollama_embedding_options is not None:
ollama_options = config_cache.ollama_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import OllamaEmbeddingOptions
ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await ollama_embed(
return await actual_func(
texts,
embed_model=model,
host=host,
@ -654,15 +729,30 @@ def create_app(args):
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
return await azure_openai_embed(texts, model=model, api_key=api_key)
actual_func = (
azure_openai_embed.func
if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed
)
return await actual_func(texts, model=model, api_key=api_key)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
return await bedrock_embed(texts, model=model)
actual_func = (
bedrock_embed.func
if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed
)
return await actual_func(texts, model=model)
elif binding == "jina":
from lightrag.llm.jina import jina_embed
return await jina_embed(
actual_func = (
jina_embed.func
if isinstance(jina_embed, EmbeddingFunc)
else jina_embed
)
return await actual_func(
texts,
embedding_dim=embedding_dim,
base_url=host,
@ -671,16 +761,21 @@ def create_app(args):
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
actual_func = (
gemini_embed.func
if isinstance(gemini_embed, EmbeddingFunc)
else gemini_embed
)
# Use pre-processed configuration if available
if config_cache.gemini_embedding_options is not None:
gemini_options = config_cache.gemini_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import GeminiEmbeddingOptions
gemini_options = GeminiEmbeddingOptions.options_dict(args)
return await gemini_embed(
return await actual_func(
texts,
model=model,
base_url=host,
@ -691,7 +786,12 @@ def create_app(args):
else: # openai and compatible
from lightrag.llm.openai import openai_embed
return await openai_embed(
actual_func = (
openai_embed.func
if isinstance(openai_embed, EmbeddingFunc)
else openai_embed
)
return await actual_func(
texts,
model=model,
base_url=host,
@ -701,7 +801,21 @@ def create_app(args):
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")
return optimized_embedding_function
# Step 4: Wrap in EmbeddingFunc and return
embedding_func_instance = EmbeddingFunc(
embedding_dim=final_embedding_dim,
func=optimized_embedding_function,
max_token_size=final_max_token_size,
send_dimensions=False, # Will be set later based on binding requirements
)
# Log final embedding configuration
logger.info(
f"Embedding config: binding={binding} model={model} "
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
)
return embedding_func_instance
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value(
@ -735,25 +849,24 @@ def create_app(args):
**kwargs,
)
# Create embedding function with optimized configuration
# Create embedding function with optimized configuration and max_token_size inheritance
import inspect
# Create the optimized embedding function
optimized_embedding_func = create_optimized_embedding_function(
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
embedding_func = create_optimized_embedding_function(
config_cache=config_cache,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
args=args, # Pass args object for fallback option generation
args=args,
)
# Get embedding_send_dim from centralized configuration
embedding_send_dim = args.embedding_send_dim
# Check if the function signature has embedding_dim parameter
# Note: Since optimized_embedding_func is an async function, inspect its signature
sig = inspect.signature(optimized_embedding_func)
# Check if the underlying function signature has embedding_dim parameter
sig = inspect.signature(embedding_func.func)
has_embedding_dim_param = "embedding_dim" in sig.parameters
# Determine send_dimensions value based on binding type
@ -771,18 +884,27 @@ def create_app(args):
else:
dimension_control = "by not hasparam"
# Set send_dimensions on the EmbeddingFunc instance
embedding_func.send_dimensions = send_dimensions
logger.info(
f"Send embedding dimension: {send_dimensions} {dimension_control} "
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
f"binding={args.embedding_binding})"
)
# Create EmbeddingFunc with send_dimensions attribute
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
func=optimized_embedding_func,
send_dimensions=send_dimensions,
)
# Log max_token_size source
if embedding_func.max_token_size:
source = (
"env variable"
if args.embedding_token_limit
else f"{args.embedding_binding} provider default"
)
logger.info(
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
)
else:
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
# Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None
@ -1214,6 +1336,12 @@ def check_and_install_dependencies():
def main():
# Explicitly initialize configuration for clarity
# (The proxy will auto-initialize anyway, but this makes intent clear)
from .config import initialize_config
initialize_config()
# Check if running under Gunicorn
if "GUNICORN_CMD_ARGS" in os.environ:
# If started with Gunicorn, return directly as Gunicorn will call get_application

View file

@ -3,14 +3,15 @@ This module contains all document-related routes for the LightRAG API.
"""
import asyncio
from functools import lru_cache
from lightrag.utils import logger, get_pinyin_sort_key
import aiofiles
import shutil
import traceback
import pipmaster as pm
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Optional, Any, Literal
from io import BytesIO
from fastapi import (
APIRouter,
BackgroundTasks,
@ -28,6 +29,24 @@ from lightrag.api.utils_api import get_combined_auth_dependency
from ..config import global_args
@lru_cache(maxsize=1)
def _is_docling_available() -> bool:
"""Check if docling is available (cached check).
This function uses lru_cache to avoid repeated import attempts.
The result is cached after the first call.
Returns:
bool: True if docling is available, False otherwise
"""
try:
import docling # noqa: F401 # type: ignore[import-not-found]
return True
except ImportError:
return False
# Function to format datetime to ISO format string with timezone information
def format_datetime(dt: Any) -> Optional[str]:
"""Format datetime to ISO format string with timezone information
@ -879,7 +898,6 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
Returns:
str: Unique filename (may have numeric suffix added)
"""
from pathlib import Path
import time
original_path = Path(original_name)
@ -902,6 +920,122 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
return f"{base_name}_{timestamp}{extension}"
# Document processing helper functions (synchronous)
# These functions run in thread pool via asyncio.to_thread() to avoid blocking the event loop
def _convert_with_docling(file_path: Path) -> str:
"""Convert document using docling (synchronous).
Args:
file_path: Path to the document file
Returns:
str: Extracted markdown content
"""
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
return result.document.export_to_markdown()
def _extract_pdf_pypdf(file_bytes: bytes, password: str = None) -> str:
"""Extract PDF content using pypdf (synchronous).
Args:
file_bytes: PDF file content as bytes
password: Optional password for encrypted PDFs
Returns:
str: Extracted text content
Raises:
Exception: If PDF is encrypted and password is incorrect or missing
"""
from pypdf import PdfReader # type: ignore
pdf_file = BytesIO(file_bytes)
reader = PdfReader(pdf_file)
# Check if PDF is encrypted
if reader.is_encrypted:
if not password:
raise Exception("PDF is encrypted but no password provided")
decrypt_result = reader.decrypt(password)
if decrypt_result == 0:
raise Exception("Incorrect PDF password")
# Extract text from all pages
content = ""
for page in reader.pages:
content += page.extract_text() + "\n"
return content
def _extract_docx(file_bytes: bytes) -> str:
"""Extract DOCX content (synchronous).
Args:
file_bytes: DOCX file content as bytes
Returns:
str: Extracted text content
"""
from docx import Document # type: ignore
docx_file = BytesIO(file_bytes)
doc = Document(docx_file)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
def _extract_pptx(file_bytes: bytes) -> str:
"""Extract PPTX content (synchronous).
Args:
file_bytes: PPTX file content as bytes
Returns:
str: Extracted text content
"""
from pptx import Presentation # type: ignore
pptx_file = BytesIO(file_bytes)
prs = Presentation(pptx_file)
content = ""
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
return content
def _extract_xlsx(file_bytes: bytes) -> str:
"""Extract XLSX content (synchronous).
Args:
file_bytes: XLSX file content as bytes
Returns:
str: Extracted text content
"""
from openpyxl import load_workbook # type: ignore
xlsx_file = BytesIO(file_bytes)
wb = load_workbook(xlsx_file)
content = ""
for sheet in wb:
content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True):
content += (
"\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
)
content += "\n"
return content
async def pipeline_enqueue_file(
rag: LightRAG, file_path: Path, track_id: str = None
) -> tuple[bool, str]:
@ -1072,87 +1206,28 @@ async def pipeline_enqueue_file(
case ".pdf":
try:
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
content = result.document.export_to_markdown()
# Try DOCLING first if configured and available
if (
global_args.document_loading_engine == "DOCLING"
and _is_docling_available()
):
content = await asyncio.to_thread(
_convert_with_docling, file_path
)
else:
if not pm.is_installed("pypdf"): # type: ignore
pm.install("pypdf")
if not pm.is_installed("pycryptodome"): # type: ignore
pm.install("pycryptodome")
from pypdf import PdfReader # type: ignore
from io import BytesIO
pdf_file = BytesIO(file)
reader = PdfReader(pdf_file)
# Check if PDF is encrypted
if reader.is_encrypted:
pdf_password = global_args.pdf_decrypt_password
if not pdf_password:
# PDF is encrypted but no password provided
error_files = [
{
"file_path": str(file_path.name),
"error_description": "[File Extraction]PDF is encrypted but no password provided",
"original_error": "Please set PDF_DECRYPT_PASSWORD environment variable to decrypt this PDF file",
"file_size": file_size,
}
]
await rag.apipeline_enqueue_error_documents(
error_files, track_id
)
logger.error(
f"[File Extraction]PDF is encrypted but no password provided: {file_path.name}"
)
return False, track_id
# Try to decrypt with password
try:
decrypt_result = reader.decrypt(pdf_password)
if decrypt_result == 0:
# Password is incorrect
error_files = [
{
"file_path": str(file_path.name),
"error_description": "[File Extraction]Failed to decrypt PDF - incorrect password",
"original_error": "The provided PDF_DECRYPT_PASSWORD is incorrect for this file",
"file_size": file_size,
}
]
await rag.apipeline_enqueue_error_documents(
error_files, track_id
)
logger.error(
f"[File Extraction]Incorrect PDF password: {file_path.name}"
)
return False, track_id
except Exception as decrypt_error:
# Decryption process error
error_files = [
{
"file_path": str(file_path.name),
"error_description": "[File Extraction]PDF decryption failed",
"original_error": f"Error during PDF decryption: {str(decrypt_error)}",
"file_size": file_size,
}
]
await rag.apipeline_enqueue_error_documents(
error_files, track_id
)
logger.error(
f"[File Extraction]PDF decryption error for {file_path.name}: {str(decrypt_error)}"
)
return False, track_id
# Extract text from PDF (encrypted PDFs are now decrypted, unencrypted PDFs proceed directly)
for page in reader.pages:
content += page.extract_text() + "\n"
if (
global_args.document_loading_engine == "DOCLING"
and not _is_docling_available()
):
logger.warning(
f"DOCLING engine configured but not available for {file_path.name}. Falling back to pypdf."
)
# Use pypdf (non-blocking via to_thread)
content = await asyncio.to_thread(
_extract_pdf_pypdf,
file,
global_args.pdf_decrypt_password,
)
except Exception as e:
error_files = [
{
@ -1172,28 +1247,24 @@ async def pipeline_enqueue_file(
case ".docx":
try:
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
content = result.document.export_to_markdown()
else:
if not pm.is_installed("python-docx"): # type: ignore
try:
pm.install("python-docx")
except Exception:
pm.install("docx")
from docx import Document # type: ignore
from io import BytesIO
docx_file = BytesIO(file)
doc = Document(docx_file)
content = "\n".join(
[paragraph.text for paragraph in doc.paragraphs]
# Try DOCLING first if configured and available
if (
global_args.document_loading_engine == "DOCLING"
and _is_docling_available()
):
content = await asyncio.to_thread(
_convert_with_docling, file_path
)
else:
if (
global_args.document_loading_engine == "DOCLING"
and not _is_docling_available()
):
logger.warning(
f"DOCLING engine configured but not available for {file_path.name}. Falling back to python-docx."
)
# Use python-docx (non-blocking via to_thread)
content = await asyncio.to_thread(_extract_docx, file)
except Exception as e:
error_files = [
{
@ -1213,26 +1284,24 @@ async def pipeline_enqueue_file(
case ".pptx":
try:
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
content = result.document.export_to_markdown()
# Try DOCLING first if configured and available
if (
global_args.document_loading_engine == "DOCLING"
and _is_docling_available()
):
content = await asyncio.to_thread(
_convert_with_docling, file_path
)
else:
if not pm.is_installed("python-pptx"): # type: ignore
pm.install("pptx")
from pptx import Presentation # type: ignore
from io import BytesIO
pptx_file = BytesIO(file)
prs = Presentation(pptx_file)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
content += shape.text + "\n"
if (
global_args.document_loading_engine == "DOCLING"
and not _is_docling_available()
):
logger.warning(
f"DOCLING engine configured but not available for {file_path.name}. Falling back to python-pptx."
)
# Use python-pptx (non-blocking via to_thread)
content = await asyncio.to_thread(_extract_pptx, file)
except Exception as e:
error_files = [
{
@ -1252,33 +1321,24 @@ async def pipeline_enqueue_file(
case ".xlsx":
try:
if global_args.document_loading_engine == "DOCLING":
if not pm.is_installed("docling"): # type: ignore
pm.install("docling")
from docling.document_converter import DocumentConverter # type: ignore
converter = DocumentConverter()
result = converter.convert(file_path)
content = result.document.export_to_markdown()
# Try DOCLING first if configured and available
if (
global_args.document_loading_engine == "DOCLING"
and _is_docling_available()
):
content = await asyncio.to_thread(
_convert_with_docling, file_path
)
else:
if not pm.is_installed("openpyxl"): # type: ignore
pm.install("openpyxl")
from openpyxl import load_workbook # type: ignore
from io import BytesIO
xlsx_file = BytesIO(file)
wb = load_workbook(xlsx_file)
for sheet in wb:
content += f"Sheet: {sheet.title}\n"
for row in sheet.iter_rows(values_only=True):
content += (
"\t".join(
str(cell) if cell is not None else ""
for cell in row
)
+ "\n"
)
content += "\n"
if (
global_args.document_loading_engine == "DOCLING"
and not _is_docling_available()
):
logger.warning(
f"DOCLING engine configured but not available for {file_path.name}. Falling back to openpyxl."
)
# Use openpyxl (non-blocking via to_thread)
content = await asyncio.to_thread(_extract_xlsx, file)
except Exception as e:
error_files = [
{

View file

@ -5,6 +5,7 @@ Start LightRAG server with Gunicorn
import os
import sys
import platform
import pipmaster as pm
from lightrag.api.utils_api import display_splash_screen, check_env_file
from lightrag.api.config import global_args
@ -34,6 +35,11 @@ def check_and_install_dependencies():
def main():
# Explicitly initialize configuration for Gunicorn mode
from lightrag.api.config import initialize_config
initialize_config()
# Set Gunicorn mode flag for lifespan cleanup detection
os.environ["LIGHTRAG_GUNICORN_MODE"] = "1"
@ -41,6 +47,68 @@ def main():
if not check_env_file():
sys.exit(1)
# Check DOCLING compatibility with Gunicorn multi-worker mode on macOS
if (
platform.system() == "Darwin"
and global_args.document_loading_engine == "DOCLING"
and global_args.workers > 1
):
print("\n" + "=" * 80)
print("❌ ERROR: Incompatible configuration detected!")
print("=" * 80)
print(
"\nDOCLING engine with Gunicorn multi-worker mode is not supported on macOS"
)
print("\nReason:")
print(" PyTorch (required by DOCLING) has known compatibility issues with")
print(" fork-based multiprocessing on macOS, which can cause crashes or")
print(" unexpected behavior when using Gunicorn with multiple workers.")
print("\nCurrent configuration:")
print(" - Operating System: macOS (Darwin)")
print(f" - Document Engine: {global_args.document_loading_engine}")
print(f" - Workers: {global_args.workers}")
print("\nPossible solutions:")
print(" 1. Use single worker mode:")
print(" --workers 1")
print("\n 2. Change document loading engine in .env:")
print(" DOCUMENT_LOADING_ENGINE=DEFAULT")
print("\n 3. Deploy on Linux where multi-worker mode is fully supported")
print("=" * 80 + "\n")
sys.exit(1)
# Check macOS fork safety environment variable for multi-worker mode
if (
platform.system() == "Darwin"
and global_args.workers > 1
and os.environ.get("OBJC_DISABLE_INITIALIZE_FORK_SAFETY") != "YES"
):
print("\n" + "=" * 80)
print("❌ ERROR: Missing required environment variable on macOS!")
print("=" * 80)
print("\nmacOS with Gunicorn multi-worker mode requires:")
print(" OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES")
print("\nReason:")
print(" NumPy uses macOS's Accelerate framework (Objective-C based) for")
print(" vector computations. The Objective-C runtime has fork safety checks")
print(" that will crash worker processes when embedding functions are called.")
print("\nCurrent configuration:")
print(" - Operating System: macOS (Darwin)")
print(f" - Workers: {global_args.workers}")
print(
f" - Environment Variable: {os.environ.get('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'NOT SET')}"
)
print("\nHow to fix:")
print(" Option 1 - Set environment variable before starting (recommended):")
print(" export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES")
print(" lightrag-server")
print("\n Option 2 - Add to your shell profile (~/.zshrc or ~/.bash_profile):")
print(" echo 'export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES' >> ~/.zshrc")
print(" source ~/.zshrc")
print("\n Option 3 - Use single worker mode (no multiprocessing):")
print(" lightrag-server --workers 1")
print("=" * 80 + "\n")
sys.exit(1)
# Check and install dependencies
check_and_install_dependencies()

View file

@ -161,7 +161,20 @@ class JsonDocStatusStorage(DocStatusStorage):
logger.debug(
f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
)
write_json(data_dict, self._file_name)
# Write JSON and check if sanitization was applied
needs_reload = write_json(data_dict, self._file_name)
# If data was sanitized, reload cleaned data to update shared memory
if needs_reload:
logger.info(
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
)
cleaned_data = load_json(self._file_name)
if cleaned_data is not None:
self._data.clear()
self._data.update(cleaned_data)
await clear_all_update_flags(self.final_namespace)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:

View file

@ -81,7 +81,20 @@ class JsonKVStorage(BaseKVStorage):
logger.debug(
f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
)
write_json(data_dict, self._file_name)
# Write JSON and check if sanitization was applied
needs_reload = write_json(data_dict, self._file_name)
# If data was sanitized, reload cleaned data to update shared memory
if needs_reload:
logger.info(
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
)
cleaned_data = load_json(self._file_name)
if cleaned_data is not None:
self._data.clear()
self._data.update(cleaned_data)
await clear_all_update_flags(self.final_namespace)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@ -224,7 +237,7 @@ class JsonKVStorage(BaseKVStorage):
data: Original data dictionary that may contain legacy structure
Returns:
Migrated data dictionary with flattened cache keys
Migrated data dictionary with flattened cache keys (sanitized if needed)
"""
from lightrag.utils import generate_cache_key
@ -261,8 +274,17 @@ class JsonKVStorage(BaseKVStorage):
logger.info(
f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure"
)
# Persist migrated data immediately
write_json(migrated_data, self._file_name)
# Persist migrated data immediately and check if sanitization was applied
needs_reload = write_json(migrated_data, self._file_name)
# If data was sanitized during write, reload cleaned data
if needs_reload:
logger.info(
f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}"
)
cleaned_data = load_json(self._file_name)
if cleaned_data is not None:
return cleaned_data # Return cleaned data to update shared memory
return migrated_data

View file

@ -133,6 +133,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
@ -146,7 +147,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error
if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
@ -170,6 +174,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
workspace_label = self._get_workspace_label()
query = (
@ -190,7 +195,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error
if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise
async def get_node(self, node_id: str) -> dict[str, str] | None:
@ -312,6 +320,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
workspace_label = self._get_workspace_label()
query = f"""
@ -328,7 +337,10 @@ class MemgraphStorage(BaseGraphStorage):
return labels
except Exception as e:
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
await result.consume() # Ensure the result is consumed even on error
if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@ -352,6 +364,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = None
try:
workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -389,7 +402,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
if results is not None:
await (
results.consume()
) # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(
@ -419,6 +435,7 @@ class MemgraphStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
workspace_label = self._get_workspace_label()
query = f"""
@ -451,7 +468,10 @@ class MemgraphStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error
if result is not None:
await (
result.consume()
) # Ensure the result is consumed even on error
raise
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
@ -1030,6 +1050,7 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
result = None
try:
workspace_label = self._get_workspace_label()
async with self._driver.session(
@ -1056,6 +1077,8 @@ class MemgraphStorage(BaseGraphStorage):
return labels
except Exception as e:
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
if result is not None:
await result.consume()
return []
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
@ -1078,6 +1101,7 @@ class MemgraphStorage(BaseGraphStorage):
if not query_lower:
return []
result = None
try:
workspace_label = self._get_workspace_label()
async with self._driver.session(
@ -1111,4 +1135,6 @@ class MemgraphStorage(BaseGraphStorage):
return labels
except Exception as e:
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
if result is not None:
await result.consume()
return []

View file

@ -371,6 +371,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id)
@ -381,7 +382,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error
if result is not None:
await result.consume() # Ensure results are consumed even on error
raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
@ -403,6 +405,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
query = (
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
@ -420,7 +423,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error
if result is not None:
await result.consume() # Ensure results are consumed even on error
raise
async def get_node(self, node_id: str) -> dict[str, str] | None:
@ -799,6 +803,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = None
try:
workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
@ -836,7 +841,10 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
if results is not None:
await (
results.consume()
) # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(
@ -1592,6 +1600,7 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
result = None
try:
query = f"""
MATCH (n:`{workspace_label}`)
@ -1616,7 +1625,8 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(
f"[{self.workspace}] Error getting popular labels: {str(e)}"
)
await result.consume()
if result is not None:
await result.consume()
raise
async def search_labels(self, query: str, limit: int = 50) -> list[str]:

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import traceback
import asyncio
import configparser
import inspect
import os
import time
import warnings
@ -12,6 +13,7 @@ from functools import partial
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Iterator,
cast,
@ -20,6 +22,7 @@ from typing import (
Optional,
List,
Dict,
Union,
)
from lightrag.prompt import PROMPTS
from lightrag.exceptions import PipelineCancelledException
@ -243,11 +246,13 @@ class LightRAG:
int,
int,
],
List[Dict[str, Any]],
Union[List[Dict[str, Any]], Awaitable[List[Dict[str, Any]]]],
] = field(default_factory=lambda: chunking_by_token_size)
"""
Custom chunking function for splitting text into chunks before processing.
The function can be either synchronous or asynchronous.
The function should take the following parameters:
- `tokenizer`: A Tokenizer instance to use for tokenization.
@ -257,7 +262,8 @@ class LightRAG:
- `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
The function should return a list of dictionaries, where each dictionary contains the following keys:
The function should return a list of dictionaries (or an awaitable that resolves to a list),
where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk.
- `content`: The text content of the chunk.
@ -270,6 +276,9 @@ class LightRAG:
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use."""
embedding_token_limit: int | None = field(default=None, init=False)
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
"""Batch size for embedding computations."""
@ -513,6 +522,16 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size
logger.debug(
f"Captured embedding max_token_size: {embedding_max_token_size}"
)
self.embedding_token_limit = embedding_max_token_size
# Step 2: Apply priority wrapper decorator
self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async,
llm_timeout=self.default_embedding_timeout,
@ -1756,7 +1775,28 @@ class LightRAG:
)
content = content_data["content"]
# Generate chunks from document
# Call chunking function, supporting both sync and async implementations
chunking_result = self.chunking_func(
self.tokenizer,
content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
)
# If result is awaitable, await to get actual result
if inspect.isawaitable(chunking_result):
chunking_result = await chunking_result
# Validate return type
if not isinstance(chunking_result, (list, tuple)):
raise TypeError(
f"chunking_func must return a list or tuple of dicts, "
f"got {type(chunking_result)}"
)
# Build chunks dictionary
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
@ -1764,14 +1804,7 @@ class LightRAG:
"file_path": file_path, # Add file path to each chunk
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
}
for dp in self.chunking_func(
self.tokenizer,
content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
)
for dp in chunking_result
}
if not chunks:

View file

@ -1,6 +1,7 @@
import copy
import os
import json
import logging
import pipmaster as pm # Pipmaster for dynamic library install
@ -16,6 +17,7 @@ from tenacity import (
)
import sys
from lightrag.utils import wrap_embedding_func_with_attrs
if sys.version_info < (3, 9):
from typing import AsyncIterator
@ -23,21 +25,121 @@ else:
from collections.abc import AsyncIterator
from typing import Union
# Import botocore exceptions for proper exception handling
try:
from botocore.exceptions import (
ClientError,
ConnectionError as BotocoreConnectionError,
ReadTimeoutError,
)
except ImportError:
# If botocore is not installed, define placeholders
ClientError = Exception
BotocoreConnectionError = Exception
ReadTimeoutError = Exception
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
class BedrockRateLimitError(BedrockError):
"""Error for rate limiting and throttling issues"""
class BedrockConnectionError(BedrockError):
"""Error for network and connection issues"""
class BedrockTimeoutError(BedrockError):
"""Error for timeout issues"""
def _set_env_if_present(key: str, value):
"""Set environment variable only if a non-empty value is provided."""
if value is not None and value != "":
os.environ[key] = value
def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
"""Convert AWS Bedrock exceptions to appropriate custom exceptions.
Args:
e: The exception to handle
operation: Description of the operation for error messages
Raises:
BedrockRateLimitError: For rate limiting and throttling issues (retryable)
BedrockConnectionError: For network and server issues (retryable)
BedrockTimeoutError: For timeout issues (retryable)
BedrockError: For validation and other non-retryable errors
"""
error_message = str(e)
# Handle botocore ClientError with specific error codes
if isinstance(e, ClientError):
error_code = e.response.get("Error", {}).get("Code", "")
error_msg = e.response.get("Error", {}).get("Message", error_message)
# Rate limiting and throttling errors (retryable)
if error_code in [
"ThrottlingException",
"ProvisionedThroughputExceededException",
]:
logging.error(f"{operation} rate limit error: {error_msg}")
raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
# Server errors (retryable)
elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
logging.error(f"{operation} connection error: {error_msg}")
raise BedrockConnectionError(f"Service error: {error_msg}")
# Check for 5xx HTTP status codes (retryable)
elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
logging.error(f"{operation} server error: {error_msg}")
raise BedrockConnectionError(f"Server error: {error_msg}")
# Validation and other client errors (non-retryable)
else:
logging.error(f"{operation} client error: {error_msg}")
raise BedrockError(f"Client error: {error_msg}")
# Connection errors (retryable)
elif isinstance(e, BotocoreConnectionError):
logging.error(f"{operation} connection error: {error_message}")
raise BedrockConnectionError(f"Connection error: {error_message}")
# Timeout errors (retryable)
elif isinstance(e, (ReadTimeoutError, TimeoutError)):
logging.error(f"{operation} timeout error: {error_message}")
raise BedrockTimeoutError(f"Timeout error: {error_message}")
# Custom Bedrock errors (already properly typed)
elif isinstance(
e,
(
BedrockRateLimitError,
BedrockConnectionError,
BedrockTimeoutError,
BedrockError,
),
):
raise
# Unknown errors (non-retryable)
else:
logging.error(f"{operation} unexpected error: {error_message}")
raise BedrockError(f"Unexpected error: {error_message}")
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(BedrockRateLimitError)
| retry_if_exception_type(BedrockConnectionError)
| retry_if_exception_type(BedrockTimeoutError)
),
)
async def bedrock_complete_if_cache(
model,
@ -158,9 +260,6 @@ async def bedrock_complete_if_cache(
break
except Exception as e:
# Log the specific error for debugging
logging.error(f"Bedrock streaming error: {e}")
# Try to clean up resources if possible
if (
iteration_started
@ -175,7 +274,8 @@ async def bedrock_complete_if_cache(
f"Failed to close Bedrock event stream: {close_error}"
)
raise BedrockError(f"Streaming error: {e}")
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock streaming")
finally:
# Clean up the event stream
@ -231,10 +331,8 @@ async def bedrock_complete_if_cache(
return content
except Exception as e:
if isinstance(e, BedrockError):
raise
else:
raise BedrockError(f"Bedrock API error: {e}")
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock converse")
# Generic Bedrock completion function
@ -253,12 +351,16 @@ async def bedrock_complete(
return result
# @wrap_embedding_func_with_attrs(embedding_dim=1024)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(BedrockRateLimitError)
| retry_if_exception_type(BedrockConnectionError)
| retry_if_exception_type(BedrockTimeoutError)
),
)
async def bedrock_embed(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
@ -281,48 +383,101 @@ async def bedrock_embed(
async with session.client(
"bedrock-runtime", region_name=region
) as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
try:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
try:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise BedrockError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
# Validate response structure
if not response_body or "embedding" not in response_body:
raise BedrockError(
f"Invalid embedding response structure for text: {text[:50]}..."
)
embedding = response_body["embedding"]
if not embedding:
raise BedrockError(
f"Received empty embedding for text: {text[:50]}..."
)
embed_texts.append(embedding)
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(
e, "Bedrock embedding (amazon, text chunk)"
)
elif model_provider == "cohere":
try:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
"texts": texts,
"input_type": "search_document",
"truncate": "NONE",
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
# Validate response structure
if not response_body or "embeddings" not in response_body:
raise BedrockError(
"Invalid embedding response structure from Cohere"
)
embeddings = response_body["embeddings"]
if not embeddings or len(embeddings) != len(texts):
raise BedrockError(
f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
)
embed_texts = embeddings
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock embedding (cohere)")
else:
raise BedrockError(
f"Model provider '{model_provider}' is not supported!"
)
response_body = await response.get("body").json()
# Final validation
if not embed_texts:
raise BedrockError("No embeddings generated")
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
return np.array(embed_texts)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock embedding")

View file

@ -453,7 +453,7 @@ async def gemini_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -26,6 +26,7 @@ from lightrag.exceptions import (
)
import torch
import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -141,6 +142,7 @@ async def hf_model_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
# Detect the appropriate device
if torch.cuda.is_available():

View file

@ -58,7 +58,7 @@ async def fetch_data(url, headers, data):
return data_list
@wrap_embedding_func_with_attrs(embedding_dim=2048)
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -174,7 +174,7 @@ async def llama_index_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -26,6 +26,10 @@ from lightrag.exceptions import (
from typing import Union, List
import numpy as np
from lightrag.utils import (
wrap_embedding_func_with_attrs,
)
@retry(
stop=stop_after_attempt(3),
@ -134,6 +138,7 @@ async def lollms_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
) -> np.ndarray:

View file

@ -33,7 +33,7 @@ from lightrag.utils import (
import numpy as np
@wrap_embedding_func_with_attrs(embedding_dim=2048)
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -1,4 +1,6 @@
from collections.abc import AsyncIterator
import os
import re
import pipmaster as pm
@ -22,8 +24,31 @@ from lightrag.exceptions import (
from lightrag.api import __api_version__
import numpy as np
from typing import Union
from lightrag.utils import logger
from typing import Optional, Union
from lightrag.utils import (
wrap_embedding_func_with_attrs,
logger,
)
_OLLAMA_CLOUD_HOST = "https://ollama.com"
_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
if host:
return host
try:
model_name_str = str(model) if model is not None else ""
except (TypeError, ValueError, AttributeError) as e:
logger.warning(f"Failed to convert model to string: {e}, using empty string")
model_name_str = ""
if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
logger.debug(
f"Detected cloud model '{model_name_str}', using Ollama Cloud host"
)
return _OLLAMA_CLOUD_HOST
return host
@retry(
@ -53,6 +78,9 @@ async def _ollama_model_if_cache(
timeout = None
kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None)
# fallback to environment variable when not provided explicitly
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")
headers = {
"Content-Type": "application/json",
"User-Agent": f"LightRAG/{__api_version__}",
@ -60,6 +88,8 @@ async def _ollama_model_if_cache(
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
host = _coerce_host_for_cloud_model(host, model)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
try:
@ -142,8 +172,11 @@ async def ollama_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
if not api_key:
api_key = os.getenv("OLLAMA_API_KEY")
headers = {
"Content-Type": "application/json",
"User-Agent": f"LightRAG/{__api_version__}",
@ -154,6 +187,8 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
host = _coerce_host_for_cloud_model(host, embed_model)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
try:
options = kwargs.pop("options", {})

View file

@ -47,7 +47,7 @@ try:
# Only enable Langfuse if both keys are configured
if langfuse_public_key and langfuse_secret_key:
from langfuse.openai import AsyncOpenAI
from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped]
LANGFUSE_ENABLED = True
logger.info("Langfuse observability enabled for OpenAI client")
@ -604,7 +604,7 @@ async def nvidia_openai_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -345,6 +345,20 @@ async def _summarize_descriptions(
llm_response_cache=llm_response_cache,
cache_type="summary",
)
# Check summary token length against embedding limit
embedding_token_limit = global_config.get("embedding_token_limit")
if embedding_token_limit is not None and summary:
tokenizer = global_config["tokenizer"]
summary_token_count = len(tokenizer.encode(summary))
threshold = int(embedding_token_limit * 0.9)
if summary_token_count > threshold:
logger.warning(
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
f"({embedding_token_limit}) for {description_type}: {description_name}"
)
return summary

View file

@ -56,6 +56,9 @@ if not logger.handlers:
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
# Global import for pypinyin with startup-time logging
try:
import pypinyin
@ -350,9 +353,20 @@ class TaskState:
@dataclass
class EmbeddingFunc:
"""Embedding function wrapper with dimension validation
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
This class should not be wrapped multiple times.
Args:
embedding_dim: Expected dimension of the embeddings
func: The actual embedding function to wrap
max_token_size: Optional token limit for the embedding model
send_dimensions: Whether to inject embedding_dim as a keyword argument
"""
embedding_dim: int
func: callable
max_token_size: int | None = None # deprecated keep it for compatible only
max_token_size: int | None = None # Token limit for the embedding model
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)
@ -376,7 +390,32 @@ class EmbeddingFunc:
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
return await self.func(*args, **kwargs)
# Call the actual embedding function
result = await self.func(*args, **kwargs)
# Validate embedding dimensions using total element count
total_elements = result.size # Total number of elements in the numpy array
expected_dim = self.embedding_dim
# Check if total elements can be evenly divided by embedding_dim
if total_elements % expected_dim != 0:
raise ValueError(
f"Embedding dimension mismatch detected: "
f"total elements ({total_elements}) cannot be evenly divided by "
f"expected dimension ({expected_dim}). "
)
# Optional: Verify vector count matches input text count
actual_vectors = total_elements // expected_dim
if args and isinstance(args[0], (list, tuple)):
expected_vectors = len(args[0])
if actual_vectors != expected_vectors:
raise ValueError(
f"Vector count mismatch: "
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
)
return result
def compute_args_hash(*args: Any) -> str:
@ -930,73 +969,120 @@ def load_json(file_name):
def _sanitize_string_for_json(text: str) -> str:
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
This is a simpler sanitizer specifically for JSON that directly removes
problematic characters without attempting to encode first.
Uses regex for optimal performance with zero-copy optimization for clean strings.
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
Args:
text: String to sanitize
Returns:
Sanitized string safe for UTF-8 encoding in JSON
Original string if clean (zero-copy), sanitized string if dirty
"""
if not text:
return text
# Directly filter out problematic characters without pre-validation
sanitized = ""
for char in text:
code_point = ord(char)
# Skip surrogate characters (U+D800 to U+DFFF) - main cause of encoding errors
if 0xD800 <= code_point <= 0xDFFF:
continue
# Skip other non-characters in Unicode
elif code_point == 0xFFFE or code_point == 0xFFFF:
continue
else:
sanitized += char
# Fast path: Check if sanitization is needed using C-level regex search
if not _SURROGATE_PATTERN.search(text):
return text # Zero-copy for clean strings - most common case
return sanitized
# Slow path: Remove problematic characters using C-level regex substitution
return _SURROGATE_PATTERN.sub("", text)
def _sanitize_json_data(data: Any) -> Any:
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
Handles all JSON-serializable types including:
- Dictionary keys and values
- Lists and tuples (preserves type)
- Nested structures
- Strings at any level
Args:
data: Data to sanitize (dict, list, tuple, str, or other types)
Returns:
Sanitized data with all strings cleaned of problematic characters
class SanitizingJSONEncoder(json.JSONEncoder):
"""
if isinstance(data, dict):
# Sanitize both keys and values
return {
_sanitize_string_for_json(k)
if isinstance(k, str)
else k: _sanitize_json_data(v)
for k, v in data.items()
}
elif isinstance(data, (list, tuple)):
# Handle both lists and tuples, preserve original type
sanitized = [_sanitize_json_data(item) for item in data]
return type(data)(sanitized)
elif isinstance(data, str):
return _sanitize_string_for_json(data)
else:
# Numbers, booleans, None, etc. - return as-is
return data
Custom JSON encoder that sanitizes data during serialization.
This encoder cleans strings during the encoding process without creating
a full copy of the data structure, making it memory-efficient for large datasets.
"""
def encode(self, o):
"""Override encode method to handle simple string cases"""
if isinstance(o, str):
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
return super().encode(o)
def iterencode(self, o, _one_shot=False):
"""
Override iterencode to sanitize strings during serialization.
This is the core method that handles complex nested structures.
"""
# Preprocess: sanitize all strings in the object
sanitized = self._sanitize_for_encoding(o)
# Call parent's iterencode with sanitized data
for chunk in super().iterencode(sanitized, _one_shot):
yield chunk
def _sanitize_for_encoding(self, obj):
"""
Recursively sanitize strings in an object.
Creates new objects only when necessary to avoid deep copies.
Args:
obj: Object to sanitize
Returns:
Sanitized object with cleaned strings
"""
if isinstance(obj, str):
return _sanitize_string_for_json(obj)
elif isinstance(obj, dict):
# Create new dict with sanitized keys and values
new_dict = {}
for k, v in obj.items():
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
clean_v = self._sanitize_for_encoding(v)
new_dict[clean_k] = clean_v
return new_dict
elif isinstance(obj, (list, tuple)):
# Sanitize list/tuple elements
cleaned = [self._sanitize_for_encoding(item) for item in obj]
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
else:
# Numbers, booleans, None, etc. remain unchanged
return obj
def write_json(json_obj, file_name):
# Sanitize data before writing to prevent UTF-8 encoding errors
sanitized_obj = _sanitize_json_data(json_obj)
"""
Write JSON data to file with optimized sanitization strategy.
This function uses a two-stage approach:
1. Fast path: Try direct serialization (works for clean data ~99% of time)
2. Slow path: Use custom encoder that sanitizes during serialization
The custom encoder approach avoids creating a deep copy of the data,
making it memory-efficient. When sanitization occurs, the caller should
reload the cleaned data from the file to update shared memory.
Args:
json_obj: Object to serialize (may be a shallow copy from shared memory)
file_name: Output file path
Returns:
bool: True if sanitization was applied (caller should reload data),
False if direct write succeeded (no reload needed)
"""
try:
# Strategy 1: Fast path - try direct serialization
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
return False # No sanitization needed, no reload required
except (UnicodeEncodeError, UnicodeDecodeError) as e:
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
with open(file_name, "w", encoding="utf-8") as f:
json.dump(sanitized_obj, f, indent=2, ensure_ascii=False)
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
logger.info(f"JSON sanitization applied during write: {file_name}")
return True # Sanitization applied, reload recommended
class TokenizerInterface(Protocol):

View file

@ -40,7 +40,6 @@ export default function QuerySettings() {
// Default values for reset functionality
const defaultValues = useMemo(() => ({
mode: 'mix' as QueryMode,
response_type: 'Multiple Paragraphs',
top_k: 40,
chunk_top_k: 20,
max_entity_tokens: 6000,
@ -153,46 +152,6 @@ export default function QuerySettings() {
</div>
</>
{/* Response Format */}
<>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<label htmlFor="response_format_select" className="ml-1 cursor-help">
{t('retrievePanel.querySettings.responseFormat')}
</label>
</TooltipTrigger>
<TooltipContent side="left">
<p>{t('retrievePanel.querySettings.responseFormatTooltip')}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
<div className="flex items-center gap-1">
<Select
value={querySettings.response_type}
onValueChange={(v) => handleChange('response_type', v)}
>
<SelectTrigger
id="response_format_select"
className="hover:bg-primary/5 h-9 cursor-pointer focus:ring-0 focus:ring-offset-0 focus:outline-0 active:right-0 flex-1 text-left [&>span]:break-all [&>span]:line-clamp-1"
>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectGroup>
<SelectItem value="Multiple Paragraphs">{t('retrievePanel.querySettings.responseFormatOptions.multipleParagraphs')}</SelectItem>
<SelectItem value="Single Paragraph">{t('retrievePanel.querySettings.responseFormatOptions.singleParagraph')}</SelectItem>
<SelectItem value="Bullet Points">{t('retrievePanel.querySettings.responseFormatOptions.bulletPoints')}</SelectItem>
</SelectGroup>
</SelectContent>
</Select>
<ResetButton
onClick={() => handleReset('response_type')}
title="Reset to default (Multiple Paragraphs)"
/>
</div>
</>
{/* Top K */}
<>
<TooltipProvider>

View file

@ -357,6 +357,7 @@ export default function RetrievalTesting() {
const queryParams = {
...state.querySettings,
query: actualQuery,
response_type: 'Multiple Paragraphs',
conversation_history: effectiveHistoryTurns > 0
? prevMessages
.filter((m) => m.isError !== true)

View file

@ -123,7 +123,6 @@ const useSettingsStoreBase = create<SettingsState>()(
querySettings: {
mode: 'global',
response_type: 'Multiple Paragraphs',
top_k: 40,
chunk_top_k: 20,
max_entity_tokens: 6000,
@ -239,7 +238,7 @@ const useSettingsStoreBase = create<SettingsState>()(
{
name: 'settings-storage',
storage: createJSONStorage(() => localStorage),
version: 18,
version: 19,
migrate: (state: any, version: number) => {
if (version < 2) {
state.showEdgeLabel = false
@ -336,6 +335,12 @@ const useSettingsStoreBase = create<SettingsState>()(
// Add userPromptHistory field for older versions
state.userPromptHistory = []
}
if (version < 19) {
// Remove deprecated response_type parameter
if (state.querySettings) {
delete state.querySettings.response_type
}
}
return state
}
}

View file

@ -29,7 +29,7 @@ dependencies = [
"json_repair",
"nano-vectordb",
"networkx",
"numpy",
"numpy>=1.24.0,<2.0.0",
"pandas>=2.0.0,<2.4.0",
"pipmaster",
"pydantic",
@ -50,7 +50,7 @@ api = [
"json_repair",
"nano-vectordb",
"networkx",
"numpy",
"numpy>=1.24.0,<2.0.0",
"openai>=1.0.0,<3.0.0",
"pandas>=2.0.0,<2.4.0",
"pipmaster",
@ -79,18 +79,23 @@ api = [
"python-multipart",
"pytz",
"uvicorn",
"gunicorn",
# Document processing dependencies (required for API document upload functionality)
"openpyxl>=3.0.0,<4.0.0", # XLSX processing
"pycryptodome>=3.0.0,<4.0.0", # PDF encryption support
"pypdf>=6.1.0", # PDF processing
"python-docx>=0.8.11,<2.0.0", # DOCX processing
"python-pptx>=0.6.21,<2.0.0", # PPTX processing
]
# Advanced document processing engine (optional)
docling = [
# On macOS, pytorch and frameworks use Objective-C are not fork-safe,
# and not compatible to gunicorn multi-worker mode
"docling>=2.0.0,<3.0.0; sys_platform != 'darwin'",
]
# Offline deployment dependencies (layered design for flexibility)
offline-docs = [
# Document processing dependencies
"openpyxl>=3.0.0,<4.0.0",
"pycryptodome>=3.0.0,<4.0.0",
"pypdf>=6.1.0",
"python-docx>=0.8.11,<2.0.0",
"python-pptx>=0.6.21,<2.0.0",
]
offline-storage = [
# Storage backend dependencies
"redis>=5.0.0,<8.0.0",
@ -115,8 +120,8 @@ offline-llm = [
]
offline = [
# Complete offline package (includes all offline dependencies)
"lightrag-hku[offline-docs,offline-storage,offline-llm]",
# Complete offline package (includes api for document processing, plus storage and LLM)
"lightrag-hku[api,offline-storage,offline-llm]",
]
evaluation = [

View file

@ -1,15 +0,0 @@
# LightRAG Offline Dependencies - Document Processing
# Install with: pip install -r requirements-offline-docs.txt
# For offline installation:
# pip download -r requirements-offline-docs.txt -d ./packages
# pip install --no-index --find-links=./packages -r requirements-offline-docs.txt
#
# Recommended: Use pip install lightrag-hku[offline-docs] for the same effect
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-docs.txt
# Document processing dependencies (with version constraints matching pyproject.toml)
openpyxl>=3.0.0,<4.0.0
pycryptodome>=3.0.0,<4.0.0
pypdf>=6.1.0
python-docx>=0.8.11,<2.0.0
python-pptx>=0.6.21,<2.0.0

View file

@ -0,0 +1,387 @@
"""
Test suite for write_json optimization
This test verifies:
1. Fast path works for clean data (no sanitization)
2. Slow path applies sanitization for dirty data
3. Sanitization is done during encoding (memory-efficient)
4. Reloading updates shared memory with cleaned data
"""
import os
import json
import tempfile
from lightrag.utils import write_json, load_json, SanitizingJSONEncoder
class TestWriteJsonOptimization:
"""Test write_json optimization with two-stage approach"""
def test_fast_path_clean_data(self):
"""Test that clean data takes the fast path without sanitization"""
clean_data = {
"name": "John Doe",
"age": 30,
"items": ["apple", "banana", "cherry"],
"nested": {"key": "value", "number": 42},
}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
# Write clean data - should return False (no sanitization)
needs_reload = write_json(clean_data, temp_file)
assert not needs_reload, "Clean data should not require sanitization"
# Verify data was written correctly
loaded_data = load_json(temp_file)
assert loaded_data == clean_data, "Loaded data should match original"
finally:
os.unlink(temp_file)
def test_slow_path_dirty_data(self):
"""Test that dirty data triggers sanitization"""
# Create data with surrogate characters (U+D800 to U+DFFF)
dirty_string = "Hello\ud800World" # Contains surrogate character
dirty_data = {"text": dirty_string, "number": 123}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
# Write dirty data - should return True (sanitization applied)
needs_reload = write_json(dirty_data, temp_file)
assert needs_reload, "Dirty data should trigger sanitization"
# Verify data was written and sanitized
loaded_data = load_json(temp_file)
assert loaded_data is not None, "Data should be written"
assert loaded_data["number"] == 123, "Clean fields should remain unchanged"
# Surrogate character should be removed
assert (
"\ud800" not in loaded_data["text"]
), "Surrogate character should be removed"
finally:
os.unlink(temp_file)
def test_sanitizing_encoder_removes_surrogates(self):
"""Test that SanitizingJSONEncoder removes surrogate characters"""
data_with_surrogates = {
"text": "Hello\ud800\udc00World", # Contains surrogate pair
"clean": "Clean text",
"nested": {"dirty_key\ud801": "value", "clean_key": "clean\ud802value"},
}
# Encode using custom encoder
encoded = json.dumps(
data_with_surrogates, cls=SanitizingJSONEncoder, ensure_ascii=False
)
# Verify no surrogate characters in output
assert "\ud800" not in encoded, "Surrogate U+D800 should be removed"
assert "\udc00" not in encoded, "Surrogate U+DC00 should be removed"
assert "\ud801" not in encoded, "Surrogate U+D801 should be removed"
assert "\ud802" not in encoded, "Surrogate U+D802 should be removed"
# Verify clean parts remain
assert "Clean text" in encoded, "Clean text should remain"
assert "clean_key" in encoded, "Clean keys should remain"
def test_nested_structure_sanitization(self):
"""Test sanitization of deeply nested structures"""
nested_data = {
"level1": {
"level2": {
"level3": {"dirty": "text\ud800here", "clean": "normal text"},
"list": ["item1", "item\ud801dirty", "item3"],
}
}
}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
needs_reload = write_json(nested_data, temp_file)
assert needs_reload, "Nested dirty data should trigger sanitization"
# Verify nested structure is preserved
loaded_data = load_json(temp_file)
assert "level1" in loaded_data
assert "level2" in loaded_data["level1"]
assert "level3" in loaded_data["level1"]["level2"]
# Verify surrogates are removed
dirty_text = loaded_data["level1"]["level2"]["level3"]["dirty"]
assert "\ud800" not in dirty_text, "Nested surrogate should be removed"
# Verify list items are sanitized
list_items = loaded_data["level1"]["level2"]["list"]
assert (
"\ud801" not in list_items[1]
), "List item surrogates should be removed"
finally:
os.unlink(temp_file)
def test_unicode_non_characters_removed(self):
"""Test that Unicode non-characters (U+FFFE, U+FFFF) don't cause encoding errors
Note: U+FFFE and U+FFFF are valid UTF-8 characters (though discouraged),
so they don't trigger sanitization. They only get removed when explicitly
using the SanitizingJSONEncoder.
"""
data_with_nonchars = {"text1": "Hello\ufffeWorld", "text2": "Test\uffffString"}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
# These characters are valid UTF-8, so they take the fast path
needs_reload = write_json(data_with_nonchars, temp_file)
assert not needs_reload, "U+FFFE/U+FFFF are valid UTF-8 characters"
loaded_data = load_json(temp_file)
# They're written as-is in the fast path
assert loaded_data == data_with_nonchars
finally:
os.unlink(temp_file)
def test_mixed_clean_dirty_data(self):
"""Test data with both clean and dirty fields"""
mixed_data = {
"clean_field": "This is perfectly fine",
"dirty_field": "This has\ud800issues",
"number": 42,
"boolean": True,
"null_value": None,
"clean_list": [1, 2, 3],
"dirty_list": ["clean", "dirty\ud801item"],
}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
needs_reload = write_json(mixed_data, temp_file)
assert (
needs_reload
), "Mixed data with dirty fields should trigger sanitization"
loaded_data = load_json(temp_file)
# Clean fields should remain unchanged
assert loaded_data["clean_field"] == "This is perfectly fine"
assert loaded_data["number"] == 42
assert loaded_data["boolean"]
assert loaded_data["null_value"] is None
assert loaded_data["clean_list"] == [1, 2, 3]
# Dirty fields should be sanitized
assert "\ud800" not in loaded_data["dirty_field"]
assert "\ud801" not in loaded_data["dirty_list"][1]
finally:
os.unlink(temp_file)
def test_empty_and_none_strings(self):
"""Test handling of empty and None values"""
data = {
"empty": "",
"none": None,
"zero": 0,
"false": False,
"empty_list": [],
"empty_dict": {},
}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
needs_reload = write_json(data, temp_file)
assert (
not needs_reload
), "Clean empty values should not trigger sanitization"
loaded_data = load_json(temp_file)
assert loaded_data == data, "Empty/None values should be preserved"
finally:
os.unlink(temp_file)
def test_specific_surrogate_udc9a(self):
"""Test specific surrogate character \\udc9a mentioned in the issue"""
# Test the exact surrogate character from the error message:
# UnicodeEncodeError: 'utf-8' codec can't encode character '\\udc9a'
data_with_udc9a = {
"text": "Some text with surrogate\udc9acharacter",
"position": 201, # As mentioned in the error
"clean_field": "Normal text",
}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
# Write data - should trigger sanitization
needs_reload = write_json(data_with_udc9a, temp_file)
assert needs_reload, "Data with \\udc9a should trigger sanitization"
# Verify surrogate was removed
loaded_data = load_json(temp_file)
assert loaded_data is not None
assert "\udc9a" not in loaded_data["text"], "\\udc9a should be removed"
assert (
loaded_data["clean_field"] == "Normal text"
), "Clean fields should remain"
finally:
os.unlink(temp_file)
def test_migration_with_surrogate_sanitization(self):
"""Test that migration process handles surrogate characters correctly
This test simulates the scenario where legacy cache contains surrogate
characters and ensures they are cleaned during migration.
"""
# Simulate legacy cache data with surrogate characters
legacy_data_with_surrogates = {
"cache_entry_1": {
"return": "Result with\ud800surrogate",
"cache_type": "extract",
"original_prompt": "Some\udc9aprompt",
},
"cache_entry_2": {
"return": "Clean result",
"cache_type": "query",
"original_prompt": "Clean prompt",
},
}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
# First write the dirty data directly (simulating legacy cache file)
# Use custom encoder to force write even with surrogates
with open(temp_file, "w", encoding="utf-8") as f:
json.dump(
legacy_data_with_surrogates,
f,
cls=SanitizingJSONEncoder,
ensure_ascii=False,
)
# Load and verify surrogates were cleaned during initial write
loaded_data = load_json(temp_file)
assert loaded_data is not None
# The data should be sanitized
assert (
"\ud800" not in loaded_data["cache_entry_1"]["return"]
), "Surrogate in return should be removed"
assert (
"\udc9a" not in loaded_data["cache_entry_1"]["original_prompt"]
), "Surrogate in prompt should be removed"
# Clean data should remain unchanged
assert (
loaded_data["cache_entry_2"]["return"] == "Clean result"
), "Clean data should remain"
finally:
os.unlink(temp_file)
def test_empty_values_after_sanitization(self):
"""Test that data with empty values after sanitization is properly handled
Critical edge case: When sanitization results in data with empty string values,
we must use 'if cleaned_data is not None' instead of 'if cleaned_data' to ensure
proper reload, since truthy check on dict depends on content, not just existence.
"""
# Create data where ALL values are only surrogate characters
all_dirty_data = {
"key1": "\ud800\udc00\ud801",
"key2": "\ud802\ud803",
}
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
temp_file = f.name
try:
# Write dirty data - should trigger sanitization
needs_reload = write_json(all_dirty_data, temp_file)
assert needs_reload, "All-dirty data should trigger sanitization"
# Load the sanitized data
cleaned_data = load_json(temp_file)
# Critical assertions for the edge case
assert cleaned_data is not None, "Cleaned data should not be None"
# Sanitization removes surrogates but preserves keys with empty values
assert cleaned_data == {
"key1": "",
"key2": "",
}, "Surrogates should be removed, keys preserved"
# This dict is truthy because it has keys (even with empty values)
assert cleaned_data, "Dict with keys is truthy"
# Test the actual edge case: empty dict
empty_data = {}
needs_reload2 = write_json(empty_data, temp_file)
assert not needs_reload2, "Empty dict is clean"
reloaded_empty = load_json(temp_file)
assert reloaded_empty is not None, "Empty dict should not be None"
assert reloaded_empty == {}, "Empty dict should remain empty"
assert (
not reloaded_empty
), "Empty dict evaluates to False (the critical check)"
finally:
os.unlink(temp_file)
if __name__ == "__main__":
# Run tests
test = TestWriteJsonOptimization()
print("Running test_fast_path_clean_data...")
test.test_fast_path_clean_data()
print("✓ Passed")
print("Running test_slow_path_dirty_data...")
test.test_slow_path_dirty_data()
print("✓ Passed")
print("Running test_sanitizing_encoder_removes_surrogates...")
test.test_sanitizing_encoder_removes_surrogates()
print("✓ Passed")
print("Running test_nested_structure_sanitization...")
test.test_nested_structure_sanitization()
print("✓ Passed")
print("Running test_unicode_non_characters_removed...")
test.test_unicode_non_characters_removed()
print("✓ Passed")
print("Running test_mixed_clean_dirty_data...")
test.test_mixed_clean_dirty_data()
print("✓ Passed")
print("Running test_empty_and_none_strings...")
test.test_empty_and_none_strings()
print("✓ Passed")
print("Running test_specific_surrogate_udc9a...")
test.test_specific_surrogate_udc9a()
print("✓ Passed")
print("Running test_migration_with_surrogate_sanitization...")
test.test_migration_with_surrogate_sanitization()
print("✓ Passed")
print("Running test_empty_values_after_sanitization...")
test.test_empty_values_after_sanitization()
print("✓ Passed")
print("\n✅ All tests passed!")

2563
uv.lock generated

File diff suppressed because it is too large Load diff