Merge branch 'HKUDS:main' into main
This commit is contained in:
commit
43a9b307bc
32 changed files with 3698 additions and 817 deletions
|
|
@ -23,10 +23,11 @@ LightRAG uses dynamic package installation (`pipmaster`) for optional features b
|
|||
|
||||
LightRAG dynamically installs packages for:
|
||||
|
||||
- **Document Processing**: `docling`, `pypdf2`, `python-docx`, `python-pptx`, `openpyxl`
|
||||
- **Storage Backends**: `redis`, `neo4j`, `pymilvus`, `pymongo`, `asyncpg`, `qdrant-client`
|
||||
- **LLM Providers**: `openai`, `anthropic`, `ollama`, `zhipuai`, `aioboto3`, `voyageai`, `llama-index`, `lmdeploy`, `transformers`, `torch`
|
||||
- Tiktoken Models**: BPE encoding models downloaded from OpenAI CDN
|
||||
- **Tiktoken Models**: BPE encoding models downloaded from OpenAI CDN
|
||||
|
||||
**Note**: Document processing dependencies (`pypdf`, `python-docx`, `python-pptx`, `openpyxl`) are now pre-installed with the `api` extras group and no longer require dynamic installation.
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
|
@ -75,32 +76,31 @@ LightRAG provides flexible dependency groups for different use cases:
|
|||
|
||||
| Group | Description | Use Case |
|
||||
|-------|-------------|----------|
|
||||
| `offline-docs` | Document processing | PDF, DOCX, PPTX, XLSX files |
|
||||
| `api` | API server + document processing | FastAPI server with PDF, DOCX, PPTX, XLSX support |
|
||||
| `offline-storage` | Storage backends | Redis, Neo4j, MongoDB, PostgreSQL, etc. |
|
||||
| `offline-llm` | LLM providers | OpenAI, Anthropic, Ollama, etc. |
|
||||
| `offline` | All of the above | Complete offline deployment |
|
||||
| `offline` | Complete offline package | API + Storage + LLM (all features) |
|
||||
|
||||
**Note**: Document processing (PDF, DOCX, PPTX, XLSX) is included in the `api` extras group. The previous `offline-docs` group has been merged into `api` for better integration.
|
||||
|
||||
> Software packages requiring `transformers`, `torch`, or `cuda` will not be included in the offline dependency group.
|
||||
|
||||
### Installation Examples
|
||||
|
||||
```bash
|
||||
# Install only document processing dependencies
|
||||
pip install lightrag-hku[offline-docs]
|
||||
# Install API with document processing
|
||||
pip install lightrag-hku[api]
|
||||
|
||||
# Install document processing and storage backends
|
||||
pip install lightrag-hku[offline-docs,offline-storage]
|
||||
# Install API and storage backends
|
||||
pip install lightrag-hku[api,offline-storage]
|
||||
|
||||
# Install all offline dependencies
|
||||
# Install all offline dependencies (recommended for offline deployment)
|
||||
pip install lightrag-hku[offline]
|
||||
```
|
||||
|
||||
### Using Individual Requirements Files
|
||||
|
||||
```bash
|
||||
# Document processing only
|
||||
pip install -r requirements-offline-docs.txt
|
||||
|
||||
# Storage backends only
|
||||
pip install -r requirements-offline-storage.txt
|
||||
|
||||
|
|
@ -244,8 +244,8 @@ ls -la ~/.tiktoken_cache/
|
|||
**Solution**:
|
||||
```bash
|
||||
# Pre-install the specific package you need
|
||||
# For document processing:
|
||||
pip install lightrag-hku[offline-docs]
|
||||
# For API with document processing:
|
||||
pip install lightrag-hku[api]
|
||||
|
||||
# For storage backends:
|
||||
pip install lightrag-hku[offline-storage]
|
||||
|
|
@ -297,9 +297,9 @@ mkdir -p ~/my_tiktoken_cache
|
|||
|
||||
5. **Minimal Installation**: Only install what you need:
|
||||
```bash
|
||||
# If you only process PDFs with OpenAI
|
||||
pip install lightrag-hku[offline-docs]
|
||||
# Then manually add: pip install openai
|
||||
# If you only need API with document processing
|
||||
pip install lightrag-hku[api]
|
||||
# Then manually add specific LLM: pip install openai
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
|
|
|||
40
env.example
40
env.example
|
|
@ -29,7 +29,7 @@ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
|||
# OLLAMA_EMULATING_MODEL_NAME=lightrag
|
||||
OLLAMA_EMULATING_MODEL_TAG=latest
|
||||
|
||||
### Max nodes return from graph retrieval in webui
|
||||
### Max nodes for graph retrieval (Ensure WebUI local settings are also updated, which is limited to this value)
|
||||
# MAX_GRAPH_NODES=1000
|
||||
|
||||
### Logging level
|
||||
|
|
@ -255,21 +255,23 @@ OLLAMA_LLM_NUM_CTX=32768
|
|||
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
|
||||
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
|
||||
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
|
||||
# EMBEDDING_SEND_DIM=false
|
||||
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
EMBEDDING_DIM=1024
|
||||
EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
# If LightRAG deployed in Docker uses host.docker.internal instead of localhost
|
||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
|
||||
### OpenAI compatible (VoyageAI embedding openai compatible)
|
||||
# EMBEDDING_BINDING=openai
|
||||
# EMBEDDING_MODEL=text-embedding-3-large
|
||||
# EMBEDDING_DIM=3072
|
||||
# EMBEDDING_BINDING_HOST=https://api.openai.com/v1
|
||||
# Ollama embedding
|
||||
# EMBEDDING_BINDING=ollama
|
||||
# EMBEDDING_MODEL=bge-m3:latest
|
||||
# EMBEDDING_DIM=1024
|
||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
### If LightRAG deployed in Docker uses host.docker.internal instead of localhost
|
||||
# EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
|
||||
### OpenAI compatible embedding
|
||||
EMBEDDING_BINDING=openai
|
||||
EMBEDDING_MODEL=text-embedding-3-large
|
||||
EMBEDDING_DIM=3072
|
||||
EMBEDDING_SEND_DIM=false
|
||||
EMBEDDING_TOKEN_LIMIT=8192
|
||||
EMBEDDING_BINDING_HOST=https://api.openai.com/v1
|
||||
EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
|
||||
### Optional for Azure
|
||||
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
||||
|
|
@ -277,6 +279,16 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
|
|||
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
||||
# AZURE_EMBEDDING_API_KEY=your_api_key
|
||||
|
||||
### Gemini embedding
|
||||
# EMBEDDING_BINDING=gemini
|
||||
# EMBEDDING_MODEL=gemini-embedding-001
|
||||
# EMBEDDING_DIM=1536
|
||||
# EMBEDDING_TOKEN_LIMIT=2048
|
||||
# EMBEDDING_BINDING_HOST=https://generativelanguage.googleapis.com
|
||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
### Gemini embedding requires sending dimension to server
|
||||
# EMBEDDING_SEND_DIM=true
|
||||
|
||||
### Jina AI Embedding
|
||||
# EMBEDDING_BINDING=jina
|
||||
# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings
|
||||
|
|
|
|||
|
|
@ -53,8 +53,8 @@ def xml_to_json(xml_file):
|
|||
"description": edge.find("./data[@key='d6']", namespace).text
|
||||
if edge.find("./data[@key='d6']", namespace) is not None
|
||||
else "",
|
||||
"keywords": edge.find("./data[@key='d7']", namespace).text
|
||||
if edge.find("./data[@key='d7']", namespace) is not None
|
||||
"keywords": edge.find("./data[@key='d9']", namespace).text
|
||||
if edge.find("./data[@key='d9']", namespace) is not None
|
||||
else "",
|
||||
"source_id": edge.find("./data[@key='d8']", namespace).text
|
||||
if edge.find("./data[@key='d8']", namespace) is not None
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
__api_version__ = "0253"
|
||||
__api_version__ = "0254"
|
||||
|
|
|
|||
|
|
@ -258,6 +258,14 @@ def parse_args() -> argparse.Namespace:
|
|||
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
|
||||
)
|
||||
|
||||
# Document loading engine configuration
|
||||
parser.add_argument(
|
||||
"--docling",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable DOCLING document loading engine (default: from env or DEFAULT)",
|
||||
)
|
||||
|
||||
# Conditionally add binding options defined in binding_options module
|
||||
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
|
||||
# and corresponding environment variables (e.g., OLLAMA_EMBEDDING_NUM_CTX)
|
||||
|
|
@ -371,8 +379,13 @@ def parse_args() -> argparse.Namespace:
|
|||
)
|
||||
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
|
||||
|
||||
# Select Document loading tool (DOCLING, DEFAULT)
|
||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||
# Set document_loading_engine from --docling flag
|
||||
if args.docling:
|
||||
args.document_loading_engine = "DOCLING"
|
||||
else:
|
||||
args.document_loading_engine = get_env_value(
|
||||
"DOCUMENT_LOADING_ENGINE", "DEFAULT"
|
||||
)
|
||||
|
||||
# PDF decryption password
|
||||
args.pdf_decrypt_password = get_env_value("PDF_DECRYPT_PASSWORD", None)
|
||||
|
|
@ -432,6 +445,11 @@ def parse_args() -> argparse.Namespace:
|
|||
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
||||
)
|
||||
|
||||
# Embedding token limit configuration
|
||||
args.embedding_token_limit = get_env_value(
|
||||
"EMBEDDING_TOKEN_LIMIT", None, int, special_none=True
|
||||
)
|
||||
|
||||
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
||||
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
||||
|
||||
|
|
@ -449,4 +467,83 @@ def update_uvicorn_mode_config():
|
|||
)
|
||||
|
||||
|
||||
global_args = parse_args()
|
||||
# Global configuration with lazy initialization
|
||||
_global_args = None
|
||||
_initialized = False
|
||||
|
||||
|
||||
def initialize_config(args=None, force=False):
|
||||
"""Initialize global configuration
|
||||
|
||||
This function allows explicit initialization of the configuration,
|
||||
which is useful for programmatic usage, testing, or embedding LightRAG
|
||||
in other applications.
|
||||
|
||||
Args:
|
||||
args: Pre-parsed argparse.Namespace or None to parse from sys.argv
|
||||
force: Force re-initialization even if already initialized
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: The configured arguments
|
||||
|
||||
Example:
|
||||
# Use parsed command line arguments (default)
|
||||
initialize_config()
|
||||
|
||||
# Use custom configuration programmatically
|
||||
custom_args = argparse.Namespace(
|
||||
host='localhost',
|
||||
port=8080,
|
||||
working_dir='./custom_rag',
|
||||
# ... other config
|
||||
)
|
||||
initialize_config(custom_args)
|
||||
"""
|
||||
global _global_args, _initialized
|
||||
|
||||
if _initialized and not force:
|
||||
return _global_args
|
||||
|
||||
_global_args = args if args is not None else parse_args()
|
||||
_initialized = True
|
||||
return _global_args
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Get global configuration, auto-initializing if needed
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: The configured arguments
|
||||
"""
|
||||
if not _initialized:
|
||||
initialize_config()
|
||||
return _global_args
|
||||
|
||||
|
||||
class _GlobalArgsProxy:
|
||||
"""Proxy object that auto-initializes configuration on first access
|
||||
|
||||
This maintains backward compatibility with existing code while
|
||||
allowing programmatic control over initialization timing.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not _initialized:
|
||||
initialize_config()
|
||||
return getattr(_global_args, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if not _initialized:
|
||||
initialize_config()
|
||||
setattr(_global_args, name, value)
|
||||
|
||||
def __repr__(self):
|
||||
if not _initialized:
|
||||
return "<GlobalArgsProxy: Not initialized>"
|
||||
return repr(_global_args)
|
||||
|
||||
|
||||
# Create proxy instance for backward compatibility
|
||||
# Existing code like `from config import global_args` continues to work
|
||||
# The proxy will auto-initialize on first attribute access
|
||||
global_args = _GlobalArgsProxy()
|
||||
|
|
|
|||
|
|
@ -618,33 +618,108 @@ def create_app(args):
|
|||
|
||||
def create_optimized_embedding_function(
|
||||
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||
):
|
||||
) -> EmbeddingFunc:
|
||||
"""
|
||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
||||
Create optimized embedding function and return an EmbeddingFunc instance
|
||||
with proper max_token_size inheritance from provider defaults.
|
||||
|
||||
This function:
|
||||
1. Imports the provider embedding function
|
||||
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
|
||||
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
|
||||
4. Returns a properly configured EmbeddingFunc instance
|
||||
"""
|
||||
|
||||
# Step 1: Import provider function and extract default attributes
|
||||
provider_func = None
|
||||
provider_max_token_size = None
|
||||
provider_embedding_dim = None
|
||||
|
||||
try:
|
||||
if binding == "openai":
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
provider_func = openai_embed
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
provider_func = ollama_embed
|
||||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
provider_func = gemini_embed
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
provider_func = jina_embed
|
||||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
provider_func = azure_openai_embed
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
provider_func = bedrock_embed
|
||||
elif binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
provider_func = lollms_embed
|
||||
|
||||
# Extract attributes if provider is an EmbeddingFunc
|
||||
if provider_func and isinstance(provider_func, EmbeddingFunc):
|
||||
provider_max_token_size = provider_func.max_token_size
|
||||
provider_embedding_dim = provider_func.embedding_dim
|
||||
logger.debug(
|
||||
f"Extracted from {binding} provider: "
|
||||
f"max_token_size={provider_max_token_size}, "
|
||||
f"embedding_dim={provider_embedding_dim}"
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import provider function for {binding}: {e}")
|
||||
|
||||
# Step 2: Apply priority (user config > provider default)
|
||||
# For max_token_size: explicit env var > provider default > None
|
||||
final_max_token_size = args.embedding_token_limit or provider_max_token_size
|
||||
# For embedding_dim: user config (always has value) takes priority
|
||||
# Only use provider default if user config is explicitly None (which shouldn't happen)
|
||||
final_embedding_dim = (
|
||||
args.embedding_dim if args.embedding_dim else provider_embedding_dim
|
||||
)
|
||||
|
||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||
try:
|
||||
if binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
return await lollms_embed(
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
lollms_embed.func
|
||||
if isinstance(lollms_embed, EmbeddingFunc)
|
||||
else lollms_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts, embed_model=model, host=host, api_key=api_key
|
||||
)
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
ollama_embed.func
|
||||
if isinstance(ollama_embed, EmbeddingFunc)
|
||||
else ollama_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.ollama_embedding_options is not None:
|
||||
ollama_options = config_cache.ollama_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||
|
||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await ollama_embed(
|
||||
return await actual_func(
|
||||
texts,
|
||||
embed_model=model,
|
||||
host=host,
|
||||
|
|
@ -654,15 +729,30 @@ def create_app(args):
|
|||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||||
actual_func = (
|
||||
azure_openai_embed.func
|
||||
if isinstance(azure_openai_embed, EmbeddingFunc)
|
||||
else azure_openai_embed
|
||||
)
|
||||
return await actual_func(texts, model=model, api_key=api_key)
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
return await bedrock_embed(texts, model=model)
|
||||
actual_func = (
|
||||
bedrock_embed.func
|
||||
if isinstance(bedrock_embed, EmbeddingFunc)
|
||||
else bedrock_embed
|
||||
)
|
||||
return await actual_func(texts, model=model)
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
return await jina_embed(
|
||||
actual_func = (
|
||||
jina_embed.func
|
||||
if isinstance(jina_embed, EmbeddingFunc)
|
||||
else jina_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
embedding_dim=embedding_dim,
|
||||
base_url=host,
|
||||
|
|
@ -671,16 +761,21 @@ def create_app(args):
|
|||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
actual_func = (
|
||||
gemini_embed.func
|
||||
if isinstance(gemini_embed, EmbeddingFunc)
|
||||
else gemini_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.gemini_embedding_options is not None:
|
||||
gemini_options = config_cache.gemini_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
||||
|
||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await gemini_embed(
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
|
|
@ -691,7 +786,12 @@ def create_app(args):
|
|||
else: # openai and compatible
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
return await openai_embed(
|
||||
actual_func = (
|
||||
openai_embed.func
|
||||
if isinstance(openai_embed, EmbeddingFunc)
|
||||
else openai_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
|
|
@ -701,7 +801,21 @@ def create_app(args):
|
|||
except ImportError as e:
|
||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||
|
||||
return optimized_embedding_function
|
||||
# Step 4: Wrap in EmbeddingFunc and return
|
||||
embedding_func_instance = EmbeddingFunc(
|
||||
embedding_dim=final_embedding_dim,
|
||||
func=optimized_embedding_function,
|
||||
max_token_size=final_max_token_size,
|
||||
send_dimensions=False, # Will be set later based on binding requirements
|
||||
)
|
||||
|
||||
# Log final embedding configuration
|
||||
logger.info(
|
||||
f"Embedding config: binding={binding} model={model} "
|
||||
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
|
||||
)
|
||||
|
||||
return embedding_func_instance
|
||||
|
||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||
embedding_timeout = get_env_value(
|
||||
|
|
@ -735,25 +849,24 @@ def create_app(args):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# Create embedding function with optimized configuration
|
||||
# Create embedding function with optimized configuration and max_token_size inheritance
|
||||
import inspect
|
||||
|
||||
# Create the optimized embedding function
|
||||
optimized_embedding_func = create_optimized_embedding_function(
|
||||
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
|
||||
embedding_func = create_optimized_embedding_function(
|
||||
config_cache=config_cache,
|
||||
binding=args.embedding_binding,
|
||||
model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
args=args, # Pass args object for fallback option generation
|
||||
args=args,
|
||||
)
|
||||
|
||||
# Get embedding_send_dim from centralized configuration
|
||||
embedding_send_dim = args.embedding_send_dim
|
||||
|
||||
# Check if the function signature has embedding_dim parameter
|
||||
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
||||
sig = inspect.signature(optimized_embedding_func)
|
||||
# Check if the underlying function signature has embedding_dim parameter
|
||||
sig = inspect.signature(embedding_func.func)
|
||||
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
||||
|
||||
# Determine send_dimensions value based on binding type
|
||||
|
|
@ -771,18 +884,27 @@ def create_app(args):
|
|||
else:
|
||||
dimension_control = "by not hasparam"
|
||||
|
||||
# Set send_dimensions on the EmbeddingFunc instance
|
||||
embedding_func.send_dimensions = send_dimensions
|
||||
|
||||
logger.info(
|
||||
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
||||
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"binding={args.embedding_binding})"
|
||||
)
|
||||
|
||||
# Create EmbeddingFunc with send_dimensions attribute
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
func=optimized_embedding_func,
|
||||
send_dimensions=send_dimensions,
|
||||
)
|
||||
# Log max_token_size source
|
||||
if embedding_func.max_token_size:
|
||||
source = (
|
||||
"env variable"
|
||||
if args.embedding_token_limit
|
||||
else f"{args.embedding_binding} provider default"
|
||||
)
|
||||
logger.info(
|
||||
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
|
||||
)
|
||||
else:
|
||||
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
|
||||
|
||||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
rerank_model_func = None
|
||||
|
|
@ -1214,6 +1336,12 @@ def check_and_install_dependencies():
|
|||
|
||||
|
||||
def main():
|
||||
# Explicitly initialize configuration for clarity
|
||||
# (The proxy will auto-initialize anyway, but this makes intent clear)
|
||||
from .config import initialize_config
|
||||
|
||||
initialize_config()
|
||||
|
||||
# Check if running under Gunicorn
|
||||
if "GUNICORN_CMD_ARGS" in os.environ:
|
||||
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
||||
|
|
|
|||
|
|
@ -3,14 +3,15 @@ This module contains all document-related routes for the LightRAG API.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
from lightrag.utils import logger, get_pinyin_sort_key
|
||||
import aiofiles
|
||||
import shutil
|
||||
import traceback
|
||||
import pipmaster as pm
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Literal
|
||||
from io import BytesIO
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
|
|
@ -28,6 +29,24 @@ from lightrag.api.utils_api import get_combined_auth_dependency
|
|||
from ..config import global_args
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _is_docling_available() -> bool:
|
||||
"""Check if docling is available (cached check).
|
||||
|
||||
This function uses lru_cache to avoid repeated import attempts.
|
||||
The result is cached after the first call.
|
||||
|
||||
Returns:
|
||||
bool: True if docling is available, False otherwise
|
||||
"""
|
||||
try:
|
||||
import docling # noqa: F401 # type: ignore[import-not-found]
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# Function to format datetime to ISO format string with timezone information
|
||||
def format_datetime(dt: Any) -> Optional[str]:
|
||||
"""Format datetime to ISO format string with timezone information
|
||||
|
|
@ -879,7 +898,6 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
|
|||
Returns:
|
||||
str: Unique filename (may have numeric suffix added)
|
||||
"""
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
original_path = Path(original_name)
|
||||
|
|
@ -902,6 +920,122 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
|
|||
return f"{base_name}_{timestamp}{extension}"
|
||||
|
||||
|
||||
# Document processing helper functions (synchronous)
|
||||
# These functions run in thread pool via asyncio.to_thread() to avoid blocking the event loop
|
||||
|
||||
|
||||
def _convert_with_docling(file_path: Path) -> str:
|
||||
"""Convert document using docling (synchronous).
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
|
||||
Returns:
|
||||
str: Extracted markdown content
|
||||
"""
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
|
||||
def _extract_pdf_pypdf(file_bytes: bytes, password: str = None) -> str:
|
||||
"""Extract PDF content using pypdf (synchronous).
|
||||
|
||||
Args:
|
||||
file_bytes: PDF file content as bytes
|
||||
password: Optional password for encrypted PDFs
|
||||
|
||||
Returns:
|
||||
str: Extracted text content
|
||||
|
||||
Raises:
|
||||
Exception: If PDF is encrypted and password is incorrect or missing
|
||||
"""
|
||||
from pypdf import PdfReader # type: ignore
|
||||
|
||||
pdf_file = BytesIO(file_bytes)
|
||||
reader = PdfReader(pdf_file)
|
||||
|
||||
# Check if PDF is encrypted
|
||||
if reader.is_encrypted:
|
||||
if not password:
|
||||
raise Exception("PDF is encrypted but no password provided")
|
||||
|
||||
decrypt_result = reader.decrypt(password)
|
||||
if decrypt_result == 0:
|
||||
raise Exception("Incorrect PDF password")
|
||||
|
||||
# Extract text from all pages
|
||||
content = ""
|
||||
for page in reader.pages:
|
||||
content += page.extract_text() + "\n"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _extract_docx(file_bytes: bytes) -> str:
|
||||
"""Extract DOCX content (synchronous).
|
||||
|
||||
Args:
|
||||
file_bytes: DOCX file content as bytes
|
||||
|
||||
Returns:
|
||||
str: Extracted text content
|
||||
"""
|
||||
from docx import Document # type: ignore
|
||||
|
||||
docx_file = BytesIO(file_bytes)
|
||||
doc = Document(docx_file)
|
||||
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
|
||||
|
||||
def _extract_pptx(file_bytes: bytes) -> str:
|
||||
"""Extract PPTX content (synchronous).
|
||||
|
||||
Args:
|
||||
file_bytes: PPTX file content as bytes
|
||||
|
||||
Returns:
|
||||
str: Extracted text content
|
||||
"""
|
||||
from pptx import Presentation # type: ignore
|
||||
|
||||
pptx_file = BytesIO(file_bytes)
|
||||
prs = Presentation(pptx_file)
|
||||
content = ""
|
||||
for slide in prs.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
content += shape.text + "\n"
|
||||
return content
|
||||
|
||||
|
||||
def _extract_xlsx(file_bytes: bytes) -> str:
|
||||
"""Extract XLSX content (synchronous).
|
||||
|
||||
Args:
|
||||
file_bytes: XLSX file content as bytes
|
||||
|
||||
Returns:
|
||||
str: Extracted text content
|
||||
"""
|
||||
from openpyxl import load_workbook # type: ignore
|
||||
|
||||
xlsx_file = BytesIO(file_bytes)
|
||||
wb = load_workbook(xlsx_file)
|
||||
content = ""
|
||||
for sheet in wb:
|
||||
content += f"Sheet: {sheet.title}\n"
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
content += (
|
||||
"\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
|
||||
)
|
||||
content += "\n"
|
||||
return content
|
||||
|
||||
|
||||
async def pipeline_enqueue_file(
|
||||
rag: LightRAG, file_path: Path, track_id: str = None
|
||||
) -> tuple[bool, str]:
|
||||
|
|
@ -1072,87 +1206,28 @@ async def pipeline_enqueue_file(
|
|||
|
||||
case ".pdf":
|
||||
try:
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
content = result.document.export_to_markdown()
|
||||
# Try DOCLING first if configured and available
|
||||
if (
|
||||
global_args.document_loading_engine == "DOCLING"
|
||||
and _is_docling_available()
|
||||
):
|
||||
content = await asyncio.to_thread(
|
||||
_convert_with_docling, file_path
|
||||
)
|
||||
else:
|
||||
if not pm.is_installed("pypdf"): # type: ignore
|
||||
pm.install("pypdf")
|
||||
if not pm.is_installed("pycryptodome"): # type: ignore
|
||||
pm.install("pycryptodome")
|
||||
from pypdf import PdfReader # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
pdf_file = BytesIO(file)
|
||||
reader = PdfReader(pdf_file)
|
||||
|
||||
# Check if PDF is encrypted
|
||||
if reader.is_encrypted:
|
||||
pdf_password = global_args.pdf_decrypt_password
|
||||
if not pdf_password:
|
||||
# PDF is encrypted but no password provided
|
||||
error_files = [
|
||||
{
|
||||
"file_path": str(file_path.name),
|
||||
"error_description": "[File Extraction]PDF is encrypted but no password provided",
|
||||
"original_error": "Please set PDF_DECRYPT_PASSWORD environment variable to decrypt this PDF file",
|
||||
"file_size": file_size,
|
||||
}
|
||||
]
|
||||
await rag.apipeline_enqueue_error_documents(
|
||||
error_files, track_id
|
||||
)
|
||||
logger.error(
|
||||
f"[File Extraction]PDF is encrypted but no password provided: {file_path.name}"
|
||||
)
|
||||
return False, track_id
|
||||
|
||||
# Try to decrypt with password
|
||||
try:
|
||||
decrypt_result = reader.decrypt(pdf_password)
|
||||
if decrypt_result == 0:
|
||||
# Password is incorrect
|
||||
error_files = [
|
||||
{
|
||||
"file_path": str(file_path.name),
|
||||
"error_description": "[File Extraction]Failed to decrypt PDF - incorrect password",
|
||||
"original_error": "The provided PDF_DECRYPT_PASSWORD is incorrect for this file",
|
||||
"file_size": file_size,
|
||||
}
|
||||
]
|
||||
await rag.apipeline_enqueue_error_documents(
|
||||
error_files, track_id
|
||||
)
|
||||
logger.error(
|
||||
f"[File Extraction]Incorrect PDF password: {file_path.name}"
|
||||
)
|
||||
return False, track_id
|
||||
except Exception as decrypt_error:
|
||||
# Decryption process error
|
||||
error_files = [
|
||||
{
|
||||
"file_path": str(file_path.name),
|
||||
"error_description": "[File Extraction]PDF decryption failed",
|
||||
"original_error": f"Error during PDF decryption: {str(decrypt_error)}",
|
||||
"file_size": file_size,
|
||||
}
|
||||
]
|
||||
await rag.apipeline_enqueue_error_documents(
|
||||
error_files, track_id
|
||||
)
|
||||
logger.error(
|
||||
f"[File Extraction]PDF decryption error for {file_path.name}: {str(decrypt_error)}"
|
||||
)
|
||||
return False, track_id
|
||||
|
||||
# Extract text from PDF (encrypted PDFs are now decrypted, unencrypted PDFs proceed directly)
|
||||
for page in reader.pages:
|
||||
content += page.extract_text() + "\n"
|
||||
if (
|
||||
global_args.document_loading_engine == "DOCLING"
|
||||
and not _is_docling_available()
|
||||
):
|
||||
logger.warning(
|
||||
f"DOCLING engine configured but not available for {file_path.name}. Falling back to pypdf."
|
||||
)
|
||||
# Use pypdf (non-blocking via to_thread)
|
||||
content = await asyncio.to_thread(
|
||||
_extract_pdf_pypdf,
|
||||
file,
|
||||
global_args.pdf_decrypt_password,
|
||||
)
|
||||
except Exception as e:
|
||||
error_files = [
|
||||
{
|
||||
|
|
@ -1172,28 +1247,24 @@ async def pipeline_enqueue_file(
|
|||
|
||||
case ".docx":
|
||||
try:
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
content = result.document.export_to_markdown()
|
||||
else:
|
||||
if not pm.is_installed("python-docx"): # type: ignore
|
||||
try:
|
||||
pm.install("python-docx")
|
||||
except Exception:
|
||||
pm.install("docx")
|
||||
from docx import Document # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
docx_file = BytesIO(file)
|
||||
doc = Document(docx_file)
|
||||
content = "\n".join(
|
||||
[paragraph.text for paragraph in doc.paragraphs]
|
||||
# Try DOCLING first if configured and available
|
||||
if (
|
||||
global_args.document_loading_engine == "DOCLING"
|
||||
and _is_docling_available()
|
||||
):
|
||||
content = await asyncio.to_thread(
|
||||
_convert_with_docling, file_path
|
||||
)
|
||||
else:
|
||||
if (
|
||||
global_args.document_loading_engine == "DOCLING"
|
||||
and not _is_docling_available()
|
||||
):
|
||||
logger.warning(
|
||||
f"DOCLING engine configured but not available for {file_path.name}. Falling back to python-docx."
|
||||
)
|
||||
# Use python-docx (non-blocking via to_thread)
|
||||
content = await asyncio.to_thread(_extract_docx, file)
|
||||
except Exception as e:
|
||||
error_files = [
|
||||
{
|
||||
|
|
@ -1213,26 +1284,24 @@ async def pipeline_enqueue_file(
|
|||
|
||||
case ".pptx":
|
||||
try:
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
content = result.document.export_to_markdown()
|
||||
# Try DOCLING first if configured and available
|
||||
if (
|
||||
global_args.document_loading_engine == "DOCLING"
|
||||
and _is_docling_available()
|
||||
):
|
||||
content = await asyncio.to_thread(
|
||||
_convert_with_docling, file_path
|
||||
)
|
||||
else:
|
||||
if not pm.is_installed("python-pptx"): # type: ignore
|
||||
pm.install("pptx")
|
||||
from pptx import Presentation # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
pptx_file = BytesIO(file)
|
||||
prs = Presentation(pptx_file)
|
||||
for slide in prs.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
content += shape.text + "\n"
|
||||
if (
|
||||
global_args.document_loading_engine == "DOCLING"
|
||||
and not _is_docling_available()
|
||||
):
|
||||
logger.warning(
|
||||
f"DOCLING engine configured but not available for {file_path.name}. Falling back to python-pptx."
|
||||
)
|
||||
# Use python-pptx (non-blocking via to_thread)
|
||||
content = await asyncio.to_thread(_extract_pptx, file)
|
||||
except Exception as e:
|
||||
error_files = [
|
||||
{
|
||||
|
|
@ -1252,33 +1321,24 @@ async def pipeline_enqueue_file(
|
|||
|
||||
case ".xlsx":
|
||||
try:
|
||||
if global_args.document_loading_engine == "DOCLING":
|
||||
if not pm.is_installed("docling"): # type: ignore
|
||||
pm.install("docling")
|
||||
from docling.document_converter import DocumentConverter # type: ignore
|
||||
|
||||
converter = DocumentConverter()
|
||||
result = converter.convert(file_path)
|
||||
content = result.document.export_to_markdown()
|
||||
# Try DOCLING first if configured and available
|
||||
if (
|
||||
global_args.document_loading_engine == "DOCLING"
|
||||
and _is_docling_available()
|
||||
):
|
||||
content = await asyncio.to_thread(
|
||||
_convert_with_docling, file_path
|
||||
)
|
||||
else:
|
||||
if not pm.is_installed("openpyxl"): # type: ignore
|
||||
pm.install("openpyxl")
|
||||
from openpyxl import load_workbook # type: ignore
|
||||
from io import BytesIO
|
||||
|
||||
xlsx_file = BytesIO(file)
|
||||
wb = load_workbook(xlsx_file)
|
||||
for sheet in wb:
|
||||
content += f"Sheet: {sheet.title}\n"
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
content += (
|
||||
"\t".join(
|
||||
str(cell) if cell is not None else ""
|
||||
for cell in row
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
content += "\n"
|
||||
if (
|
||||
global_args.document_loading_engine == "DOCLING"
|
||||
and not _is_docling_available()
|
||||
):
|
||||
logger.warning(
|
||||
f"DOCLING engine configured but not available for {file_path.name}. Falling back to openpyxl."
|
||||
)
|
||||
# Use openpyxl (non-blocking via to_thread)
|
||||
content = await asyncio.to_thread(_extract_xlsx, file)
|
||||
except Exception as e:
|
||||
error_files = [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Start LightRAG server with Gunicorn
|
|||
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import pipmaster as pm
|
||||
from lightrag.api.utils_api import display_splash_screen, check_env_file
|
||||
from lightrag.api.config import global_args
|
||||
|
|
@ -34,6 +35,11 @@ def check_and_install_dependencies():
|
|||
|
||||
|
||||
def main():
|
||||
# Explicitly initialize configuration for Gunicorn mode
|
||||
from lightrag.api.config import initialize_config
|
||||
|
||||
initialize_config()
|
||||
|
||||
# Set Gunicorn mode flag for lifespan cleanup detection
|
||||
os.environ["LIGHTRAG_GUNICORN_MODE"] = "1"
|
||||
|
||||
|
|
@ -41,6 +47,68 @@ def main():
|
|||
if not check_env_file():
|
||||
sys.exit(1)
|
||||
|
||||
# Check DOCLING compatibility with Gunicorn multi-worker mode on macOS
|
||||
if (
|
||||
platform.system() == "Darwin"
|
||||
and global_args.document_loading_engine == "DOCLING"
|
||||
and global_args.workers > 1
|
||||
):
|
||||
print("\n" + "=" * 80)
|
||||
print("❌ ERROR: Incompatible configuration detected!")
|
||||
print("=" * 80)
|
||||
print(
|
||||
"\nDOCLING engine with Gunicorn multi-worker mode is not supported on macOS"
|
||||
)
|
||||
print("\nReason:")
|
||||
print(" PyTorch (required by DOCLING) has known compatibility issues with")
|
||||
print(" fork-based multiprocessing on macOS, which can cause crashes or")
|
||||
print(" unexpected behavior when using Gunicorn with multiple workers.")
|
||||
print("\nCurrent configuration:")
|
||||
print(" - Operating System: macOS (Darwin)")
|
||||
print(f" - Document Engine: {global_args.document_loading_engine}")
|
||||
print(f" - Workers: {global_args.workers}")
|
||||
print("\nPossible solutions:")
|
||||
print(" 1. Use single worker mode:")
|
||||
print(" --workers 1")
|
||||
print("\n 2. Change document loading engine in .env:")
|
||||
print(" DOCUMENT_LOADING_ENGINE=DEFAULT")
|
||||
print("\n 3. Deploy on Linux where multi-worker mode is fully supported")
|
||||
print("=" * 80 + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
# Check macOS fork safety environment variable for multi-worker mode
|
||||
if (
|
||||
platform.system() == "Darwin"
|
||||
and global_args.workers > 1
|
||||
and os.environ.get("OBJC_DISABLE_INITIALIZE_FORK_SAFETY") != "YES"
|
||||
):
|
||||
print("\n" + "=" * 80)
|
||||
print("❌ ERROR: Missing required environment variable on macOS!")
|
||||
print("=" * 80)
|
||||
print("\nmacOS with Gunicorn multi-worker mode requires:")
|
||||
print(" OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES")
|
||||
print("\nReason:")
|
||||
print(" NumPy uses macOS's Accelerate framework (Objective-C based) for")
|
||||
print(" vector computations. The Objective-C runtime has fork safety checks")
|
||||
print(" that will crash worker processes when embedding functions are called.")
|
||||
print("\nCurrent configuration:")
|
||||
print(" - Operating System: macOS (Darwin)")
|
||||
print(f" - Workers: {global_args.workers}")
|
||||
print(
|
||||
f" - Environment Variable: {os.environ.get('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'NOT SET')}"
|
||||
)
|
||||
print("\nHow to fix:")
|
||||
print(" Option 1 - Set environment variable before starting (recommended):")
|
||||
print(" export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES")
|
||||
print(" lightrag-server")
|
||||
print("\n Option 2 - Add to your shell profile (~/.zshrc or ~/.bash_profile):")
|
||||
print(" echo 'export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES' >> ~/.zshrc")
|
||||
print(" source ~/.zshrc")
|
||||
print("\n Option 3 - Use single worker mode (no multiprocessing):")
|
||||
print(" lightrag-server --workers 1")
|
||||
print("=" * 80 + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
# Check and install dependencies
|
||||
check_and_install_dependencies()
|
||||
|
||||
|
|
|
|||
|
|
@ -161,7 +161,20 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||
logger.debug(
|
||||
f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
|
||||
# Write JSON and check if sanitization was applied
|
||||
needs_reload = write_json(data_dict, self._file_name)
|
||||
|
||||
# If data was sanitized, reload cleaned data to update shared memory
|
||||
if needs_reload:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
|
||||
)
|
||||
cleaned_data = load_json(self._file_name)
|
||||
if cleaned_data is not None:
|
||||
self._data.clear()
|
||||
self._data.update(cleaned_data)
|
||||
|
||||
await clear_all_update_flags(self.final_namespace)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
|
|
|
|||
|
|
@ -81,7 +81,20 @@ class JsonKVStorage(BaseKVStorage):
|
|||
logger.debug(
|
||||
f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
|
||||
# Write JSON and check if sanitization was applied
|
||||
needs_reload = write_json(data_dict, self._file_name)
|
||||
|
||||
# If data was sanitized, reload cleaned data to update shared memory
|
||||
if needs_reload:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
|
||||
)
|
||||
cleaned_data = load_json(self._file_name)
|
||||
if cleaned_data is not None:
|
||||
self._data.clear()
|
||||
self._data.update(cleaned_data)
|
||||
|
||||
await clear_all_update_flags(self.final_namespace)
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
|
|
@ -224,7 +237,7 @@ class JsonKVStorage(BaseKVStorage):
|
|||
data: Original data dictionary that may contain legacy structure
|
||||
|
||||
Returns:
|
||||
Migrated data dictionary with flattened cache keys
|
||||
Migrated data dictionary with flattened cache keys (sanitized if needed)
|
||||
"""
|
||||
from lightrag.utils import generate_cache_key
|
||||
|
||||
|
|
@ -261,8 +274,17 @@ class JsonKVStorage(BaseKVStorage):
|
|||
logger.info(
|
||||
f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure"
|
||||
)
|
||||
# Persist migrated data immediately
|
||||
write_json(migrated_data, self._file_name)
|
||||
# Persist migrated data immediately and check if sanitization was applied
|
||||
needs_reload = write_json(migrated_data, self._file_name)
|
||||
|
||||
# If data was sanitized during write, reload cleaned data
|
||||
if needs_reload:
|
||||
logger.info(
|
||||
f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}"
|
||||
)
|
||||
cleaned_data = load_json(self._file_name)
|
||||
if cleaned_data is not None:
|
||||
return cleaned_data # Return cleaned data to update shared memory
|
||||
|
||||
return migrated_data
|
||||
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
result = None
|
||||
try:
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
||||
|
|
@ -146,7 +147,10 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
logger.error(
|
||||
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
||||
)
|
||||
await result.consume() # Ensure the result is consumed even on error
|
||||
if result is not None:
|
||||
await (
|
||||
result.consume()
|
||||
) # Ensure the result is consumed even on error
|
||||
raise
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
|
|
@ -170,6 +174,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
result = None
|
||||
try:
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = (
|
||||
|
|
@ -190,7 +195,10 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
logger.error(
|
||||
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
||||
)
|
||||
await result.consume() # Ensure the result is consumed even on error
|
||||
if result is not None:
|
||||
await (
|
||||
result.consume()
|
||||
) # Ensure the result is consumed even on error
|
||||
raise
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
|
|
@ -312,6 +320,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
result = None
|
||||
try:
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
|
|
@ -328,7 +337,10 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
return labels
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
|
||||
await result.consume() # Ensure the result is consumed even on error
|
||||
if result is not None:
|
||||
await (
|
||||
result.consume()
|
||||
) # Ensure the result is consumed even on error
|
||||
raise
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
|
|
@ -352,6 +364,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
results = None
|
||||
try:
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
|
|
@ -389,7 +402,10 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
logger.error(
|
||||
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
||||
)
|
||||
await results.consume() # Ensure results are consumed even on error
|
||||
if results is not None:
|
||||
await (
|
||||
results.consume()
|
||||
) # Ensure results are consumed even on error
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
@ -419,6 +435,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
result = None
|
||||
try:
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""
|
||||
|
|
@ -451,7 +468,10 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
logger.error(
|
||||
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||
)
|
||||
await result.consume() # Ensure the result is consumed even on error
|
||||
if result is not None:
|
||||
await (
|
||||
result.consume()
|
||||
) # Ensure the result is consumed even on error
|
||||
raise
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
|
|
@ -1030,6 +1050,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||
)
|
||||
|
||||
result = None
|
||||
try:
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
|
|
@ -1056,6 +1077,8 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
return labels
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
|
||||
if result is not None:
|
||||
await result.consume()
|
||||
return []
|
||||
|
||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||
|
|
@ -1078,6 +1101,7 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
if not query_lower:
|
||||
return []
|
||||
|
||||
result = None
|
||||
try:
|
||||
workspace_label = self._get_workspace_label()
|
||||
async with self._driver.session(
|
||||
|
|
@ -1111,4 +1135,6 @@ class MemgraphStorage(BaseGraphStorage):
|
|||
return labels
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
|
||||
if result is not None:
|
||||
await result.consume()
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -371,6 +371,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
result = None
|
||||
try:
|
||||
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
||||
result = await session.run(query, entity_id=node_id)
|
||||
|
|
@ -381,7 +382,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
logger.error(
|
||||
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
||||
)
|
||||
await result.consume() # Ensure results are consumed even on error
|
||||
if result is not None:
|
||||
await result.consume() # Ensure results are consumed even on error
|
||||
raise
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
|
|
@ -403,6 +405,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
result = None
|
||||
try:
|
||||
query = (
|
||||
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
|
||||
|
|
@ -420,7 +423,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
logger.error(
|
||||
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
||||
)
|
||||
await result.consume() # Ensure results are consumed even on error
|
||||
if result is not None:
|
||||
await result.consume() # Ensure results are consumed even on error
|
||||
raise
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
|
|
@ -799,6 +803,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
results = None
|
||||
try:
|
||||
workspace_label = self._get_workspace_label()
|
||||
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||
|
|
@ -836,7 +841,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
logger.error(
|
||||
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
||||
)
|
||||
await results.consume() # Ensure results are consumed even on error
|
||||
if results is not None:
|
||||
await (
|
||||
results.consume()
|
||||
) # Ensure results are consumed even on error
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
@ -1592,6 +1600,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
result = None
|
||||
try:
|
||||
query = f"""
|
||||
MATCH (n:`{workspace_label}`)
|
||||
|
|
@ -1616,7 +1625,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||
logger.error(
|
||||
f"[{self.workspace}] Error getting popular labels: {str(e)}"
|
||||
)
|
||||
await result.consume()
|
||||
if result is not None:
|
||||
await result.consume()
|
||||
raise
|
||||
|
||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import traceback
|
||||
import asyncio
|
||||
import configparser
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
|
@ -12,6 +13,7 @@ from functools import partial
|
|||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterator,
|
||||
cast,
|
||||
|
|
@ -20,6 +22,7 @@ from typing import (
|
|||
Optional,
|
||||
List,
|
||||
Dict,
|
||||
Union,
|
||||
)
|
||||
from lightrag.prompt import PROMPTS
|
||||
from lightrag.exceptions import PipelineCancelledException
|
||||
|
|
@ -243,11 +246,13 @@ class LightRAG:
|
|||
int,
|
||||
int,
|
||||
],
|
||||
List[Dict[str, Any]],
|
||||
Union[List[Dict[str, Any]], Awaitable[List[Dict[str, Any]]]],
|
||||
] = field(default_factory=lambda: chunking_by_token_size)
|
||||
"""
|
||||
Custom chunking function for splitting text into chunks before processing.
|
||||
|
||||
The function can be either synchronous or asynchronous.
|
||||
|
||||
The function should take the following parameters:
|
||||
|
||||
- `tokenizer`: A Tokenizer instance to use for tokenization.
|
||||
|
|
@ -257,7 +262,8 @@ class LightRAG:
|
|||
- `chunk_token_size`: The maximum number of tokens per chunk.
|
||||
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
||||
|
||||
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
||||
The function should return a list of dictionaries (or an awaitable that resolves to a list),
|
||||
where each dictionary contains the following keys:
|
||||
- `tokens`: The number of tokens in the chunk.
|
||||
- `content`: The text content of the chunk.
|
||||
|
||||
|
|
@ -270,6 +276,9 @@ class LightRAG:
|
|||
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_token_limit: int | None = field(default=None, init=False)
|
||||
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
|
||||
|
||||
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
|
||||
"""Batch size for embedding computations."""
|
||||
|
||||
|
|
@ -513,6 +522,16 @@ class LightRAG:
|
|||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init Embedding
|
||||
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
|
||||
embedding_max_token_size = None
|
||||
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
|
||||
embedding_max_token_size = self.embedding_func.max_token_size
|
||||
logger.debug(
|
||||
f"Captured embedding max_token_size: {embedding_max_token_size}"
|
||||
)
|
||||
self.embedding_token_limit = embedding_max_token_size
|
||||
|
||||
# Step 2: Apply priority wrapper decorator
|
||||
self.embedding_func = priority_limit_async_func_call(
|
||||
self.embedding_func_max_async,
|
||||
llm_timeout=self.default_embedding_timeout,
|
||||
|
|
@ -1756,7 +1775,28 @@ class LightRAG:
|
|||
)
|
||||
content = content_data["content"]
|
||||
|
||||
# Generate chunks from document
|
||||
# Call chunking function, supporting both sync and async implementations
|
||||
chunking_result = self.chunking_func(
|
||||
self.tokenizer,
|
||||
content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
)
|
||||
|
||||
# If result is awaitable, await to get actual result
|
||||
if inspect.isawaitable(chunking_result):
|
||||
chunking_result = await chunking_result
|
||||
|
||||
# Validate return type
|
||||
if not isinstance(chunking_result, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"chunking_func must return a list or tuple of dicts, "
|
||||
f"got {type(chunking_result)}"
|
||||
)
|
||||
|
||||
# Build chunks dictionary
|
||||
chunks: dict[str, Any] = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
|
|
@ -1764,14 +1804,7 @@ class LightRAG:
|
|||
"file_path": file_path, # Add file path to each chunk
|
||||
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
self.tokenizer,
|
||||
content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
)
|
||||
for dp in chunking_result
|
||||
}
|
||||
|
||||
if not chunks:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ from tenacity import (
|
|||
)
|
||||
|
||||
import sys
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
|
|
@ -23,21 +25,121 @@ else:
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Union
|
||||
|
||||
# Import botocore exceptions for proper exception handling
|
||||
try:
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
ConnectionError as BotocoreConnectionError,
|
||||
ReadTimeoutError,
|
||||
)
|
||||
except ImportError:
|
||||
# If botocore is not installed, define placeholders
|
||||
ClientError = Exception
|
||||
BotocoreConnectionError = Exception
|
||||
ReadTimeoutError = Exception
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
||||
|
||||
class BedrockRateLimitError(BedrockError):
|
||||
"""Error for rate limiting and throttling issues"""
|
||||
|
||||
|
||||
class BedrockConnectionError(BedrockError):
|
||||
"""Error for network and connection issues"""
|
||||
|
||||
|
||||
class BedrockTimeoutError(BedrockError):
|
||||
"""Error for timeout issues"""
|
||||
|
||||
|
||||
def _set_env_if_present(key: str, value):
|
||||
"""Set environment variable only if a non-empty value is provided."""
|
||||
if value is not None and value != "":
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
|
||||
"""Convert AWS Bedrock exceptions to appropriate custom exceptions.
|
||||
|
||||
Args:
|
||||
e: The exception to handle
|
||||
operation: Description of the operation for error messages
|
||||
|
||||
Raises:
|
||||
BedrockRateLimitError: For rate limiting and throttling issues (retryable)
|
||||
BedrockConnectionError: For network and server issues (retryable)
|
||||
BedrockTimeoutError: For timeout issues (retryable)
|
||||
BedrockError: For validation and other non-retryable errors
|
||||
"""
|
||||
error_message = str(e)
|
||||
|
||||
# Handle botocore ClientError with specific error codes
|
||||
if isinstance(e, ClientError):
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
error_msg = e.response.get("Error", {}).get("Message", error_message)
|
||||
|
||||
# Rate limiting and throttling errors (retryable)
|
||||
if error_code in [
|
||||
"ThrottlingException",
|
||||
"ProvisionedThroughputExceededException",
|
||||
]:
|
||||
logging.error(f"{operation} rate limit error: {error_msg}")
|
||||
raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
|
||||
|
||||
# Server errors (retryable)
|
||||
elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
|
||||
logging.error(f"{operation} connection error: {error_msg}")
|
||||
raise BedrockConnectionError(f"Service error: {error_msg}")
|
||||
|
||||
# Check for 5xx HTTP status codes (retryable)
|
||||
elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
|
||||
logging.error(f"{operation} server error: {error_msg}")
|
||||
raise BedrockConnectionError(f"Server error: {error_msg}")
|
||||
|
||||
# Validation and other client errors (non-retryable)
|
||||
else:
|
||||
logging.error(f"{operation} client error: {error_msg}")
|
||||
raise BedrockError(f"Client error: {error_msg}")
|
||||
|
||||
# Connection errors (retryable)
|
||||
elif isinstance(e, BotocoreConnectionError):
|
||||
logging.error(f"{operation} connection error: {error_message}")
|
||||
raise BedrockConnectionError(f"Connection error: {error_message}")
|
||||
|
||||
# Timeout errors (retryable)
|
||||
elif isinstance(e, (ReadTimeoutError, TimeoutError)):
|
||||
logging.error(f"{operation} timeout error: {error_message}")
|
||||
raise BedrockTimeoutError(f"Timeout error: {error_message}")
|
||||
|
||||
# Custom Bedrock errors (already properly typed)
|
||||
elif isinstance(
|
||||
e,
|
||||
(
|
||||
BedrockRateLimitError,
|
||||
BedrockConnectionError,
|
||||
BedrockTimeoutError,
|
||||
BedrockError,
|
||||
),
|
||||
):
|
||||
raise
|
||||
|
||||
# Unknown errors (non-retryable)
|
||||
else:
|
||||
logging.error(f"{operation} unexpected error: {error_message}")
|
||||
raise BedrockError(f"Unexpected error: {error_message}")
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, max=60),
|
||||
retry=retry_if_exception_type((BedrockError)),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(BedrockRateLimitError)
|
||||
| retry_if_exception_type(BedrockConnectionError)
|
||||
| retry_if_exception_type(BedrockTimeoutError)
|
||||
),
|
||||
)
|
||||
async def bedrock_complete_if_cache(
|
||||
model,
|
||||
|
|
@ -158,9 +260,6 @@ async def bedrock_complete_if_cache(
|
|||
break
|
||||
|
||||
except Exception as e:
|
||||
# Log the specific error for debugging
|
||||
logging.error(f"Bedrock streaming error: {e}")
|
||||
|
||||
# Try to clean up resources if possible
|
||||
if (
|
||||
iteration_started
|
||||
|
|
@ -175,7 +274,8 @@ async def bedrock_complete_if_cache(
|
|||
f"Failed to close Bedrock event stream: {close_error}"
|
||||
)
|
||||
|
||||
raise BedrockError(f"Streaming error: {e}")
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock streaming")
|
||||
|
||||
finally:
|
||||
# Clean up the event stream
|
||||
|
|
@ -231,10 +331,8 @@ async def bedrock_complete_if_cache(
|
|||
return content
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, BedrockError):
|
||||
raise
|
||||
else:
|
||||
raise BedrockError(f"Bedrock API error: {e}")
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock converse")
|
||||
|
||||
|
||||
# Generic Bedrock completion function
|
||||
|
|
@ -253,12 +351,16 @@ async def bedrock_complete(
|
|||
return result
|
||||
|
||||
|
||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024)
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(3),
|
||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
||||
# )
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(BedrockRateLimitError)
|
||||
| retry_if_exception_type(BedrockConnectionError)
|
||||
| retry_if_exception_type(BedrockTimeoutError)
|
||||
),
|
||||
)
|
||||
async def bedrock_embed(
|
||||
texts: list[str],
|
||||
model: str = "amazon.titan-embed-text-v2:0",
|
||||
|
|
@ -281,48 +383,101 @@ async def bedrock_embed(
|
|||
async with session.client(
|
||||
"bedrock-runtime", region_name=region
|
||||
) as bedrock_async_client:
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
if "v2" in model:
|
||||
try:
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
try:
|
||||
if "v2" in model:
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
# 'dimensions': embedding_dim,
|
||||
"embeddingTypes": ["float"],
|
||||
}
|
||||
)
|
||||
elif "v1" in model:
|
||||
body = json.dumps({"inputText": text})
|
||||
else:
|
||||
raise BedrockError(f"Model {model} is not supported!")
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
modelId=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = await response.get("body").json()
|
||||
|
||||
# Validate response structure
|
||||
if not response_body or "embedding" not in response_body:
|
||||
raise BedrockError(
|
||||
f"Invalid embedding response structure for text: {text[:50]}..."
|
||||
)
|
||||
|
||||
embedding = response_body["embedding"]
|
||||
if not embedding:
|
||||
raise BedrockError(
|
||||
f"Received empty embedding for text: {text[:50]}..."
|
||||
)
|
||||
|
||||
embed_texts.append(embedding)
|
||||
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(
|
||||
e, "Bedrock embedding (amazon, text chunk)"
|
||||
)
|
||||
|
||||
elif model_provider == "cohere":
|
||||
try:
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
# 'dimensions': embedding_dim,
|
||||
"embeddingTypes": ["float"],
|
||||
"texts": texts,
|
||||
"input_type": "search_document",
|
||||
"truncate": "NONE",
|
||||
}
|
||||
)
|
||||
elif "v1" in model:
|
||||
body = json.dumps({"inputText": text})
|
||||
else:
|
||||
raise ValueError(f"Model {model} is not supported!")
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
modelId=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
# Validate response structure
|
||||
if not response_body or "embeddings" not in response_body:
|
||||
raise BedrockError(
|
||||
"Invalid embedding response structure from Cohere"
|
||||
)
|
||||
|
||||
embeddings = response_body["embeddings"]
|
||||
if not embeddings or len(embeddings) != len(texts):
|
||||
raise BedrockError(
|
||||
f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
|
||||
)
|
||||
|
||||
embed_texts = embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock embedding (cohere)")
|
||||
|
||||
else:
|
||||
raise BedrockError(
|
||||
f"Model provider '{model_provider}' is not supported!"
|
||||
)
|
||||
|
||||
response_body = await response.get("body").json()
|
||||
# Final validation
|
||||
if not embed_texts:
|
||||
raise BedrockError("No embeddings generated")
|
||||
|
||||
embed_texts.append(response_body["embedding"])
|
||||
elif model_provider == "cohere":
|
||||
body = json.dumps(
|
||||
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
||||
)
|
||||
return np.array(embed_texts)
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
embed_texts = response_body["embeddings"]
|
||||
else:
|
||||
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
||||
|
||||
return np.array(embed_texts)
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock embedding")
|
||||
|
|
|
|||
|
|
@ -453,7 +453,7 @@ async def gemini_model_complete(
|
|||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from lightrag.exceptions import (
|
|||
)
|
||||
import torch
|
||||
import numpy as np
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
|
@ -141,6 +142,7 @@ async def hf_model_complete(
|
|||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
# Detect the appropriate device
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ async def fetch_data(url, headers, data):
|
|||
return data_list
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ async def llama_index_complete(
|
|||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -26,6 +26,10 @@ from lightrag.exceptions import (
|
|||
from typing import Union, List
|
||||
import numpy as np
|
||||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
|
|
@ -134,6 +138,7 @@ async def lollms_model_complete(
|
|||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
async def lollms_embed(
|
||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
||||
) -> np.ndarray:
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from lightrag.utils import (
|
|||
import numpy as np
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from collections.abc import AsyncIterator
|
||||
import os
|
||||
import re
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
|
|
@ -22,8 +24,31 @@ from lightrag.exceptions import (
|
|||
from lightrag.api import __api_version__
|
||||
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
from lightrag.utils import logger
|
||||
from typing import Optional, Union
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
logger,
|
||||
)
|
||||
|
||||
|
||||
_OLLAMA_CLOUD_HOST = "https://ollama.com"
|
||||
_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
|
||||
|
||||
|
||||
def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
|
||||
if host:
|
||||
return host
|
||||
try:
|
||||
model_name_str = str(model) if model is not None else ""
|
||||
except (TypeError, ValueError, AttributeError) as e:
|
||||
logger.warning(f"Failed to convert model to string: {e}, using empty string")
|
||||
model_name_str = ""
|
||||
if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
|
||||
logger.debug(
|
||||
f"Detected cloud model '{model_name_str}', using Ollama Cloud host"
|
||||
)
|
||||
return _OLLAMA_CLOUD_HOST
|
||||
return host
|
||||
|
||||
|
||||
@retry(
|
||||
|
|
@ -53,6 +78,9 @@ async def _ollama_model_if_cache(
|
|||
timeout = None
|
||||
kwargs.pop("hashing_kv", None)
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
# fallback to environment variable when not provided explicitly
|
||||
if not api_key:
|
||||
api_key = os.getenv("OLLAMA_API_KEY")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"LightRAG/{__api_version__}",
|
||||
|
|
@ -60,6 +88,8 @@ async def _ollama_model_if_cache(
|
|||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
host = _coerce_host_for_cloud_model(host, model)
|
||||
|
||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||
|
||||
try:
|
||||
|
|
@ -142,8 +172,11 @@ async def ollama_model_complete(
|
|||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
if not api_key:
|
||||
api_key = os.getenv("OLLAMA_API_KEY")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"LightRAG/{__api_version__}",
|
||||
|
|
@ -154,6 +187,8 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
|||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
host = _coerce_host_for_cloud_model(host, embed_model)
|
||||
|
||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||
try:
|
||||
options = kwargs.pop("options", {})
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ try:
|
|||
|
||||
# Only enable Langfuse if both keys are configured
|
||||
if langfuse_public_key and langfuse_secret_key:
|
||||
from langfuse.openai import AsyncOpenAI
|
||||
from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped]
|
||||
|
||||
LANGFUSE_ENABLED = True
|
||||
logger.info("Langfuse observability enabled for OpenAI client")
|
||||
|
|
@ -604,7 +604,7 @@ async def nvidia_openai_complete(
|
|||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -345,6 +345,20 @@ async def _summarize_descriptions(
|
|||
llm_response_cache=llm_response_cache,
|
||||
cache_type="summary",
|
||||
)
|
||||
|
||||
# Check summary token length against embedding limit
|
||||
embedding_token_limit = global_config.get("embedding_token_limit")
|
||||
if embedding_token_limit is not None and summary:
|
||||
tokenizer = global_config["tokenizer"]
|
||||
summary_token_count = len(tokenizer.encode(summary))
|
||||
threshold = int(embedding_token_limit * 0.9)
|
||||
|
||||
if summary_token_count > threshold:
|
||||
logger.warning(
|
||||
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
|
||||
f"({embedding_token_limit}) for {description_type}: {description_name}"
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ if not logger.handlers:
|
|||
# Set httpx logging level to WARNING
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||
|
||||
# Global import for pypinyin with startup-time logging
|
||||
try:
|
||||
import pypinyin
|
||||
|
|
@ -350,9 +353,20 @@ class TaskState:
|
|||
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
"""Embedding function wrapper with dimension validation
|
||||
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
||||
This class should not be wrapped multiple times.
|
||||
|
||||
Args:
|
||||
embedding_dim: Expected dimension of the embeddings
|
||||
func: The actual embedding function to wrap
|
||||
max_token_size: Optional token limit for the embedding model
|
||||
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
||||
"""
|
||||
|
||||
embedding_dim: int
|
||||
func: callable
|
||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
||||
max_token_size: int | None = None # Token limit for the embedding model
|
||||
send_dimensions: bool = (
|
||||
False # Control whether to send embedding_dim to the function
|
||||
)
|
||||
|
|
@ -376,7 +390,32 @@ class EmbeddingFunc:
|
|||
# Inject embedding_dim from decorator
|
||||
kwargs["embedding_dim"] = self.embedding_dim
|
||||
|
||||
return await self.func(*args, **kwargs)
|
||||
# Call the actual embedding function
|
||||
result = await self.func(*args, **kwargs)
|
||||
|
||||
# Validate embedding dimensions using total element count
|
||||
total_elements = result.size # Total number of elements in the numpy array
|
||||
expected_dim = self.embedding_dim
|
||||
|
||||
# Check if total elements can be evenly divided by embedding_dim
|
||||
if total_elements % expected_dim != 0:
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch detected: "
|
||||
f"total elements ({total_elements}) cannot be evenly divided by "
|
||||
f"expected dimension ({expected_dim}). "
|
||||
)
|
||||
|
||||
# Optional: Verify vector count matches input text count
|
||||
actual_vectors = total_elements // expected_dim
|
||||
if args and isinstance(args[0], (list, tuple)):
|
||||
expected_vectors = len(args[0])
|
||||
if actual_vectors != expected_vectors:
|
||||
raise ValueError(
|
||||
f"Vector count mismatch: "
|
||||
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compute_args_hash(*args: Any) -> str:
|
||||
|
|
@ -930,73 +969,120 @@ def load_json(file_name):
|
|||
def _sanitize_string_for_json(text: str) -> str:
|
||||
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
||||
|
||||
This is a simpler sanitizer specifically for JSON that directly removes
|
||||
problematic characters without attempting to encode first.
|
||||
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
||||
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
||||
|
||||
Args:
|
||||
text: String to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized string safe for UTF-8 encoding in JSON
|
||||
Original string if clean (zero-copy), sanitized string if dirty
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# Directly filter out problematic characters without pre-validation
|
||||
sanitized = ""
|
||||
for char in text:
|
||||
code_point = ord(char)
|
||||
# Skip surrogate characters (U+D800 to U+DFFF) - main cause of encoding errors
|
||||
if 0xD800 <= code_point <= 0xDFFF:
|
||||
continue
|
||||
# Skip other non-characters in Unicode
|
||||
elif code_point == 0xFFFE or code_point == 0xFFFF:
|
||||
continue
|
||||
else:
|
||||
sanitized += char
|
||||
# Fast path: Check if sanitization is needed using C-level regex search
|
||||
if not _SURROGATE_PATTERN.search(text):
|
||||
return text # Zero-copy for clean strings - most common case
|
||||
|
||||
return sanitized
|
||||
# Slow path: Remove problematic characters using C-level regex substitution
|
||||
return _SURROGATE_PATTERN.sub("", text)
|
||||
|
||||
|
||||
def _sanitize_json_data(data: Any) -> Any:
|
||||
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
|
||||
|
||||
Handles all JSON-serializable types including:
|
||||
- Dictionary keys and values
|
||||
- Lists and tuples (preserves type)
|
||||
- Nested structures
|
||||
- Strings at any level
|
||||
|
||||
Args:
|
||||
data: Data to sanitize (dict, list, tuple, str, or other types)
|
||||
|
||||
Returns:
|
||||
Sanitized data with all strings cleaned of problematic characters
|
||||
class SanitizingJSONEncoder(json.JSONEncoder):
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
# Sanitize both keys and values
|
||||
return {
|
||||
_sanitize_string_for_json(k)
|
||||
if isinstance(k, str)
|
||||
else k: _sanitize_json_data(v)
|
||||
for k, v in data.items()
|
||||
}
|
||||
elif isinstance(data, (list, tuple)):
|
||||
# Handle both lists and tuples, preserve original type
|
||||
sanitized = [_sanitize_json_data(item) for item in data]
|
||||
return type(data)(sanitized)
|
||||
elif isinstance(data, str):
|
||||
return _sanitize_string_for_json(data)
|
||||
else:
|
||||
# Numbers, booleans, None, etc. - return as-is
|
||||
return data
|
||||
Custom JSON encoder that sanitizes data during serialization.
|
||||
|
||||
This encoder cleans strings during the encoding process without creating
|
||||
a full copy of the data structure, making it memory-efficient for large datasets.
|
||||
"""
|
||||
|
||||
def encode(self, o):
|
||||
"""Override encode method to handle simple string cases"""
|
||||
if isinstance(o, str):
|
||||
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
|
||||
return super().encode(o)
|
||||
|
||||
def iterencode(self, o, _one_shot=False):
|
||||
"""
|
||||
Override iterencode to sanitize strings during serialization.
|
||||
This is the core method that handles complex nested structures.
|
||||
"""
|
||||
# Preprocess: sanitize all strings in the object
|
||||
sanitized = self._sanitize_for_encoding(o)
|
||||
|
||||
# Call parent's iterencode with sanitized data
|
||||
for chunk in super().iterencode(sanitized, _one_shot):
|
||||
yield chunk
|
||||
|
||||
def _sanitize_for_encoding(self, obj):
|
||||
"""
|
||||
Recursively sanitize strings in an object.
|
||||
Creates new objects only when necessary to avoid deep copies.
|
||||
|
||||
Args:
|
||||
obj: Object to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized object with cleaned strings
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return _sanitize_string_for_json(obj)
|
||||
|
||||
elif isinstance(obj, dict):
|
||||
# Create new dict with sanitized keys and values
|
||||
new_dict = {}
|
||||
for k, v in obj.items():
|
||||
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
|
||||
clean_v = self._sanitize_for_encoding(v)
|
||||
new_dict[clean_k] = clean_v
|
||||
return new_dict
|
||||
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
# Sanitize list/tuple elements
|
||||
cleaned = [self._sanitize_for_encoding(item) for item in obj]
|
||||
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
|
||||
|
||||
else:
|
||||
# Numbers, booleans, None, etc. remain unchanged
|
||||
return obj
|
||||
|
||||
|
||||
def write_json(json_obj, file_name):
|
||||
# Sanitize data before writing to prevent UTF-8 encoding errors
|
||||
sanitized_obj = _sanitize_json_data(json_obj)
|
||||
"""
|
||||
Write JSON data to file with optimized sanitization strategy.
|
||||
|
||||
This function uses a two-stage approach:
|
||||
1. Fast path: Try direct serialization (works for clean data ~99% of time)
|
||||
2. Slow path: Use custom encoder that sanitizes during serialization
|
||||
|
||||
The custom encoder approach avoids creating a deep copy of the data,
|
||||
making it memory-efficient. When sanitization occurs, the caller should
|
||||
reload the cleaned data from the file to update shared memory.
|
||||
|
||||
Args:
|
||||
json_obj: Object to serialize (may be a shallow copy from shared memory)
|
||||
file_name: Output file path
|
||||
|
||||
Returns:
|
||||
bool: True if sanitization was applied (caller should reload data),
|
||||
False if direct write succeeded (no reload needed)
|
||||
"""
|
||||
try:
|
||||
# Strategy 1: Fast path - try direct serialization
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||
return False # No sanitization needed, no reload required
|
||||
|
||||
except (UnicodeEncodeError, UnicodeDecodeError) as e:
|
||||
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
|
||||
|
||||
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(sanitized_obj, f, indent=2, ensure_ascii=False)
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
|
||||
|
||||
logger.info(f"JSON sanitization applied during write: {file_name}")
|
||||
return True # Sanitization applied, reload recommended
|
||||
|
||||
|
||||
class TokenizerInterface(Protocol):
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ export default function QuerySettings() {
|
|||
// Default values for reset functionality
|
||||
const defaultValues = useMemo(() => ({
|
||||
mode: 'mix' as QueryMode,
|
||||
response_type: 'Multiple Paragraphs',
|
||||
top_k: 40,
|
||||
chunk_top_k: 20,
|
||||
max_entity_tokens: 6000,
|
||||
|
|
@ -153,46 +152,6 @@ export default function QuerySettings() {
|
|||
</div>
|
||||
</>
|
||||
|
||||
{/* Response Format */}
|
||||
<>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<label htmlFor="response_format_select" className="ml-1 cursor-help">
|
||||
{t('retrievePanel.querySettings.responseFormat')}
|
||||
</label>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="left">
|
||||
<p>{t('retrievePanel.querySettings.responseFormatTooltip')}</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<div className="flex items-center gap-1">
|
||||
<Select
|
||||
value={querySettings.response_type}
|
||||
onValueChange={(v) => handleChange('response_type', v)}
|
||||
>
|
||||
<SelectTrigger
|
||||
id="response_format_select"
|
||||
className="hover:bg-primary/5 h-9 cursor-pointer focus:ring-0 focus:ring-offset-0 focus:outline-0 active:right-0 flex-1 text-left [&>span]:break-all [&>span]:line-clamp-1"
|
||||
>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectItem value="Multiple Paragraphs">{t('retrievePanel.querySettings.responseFormatOptions.multipleParagraphs')}</SelectItem>
|
||||
<SelectItem value="Single Paragraph">{t('retrievePanel.querySettings.responseFormatOptions.singleParagraph')}</SelectItem>
|
||||
<SelectItem value="Bullet Points">{t('retrievePanel.querySettings.responseFormatOptions.bulletPoints')}</SelectItem>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<ResetButton
|
||||
onClick={() => handleReset('response_type')}
|
||||
title="Reset to default (Multiple Paragraphs)"
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
|
||||
{/* Top K */}
|
||||
<>
|
||||
<TooltipProvider>
|
||||
|
|
|
|||
|
|
@ -357,6 +357,7 @@ export default function RetrievalTesting() {
|
|||
const queryParams = {
|
||||
...state.querySettings,
|
||||
query: actualQuery,
|
||||
response_type: 'Multiple Paragraphs',
|
||||
conversation_history: effectiveHistoryTurns > 0
|
||||
? prevMessages
|
||||
.filter((m) => m.isError !== true)
|
||||
|
|
|
|||
|
|
@ -123,7 +123,6 @@ const useSettingsStoreBase = create<SettingsState>()(
|
|||
|
||||
querySettings: {
|
||||
mode: 'global',
|
||||
response_type: 'Multiple Paragraphs',
|
||||
top_k: 40,
|
||||
chunk_top_k: 20,
|
||||
max_entity_tokens: 6000,
|
||||
|
|
@ -239,7 +238,7 @@ const useSettingsStoreBase = create<SettingsState>()(
|
|||
{
|
||||
name: 'settings-storage',
|
||||
storage: createJSONStorage(() => localStorage),
|
||||
version: 18,
|
||||
version: 19,
|
||||
migrate: (state: any, version: number) => {
|
||||
if (version < 2) {
|
||||
state.showEdgeLabel = false
|
||||
|
|
@ -336,6 +335,12 @@ const useSettingsStoreBase = create<SettingsState>()(
|
|||
// Add userPromptHistory field for older versions
|
||||
state.userPromptHistory = []
|
||||
}
|
||||
if (version < 19) {
|
||||
// Remove deprecated response_type parameter
|
||||
if (state.querySettings) {
|
||||
delete state.querySettings.response_type
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ dependencies = [
|
|||
"json_repair",
|
||||
"nano-vectordb",
|
||||
"networkx",
|
||||
"numpy",
|
||||
"numpy>=1.24.0,<2.0.0",
|
||||
"pandas>=2.0.0,<2.4.0",
|
||||
"pipmaster",
|
||||
"pydantic",
|
||||
|
|
@ -50,7 +50,7 @@ api = [
|
|||
"json_repair",
|
||||
"nano-vectordb",
|
||||
"networkx",
|
||||
"numpy",
|
||||
"numpy>=1.24.0,<2.0.0",
|
||||
"openai>=1.0.0,<3.0.0",
|
||||
"pandas>=2.0.0,<2.4.0",
|
||||
"pipmaster",
|
||||
|
|
@ -79,18 +79,23 @@ api = [
|
|||
"python-multipart",
|
||||
"pytz",
|
||||
"uvicorn",
|
||||
"gunicorn",
|
||||
# Document processing dependencies (required for API document upload functionality)
|
||||
"openpyxl>=3.0.0,<4.0.0", # XLSX processing
|
||||
"pycryptodome>=3.0.0,<4.0.0", # PDF encryption support
|
||||
"pypdf>=6.1.0", # PDF processing
|
||||
"python-docx>=0.8.11,<2.0.0", # DOCX processing
|
||||
"python-pptx>=0.6.21,<2.0.0", # PPTX processing
|
||||
]
|
||||
|
||||
# Advanced document processing engine (optional)
|
||||
docling = [
|
||||
# On macOS, pytorch and frameworks use Objective-C are not fork-safe,
|
||||
# and not compatible to gunicorn multi-worker mode
|
||||
"docling>=2.0.0,<3.0.0; sys_platform != 'darwin'",
|
||||
]
|
||||
|
||||
# Offline deployment dependencies (layered design for flexibility)
|
||||
offline-docs = [
|
||||
# Document processing dependencies
|
||||
"openpyxl>=3.0.0,<4.0.0",
|
||||
"pycryptodome>=3.0.0,<4.0.0",
|
||||
"pypdf>=6.1.0",
|
||||
"python-docx>=0.8.11,<2.0.0",
|
||||
"python-pptx>=0.6.21,<2.0.0",
|
||||
]
|
||||
|
||||
offline-storage = [
|
||||
# Storage backend dependencies
|
||||
"redis>=5.0.0,<8.0.0",
|
||||
|
|
@ -115,8 +120,8 @@ offline-llm = [
|
|||
]
|
||||
|
||||
offline = [
|
||||
# Complete offline package (includes all offline dependencies)
|
||||
"lightrag-hku[offline-docs,offline-storage,offline-llm]",
|
||||
# Complete offline package (includes api for document processing, plus storage and LLM)
|
||||
"lightrag-hku[api,offline-storage,offline-llm]",
|
||||
]
|
||||
|
||||
evaluation = [
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
# LightRAG Offline Dependencies - Document Processing
|
||||
# Install with: pip install -r requirements-offline-docs.txt
|
||||
# For offline installation:
|
||||
# pip download -r requirements-offline-docs.txt -d ./packages
|
||||
# pip install --no-index --find-links=./packages -r requirements-offline-docs.txt
|
||||
#
|
||||
# Recommended: Use pip install lightrag-hku[offline-docs] for the same effect
|
||||
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-docs.txt
|
||||
|
||||
# Document processing dependencies (with version constraints matching pyproject.toml)
|
||||
openpyxl>=3.0.0,<4.0.0
|
||||
pycryptodome>=3.0.0,<4.0.0
|
||||
pypdf>=6.1.0
|
||||
python-docx>=0.8.11,<2.0.0
|
||||
python-pptx>=0.6.21,<2.0.0
|
||||
387
tests/test_write_json_optimization.py
Normal file
387
tests/test_write_json_optimization.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
"""
|
||||
Test suite for write_json optimization
|
||||
|
||||
This test verifies:
|
||||
1. Fast path works for clean data (no sanitization)
|
||||
2. Slow path applies sanitization for dirty data
|
||||
3. Sanitization is done during encoding (memory-efficient)
|
||||
4. Reloading updates shared memory with cleaned data
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from lightrag.utils import write_json, load_json, SanitizingJSONEncoder
|
||||
|
||||
|
||||
class TestWriteJsonOptimization:
|
||||
"""Test write_json optimization with two-stage approach"""
|
||||
|
||||
def test_fast_path_clean_data(self):
|
||||
"""Test that clean data takes the fast path without sanitization"""
|
||||
clean_data = {
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"items": ["apple", "banana", "cherry"],
|
||||
"nested": {"key": "value", "number": 42},
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# Write clean data - should return False (no sanitization)
|
||||
needs_reload = write_json(clean_data, temp_file)
|
||||
assert not needs_reload, "Clean data should not require sanitization"
|
||||
|
||||
# Verify data was written correctly
|
||||
loaded_data = load_json(temp_file)
|
||||
assert loaded_data == clean_data, "Loaded data should match original"
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_slow_path_dirty_data(self):
|
||||
"""Test that dirty data triggers sanitization"""
|
||||
# Create data with surrogate characters (U+D800 to U+DFFF)
|
||||
dirty_string = "Hello\ud800World" # Contains surrogate character
|
||||
dirty_data = {"text": dirty_string, "number": 123}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# Write dirty data - should return True (sanitization applied)
|
||||
needs_reload = write_json(dirty_data, temp_file)
|
||||
assert needs_reload, "Dirty data should trigger sanitization"
|
||||
|
||||
# Verify data was written and sanitized
|
||||
loaded_data = load_json(temp_file)
|
||||
assert loaded_data is not None, "Data should be written"
|
||||
assert loaded_data["number"] == 123, "Clean fields should remain unchanged"
|
||||
# Surrogate character should be removed
|
||||
assert (
|
||||
"\ud800" not in loaded_data["text"]
|
||||
), "Surrogate character should be removed"
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_sanitizing_encoder_removes_surrogates(self):
|
||||
"""Test that SanitizingJSONEncoder removes surrogate characters"""
|
||||
data_with_surrogates = {
|
||||
"text": "Hello\ud800\udc00World", # Contains surrogate pair
|
||||
"clean": "Clean text",
|
||||
"nested": {"dirty_key\ud801": "value", "clean_key": "clean\ud802value"},
|
||||
}
|
||||
|
||||
# Encode using custom encoder
|
||||
encoded = json.dumps(
|
||||
data_with_surrogates, cls=SanitizingJSONEncoder, ensure_ascii=False
|
||||
)
|
||||
|
||||
# Verify no surrogate characters in output
|
||||
assert "\ud800" not in encoded, "Surrogate U+D800 should be removed"
|
||||
assert "\udc00" not in encoded, "Surrogate U+DC00 should be removed"
|
||||
assert "\ud801" not in encoded, "Surrogate U+D801 should be removed"
|
||||
assert "\ud802" not in encoded, "Surrogate U+D802 should be removed"
|
||||
|
||||
# Verify clean parts remain
|
||||
assert "Clean text" in encoded, "Clean text should remain"
|
||||
assert "clean_key" in encoded, "Clean keys should remain"
|
||||
|
||||
def test_nested_structure_sanitization(self):
|
||||
"""Test sanitization of deeply nested structures"""
|
||||
nested_data = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {"dirty": "text\ud800here", "clean": "normal text"},
|
||||
"list": ["item1", "item\ud801dirty", "item3"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
needs_reload = write_json(nested_data, temp_file)
|
||||
assert needs_reload, "Nested dirty data should trigger sanitization"
|
||||
|
||||
# Verify nested structure is preserved
|
||||
loaded_data = load_json(temp_file)
|
||||
assert "level1" in loaded_data
|
||||
assert "level2" in loaded_data["level1"]
|
||||
assert "level3" in loaded_data["level1"]["level2"]
|
||||
|
||||
# Verify surrogates are removed
|
||||
dirty_text = loaded_data["level1"]["level2"]["level3"]["dirty"]
|
||||
assert "\ud800" not in dirty_text, "Nested surrogate should be removed"
|
||||
|
||||
# Verify list items are sanitized
|
||||
list_items = loaded_data["level1"]["level2"]["list"]
|
||||
assert (
|
||||
"\ud801" not in list_items[1]
|
||||
), "List item surrogates should be removed"
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_unicode_non_characters_removed(self):
|
||||
"""Test that Unicode non-characters (U+FFFE, U+FFFF) don't cause encoding errors
|
||||
|
||||
Note: U+FFFE and U+FFFF are valid UTF-8 characters (though discouraged),
|
||||
so they don't trigger sanitization. They only get removed when explicitly
|
||||
using the SanitizingJSONEncoder.
|
||||
"""
|
||||
data_with_nonchars = {"text1": "Hello\ufffeWorld", "text2": "Test\uffffString"}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# These characters are valid UTF-8, so they take the fast path
|
||||
needs_reload = write_json(data_with_nonchars, temp_file)
|
||||
assert not needs_reload, "U+FFFE/U+FFFF are valid UTF-8 characters"
|
||||
|
||||
loaded_data = load_json(temp_file)
|
||||
# They're written as-is in the fast path
|
||||
assert loaded_data == data_with_nonchars
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_mixed_clean_dirty_data(self):
|
||||
"""Test data with both clean and dirty fields"""
|
||||
mixed_data = {
|
||||
"clean_field": "This is perfectly fine",
|
||||
"dirty_field": "This has\ud800issues",
|
||||
"number": 42,
|
||||
"boolean": True,
|
||||
"null_value": None,
|
||||
"clean_list": [1, 2, 3],
|
||||
"dirty_list": ["clean", "dirty\ud801item"],
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
needs_reload = write_json(mixed_data, temp_file)
|
||||
assert (
|
||||
needs_reload
|
||||
), "Mixed data with dirty fields should trigger sanitization"
|
||||
|
||||
loaded_data = load_json(temp_file)
|
||||
|
||||
# Clean fields should remain unchanged
|
||||
assert loaded_data["clean_field"] == "This is perfectly fine"
|
||||
assert loaded_data["number"] == 42
|
||||
assert loaded_data["boolean"]
|
||||
assert loaded_data["null_value"] is None
|
||||
assert loaded_data["clean_list"] == [1, 2, 3]
|
||||
|
||||
# Dirty fields should be sanitized
|
||||
assert "\ud800" not in loaded_data["dirty_field"]
|
||||
assert "\ud801" not in loaded_data["dirty_list"][1]
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_empty_and_none_strings(self):
|
||||
"""Test handling of empty and None values"""
|
||||
data = {
|
||||
"empty": "",
|
||||
"none": None,
|
||||
"zero": 0,
|
||||
"false": False,
|
||||
"empty_list": [],
|
||||
"empty_dict": {},
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
needs_reload = write_json(data, temp_file)
|
||||
assert (
|
||||
not needs_reload
|
||||
), "Clean empty values should not trigger sanitization"
|
||||
|
||||
loaded_data = load_json(temp_file)
|
||||
assert loaded_data == data, "Empty/None values should be preserved"
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_specific_surrogate_udc9a(self):
|
||||
"""Test specific surrogate character \\udc9a mentioned in the issue"""
|
||||
# Test the exact surrogate character from the error message:
|
||||
# UnicodeEncodeError: 'utf-8' codec can't encode character '\\udc9a'
|
||||
data_with_udc9a = {
|
||||
"text": "Some text with surrogate\udc9acharacter",
|
||||
"position": 201, # As mentioned in the error
|
||||
"clean_field": "Normal text",
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# Write data - should trigger sanitization
|
||||
needs_reload = write_json(data_with_udc9a, temp_file)
|
||||
assert needs_reload, "Data with \\udc9a should trigger sanitization"
|
||||
|
||||
# Verify surrogate was removed
|
||||
loaded_data = load_json(temp_file)
|
||||
assert loaded_data is not None
|
||||
assert "\udc9a" not in loaded_data["text"], "\\udc9a should be removed"
|
||||
assert (
|
||||
loaded_data["clean_field"] == "Normal text"
|
||||
), "Clean fields should remain"
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_migration_with_surrogate_sanitization(self):
|
||||
"""Test that migration process handles surrogate characters correctly
|
||||
|
||||
This test simulates the scenario where legacy cache contains surrogate
|
||||
characters and ensures they are cleaned during migration.
|
||||
"""
|
||||
# Simulate legacy cache data with surrogate characters
|
||||
legacy_data_with_surrogates = {
|
||||
"cache_entry_1": {
|
||||
"return": "Result with\ud800surrogate",
|
||||
"cache_type": "extract",
|
||||
"original_prompt": "Some\udc9aprompt",
|
||||
},
|
||||
"cache_entry_2": {
|
||||
"return": "Clean result",
|
||||
"cache_type": "query",
|
||||
"original_prompt": "Clean prompt",
|
||||
},
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# First write the dirty data directly (simulating legacy cache file)
|
||||
# Use custom encoder to force write even with surrogates
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
legacy_data_with_surrogates,
|
||||
f,
|
||||
cls=SanitizingJSONEncoder,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Load and verify surrogates were cleaned during initial write
|
||||
loaded_data = load_json(temp_file)
|
||||
assert loaded_data is not None
|
||||
|
||||
# The data should be sanitized
|
||||
assert (
|
||||
"\ud800" not in loaded_data["cache_entry_1"]["return"]
|
||||
), "Surrogate in return should be removed"
|
||||
assert (
|
||||
"\udc9a" not in loaded_data["cache_entry_1"]["original_prompt"]
|
||||
), "Surrogate in prompt should be removed"
|
||||
|
||||
# Clean data should remain unchanged
|
||||
assert (
|
||||
loaded_data["cache_entry_2"]["return"] == "Clean result"
|
||||
), "Clean data should remain"
|
||||
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_empty_values_after_sanitization(self):
|
||||
"""Test that data with empty values after sanitization is properly handled
|
||||
|
||||
Critical edge case: When sanitization results in data with empty string values,
|
||||
we must use 'if cleaned_data is not None' instead of 'if cleaned_data' to ensure
|
||||
proper reload, since truthy check on dict depends on content, not just existence.
|
||||
"""
|
||||
# Create data where ALL values are only surrogate characters
|
||||
all_dirty_data = {
|
||||
"key1": "\ud800\udc00\ud801",
|
||||
"key2": "\ud802\ud803",
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# Write dirty data - should trigger sanitization
|
||||
needs_reload = write_json(all_dirty_data, temp_file)
|
||||
assert needs_reload, "All-dirty data should trigger sanitization"
|
||||
|
||||
# Load the sanitized data
|
||||
cleaned_data = load_json(temp_file)
|
||||
|
||||
# Critical assertions for the edge case
|
||||
assert cleaned_data is not None, "Cleaned data should not be None"
|
||||
# Sanitization removes surrogates but preserves keys with empty values
|
||||
assert cleaned_data == {
|
||||
"key1": "",
|
||||
"key2": "",
|
||||
}, "Surrogates should be removed, keys preserved"
|
||||
# This dict is truthy because it has keys (even with empty values)
|
||||
assert cleaned_data, "Dict with keys is truthy"
|
||||
|
||||
# Test the actual edge case: empty dict
|
||||
empty_data = {}
|
||||
needs_reload2 = write_json(empty_data, temp_file)
|
||||
assert not needs_reload2, "Empty dict is clean"
|
||||
|
||||
reloaded_empty = load_json(temp_file)
|
||||
assert reloaded_empty is not None, "Empty dict should not be None"
|
||||
assert reloaded_empty == {}, "Empty dict should remain empty"
|
||||
assert (
|
||||
not reloaded_empty
|
||||
), "Empty dict evaluates to False (the critical check)"
|
||||
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
test = TestWriteJsonOptimization()
|
||||
|
||||
print("Running test_fast_path_clean_data...")
|
||||
test.test_fast_path_clean_data()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_slow_path_dirty_data...")
|
||||
test.test_slow_path_dirty_data()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_sanitizing_encoder_removes_surrogates...")
|
||||
test.test_sanitizing_encoder_removes_surrogates()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_nested_structure_sanitization...")
|
||||
test.test_nested_structure_sanitization()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_unicode_non_characters_removed...")
|
||||
test.test_unicode_non_characters_removed()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_mixed_clean_dirty_data...")
|
||||
test.test_mixed_clean_dirty_data()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_empty_and_none_strings...")
|
||||
test.test_empty_and_none_strings()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_specific_surrogate_udc9a...")
|
||||
test.test_specific_surrogate_udc9a()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_migration_with_surrogate_sanitization...")
|
||||
test.test_migration_with_surrogate_sanitization()
|
||||
print("✓ Passed")
|
||||
|
||||
print("Running test_empty_values_after_sanitization...")
|
||||
test.test_empty_values_after_sanitization()
|
||||
print("✓ Passed")
|
||||
|
||||
print("\n✅ All tests passed!")
|
||||
Loading…
Add table
Reference in a new issue