Merge branch 'HKUDS:main' into main
This commit is contained in:
commit
43a9b307bc
32 changed files with 3698 additions and 817 deletions
|
|
@ -23,10 +23,11 @@ LightRAG uses dynamic package installation (`pipmaster`) for optional features b
|
||||||
|
|
||||||
LightRAG dynamically installs packages for:
|
LightRAG dynamically installs packages for:
|
||||||
|
|
||||||
- **Document Processing**: `docling`, `pypdf2`, `python-docx`, `python-pptx`, `openpyxl`
|
|
||||||
- **Storage Backends**: `redis`, `neo4j`, `pymilvus`, `pymongo`, `asyncpg`, `qdrant-client`
|
- **Storage Backends**: `redis`, `neo4j`, `pymilvus`, `pymongo`, `asyncpg`, `qdrant-client`
|
||||||
- **LLM Providers**: `openai`, `anthropic`, `ollama`, `zhipuai`, `aioboto3`, `voyageai`, `llama-index`, `lmdeploy`, `transformers`, `torch`
|
- **LLM Providers**: `openai`, `anthropic`, `ollama`, `zhipuai`, `aioboto3`, `voyageai`, `llama-index`, `lmdeploy`, `transformers`, `torch`
|
||||||
- Tiktoken Models**: BPE encoding models downloaded from OpenAI CDN
|
- **Tiktoken Models**: BPE encoding models downloaded from OpenAI CDN
|
||||||
|
|
||||||
|
**Note**: Document processing dependencies (`pypdf`, `python-docx`, `python-pptx`, `openpyxl`) are now pre-installed with the `api` extras group and no longer require dynamic installation.
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
|
@ -75,32 +76,31 @@ LightRAG provides flexible dependency groups for different use cases:
|
||||||
|
|
||||||
| Group | Description | Use Case |
|
| Group | Description | Use Case |
|
||||||
|-------|-------------|----------|
|
|-------|-------------|----------|
|
||||||
| `offline-docs` | Document processing | PDF, DOCX, PPTX, XLSX files |
|
| `api` | API server + document processing | FastAPI server with PDF, DOCX, PPTX, XLSX support |
|
||||||
| `offline-storage` | Storage backends | Redis, Neo4j, MongoDB, PostgreSQL, etc. |
|
| `offline-storage` | Storage backends | Redis, Neo4j, MongoDB, PostgreSQL, etc. |
|
||||||
| `offline-llm` | LLM providers | OpenAI, Anthropic, Ollama, etc. |
|
| `offline-llm` | LLM providers | OpenAI, Anthropic, Ollama, etc. |
|
||||||
| `offline` | All of the above | Complete offline deployment |
|
| `offline` | Complete offline package | API + Storage + LLM (all features) |
|
||||||
|
|
||||||
|
**Note**: Document processing (PDF, DOCX, PPTX, XLSX) is included in the `api` extras group. The previous `offline-docs` group has been merged into `api` for better integration.
|
||||||
|
|
||||||
> Software packages requiring `transformers`, `torch`, or `cuda` will not be included in the offline dependency group.
|
> Software packages requiring `transformers`, `torch`, or `cuda` will not be included in the offline dependency group.
|
||||||
|
|
||||||
### Installation Examples
|
### Installation Examples
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install only document processing dependencies
|
# Install API with document processing
|
||||||
pip install lightrag-hku[offline-docs]
|
pip install lightrag-hku[api]
|
||||||
|
|
||||||
# Install document processing and storage backends
|
# Install API and storage backends
|
||||||
pip install lightrag-hku[offline-docs,offline-storage]
|
pip install lightrag-hku[api,offline-storage]
|
||||||
|
|
||||||
# Install all offline dependencies
|
# Install all offline dependencies (recommended for offline deployment)
|
||||||
pip install lightrag-hku[offline]
|
pip install lightrag-hku[offline]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using Individual Requirements Files
|
### Using Individual Requirements Files
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Document processing only
|
|
||||||
pip install -r requirements-offline-docs.txt
|
|
||||||
|
|
||||||
# Storage backends only
|
# Storage backends only
|
||||||
pip install -r requirements-offline-storage.txt
|
pip install -r requirements-offline-storage.txt
|
||||||
|
|
||||||
|
|
@ -244,8 +244,8 @@ ls -la ~/.tiktoken_cache/
|
||||||
**Solution**:
|
**Solution**:
|
||||||
```bash
|
```bash
|
||||||
# Pre-install the specific package you need
|
# Pre-install the specific package you need
|
||||||
# For document processing:
|
# For API with document processing:
|
||||||
pip install lightrag-hku[offline-docs]
|
pip install lightrag-hku[api]
|
||||||
|
|
||||||
# For storage backends:
|
# For storage backends:
|
||||||
pip install lightrag-hku[offline-storage]
|
pip install lightrag-hku[offline-storage]
|
||||||
|
|
@ -297,9 +297,9 @@ mkdir -p ~/my_tiktoken_cache
|
||||||
|
|
||||||
5. **Minimal Installation**: Only install what you need:
|
5. **Minimal Installation**: Only install what you need:
|
||||||
```bash
|
```bash
|
||||||
# If you only process PDFs with OpenAI
|
# If you only need API with document processing
|
||||||
pip install lightrag-hku[offline-docs]
|
pip install lightrag-hku[api]
|
||||||
# Then manually add: pip install openai
|
# Then manually add specific LLM: pip install openai
|
||||||
```
|
```
|
||||||
|
|
||||||
## Additional Resources
|
## Additional Resources
|
||||||
|
|
|
||||||
40
env.example
40
env.example
|
|
@ -29,7 +29,7 @@ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
||||||
# OLLAMA_EMULATING_MODEL_NAME=lightrag
|
# OLLAMA_EMULATING_MODEL_NAME=lightrag
|
||||||
OLLAMA_EMULATING_MODEL_TAG=latest
|
OLLAMA_EMULATING_MODEL_TAG=latest
|
||||||
|
|
||||||
### Max nodes return from graph retrieval in webui
|
### Max nodes for graph retrieval (Ensure WebUI local settings are also updated, which is limited to this value)
|
||||||
# MAX_GRAPH_NODES=1000
|
# MAX_GRAPH_NODES=1000
|
||||||
|
|
||||||
### Logging level
|
### Logging level
|
||||||
|
|
@ -255,21 +255,23 @@ OLLAMA_LLM_NUM_CTX=32768
|
||||||
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
|
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
|
||||||
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
|
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
|
||||||
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
|
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
|
||||||
# EMBEDDING_SEND_DIM=false
|
|
||||||
|
|
||||||
EMBEDDING_BINDING=ollama
|
# Ollama embedding
|
||||||
EMBEDDING_MODEL=bge-m3:latest
|
# EMBEDDING_BINDING=ollama
|
||||||
EMBEDDING_DIM=1024
|
# EMBEDDING_MODEL=bge-m3:latest
|
||||||
EMBEDDING_BINDING_API_KEY=your_api_key
|
# EMBEDDING_DIM=1024
|
||||||
# If LightRAG deployed in Docker uses host.docker.internal instead of localhost
|
|
||||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
|
||||||
|
|
||||||
### OpenAI compatible (VoyageAI embedding openai compatible)
|
|
||||||
# EMBEDDING_BINDING=openai
|
|
||||||
# EMBEDDING_MODEL=text-embedding-3-large
|
|
||||||
# EMBEDDING_DIM=3072
|
|
||||||
# EMBEDDING_BINDING_HOST=https://api.openai.com/v1
|
|
||||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||||
|
### If LightRAG deployed in Docker uses host.docker.internal instead of localhost
|
||||||
|
# EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||||
|
|
||||||
|
### OpenAI compatible embedding
|
||||||
|
EMBEDDING_BINDING=openai
|
||||||
|
EMBEDDING_MODEL=text-embedding-3-large
|
||||||
|
EMBEDDING_DIM=3072
|
||||||
|
EMBEDDING_SEND_DIM=false
|
||||||
|
EMBEDDING_TOKEN_LIMIT=8192
|
||||||
|
EMBEDDING_BINDING_HOST=https://api.openai.com/v1
|
||||||
|
EMBEDDING_BINDING_API_KEY=your_api_key
|
||||||
|
|
||||||
### Optional for Azure
|
### Optional for Azure
|
||||||
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
||||||
|
|
@ -277,6 +279,16 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||||
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
||||||
# AZURE_EMBEDDING_API_KEY=your_api_key
|
# AZURE_EMBEDDING_API_KEY=your_api_key
|
||||||
|
|
||||||
|
### Gemini embedding
|
||||||
|
# EMBEDDING_BINDING=gemini
|
||||||
|
# EMBEDDING_MODEL=gemini-embedding-001
|
||||||
|
# EMBEDDING_DIM=1536
|
||||||
|
# EMBEDDING_TOKEN_LIMIT=2048
|
||||||
|
# EMBEDDING_BINDING_HOST=https://generativelanguage.googleapis.com
|
||||||
|
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||||
|
### Gemini embedding requires sending dimension to server
|
||||||
|
# EMBEDDING_SEND_DIM=true
|
||||||
|
|
||||||
### Jina AI Embedding
|
### Jina AI Embedding
|
||||||
# EMBEDDING_BINDING=jina
|
# EMBEDDING_BINDING=jina
|
||||||
# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings
|
# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,8 @@ def xml_to_json(xml_file):
|
||||||
"description": edge.find("./data[@key='d6']", namespace).text
|
"description": edge.find("./data[@key='d6']", namespace).text
|
||||||
if edge.find("./data[@key='d6']", namespace) is not None
|
if edge.find("./data[@key='d6']", namespace) is not None
|
||||||
else "",
|
else "",
|
||||||
"keywords": edge.find("./data[@key='d7']", namespace).text
|
"keywords": edge.find("./data[@key='d9']", namespace).text
|
||||||
if edge.find("./data[@key='d7']", namespace) is not None
|
if edge.find("./data[@key='d9']", namespace) is not None
|
||||||
else "",
|
else "",
|
||||||
"source_id": edge.find("./data[@key='d8']", namespace).text
|
"source_id": edge.find("./data[@key='d8']", namespace).text
|
||||||
if edge.find("./data[@key='d8']", namespace) is not None
|
if edge.find("./data[@key='d8']", namespace) is not None
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
__api_version__ = "0253"
|
__api_version__ = "0254"
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,14 @@ def parse_args() -> argparse.Namespace:
|
||||||
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
|
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Document loading engine configuration
|
||||||
|
parser.add_argument(
|
||||||
|
"--docling",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Enable DOCLING document loading engine (default: from env or DEFAULT)",
|
||||||
|
)
|
||||||
|
|
||||||
# Conditionally add binding options defined in binding_options module
|
# Conditionally add binding options defined in binding_options module
|
||||||
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
|
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
|
||||||
# and corresponding environment variables (e.g., OLLAMA_EMBEDDING_NUM_CTX)
|
# and corresponding environment variables (e.g., OLLAMA_EMBEDDING_NUM_CTX)
|
||||||
|
|
@ -371,8 +379,13 @@ def parse_args() -> argparse.Namespace:
|
||||||
)
|
)
|
||||||
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
|
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
|
||||||
|
|
||||||
# Select Document loading tool (DOCLING, DEFAULT)
|
# Set document_loading_engine from --docling flag
|
||||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
if args.docling:
|
||||||
|
args.document_loading_engine = "DOCLING"
|
||||||
|
else:
|
||||||
|
args.document_loading_engine = get_env_value(
|
||||||
|
"DOCUMENT_LOADING_ENGINE", "DEFAULT"
|
||||||
|
)
|
||||||
|
|
||||||
# PDF decryption password
|
# PDF decryption password
|
||||||
args.pdf_decrypt_password = get_env_value("PDF_DECRYPT_PASSWORD", None)
|
args.pdf_decrypt_password = get_env_value("PDF_DECRYPT_PASSWORD", None)
|
||||||
|
|
@ -432,6 +445,11 @@ def parse_args() -> argparse.Namespace:
|
||||||
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Embedding token limit configuration
|
||||||
|
args.embedding_token_limit = get_env_value(
|
||||||
|
"EMBEDDING_TOKEN_LIMIT", None, int, special_none=True
|
||||||
|
)
|
||||||
|
|
||||||
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
||||||
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
||||||
|
|
||||||
|
|
@ -449,4 +467,83 @@ def update_uvicorn_mode_config():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
global_args = parse_args()
|
# Global configuration with lazy initialization
|
||||||
|
_global_args = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_config(args=None, force=False):
|
||||||
|
"""Initialize global configuration
|
||||||
|
|
||||||
|
This function allows explicit initialization of the configuration,
|
||||||
|
which is useful for programmatic usage, testing, or embedding LightRAG
|
||||||
|
in other applications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Pre-parsed argparse.Namespace or None to parse from sys.argv
|
||||||
|
force: Force re-initialization even if already initialized
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.Namespace: The configured arguments
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Use parsed command line arguments (default)
|
||||||
|
initialize_config()
|
||||||
|
|
||||||
|
# Use custom configuration programmatically
|
||||||
|
custom_args = argparse.Namespace(
|
||||||
|
host='localhost',
|
||||||
|
port=8080,
|
||||||
|
working_dir='./custom_rag',
|
||||||
|
# ... other config
|
||||||
|
)
|
||||||
|
initialize_config(custom_args)
|
||||||
|
"""
|
||||||
|
global _global_args, _initialized
|
||||||
|
|
||||||
|
if _initialized and not force:
|
||||||
|
return _global_args
|
||||||
|
|
||||||
|
_global_args = args if args is not None else parse_args()
|
||||||
|
_initialized = True
|
||||||
|
return _global_args
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
"""Get global configuration, auto-initializing if needed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.Namespace: The configured arguments
|
||||||
|
"""
|
||||||
|
if not _initialized:
|
||||||
|
initialize_config()
|
||||||
|
return _global_args
|
||||||
|
|
||||||
|
|
||||||
|
class _GlobalArgsProxy:
|
||||||
|
"""Proxy object that auto-initializes configuration on first access
|
||||||
|
|
||||||
|
This maintains backward compatibility with existing code while
|
||||||
|
allowing programmatic control over initialization timing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if not _initialized:
|
||||||
|
initialize_config()
|
||||||
|
return getattr(_global_args, name)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if not _initialized:
|
||||||
|
initialize_config()
|
||||||
|
setattr(_global_args, name, value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if not _initialized:
|
||||||
|
return "<GlobalArgsProxy: Not initialized>"
|
||||||
|
return repr(_global_args)
|
||||||
|
|
||||||
|
|
||||||
|
# Create proxy instance for backward compatibility
|
||||||
|
# Existing code like `from config import global_args` continues to work
|
||||||
|
# The proxy will auto-initialize on first attribute access
|
||||||
|
global_args = _GlobalArgsProxy()
|
||||||
|
|
|
||||||
|
|
@ -618,33 +618,108 @@ def create_app(args):
|
||||||
|
|
||||||
def create_optimized_embedding_function(
|
def create_optimized_embedding_function(
|
||||||
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||||
):
|
) -> EmbeddingFunc:
|
||||||
"""
|
"""
|
||||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
Create optimized embedding function and return an EmbeddingFunc instance
|
||||||
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
with proper max_token_size inheritance from provider defaults.
|
||||||
|
|
||||||
|
This function:
|
||||||
|
1. Imports the provider embedding function
|
||||||
|
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
|
||||||
|
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
|
||||||
|
4. Returns a properly configured EmbeddingFunc instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Step 1: Import provider function and extract default attributes
|
||||||
|
provider_func = None
|
||||||
|
provider_max_token_size = None
|
||||||
|
provider_embedding_dim = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if binding == "openai":
|
||||||
|
from lightrag.llm.openai import openai_embed
|
||||||
|
|
||||||
|
provider_func = openai_embed
|
||||||
|
elif binding == "ollama":
|
||||||
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
|
||||||
|
provider_func = ollama_embed
|
||||||
|
elif binding == "gemini":
|
||||||
|
from lightrag.llm.gemini import gemini_embed
|
||||||
|
|
||||||
|
provider_func = gemini_embed
|
||||||
|
elif binding == "jina":
|
||||||
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
|
provider_func = jina_embed
|
||||||
|
elif binding == "azure_openai":
|
||||||
|
from lightrag.llm.azure_openai import azure_openai_embed
|
||||||
|
|
||||||
|
provider_func = azure_openai_embed
|
||||||
|
elif binding == "aws_bedrock":
|
||||||
|
from lightrag.llm.bedrock import bedrock_embed
|
||||||
|
|
||||||
|
provider_func = bedrock_embed
|
||||||
|
elif binding == "lollms":
|
||||||
|
from lightrag.llm.lollms import lollms_embed
|
||||||
|
|
||||||
|
provider_func = lollms_embed
|
||||||
|
|
||||||
|
# Extract attributes if provider is an EmbeddingFunc
|
||||||
|
if provider_func and isinstance(provider_func, EmbeddingFunc):
|
||||||
|
provider_max_token_size = provider_func.max_token_size
|
||||||
|
provider_embedding_dim = provider_func.embedding_dim
|
||||||
|
logger.debug(
|
||||||
|
f"Extracted from {binding} provider: "
|
||||||
|
f"max_token_size={provider_max_token_size}, "
|
||||||
|
f"embedding_dim={provider_embedding_dim}"
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not import provider function for {binding}: {e}")
|
||||||
|
|
||||||
|
# Step 2: Apply priority (user config > provider default)
|
||||||
|
# For max_token_size: explicit env var > provider default > None
|
||||||
|
final_max_token_size = args.embedding_token_limit or provider_max_token_size
|
||||||
|
# For embedding_dim: user config (always has value) takes priority
|
||||||
|
# Only use provider default if user config is explicitly None (which shouldn't happen)
|
||||||
|
final_embedding_dim = (
|
||||||
|
args.embedding_dim if args.embedding_dim else provider_embedding_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||||
try:
|
try:
|
||||||
if binding == "lollms":
|
if binding == "lollms":
|
||||||
from lightrag.llm.lollms import lollms_embed
|
from lightrag.llm.lollms import lollms_embed
|
||||||
|
|
||||||
return await lollms_embed(
|
# Get real function, skip EmbeddingFunc wrapper if present
|
||||||
|
actual_func = (
|
||||||
|
lollms_embed.func
|
||||||
|
if isinstance(lollms_embed, EmbeddingFunc)
|
||||||
|
else lollms_embed
|
||||||
|
)
|
||||||
|
return await actual_func(
|
||||||
texts, embed_model=model, host=host, api_key=api_key
|
texts, embed_model=model, host=host, api_key=api_key
|
||||||
)
|
)
|
||||||
elif binding == "ollama":
|
elif binding == "ollama":
|
||||||
from lightrag.llm.ollama import ollama_embed
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
|
||||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
# Get real function, skip EmbeddingFunc wrapper if present
|
||||||
|
actual_func = (
|
||||||
|
ollama_embed.func
|
||||||
|
if isinstance(ollama_embed, EmbeddingFunc)
|
||||||
|
else ollama_embed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use pre-processed configuration if available
|
||||||
if config_cache.ollama_embedding_options is not None:
|
if config_cache.ollama_embedding_options is not None:
|
||||||
ollama_options = config_cache.ollama_embedding_options
|
ollama_options = config_cache.ollama_embedding_options
|
||||||
else:
|
else:
|
||||||
# Fallback for cases where config cache wasn't initialized properly
|
|
||||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||||
|
|
||||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||||
|
|
||||||
return await ollama_embed(
|
return await actual_func(
|
||||||
texts,
|
texts,
|
||||||
embed_model=model,
|
embed_model=model,
|
||||||
host=host,
|
host=host,
|
||||||
|
|
@ -654,15 +729,30 @@ def create_app(args):
|
||||||
elif binding == "azure_openai":
|
elif binding == "azure_openai":
|
||||||
from lightrag.llm.azure_openai import azure_openai_embed
|
from lightrag.llm.azure_openai import azure_openai_embed
|
||||||
|
|
||||||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
actual_func = (
|
||||||
|
azure_openai_embed.func
|
||||||
|
if isinstance(azure_openai_embed, EmbeddingFunc)
|
||||||
|
else azure_openai_embed
|
||||||
|
)
|
||||||
|
return await actual_func(texts, model=model, api_key=api_key)
|
||||||
elif binding == "aws_bedrock":
|
elif binding == "aws_bedrock":
|
||||||
from lightrag.llm.bedrock import bedrock_embed
|
from lightrag.llm.bedrock import bedrock_embed
|
||||||
|
|
||||||
return await bedrock_embed(texts, model=model)
|
actual_func = (
|
||||||
|
bedrock_embed.func
|
||||||
|
if isinstance(bedrock_embed, EmbeddingFunc)
|
||||||
|
else bedrock_embed
|
||||||
|
)
|
||||||
|
return await actual_func(texts, model=model)
|
||||||
elif binding == "jina":
|
elif binding == "jina":
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
return await jina_embed(
|
actual_func = (
|
||||||
|
jina_embed.func
|
||||||
|
if isinstance(jina_embed, EmbeddingFunc)
|
||||||
|
else jina_embed
|
||||||
|
)
|
||||||
|
return await actual_func(
|
||||||
texts,
|
texts,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
base_url=host,
|
base_url=host,
|
||||||
|
|
@ -671,16 +761,21 @@ def create_app(args):
|
||||||
elif binding == "gemini":
|
elif binding == "gemini":
|
||||||
from lightrag.llm.gemini import gemini_embed
|
from lightrag.llm.gemini import gemini_embed
|
||||||
|
|
||||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
actual_func = (
|
||||||
|
gemini_embed.func
|
||||||
|
if isinstance(gemini_embed, EmbeddingFunc)
|
||||||
|
else gemini_embed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use pre-processed configuration if available
|
||||||
if config_cache.gemini_embedding_options is not None:
|
if config_cache.gemini_embedding_options is not None:
|
||||||
gemini_options = config_cache.gemini_embedding_options
|
gemini_options = config_cache.gemini_embedding_options
|
||||||
else:
|
else:
|
||||||
# Fallback for cases where config cache wasn't initialized properly
|
|
||||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
||||||
|
|
||||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
||||||
|
|
||||||
return await gemini_embed(
|
return await actual_func(
|
||||||
texts,
|
texts,
|
||||||
model=model,
|
model=model,
|
||||||
base_url=host,
|
base_url=host,
|
||||||
|
|
@ -691,7 +786,12 @@ def create_app(args):
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
from lightrag.llm.openai import openai_embed
|
from lightrag.llm.openai import openai_embed
|
||||||
|
|
||||||
return await openai_embed(
|
actual_func = (
|
||||||
|
openai_embed.func
|
||||||
|
if isinstance(openai_embed, EmbeddingFunc)
|
||||||
|
else openai_embed
|
||||||
|
)
|
||||||
|
return await actual_func(
|
||||||
texts,
|
texts,
|
||||||
model=model,
|
model=model,
|
||||||
base_url=host,
|
base_url=host,
|
||||||
|
|
@ -701,7 +801,21 @@ def create_app(args):
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||||
|
|
||||||
return optimized_embedding_function
|
# Step 4: Wrap in EmbeddingFunc and return
|
||||||
|
embedding_func_instance = EmbeddingFunc(
|
||||||
|
embedding_dim=final_embedding_dim,
|
||||||
|
func=optimized_embedding_function,
|
||||||
|
max_token_size=final_max_token_size,
|
||||||
|
send_dimensions=False, # Will be set later based on binding requirements
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log final embedding configuration
|
||||||
|
logger.info(
|
||||||
|
f"Embedding config: binding={binding} model={model} "
|
||||||
|
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return embedding_func_instance
|
||||||
|
|
||||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||||
embedding_timeout = get_env_value(
|
embedding_timeout = get_env_value(
|
||||||
|
|
@ -735,25 +849,24 @@ def create_app(args):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create embedding function with optimized configuration
|
# Create embedding function with optimized configuration and max_token_size inheritance
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
# Create the optimized embedding function
|
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
|
||||||
optimized_embedding_func = create_optimized_embedding_function(
|
embedding_func = create_optimized_embedding_function(
|
||||||
config_cache=config_cache,
|
config_cache=config_cache,
|
||||||
binding=args.embedding_binding,
|
binding=args.embedding_binding,
|
||||||
model=args.embedding_model,
|
model=args.embedding_model,
|
||||||
host=args.embedding_binding_host,
|
host=args.embedding_binding_host,
|
||||||
api_key=args.embedding_binding_api_key,
|
api_key=args.embedding_binding_api_key,
|
||||||
args=args, # Pass args object for fallback option generation
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get embedding_send_dim from centralized configuration
|
# Get embedding_send_dim from centralized configuration
|
||||||
embedding_send_dim = args.embedding_send_dim
|
embedding_send_dim = args.embedding_send_dim
|
||||||
|
|
||||||
# Check if the function signature has embedding_dim parameter
|
# Check if the underlying function signature has embedding_dim parameter
|
||||||
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
sig = inspect.signature(embedding_func.func)
|
||||||
sig = inspect.signature(optimized_embedding_func)
|
|
||||||
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
||||||
|
|
||||||
# Determine send_dimensions value based on binding type
|
# Determine send_dimensions value based on binding type
|
||||||
|
|
@ -771,18 +884,27 @@ def create_app(args):
|
||||||
else:
|
else:
|
||||||
dimension_control = "by not hasparam"
|
dimension_control = "by not hasparam"
|
||||||
|
|
||||||
|
# Set send_dimensions on the EmbeddingFunc instance
|
||||||
|
embedding_func.send_dimensions = send_dimensions
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
||||||
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
|
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||||
f"binding={args.embedding_binding})"
|
f"binding={args.embedding_binding})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create EmbeddingFunc with send_dimensions attribute
|
# Log max_token_size source
|
||||||
embedding_func = EmbeddingFunc(
|
if embedding_func.max_token_size:
|
||||||
embedding_dim=args.embedding_dim,
|
source = (
|
||||||
func=optimized_embedding_func,
|
"env variable"
|
||||||
send_dimensions=send_dimensions,
|
if args.embedding_token_limit
|
||||||
)
|
else f"{args.embedding_binding} provider default"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
|
||||||
|
|
||||||
# Configure rerank function based on args.rerank_bindingparameter
|
# Configure rerank function based on args.rerank_bindingparameter
|
||||||
rerank_model_func = None
|
rerank_model_func = None
|
||||||
|
|
@ -1214,6 +1336,12 @@ def check_and_install_dependencies():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# Explicitly initialize configuration for clarity
|
||||||
|
# (The proxy will auto-initialize anyway, but this makes intent clear)
|
||||||
|
from .config import initialize_config
|
||||||
|
|
||||||
|
initialize_config()
|
||||||
|
|
||||||
# Check if running under Gunicorn
|
# Check if running under Gunicorn
|
||||||
if "GUNICORN_CMD_ARGS" in os.environ:
|
if "GUNICORN_CMD_ARGS" in os.environ:
|
||||||
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,15 @@ This module contains all document-related routes for the LightRAG API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from functools import lru_cache
|
||||||
from lightrag.utils import logger, get_pinyin_sort_key
|
from lightrag.utils import logger, get_pinyin_sort_key
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import shutil
|
import shutil
|
||||||
import traceback
|
import traceback
|
||||||
import pipmaster as pm
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Any, Literal
|
from typing import Dict, List, Optional, Any, Literal
|
||||||
|
from io import BytesIO
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
BackgroundTasks,
|
BackgroundTasks,
|
||||||
|
|
@ -28,6 +29,24 @@ from lightrag.api.utils_api import get_combined_auth_dependency
|
||||||
from ..config import global_args
|
from ..config import global_args
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _is_docling_available() -> bool:
|
||||||
|
"""Check if docling is available (cached check).
|
||||||
|
|
||||||
|
This function uses lru_cache to avoid repeated import attempts.
|
||||||
|
The result is cached after the first call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if docling is available, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import docling # noqa: F401 # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Function to format datetime to ISO format string with timezone information
|
# Function to format datetime to ISO format string with timezone information
|
||||||
def format_datetime(dt: Any) -> Optional[str]:
|
def format_datetime(dt: Any) -> Optional[str]:
|
||||||
"""Format datetime to ISO format string with timezone information
|
"""Format datetime to ISO format string with timezone information
|
||||||
|
|
@ -879,7 +898,6 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
|
||||||
Returns:
|
Returns:
|
||||||
str: Unique filename (may have numeric suffix added)
|
str: Unique filename (may have numeric suffix added)
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
original_path = Path(original_name)
|
original_path = Path(original_name)
|
||||||
|
|
@ -902,6 +920,122 @@ def get_unique_filename_in_enqueued(target_dir: Path, original_name: str) -> str
|
||||||
return f"{base_name}_{timestamp}{extension}"
|
return f"{base_name}_{timestamp}{extension}"
|
||||||
|
|
||||||
|
|
||||||
|
# Document processing helper functions (synchronous)
|
||||||
|
# These functions run in thread pool via asyncio.to_thread() to avoid blocking the event loop
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_with_docling(file_path: Path) -> str:
|
||||||
|
"""Convert document using docling (synchronous).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the document file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Extracted markdown content
|
||||||
|
"""
|
||||||
|
from docling.document_converter import DocumentConverter # type: ignore
|
||||||
|
|
||||||
|
converter = DocumentConverter()
|
||||||
|
result = converter.convert(file_path)
|
||||||
|
return result.document.export_to_markdown()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pdf_pypdf(file_bytes: bytes, password: str = None) -> str:
|
||||||
|
"""Extract PDF content using pypdf (synchronous).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_bytes: PDF file content as bytes
|
||||||
|
password: Optional password for encrypted PDFs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Extracted text content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If PDF is encrypted and password is incorrect or missing
|
||||||
|
"""
|
||||||
|
from pypdf import PdfReader # type: ignore
|
||||||
|
|
||||||
|
pdf_file = BytesIO(file_bytes)
|
||||||
|
reader = PdfReader(pdf_file)
|
||||||
|
|
||||||
|
# Check if PDF is encrypted
|
||||||
|
if reader.is_encrypted:
|
||||||
|
if not password:
|
||||||
|
raise Exception("PDF is encrypted but no password provided")
|
||||||
|
|
||||||
|
decrypt_result = reader.decrypt(password)
|
||||||
|
if decrypt_result == 0:
|
||||||
|
raise Exception("Incorrect PDF password")
|
||||||
|
|
||||||
|
# Extract text from all pages
|
||||||
|
content = ""
|
||||||
|
for page in reader.pages:
|
||||||
|
content += page.extract_text() + "\n"
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_docx(file_bytes: bytes) -> str:
|
||||||
|
"""Extract DOCX content (synchronous).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_bytes: DOCX file content as bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Extracted text content
|
||||||
|
"""
|
||||||
|
from docx import Document # type: ignore
|
||||||
|
|
||||||
|
docx_file = BytesIO(file_bytes)
|
||||||
|
doc = Document(docx_file)
|
||||||
|
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pptx(file_bytes: bytes) -> str:
|
||||||
|
"""Extract PPTX content (synchronous).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_bytes: PPTX file content as bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Extracted text content
|
||||||
|
"""
|
||||||
|
from pptx import Presentation # type: ignore
|
||||||
|
|
||||||
|
pptx_file = BytesIO(file_bytes)
|
||||||
|
prs = Presentation(pptx_file)
|
||||||
|
content = ""
|
||||||
|
for slide in prs.slides:
|
||||||
|
for shape in slide.shapes:
|
||||||
|
if hasattr(shape, "text"):
|
||||||
|
content += shape.text + "\n"
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_xlsx(file_bytes: bytes) -> str:
|
||||||
|
"""Extract XLSX content (synchronous).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_bytes: XLSX file content as bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Extracted text content
|
||||||
|
"""
|
||||||
|
from openpyxl import load_workbook # type: ignore
|
||||||
|
|
||||||
|
xlsx_file = BytesIO(file_bytes)
|
||||||
|
wb = load_workbook(xlsx_file)
|
||||||
|
content = ""
|
||||||
|
for sheet in wb:
|
||||||
|
content += f"Sheet: {sheet.title}\n"
|
||||||
|
for row in sheet.iter_rows(values_only=True):
|
||||||
|
content += (
|
||||||
|
"\t".join(str(cell) if cell is not None else "" for cell in row) + "\n"
|
||||||
|
)
|
||||||
|
content += "\n"
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_enqueue_file(
|
async def pipeline_enqueue_file(
|
||||||
rag: LightRAG, file_path: Path, track_id: str = None
|
rag: LightRAG, file_path: Path, track_id: str = None
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
|
|
@ -1072,87 +1206,28 @@ async def pipeline_enqueue_file(
|
||||||
|
|
||||||
case ".pdf":
|
case ".pdf":
|
||||||
try:
|
try:
|
||||||
if global_args.document_loading_engine == "DOCLING":
|
# Try DOCLING first if configured and available
|
||||||
if not pm.is_installed("docling"): # type: ignore
|
if (
|
||||||
pm.install("docling")
|
global_args.document_loading_engine == "DOCLING"
|
||||||
from docling.document_converter import DocumentConverter # type: ignore
|
and _is_docling_available()
|
||||||
|
):
|
||||||
converter = DocumentConverter()
|
content = await asyncio.to_thread(
|
||||||
result = converter.convert(file_path)
|
_convert_with_docling, file_path
|
||||||
content = result.document.export_to_markdown()
|
)
|
||||||
else:
|
else:
|
||||||
if not pm.is_installed("pypdf"): # type: ignore
|
if (
|
||||||
pm.install("pypdf")
|
global_args.document_loading_engine == "DOCLING"
|
||||||
if not pm.is_installed("pycryptodome"): # type: ignore
|
and not _is_docling_available()
|
||||||
pm.install("pycryptodome")
|
):
|
||||||
from pypdf import PdfReader # type: ignore
|
logger.warning(
|
||||||
from io import BytesIO
|
f"DOCLING engine configured but not available for {file_path.name}. Falling back to pypdf."
|
||||||
|
)
|
||||||
pdf_file = BytesIO(file)
|
# Use pypdf (non-blocking via to_thread)
|
||||||
reader = PdfReader(pdf_file)
|
content = await asyncio.to_thread(
|
||||||
|
_extract_pdf_pypdf,
|
||||||
# Check if PDF is encrypted
|
file,
|
||||||
if reader.is_encrypted:
|
global_args.pdf_decrypt_password,
|
||||||
pdf_password = global_args.pdf_decrypt_password
|
)
|
||||||
if not pdf_password:
|
|
||||||
# PDF is encrypted but no password provided
|
|
||||||
error_files = [
|
|
||||||
{
|
|
||||||
"file_path": str(file_path.name),
|
|
||||||
"error_description": "[File Extraction]PDF is encrypted but no password provided",
|
|
||||||
"original_error": "Please set PDF_DECRYPT_PASSWORD environment variable to decrypt this PDF file",
|
|
||||||
"file_size": file_size,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
await rag.apipeline_enqueue_error_documents(
|
|
||||||
error_files, track_id
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
f"[File Extraction]PDF is encrypted but no password provided: {file_path.name}"
|
|
||||||
)
|
|
||||||
return False, track_id
|
|
||||||
|
|
||||||
# Try to decrypt with password
|
|
||||||
try:
|
|
||||||
decrypt_result = reader.decrypt(pdf_password)
|
|
||||||
if decrypt_result == 0:
|
|
||||||
# Password is incorrect
|
|
||||||
error_files = [
|
|
||||||
{
|
|
||||||
"file_path": str(file_path.name),
|
|
||||||
"error_description": "[File Extraction]Failed to decrypt PDF - incorrect password",
|
|
||||||
"original_error": "The provided PDF_DECRYPT_PASSWORD is incorrect for this file",
|
|
||||||
"file_size": file_size,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
await rag.apipeline_enqueue_error_documents(
|
|
||||||
error_files, track_id
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
f"[File Extraction]Incorrect PDF password: {file_path.name}"
|
|
||||||
)
|
|
||||||
return False, track_id
|
|
||||||
except Exception as decrypt_error:
|
|
||||||
# Decryption process error
|
|
||||||
error_files = [
|
|
||||||
{
|
|
||||||
"file_path": str(file_path.name),
|
|
||||||
"error_description": "[File Extraction]PDF decryption failed",
|
|
||||||
"original_error": f"Error during PDF decryption: {str(decrypt_error)}",
|
|
||||||
"file_size": file_size,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
await rag.apipeline_enqueue_error_documents(
|
|
||||||
error_files, track_id
|
|
||||||
)
|
|
||||||
logger.error(
|
|
||||||
f"[File Extraction]PDF decryption error for {file_path.name}: {str(decrypt_error)}"
|
|
||||||
)
|
|
||||||
return False, track_id
|
|
||||||
|
|
||||||
# Extract text from PDF (encrypted PDFs are now decrypted, unencrypted PDFs proceed directly)
|
|
||||||
for page in reader.pages:
|
|
||||||
content += page.extract_text() + "\n"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_files = [
|
error_files = [
|
||||||
{
|
{
|
||||||
|
|
@ -1172,28 +1247,24 @@ async def pipeline_enqueue_file(
|
||||||
|
|
||||||
case ".docx":
|
case ".docx":
|
||||||
try:
|
try:
|
||||||
if global_args.document_loading_engine == "DOCLING":
|
# Try DOCLING first if configured and available
|
||||||
if not pm.is_installed("docling"): # type: ignore
|
if (
|
||||||
pm.install("docling")
|
global_args.document_loading_engine == "DOCLING"
|
||||||
from docling.document_converter import DocumentConverter # type: ignore
|
and _is_docling_available()
|
||||||
|
):
|
||||||
converter = DocumentConverter()
|
content = await asyncio.to_thread(
|
||||||
result = converter.convert(file_path)
|
_convert_with_docling, file_path
|
||||||
content = result.document.export_to_markdown()
|
|
||||||
else:
|
|
||||||
if not pm.is_installed("python-docx"): # type: ignore
|
|
||||||
try:
|
|
||||||
pm.install("python-docx")
|
|
||||||
except Exception:
|
|
||||||
pm.install("docx")
|
|
||||||
from docx import Document # type: ignore
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
docx_file = BytesIO(file)
|
|
||||||
doc = Document(docx_file)
|
|
||||||
content = "\n".join(
|
|
||||||
[paragraph.text for paragraph in doc.paragraphs]
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
global_args.document_loading_engine == "DOCLING"
|
||||||
|
and not _is_docling_available()
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"DOCLING engine configured but not available for {file_path.name}. Falling back to python-docx."
|
||||||
|
)
|
||||||
|
# Use python-docx (non-blocking via to_thread)
|
||||||
|
content = await asyncio.to_thread(_extract_docx, file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_files = [
|
error_files = [
|
||||||
{
|
{
|
||||||
|
|
@ -1213,26 +1284,24 @@ async def pipeline_enqueue_file(
|
||||||
|
|
||||||
case ".pptx":
|
case ".pptx":
|
||||||
try:
|
try:
|
||||||
if global_args.document_loading_engine == "DOCLING":
|
# Try DOCLING first if configured and available
|
||||||
if not pm.is_installed("docling"): # type: ignore
|
if (
|
||||||
pm.install("docling")
|
global_args.document_loading_engine == "DOCLING"
|
||||||
from docling.document_converter import DocumentConverter # type: ignore
|
and _is_docling_available()
|
||||||
|
):
|
||||||
converter = DocumentConverter()
|
content = await asyncio.to_thread(
|
||||||
result = converter.convert(file_path)
|
_convert_with_docling, file_path
|
||||||
content = result.document.export_to_markdown()
|
)
|
||||||
else:
|
else:
|
||||||
if not pm.is_installed("python-pptx"): # type: ignore
|
if (
|
||||||
pm.install("pptx")
|
global_args.document_loading_engine == "DOCLING"
|
||||||
from pptx import Presentation # type: ignore
|
and not _is_docling_available()
|
||||||
from io import BytesIO
|
):
|
||||||
|
logger.warning(
|
||||||
pptx_file = BytesIO(file)
|
f"DOCLING engine configured but not available for {file_path.name}. Falling back to python-pptx."
|
||||||
prs = Presentation(pptx_file)
|
)
|
||||||
for slide in prs.slides:
|
# Use python-pptx (non-blocking via to_thread)
|
||||||
for shape in slide.shapes:
|
content = await asyncio.to_thread(_extract_pptx, file)
|
||||||
if hasattr(shape, "text"):
|
|
||||||
content += shape.text + "\n"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_files = [
|
error_files = [
|
||||||
{
|
{
|
||||||
|
|
@ -1252,33 +1321,24 @@ async def pipeline_enqueue_file(
|
||||||
|
|
||||||
case ".xlsx":
|
case ".xlsx":
|
||||||
try:
|
try:
|
||||||
if global_args.document_loading_engine == "DOCLING":
|
# Try DOCLING first if configured and available
|
||||||
if not pm.is_installed("docling"): # type: ignore
|
if (
|
||||||
pm.install("docling")
|
global_args.document_loading_engine == "DOCLING"
|
||||||
from docling.document_converter import DocumentConverter # type: ignore
|
and _is_docling_available()
|
||||||
|
):
|
||||||
converter = DocumentConverter()
|
content = await asyncio.to_thread(
|
||||||
result = converter.convert(file_path)
|
_convert_with_docling, file_path
|
||||||
content = result.document.export_to_markdown()
|
)
|
||||||
else:
|
else:
|
||||||
if not pm.is_installed("openpyxl"): # type: ignore
|
if (
|
||||||
pm.install("openpyxl")
|
global_args.document_loading_engine == "DOCLING"
|
||||||
from openpyxl import load_workbook # type: ignore
|
and not _is_docling_available()
|
||||||
from io import BytesIO
|
):
|
||||||
|
logger.warning(
|
||||||
xlsx_file = BytesIO(file)
|
f"DOCLING engine configured but not available for {file_path.name}. Falling back to openpyxl."
|
||||||
wb = load_workbook(xlsx_file)
|
)
|
||||||
for sheet in wb:
|
# Use openpyxl (non-blocking via to_thread)
|
||||||
content += f"Sheet: {sheet.title}\n"
|
content = await asyncio.to_thread(_extract_xlsx, file)
|
||||||
for row in sheet.iter_rows(values_only=True):
|
|
||||||
content += (
|
|
||||||
"\t".join(
|
|
||||||
str(cell) if cell is not None else ""
|
|
||||||
for cell in row
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
content += "\n"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_files = [
|
error_files = [
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ Start LightRAG server with Gunicorn
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import platform
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
from lightrag.api.utils_api import display_splash_screen, check_env_file
|
from lightrag.api.utils_api import display_splash_screen, check_env_file
|
||||||
from lightrag.api.config import global_args
|
from lightrag.api.config import global_args
|
||||||
|
|
@ -34,6 +35,11 @@ def check_and_install_dependencies():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# Explicitly initialize configuration for Gunicorn mode
|
||||||
|
from lightrag.api.config import initialize_config
|
||||||
|
|
||||||
|
initialize_config()
|
||||||
|
|
||||||
# Set Gunicorn mode flag for lifespan cleanup detection
|
# Set Gunicorn mode flag for lifespan cleanup detection
|
||||||
os.environ["LIGHTRAG_GUNICORN_MODE"] = "1"
|
os.environ["LIGHTRAG_GUNICORN_MODE"] = "1"
|
||||||
|
|
||||||
|
|
@ -41,6 +47,68 @@ def main():
|
||||||
if not check_env_file():
|
if not check_env_file():
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Check DOCLING compatibility with Gunicorn multi-worker mode on macOS
|
||||||
|
if (
|
||||||
|
platform.system() == "Darwin"
|
||||||
|
and global_args.document_loading_engine == "DOCLING"
|
||||||
|
and global_args.workers > 1
|
||||||
|
):
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("❌ ERROR: Incompatible configuration detected!")
|
||||||
|
print("=" * 80)
|
||||||
|
print(
|
||||||
|
"\nDOCLING engine with Gunicorn multi-worker mode is not supported on macOS"
|
||||||
|
)
|
||||||
|
print("\nReason:")
|
||||||
|
print(" PyTorch (required by DOCLING) has known compatibility issues with")
|
||||||
|
print(" fork-based multiprocessing on macOS, which can cause crashes or")
|
||||||
|
print(" unexpected behavior when using Gunicorn with multiple workers.")
|
||||||
|
print("\nCurrent configuration:")
|
||||||
|
print(" - Operating System: macOS (Darwin)")
|
||||||
|
print(f" - Document Engine: {global_args.document_loading_engine}")
|
||||||
|
print(f" - Workers: {global_args.workers}")
|
||||||
|
print("\nPossible solutions:")
|
||||||
|
print(" 1. Use single worker mode:")
|
||||||
|
print(" --workers 1")
|
||||||
|
print("\n 2. Change document loading engine in .env:")
|
||||||
|
print(" DOCUMENT_LOADING_ENGINE=DEFAULT")
|
||||||
|
print("\n 3. Deploy on Linux where multi-worker mode is fully supported")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Check macOS fork safety environment variable for multi-worker mode
|
||||||
|
if (
|
||||||
|
platform.system() == "Darwin"
|
||||||
|
and global_args.workers > 1
|
||||||
|
and os.environ.get("OBJC_DISABLE_INITIALIZE_FORK_SAFETY") != "YES"
|
||||||
|
):
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("❌ ERROR: Missing required environment variable on macOS!")
|
||||||
|
print("=" * 80)
|
||||||
|
print("\nmacOS with Gunicorn multi-worker mode requires:")
|
||||||
|
print(" OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES")
|
||||||
|
print("\nReason:")
|
||||||
|
print(" NumPy uses macOS's Accelerate framework (Objective-C based) for")
|
||||||
|
print(" vector computations. The Objective-C runtime has fork safety checks")
|
||||||
|
print(" that will crash worker processes when embedding functions are called.")
|
||||||
|
print("\nCurrent configuration:")
|
||||||
|
print(" - Operating System: macOS (Darwin)")
|
||||||
|
print(f" - Workers: {global_args.workers}")
|
||||||
|
print(
|
||||||
|
f" - Environment Variable: {os.environ.get('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'NOT SET')}"
|
||||||
|
)
|
||||||
|
print("\nHow to fix:")
|
||||||
|
print(" Option 1 - Set environment variable before starting (recommended):")
|
||||||
|
print(" export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES")
|
||||||
|
print(" lightrag-server")
|
||||||
|
print("\n Option 2 - Add to your shell profile (~/.zshrc or ~/.bash_profile):")
|
||||||
|
print(" echo 'export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES' >> ~/.zshrc")
|
||||||
|
print(" source ~/.zshrc")
|
||||||
|
print("\n Option 3 - Use single worker mode (no multiprocessing):")
|
||||||
|
print(" lightrag-server --workers 1")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Check and install dependencies
|
# Check and install dependencies
|
||||||
check_and_install_dependencies()
|
check_and_install_dependencies()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,20 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
|
f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
|
||||||
)
|
)
|
||||||
write_json(data_dict, self._file_name)
|
|
||||||
|
# Write JSON and check if sanitization was applied
|
||||||
|
needs_reload = write_json(data_dict, self._file_name)
|
||||||
|
|
||||||
|
# If data was sanitized, reload cleaned data to update shared memory
|
||||||
|
if needs_reload:
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
|
||||||
|
)
|
||||||
|
cleaned_data = load_json(self._file_name)
|
||||||
|
if cleaned_data is not None:
|
||||||
|
self._data.clear()
|
||||||
|
self._data.update(cleaned_data)
|
||||||
|
|
||||||
await clear_all_update_flags(self.final_namespace)
|
await clear_all_update_flags(self.final_namespace)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,20 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||||
)
|
)
|
||||||
write_json(data_dict, self._file_name)
|
|
||||||
|
# Write JSON and check if sanitization was applied
|
||||||
|
needs_reload = write_json(data_dict, self._file_name)
|
||||||
|
|
||||||
|
# If data was sanitized, reload cleaned data to update shared memory
|
||||||
|
if needs_reload:
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
|
||||||
|
)
|
||||||
|
cleaned_data = load_json(self._file_name)
|
||||||
|
if cleaned_data is not None:
|
||||||
|
self._data.clear()
|
||||||
|
self._data.update(cleaned_data)
|
||||||
|
|
||||||
await clear_all_update_flags(self.final_namespace)
|
await clear_all_update_flags(self.final_namespace)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
|
|
@ -224,7 +237,7 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
data: Original data dictionary that may contain legacy structure
|
data: Original data dictionary that may contain legacy structure
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Migrated data dictionary with flattened cache keys
|
Migrated data dictionary with flattened cache keys (sanitized if needed)
|
||||||
"""
|
"""
|
||||||
from lightrag.utils import generate_cache_key
|
from lightrag.utils import generate_cache_key
|
||||||
|
|
||||||
|
|
@ -261,8 +274,17 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure"
|
f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure"
|
||||||
)
|
)
|
||||||
# Persist migrated data immediately
|
# Persist migrated data immediately and check if sanitization was applied
|
||||||
write_json(migrated_data, self._file_name)
|
needs_reload = write_json(migrated_data, self._file_name)
|
||||||
|
|
||||||
|
# If data was sanitized during write, reload cleaned data
|
||||||
|
if needs_reload:
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}"
|
||||||
|
)
|
||||||
|
cleaned_data = load_json(self._file_name)
|
||||||
|
if cleaned_data is not None:
|
||||||
|
return cleaned_data # Return cleaned data to update shared memory
|
||||||
|
|
||||||
return migrated_data
|
return migrated_data
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
||||||
|
|
@ -146,7 +147,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure the result is consumed even on error
|
if result is not None:
|
||||||
|
await (
|
||||||
|
result.consume()
|
||||||
|
) # Ensure the result is consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
|
|
@ -170,6 +174,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = (
|
query = (
|
||||||
|
|
@ -190,7 +195,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure the result is consumed even on error
|
if result is not None:
|
||||||
|
await (
|
||||||
|
result.consume()
|
||||||
|
) # Ensure the result is consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
|
|
@ -312,6 +320,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"""
|
query = f"""
|
||||||
|
|
@ -328,7 +337,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
return labels
|
return labels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
|
logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}")
|
||||||
await result.consume() # Ensure the result is consumed even on error
|
if result is not None:
|
||||||
|
await (
|
||||||
|
result.consume()
|
||||||
|
) # Ensure the result is consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
|
|
@ -352,6 +364,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
results = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||||
|
|
@ -389,7 +402,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await results.consume() # Ensure results are consumed even on error
|
if results is not None:
|
||||||
|
await (
|
||||||
|
results.consume()
|
||||||
|
) # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -419,6 +435,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"""
|
query = f"""
|
||||||
|
|
@ -451,7 +468,10 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
|
f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure the result is consumed even on error
|
if result is not None:
|
||||||
|
await (
|
||||||
|
result.consume()
|
||||||
|
) # Ensure the result is consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
|
|
@ -1030,6 +1050,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
|
|
@ -1056,6 +1077,8 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
return labels
|
return labels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
|
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
|
||||||
|
if result is not None:
|
||||||
|
await result.consume()
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||||
|
|
@ -1078,6 +1101,7 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
if not query_lower:
|
if not query_lower:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
|
|
@ -1111,4 +1135,6 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
return labels
|
return labels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
|
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
|
||||||
|
if result is not None:
|
||||||
|
await result.consume()
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -371,6 +371,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
|
||||||
result = await session.run(query, entity_id=node_id)
|
result = await session.run(query, entity_id=node_id)
|
||||||
|
|
@ -381,7 +382,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure results are consumed even on error
|
if result is not None:
|
||||||
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
|
|
@ -403,6 +405,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
|
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
|
||||||
|
|
@ -420,7 +423,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume() # Ensure results are consumed even on error
|
if result is not None:
|
||||||
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
|
|
@ -799,6 +803,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
results = None
|
||||||
try:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
workspace_label = self._get_workspace_label()
|
||||||
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
|
||||||
|
|
@ -836,7 +841,10 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
await results.consume() # Ensure results are consumed even on error
|
if results is not None:
|
||||||
|
await (
|
||||||
|
results.consume()
|
||||||
|
) # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -1592,6 +1600,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
result = None
|
||||||
try:
|
try:
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:`{workspace_label}`)
|
MATCH (n:`{workspace_label}`)
|
||||||
|
|
@ -1616,7 +1625,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error getting popular labels: {str(e)}"
|
f"[{self.workspace}] Error getting popular labels: {str(e)}"
|
||||||
)
|
)
|
||||||
await result.consume()
|
if result is not None:
|
||||||
|
await result.consume()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import configparser
|
import configparser
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -12,6 +13,7 @@ from functools import partial
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Iterator,
|
Iterator,
|
||||||
cast,
|
cast,
|
||||||
|
|
@ -20,6 +22,7 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
List,
|
List,
|
||||||
Dict,
|
Dict,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
from lightrag.prompt import PROMPTS
|
from lightrag.prompt import PROMPTS
|
||||||
from lightrag.exceptions import PipelineCancelledException
|
from lightrag.exceptions import PipelineCancelledException
|
||||||
|
|
@ -243,11 +246,13 @@ class LightRAG:
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
],
|
],
|
||||||
List[Dict[str, Any]],
|
Union[List[Dict[str, Any]], Awaitable[List[Dict[str, Any]]]],
|
||||||
] = field(default_factory=lambda: chunking_by_token_size)
|
] = field(default_factory=lambda: chunking_by_token_size)
|
||||||
"""
|
"""
|
||||||
Custom chunking function for splitting text into chunks before processing.
|
Custom chunking function for splitting text into chunks before processing.
|
||||||
|
|
||||||
|
The function can be either synchronous or asynchronous.
|
||||||
|
|
||||||
The function should take the following parameters:
|
The function should take the following parameters:
|
||||||
|
|
||||||
- `tokenizer`: A Tokenizer instance to use for tokenization.
|
- `tokenizer`: A Tokenizer instance to use for tokenization.
|
||||||
|
|
@ -257,7 +262,8 @@ class LightRAG:
|
||||||
- `chunk_token_size`: The maximum number of tokens per chunk.
|
- `chunk_token_size`: The maximum number of tokens per chunk.
|
||||||
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
||||||
|
|
||||||
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
The function should return a list of dictionaries (or an awaitable that resolves to a list),
|
||||||
|
where each dictionary contains the following keys:
|
||||||
- `tokens`: The number of tokens in the chunk.
|
- `tokens`: The number of tokens in the chunk.
|
||||||
- `content`: The text content of the chunk.
|
- `content`: The text content of the chunk.
|
||||||
|
|
||||||
|
|
@ -270,6 +276,9 @@ class LightRAG:
|
||||||
embedding_func: EmbeddingFunc | None = field(default=None)
|
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||||
"""Function for computing text embeddings. Must be set before use."""
|
"""Function for computing text embeddings. Must be set before use."""
|
||||||
|
|
||||||
|
embedding_token_limit: int | None = field(default=None, init=False)
|
||||||
|
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
|
||||||
|
|
||||||
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
|
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
|
||||||
"""Batch size for embedding computations."""
|
"""Batch size for embedding computations."""
|
||||||
|
|
||||||
|
|
@ -513,6 +522,16 @@ class LightRAG:
|
||||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||||
|
|
||||||
# Init Embedding
|
# Init Embedding
|
||||||
|
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
|
||||||
|
embedding_max_token_size = None
|
||||||
|
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
|
||||||
|
embedding_max_token_size = self.embedding_func.max_token_size
|
||||||
|
logger.debug(
|
||||||
|
f"Captured embedding max_token_size: {embedding_max_token_size}"
|
||||||
|
)
|
||||||
|
self.embedding_token_limit = embedding_max_token_size
|
||||||
|
|
||||||
|
# Step 2: Apply priority wrapper decorator
|
||||||
self.embedding_func = priority_limit_async_func_call(
|
self.embedding_func = priority_limit_async_func_call(
|
||||||
self.embedding_func_max_async,
|
self.embedding_func_max_async,
|
||||||
llm_timeout=self.default_embedding_timeout,
|
llm_timeout=self.default_embedding_timeout,
|
||||||
|
|
@ -1756,7 +1775,28 @@ class LightRAG:
|
||||||
)
|
)
|
||||||
content = content_data["content"]
|
content = content_data["content"]
|
||||||
|
|
||||||
# Generate chunks from document
|
# Call chunking function, supporting both sync and async implementations
|
||||||
|
chunking_result = self.chunking_func(
|
||||||
|
self.tokenizer,
|
||||||
|
content,
|
||||||
|
split_by_character,
|
||||||
|
split_by_character_only,
|
||||||
|
self.chunk_overlap_token_size,
|
||||||
|
self.chunk_token_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If result is awaitable, await to get actual result
|
||||||
|
if inspect.isawaitable(chunking_result):
|
||||||
|
chunking_result = await chunking_result
|
||||||
|
|
||||||
|
# Validate return type
|
||||||
|
if not isinstance(chunking_result, (list, tuple)):
|
||||||
|
raise TypeError(
|
||||||
|
f"chunking_func must return a list or tuple of dicts, "
|
||||||
|
f"got {type(chunking_result)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build chunks dictionary
|
||||||
chunks: dict[str, Any] = {
|
chunks: dict[str, Any] = {
|
||||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||||
**dp,
|
**dp,
|
||||||
|
|
@ -1764,14 +1804,7 @@ class LightRAG:
|
||||||
"file_path": file_path, # Add file path to each chunk
|
"file_path": file_path, # Add file path to each chunk
|
||||||
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
|
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
|
||||||
}
|
}
|
||||||
for dp in self.chunking_func(
|
for dp in chunking_result
|
||||||
self.tokenizer,
|
|
||||||
content,
|
|
||||||
split_by_character,
|
|
||||||
split_by_character_only,
|
|
||||||
self.chunk_overlap_token_size,
|
|
||||||
self.chunk_token_size,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
import pipmaster as pm # Pipmaster for dynamic library install
|
import pipmaster as pm # Pipmaster for dynamic library install
|
||||||
|
|
||||||
|
|
@ -16,6 +17,7 @@ from tenacity import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||||
|
|
||||||
if sys.version_info < (3, 9):
|
if sys.version_info < (3, 9):
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
|
@ -23,21 +25,121 @@ else:
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
# Import botocore exceptions for proper exception handling
|
||||||
|
try:
|
||||||
|
from botocore.exceptions import (
|
||||||
|
ClientError,
|
||||||
|
ConnectionError as BotocoreConnectionError,
|
||||||
|
ReadTimeoutError,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# If botocore is not installed, define placeholders
|
||||||
|
ClientError = Exception
|
||||||
|
BotocoreConnectionError = Exception
|
||||||
|
ReadTimeoutError = Exception
|
||||||
|
|
||||||
|
|
||||||
class BedrockError(Exception):
|
class BedrockError(Exception):
|
||||||
"""Generic error for issues related to Amazon Bedrock"""
|
"""Generic error for issues related to Amazon Bedrock"""
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockRateLimitError(BedrockError):
|
||||||
|
"""Error for rate limiting and throttling issues"""
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockConnectionError(BedrockError):
|
||||||
|
"""Error for network and connection issues"""
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockTimeoutError(BedrockError):
|
||||||
|
"""Error for timeout issues"""
|
||||||
|
|
||||||
|
|
||||||
def _set_env_if_present(key: str, value):
|
def _set_env_if_present(key: str, value):
|
||||||
"""Set environment variable only if a non-empty value is provided."""
|
"""Set environment variable only if a non-empty value is provided."""
|
||||||
if value is not None and value != "":
|
if value is not None and value != "":
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
|
||||||
|
"""Convert AWS Bedrock exceptions to appropriate custom exceptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
e: The exception to handle
|
||||||
|
operation: Description of the operation for error messages
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BedrockRateLimitError: For rate limiting and throttling issues (retryable)
|
||||||
|
BedrockConnectionError: For network and server issues (retryable)
|
||||||
|
BedrockTimeoutError: For timeout issues (retryable)
|
||||||
|
BedrockError: For validation and other non-retryable errors
|
||||||
|
"""
|
||||||
|
error_message = str(e)
|
||||||
|
|
||||||
|
# Handle botocore ClientError with specific error codes
|
||||||
|
if isinstance(e, ClientError):
|
||||||
|
error_code = e.response.get("Error", {}).get("Code", "")
|
||||||
|
error_msg = e.response.get("Error", {}).get("Message", error_message)
|
||||||
|
|
||||||
|
# Rate limiting and throttling errors (retryable)
|
||||||
|
if error_code in [
|
||||||
|
"ThrottlingException",
|
||||||
|
"ProvisionedThroughputExceededException",
|
||||||
|
]:
|
||||||
|
logging.error(f"{operation} rate limit error: {error_msg}")
|
||||||
|
raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
|
||||||
|
|
||||||
|
# Server errors (retryable)
|
||||||
|
elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
|
||||||
|
logging.error(f"{operation} connection error: {error_msg}")
|
||||||
|
raise BedrockConnectionError(f"Service error: {error_msg}")
|
||||||
|
|
||||||
|
# Check for 5xx HTTP status codes (retryable)
|
||||||
|
elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
|
||||||
|
logging.error(f"{operation} server error: {error_msg}")
|
||||||
|
raise BedrockConnectionError(f"Server error: {error_msg}")
|
||||||
|
|
||||||
|
# Validation and other client errors (non-retryable)
|
||||||
|
else:
|
||||||
|
logging.error(f"{operation} client error: {error_msg}")
|
||||||
|
raise BedrockError(f"Client error: {error_msg}")
|
||||||
|
|
||||||
|
# Connection errors (retryable)
|
||||||
|
elif isinstance(e, BotocoreConnectionError):
|
||||||
|
logging.error(f"{operation} connection error: {error_message}")
|
||||||
|
raise BedrockConnectionError(f"Connection error: {error_message}")
|
||||||
|
|
||||||
|
# Timeout errors (retryable)
|
||||||
|
elif isinstance(e, (ReadTimeoutError, TimeoutError)):
|
||||||
|
logging.error(f"{operation} timeout error: {error_message}")
|
||||||
|
raise BedrockTimeoutError(f"Timeout error: {error_message}")
|
||||||
|
|
||||||
|
# Custom Bedrock errors (already properly typed)
|
||||||
|
elif isinstance(
|
||||||
|
e,
|
||||||
|
(
|
||||||
|
BedrockRateLimitError,
|
||||||
|
BedrockConnectionError,
|
||||||
|
BedrockTimeoutError,
|
||||||
|
BedrockError,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Unknown errors (non-retryable)
|
||||||
|
else:
|
||||||
|
logging.error(f"{operation} unexpected error: {error_message}")
|
||||||
|
raise BedrockError(f"Unexpected error: {error_message}")
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(5),
|
stop=stop_after_attempt(5),
|
||||||
wait=wait_exponential(multiplier=1, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
retry=retry_if_exception_type((BedrockError)),
|
retry=(
|
||||||
|
retry_if_exception_type(BedrockRateLimitError)
|
||||||
|
| retry_if_exception_type(BedrockConnectionError)
|
||||||
|
| retry_if_exception_type(BedrockTimeoutError)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
async def bedrock_complete_if_cache(
|
async def bedrock_complete_if_cache(
|
||||||
model,
|
model,
|
||||||
|
|
@ -158,9 +260,6 @@ async def bedrock_complete_if_cache(
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the specific error for debugging
|
|
||||||
logging.error(f"Bedrock streaming error: {e}")
|
|
||||||
|
|
||||||
# Try to clean up resources if possible
|
# Try to clean up resources if possible
|
||||||
if (
|
if (
|
||||||
iteration_started
|
iteration_started
|
||||||
|
|
@ -175,7 +274,8 @@ async def bedrock_complete_if_cache(
|
||||||
f"Failed to close Bedrock event stream: {close_error}"
|
f"Failed to close Bedrock event stream: {close_error}"
|
||||||
)
|
)
|
||||||
|
|
||||||
raise BedrockError(f"Streaming error: {e}")
|
# Convert to appropriate exception type
|
||||||
|
_handle_bedrock_exception(e, "Bedrock streaming")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up the event stream
|
# Clean up the event stream
|
||||||
|
|
@ -231,10 +331,8 @@ async def bedrock_complete_if_cache(
|
||||||
return content
|
return content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, BedrockError):
|
# Convert to appropriate exception type
|
||||||
raise
|
_handle_bedrock_exception(e, "Bedrock converse")
|
||||||
else:
|
|
||||||
raise BedrockError(f"Bedrock API error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# Generic Bedrock completion function
|
# Generic Bedrock completion function
|
||||||
|
|
@ -253,12 +351,16 @@ async def bedrock_complete(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024)
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
# @retry(
|
@retry(
|
||||||
# stop=stop_after_attempt(3),
|
stop=stop_after_attempt(5),
|
||||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
retry=(
|
||||||
# )
|
retry_if_exception_type(BedrockRateLimitError)
|
||||||
|
| retry_if_exception_type(BedrockConnectionError)
|
||||||
|
| retry_if_exception_type(BedrockTimeoutError)
|
||||||
|
),
|
||||||
|
)
|
||||||
async def bedrock_embed(
|
async def bedrock_embed(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
model: str = "amazon.titan-embed-text-v2:0",
|
model: str = "amazon.titan-embed-text-v2:0",
|
||||||
|
|
@ -281,48 +383,101 @@ async def bedrock_embed(
|
||||||
async with session.client(
|
async with session.client(
|
||||||
"bedrock-runtime", region_name=region
|
"bedrock-runtime", region_name=region
|
||||||
) as bedrock_async_client:
|
) as bedrock_async_client:
|
||||||
if (model_provider := model.split(".")[0]) == "amazon":
|
try:
|
||||||
embed_texts = []
|
if (model_provider := model.split(".")[0]) == "amazon":
|
||||||
for text in texts:
|
embed_texts = []
|
||||||
if "v2" in model:
|
for text in texts:
|
||||||
|
try:
|
||||||
|
if "v2" in model:
|
||||||
|
body = json.dumps(
|
||||||
|
{
|
||||||
|
"inputText": text,
|
||||||
|
# 'dimensions': embedding_dim,
|
||||||
|
"embeddingTypes": ["float"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif "v1" in model:
|
||||||
|
body = json.dumps({"inputText": text})
|
||||||
|
else:
|
||||||
|
raise BedrockError(f"Model {model} is not supported!")
|
||||||
|
|
||||||
|
response = await bedrock_async_client.invoke_model(
|
||||||
|
modelId=model,
|
||||||
|
body=body,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_body = await response.get("body").json()
|
||||||
|
|
||||||
|
# Validate response structure
|
||||||
|
if not response_body or "embedding" not in response_body:
|
||||||
|
raise BedrockError(
|
||||||
|
f"Invalid embedding response structure for text: {text[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = response_body["embedding"]
|
||||||
|
if not embedding:
|
||||||
|
raise BedrockError(
|
||||||
|
f"Received empty embedding for text: {text[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_texts.append(embedding)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Convert to appropriate exception type
|
||||||
|
_handle_bedrock_exception(
|
||||||
|
e, "Bedrock embedding (amazon, text chunk)"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_provider == "cohere":
|
||||||
|
try:
|
||||||
body = json.dumps(
|
body = json.dumps(
|
||||||
{
|
{
|
||||||
"inputText": text,
|
"texts": texts,
|
||||||
# 'dimensions': embedding_dim,
|
"input_type": "search_document",
|
||||||
"embeddingTypes": ["float"],
|
"truncate": "NONE",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif "v1" in model:
|
|
||||||
body = json.dumps({"inputText": text})
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Model {model} is not supported!")
|
|
||||||
|
|
||||||
response = await bedrock_async_client.invoke_model(
|
response = await bedrock_async_client.invoke_model(
|
||||||
modelId=model,
|
model=model,
|
||||||
body=body,
|
body=body,
|
||||||
accept="application/json",
|
accept="application/json",
|
||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
|
||||||
|
# Validate response structure
|
||||||
|
if not response_body or "embeddings" not in response_body:
|
||||||
|
raise BedrockError(
|
||||||
|
"Invalid embedding response structure from Cohere"
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = response_body["embeddings"]
|
||||||
|
if not embeddings or len(embeddings) != len(texts):
|
||||||
|
raise BedrockError(
|
||||||
|
f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_texts = embeddings
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Convert to appropriate exception type
|
||||||
|
_handle_bedrock_exception(e, "Bedrock embedding (cohere)")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise BedrockError(
|
||||||
|
f"Model provider '{model_provider}' is not supported!"
|
||||||
)
|
)
|
||||||
|
|
||||||
response_body = await response.get("body").json()
|
# Final validation
|
||||||
|
if not embed_texts:
|
||||||
|
raise BedrockError("No embeddings generated")
|
||||||
|
|
||||||
embed_texts.append(response_body["embedding"])
|
return np.array(embed_texts)
|
||||||
elif model_provider == "cohere":
|
|
||||||
body = json.dumps(
|
|
||||||
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await bedrock_async_client.invoke_model(
|
except Exception as e:
|
||||||
model=model,
|
# Convert to appropriate exception type
|
||||||
body=body,
|
_handle_bedrock_exception(e, "Bedrock embedding")
|
||||||
accept="application/json",
|
|
||||||
contentType="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
response_body = json.loads(response.get("body").read())
|
|
||||||
|
|
||||||
embed_texts = response_body["embeddings"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
|
||||||
|
|
||||||
return np.array(embed_texts)
|
|
||||||
|
|
|
||||||
|
|
@ -453,7 +453,7 @@ async def gemini_model_complete(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from lightrag.exceptions import (
|
||||||
)
|
)
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
@ -141,6 +142,7 @@ async def hf_model_complete(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||||
# Detect the appropriate device
|
# Detect the appropriate device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ async def fetch_data(url, headers, data):
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ async def llama_index_complete(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,10 @@ from lightrag.exceptions import (
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from lightrag.utils import (
|
||||||
|
wrap_embedding_func_with_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
|
|
@ -134,6 +138,7 @@ async def lollms_model_complete(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
async def lollms_embed(
|
async def lollms_embed(
|
||||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ from lightrag.utils import (
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -22,8 +24,31 @@ from lightrag.exceptions import (
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import (
|
||||||
|
wrap_embedding_func_with_attrs,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_OLLAMA_CLOUD_HOST = "https://ollama.com"
|
||||||
|
_CLOUD_MODEL_SUFFIX_PATTERN = re.compile(r"(?:-cloud|:cloud)$")
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_host_for_cloud_model(host: Optional[str], model: object) -> Optional[str]:
|
||||||
|
if host:
|
||||||
|
return host
|
||||||
|
try:
|
||||||
|
model_name_str = str(model) if model is not None else ""
|
||||||
|
except (TypeError, ValueError, AttributeError) as e:
|
||||||
|
logger.warning(f"Failed to convert model to string: {e}, using empty string")
|
||||||
|
model_name_str = ""
|
||||||
|
if _CLOUD_MODEL_SUFFIX_PATTERN.search(model_name_str):
|
||||||
|
logger.debug(
|
||||||
|
f"Detected cloud model '{model_name_str}', using Ollama Cloud host"
|
||||||
|
)
|
||||||
|
return _OLLAMA_CLOUD_HOST
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
|
|
@ -53,6 +78,9 @@ async def _ollama_model_if_cache(
|
||||||
timeout = None
|
timeout = None
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
api_key = kwargs.pop("api_key", None)
|
api_key = kwargs.pop("api_key", None)
|
||||||
|
# fallback to environment variable when not provided explicitly
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("OLLAMA_API_KEY")
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"User-Agent": f"LightRAG/{__api_version__}",
|
"User-Agent": f"LightRAG/{__api_version__}",
|
||||||
|
|
@ -60,6 +88,8 @@ async def _ollama_model_if_cache(
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
host = _coerce_host_for_cloud_model(host, model)
|
||||||
|
|
||||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -142,8 +172,11 @@ async def ollama_model_complete(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
api_key = kwargs.pop("api_key", None)
|
api_key = kwargs.pop("api_key", None)
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("OLLAMA_API_KEY")
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"User-Agent": f"LightRAG/{__api_version__}",
|
"User-Agent": f"LightRAG/{__api_version__}",
|
||||||
|
|
@ -154,6 +187,8 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
host = kwargs.pop("host", None)
|
host = kwargs.pop("host", None)
|
||||||
timeout = kwargs.pop("timeout", None)
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
|
||||||
|
host = _coerce_host_for_cloud_model(host, embed_model)
|
||||||
|
|
||||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||||
try:
|
try:
|
||||||
options = kwargs.pop("options", {})
|
options = kwargs.pop("options", {})
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ try:
|
||||||
|
|
||||||
# Only enable Langfuse if both keys are configured
|
# Only enable Langfuse if both keys are configured
|
||||||
if langfuse_public_key and langfuse_secret_key:
|
if langfuse_public_key and langfuse_secret_key:
|
||||||
from langfuse.openai import AsyncOpenAI
|
from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped]
|
||||||
|
|
||||||
LANGFUSE_ENABLED = True
|
LANGFUSE_ENABLED = True
|
||||||
logger.info("Langfuse observability enabled for OpenAI client")
|
logger.info("Langfuse observability enabled for OpenAI client")
|
||||||
|
|
@ -604,7 +604,7 @@ async def nvidia_openai_complete(
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
|
||||||
|
|
@ -345,6 +345,20 @@ async def _summarize_descriptions(
|
||||||
llm_response_cache=llm_response_cache,
|
llm_response_cache=llm_response_cache,
|
||||||
cache_type="summary",
|
cache_type="summary",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check summary token length against embedding limit
|
||||||
|
embedding_token_limit = global_config.get("embedding_token_limit")
|
||||||
|
if embedding_token_limit is not None and summary:
|
||||||
|
tokenizer = global_config["tokenizer"]
|
||||||
|
summary_token_count = len(tokenizer.encode(summary))
|
||||||
|
threshold = int(embedding_token_limit * 0.9)
|
||||||
|
|
||||||
|
if summary_token_count > threshold:
|
||||||
|
logger.warning(
|
||||||
|
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
|
||||||
|
f"({embedding_token_limit}) for {description_type}: {description_name}"
|
||||||
|
)
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,9 @@ if not logger.handlers:
|
||||||
# Set httpx logging level to WARNING
|
# Set httpx logging level to WARNING
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||||
|
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||||
|
|
||||||
# Global import for pypinyin with startup-time logging
|
# Global import for pypinyin with startup-time logging
|
||||||
try:
|
try:
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
|
@ -350,9 +353,20 @@ class TaskState:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
|
"""Embedding function wrapper with dimension validation
|
||||||
|
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
||||||
|
This class should not be wrapped multiple times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_dim: Expected dimension of the embeddings
|
||||||
|
func: The actual embedding function to wrap
|
||||||
|
max_token_size: Optional token limit for the embedding model
|
||||||
|
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
||||||
|
"""
|
||||||
|
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
func: callable
|
func: callable
|
||||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
max_token_size: int | None = None # Token limit for the embedding model
|
||||||
send_dimensions: bool = (
|
send_dimensions: bool = (
|
||||||
False # Control whether to send embedding_dim to the function
|
False # Control whether to send embedding_dim to the function
|
||||||
)
|
)
|
||||||
|
|
@ -376,7 +390,32 @@ class EmbeddingFunc:
|
||||||
# Inject embedding_dim from decorator
|
# Inject embedding_dim from decorator
|
||||||
kwargs["embedding_dim"] = self.embedding_dim
|
kwargs["embedding_dim"] = self.embedding_dim
|
||||||
|
|
||||||
return await self.func(*args, **kwargs)
|
# Call the actual embedding function
|
||||||
|
result = await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Validate embedding dimensions using total element count
|
||||||
|
total_elements = result.size # Total number of elements in the numpy array
|
||||||
|
expected_dim = self.embedding_dim
|
||||||
|
|
||||||
|
# Check if total elements can be evenly divided by embedding_dim
|
||||||
|
if total_elements % expected_dim != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding dimension mismatch detected: "
|
||||||
|
f"total elements ({total_elements}) cannot be evenly divided by "
|
||||||
|
f"expected dimension ({expected_dim}). "
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional: Verify vector count matches input text count
|
||||||
|
actual_vectors = total_elements // expected_dim
|
||||||
|
if args and isinstance(args[0], (list, tuple)):
|
||||||
|
expected_vectors = len(args[0])
|
||||||
|
if actual_vectors != expected_vectors:
|
||||||
|
raise ValueError(
|
||||||
|
f"Vector count mismatch: "
|
||||||
|
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args: Any) -> str:
|
def compute_args_hash(*args: Any) -> str:
|
||||||
|
|
@ -930,73 +969,120 @@ def load_json(file_name):
|
||||||
def _sanitize_string_for_json(text: str) -> str:
|
def _sanitize_string_for_json(text: str) -> str:
|
||||||
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
||||||
|
|
||||||
This is a simpler sanitizer specifically for JSON that directly removes
|
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
||||||
problematic characters without attempting to encode first.
|
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: String to sanitize
|
text: String to sanitize
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sanitized string safe for UTF-8 encoding in JSON
|
Original string if clean (zero-copy), sanitized string if dirty
|
||||||
"""
|
"""
|
||||||
if not text:
|
if not text:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# Directly filter out problematic characters without pre-validation
|
# Fast path: Check if sanitization is needed using C-level regex search
|
||||||
sanitized = ""
|
if not _SURROGATE_PATTERN.search(text):
|
||||||
for char in text:
|
return text # Zero-copy for clean strings - most common case
|
||||||
code_point = ord(char)
|
|
||||||
# Skip surrogate characters (U+D800 to U+DFFF) - main cause of encoding errors
|
|
||||||
if 0xD800 <= code_point <= 0xDFFF:
|
|
||||||
continue
|
|
||||||
# Skip other non-characters in Unicode
|
|
||||||
elif code_point == 0xFFFE or code_point == 0xFFFF:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
sanitized += char
|
|
||||||
|
|
||||||
return sanitized
|
# Slow path: Remove problematic characters using C-level regex substitution
|
||||||
|
return _SURROGATE_PATTERN.sub("", text)
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_json_data(data: Any) -> Any:
|
class SanitizingJSONEncoder(json.JSONEncoder):
|
||||||
"""Recursively sanitize all string values in data structure for safe UTF-8 encoding
|
|
||||||
|
|
||||||
Handles all JSON-serializable types including:
|
|
||||||
- Dictionary keys and values
|
|
||||||
- Lists and tuples (preserves type)
|
|
||||||
- Nested structures
|
|
||||||
- Strings at any level
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Data to sanitize (dict, list, tuple, str, or other types)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized data with all strings cleaned of problematic characters
|
|
||||||
"""
|
"""
|
||||||
if isinstance(data, dict):
|
Custom JSON encoder that sanitizes data during serialization.
|
||||||
# Sanitize both keys and values
|
|
||||||
return {
|
This encoder cleans strings during the encoding process without creating
|
||||||
_sanitize_string_for_json(k)
|
a full copy of the data structure, making it memory-efficient for large datasets.
|
||||||
if isinstance(k, str)
|
"""
|
||||||
else k: _sanitize_json_data(v)
|
|
||||||
for k, v in data.items()
|
def encode(self, o):
|
||||||
}
|
"""Override encode method to handle simple string cases"""
|
||||||
elif isinstance(data, (list, tuple)):
|
if isinstance(o, str):
|
||||||
# Handle both lists and tuples, preserve original type
|
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
|
||||||
sanitized = [_sanitize_json_data(item) for item in data]
|
return super().encode(o)
|
||||||
return type(data)(sanitized)
|
|
||||||
elif isinstance(data, str):
|
def iterencode(self, o, _one_shot=False):
|
||||||
return _sanitize_string_for_json(data)
|
"""
|
||||||
else:
|
Override iterencode to sanitize strings during serialization.
|
||||||
# Numbers, booleans, None, etc. - return as-is
|
This is the core method that handles complex nested structures.
|
||||||
return data
|
"""
|
||||||
|
# Preprocess: sanitize all strings in the object
|
||||||
|
sanitized = self._sanitize_for_encoding(o)
|
||||||
|
|
||||||
|
# Call parent's iterencode with sanitized data
|
||||||
|
for chunk in super().iterencode(sanitized, _one_shot):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _sanitize_for_encoding(self, obj):
|
||||||
|
"""
|
||||||
|
Recursively sanitize strings in an object.
|
||||||
|
Creates new objects only when necessary to avoid deep copies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized object with cleaned strings
|
||||||
|
"""
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return _sanitize_string_for_json(obj)
|
||||||
|
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
# Create new dict with sanitized keys and values
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in obj.items():
|
||||||
|
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
|
||||||
|
clean_v = self._sanitize_for_encoding(v)
|
||||||
|
new_dict[clean_k] = clean_v
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
# Sanitize list/tuple elements
|
||||||
|
cleaned = [self._sanitize_for_encoding(item) for item in obj]
|
||||||
|
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Numbers, booleans, None, etc. remain unchanged
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def write_json(json_obj, file_name):
|
def write_json(json_obj, file_name):
|
||||||
# Sanitize data before writing to prevent UTF-8 encoding errors
|
"""
|
||||||
sanitized_obj = _sanitize_json_data(json_obj)
|
Write JSON data to file with optimized sanitization strategy.
|
||||||
|
|
||||||
|
This function uses a two-stage approach:
|
||||||
|
1. Fast path: Try direct serialization (works for clean data ~99% of time)
|
||||||
|
2. Slow path: Use custom encoder that sanitizes during serialization
|
||||||
|
|
||||||
|
The custom encoder approach avoids creating a deep copy of the data,
|
||||||
|
making it memory-efficient. When sanitization occurs, the caller should
|
||||||
|
reload the cleaned data from the file to update shared memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_obj: Object to serialize (may be a shallow copy from shared memory)
|
||||||
|
file_name: Output file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if sanitization was applied (caller should reload data),
|
||||||
|
False if direct write succeeded (no reload needed)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Strategy 1: Fast path - try direct serialization
|
||||||
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
return False # No sanitization needed, no reload required
|
||||||
|
|
||||||
|
except (UnicodeEncodeError, UnicodeDecodeError) as e:
|
||||||
|
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
|
||||||
|
|
||||||
|
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(sanitized_obj, f, indent=2, ensure_ascii=False)
|
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
|
||||||
|
|
||||||
|
logger.info(f"JSON sanitization applied during write: {file_name}")
|
||||||
|
return True # Sanitization applied, reload recommended
|
||||||
|
|
||||||
|
|
||||||
class TokenizerInterface(Protocol):
|
class TokenizerInterface(Protocol):
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,6 @@ export default function QuerySettings() {
|
||||||
// Default values for reset functionality
|
// Default values for reset functionality
|
||||||
const defaultValues = useMemo(() => ({
|
const defaultValues = useMemo(() => ({
|
||||||
mode: 'mix' as QueryMode,
|
mode: 'mix' as QueryMode,
|
||||||
response_type: 'Multiple Paragraphs',
|
|
||||||
top_k: 40,
|
top_k: 40,
|
||||||
chunk_top_k: 20,
|
chunk_top_k: 20,
|
||||||
max_entity_tokens: 6000,
|
max_entity_tokens: 6000,
|
||||||
|
|
@ -153,46 +152,6 @@ export default function QuerySettings() {
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
|
|
||||||
{/* Response Format */}
|
|
||||||
<>
|
|
||||||
<TooltipProvider>
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<label htmlFor="response_format_select" className="ml-1 cursor-help">
|
|
||||||
{t('retrievePanel.querySettings.responseFormat')}
|
|
||||||
</label>
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent side="left">
|
|
||||||
<p>{t('retrievePanel.querySettings.responseFormatTooltip')}</p>
|
|
||||||
</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
</TooltipProvider>
|
|
||||||
<div className="flex items-center gap-1">
|
|
||||||
<Select
|
|
||||||
value={querySettings.response_type}
|
|
||||||
onValueChange={(v) => handleChange('response_type', v)}
|
|
||||||
>
|
|
||||||
<SelectTrigger
|
|
||||||
id="response_format_select"
|
|
||||||
className="hover:bg-primary/5 h-9 cursor-pointer focus:ring-0 focus:ring-offset-0 focus:outline-0 active:right-0 flex-1 text-left [&>span]:break-all [&>span]:line-clamp-1"
|
|
||||||
>
|
|
||||||
<SelectValue />
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
<SelectGroup>
|
|
||||||
<SelectItem value="Multiple Paragraphs">{t('retrievePanel.querySettings.responseFormatOptions.multipleParagraphs')}</SelectItem>
|
|
||||||
<SelectItem value="Single Paragraph">{t('retrievePanel.querySettings.responseFormatOptions.singleParagraph')}</SelectItem>
|
|
||||||
<SelectItem value="Bullet Points">{t('retrievePanel.querySettings.responseFormatOptions.bulletPoints')}</SelectItem>
|
|
||||||
</SelectGroup>
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
<ResetButton
|
|
||||||
onClick={() => handleReset('response_type')}
|
|
||||||
title="Reset to default (Multiple Paragraphs)"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
|
|
||||||
{/* Top K */}
|
{/* Top K */}
|
||||||
<>
|
<>
|
||||||
<TooltipProvider>
|
<TooltipProvider>
|
||||||
|
|
|
||||||
|
|
@ -357,6 +357,7 @@ export default function RetrievalTesting() {
|
||||||
const queryParams = {
|
const queryParams = {
|
||||||
...state.querySettings,
|
...state.querySettings,
|
||||||
query: actualQuery,
|
query: actualQuery,
|
||||||
|
response_type: 'Multiple Paragraphs',
|
||||||
conversation_history: effectiveHistoryTurns > 0
|
conversation_history: effectiveHistoryTurns > 0
|
||||||
? prevMessages
|
? prevMessages
|
||||||
.filter((m) => m.isError !== true)
|
.filter((m) => m.isError !== true)
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,6 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||||
|
|
||||||
querySettings: {
|
querySettings: {
|
||||||
mode: 'global',
|
mode: 'global',
|
||||||
response_type: 'Multiple Paragraphs',
|
|
||||||
top_k: 40,
|
top_k: 40,
|
||||||
chunk_top_k: 20,
|
chunk_top_k: 20,
|
||||||
max_entity_tokens: 6000,
|
max_entity_tokens: 6000,
|
||||||
|
|
@ -239,7 +238,7 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||||
{
|
{
|
||||||
name: 'settings-storage',
|
name: 'settings-storage',
|
||||||
storage: createJSONStorage(() => localStorage),
|
storage: createJSONStorage(() => localStorage),
|
||||||
version: 18,
|
version: 19,
|
||||||
migrate: (state: any, version: number) => {
|
migrate: (state: any, version: number) => {
|
||||||
if (version < 2) {
|
if (version < 2) {
|
||||||
state.showEdgeLabel = false
|
state.showEdgeLabel = false
|
||||||
|
|
@ -336,6 +335,12 @@ const useSettingsStoreBase = create<SettingsState>()(
|
||||||
// Add userPromptHistory field for older versions
|
// Add userPromptHistory field for older versions
|
||||||
state.userPromptHistory = []
|
state.userPromptHistory = []
|
||||||
}
|
}
|
||||||
|
if (version < 19) {
|
||||||
|
// Remove deprecated response_type parameter
|
||||||
|
if (state.querySettings) {
|
||||||
|
delete state.querySettings.response_type
|
||||||
|
}
|
||||||
|
}
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ dependencies = [
|
||||||
"json_repair",
|
"json_repair",
|
||||||
"nano-vectordb",
|
"nano-vectordb",
|
||||||
"networkx",
|
"networkx",
|
||||||
"numpy",
|
"numpy>=1.24.0,<2.0.0",
|
||||||
"pandas>=2.0.0,<2.4.0",
|
"pandas>=2.0.0,<2.4.0",
|
||||||
"pipmaster",
|
"pipmaster",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
|
|
@ -50,7 +50,7 @@ api = [
|
||||||
"json_repair",
|
"json_repair",
|
||||||
"nano-vectordb",
|
"nano-vectordb",
|
||||||
"networkx",
|
"networkx",
|
||||||
"numpy",
|
"numpy>=1.24.0,<2.0.0",
|
||||||
"openai>=1.0.0,<3.0.0",
|
"openai>=1.0.0,<3.0.0",
|
||||||
"pandas>=2.0.0,<2.4.0",
|
"pandas>=2.0.0,<2.4.0",
|
||||||
"pipmaster",
|
"pipmaster",
|
||||||
|
|
@ -79,18 +79,23 @@ api = [
|
||||||
"python-multipart",
|
"python-multipart",
|
||||||
"pytz",
|
"pytz",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
|
"gunicorn",
|
||||||
|
# Document processing dependencies (required for API document upload functionality)
|
||||||
|
"openpyxl>=3.0.0,<4.0.0", # XLSX processing
|
||||||
|
"pycryptodome>=3.0.0,<4.0.0", # PDF encryption support
|
||||||
|
"pypdf>=6.1.0", # PDF processing
|
||||||
|
"python-docx>=0.8.11,<2.0.0", # DOCX processing
|
||||||
|
"python-pptx>=0.6.21,<2.0.0", # PPTX processing
|
||||||
|
]
|
||||||
|
|
||||||
|
# Advanced document processing engine (optional)
|
||||||
|
docling = [
|
||||||
|
# On macOS, pytorch and frameworks use Objective-C are not fork-safe,
|
||||||
|
# and not compatible to gunicorn multi-worker mode
|
||||||
|
"docling>=2.0.0,<3.0.0; sys_platform != 'darwin'",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Offline deployment dependencies (layered design for flexibility)
|
# Offline deployment dependencies (layered design for flexibility)
|
||||||
offline-docs = [
|
|
||||||
# Document processing dependencies
|
|
||||||
"openpyxl>=3.0.0,<4.0.0",
|
|
||||||
"pycryptodome>=3.0.0,<4.0.0",
|
|
||||||
"pypdf>=6.1.0",
|
|
||||||
"python-docx>=0.8.11,<2.0.0",
|
|
||||||
"python-pptx>=0.6.21,<2.0.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
offline-storage = [
|
offline-storage = [
|
||||||
# Storage backend dependencies
|
# Storage backend dependencies
|
||||||
"redis>=5.0.0,<8.0.0",
|
"redis>=5.0.0,<8.0.0",
|
||||||
|
|
@ -115,8 +120,8 @@ offline-llm = [
|
||||||
]
|
]
|
||||||
|
|
||||||
offline = [
|
offline = [
|
||||||
# Complete offline package (includes all offline dependencies)
|
# Complete offline package (includes api for document processing, plus storage and LLM)
|
||||||
"lightrag-hku[offline-docs,offline-storage,offline-llm]",
|
"lightrag-hku[api,offline-storage,offline-llm]",
|
||||||
]
|
]
|
||||||
|
|
||||||
evaluation = [
|
evaluation = [
|
||||||
|
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
# LightRAG Offline Dependencies - Document Processing
|
|
||||||
# Install with: pip install -r requirements-offline-docs.txt
|
|
||||||
# For offline installation:
|
|
||||||
# pip download -r requirements-offline-docs.txt -d ./packages
|
|
||||||
# pip install --no-index --find-links=./packages -r requirements-offline-docs.txt
|
|
||||||
#
|
|
||||||
# Recommended: Use pip install lightrag-hku[offline-docs] for the same effect
|
|
||||||
# Or use constraints: pip install --constraint constraints-offline.txt -r requirements-offline-docs.txt
|
|
||||||
|
|
||||||
# Document processing dependencies (with version constraints matching pyproject.toml)
|
|
||||||
openpyxl>=3.0.0,<4.0.0
|
|
||||||
pycryptodome>=3.0.0,<4.0.0
|
|
||||||
pypdf>=6.1.0
|
|
||||||
python-docx>=0.8.11,<2.0.0
|
|
||||||
python-pptx>=0.6.21,<2.0.0
|
|
||||||
387
tests/test_write_json_optimization.py
Normal file
387
tests/test_write_json_optimization.py
Normal file
|
|
@ -0,0 +1,387 @@
|
||||||
|
"""
|
||||||
|
Test suite for write_json optimization
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
1. Fast path works for clean data (no sanitization)
|
||||||
|
2. Slow path applies sanitization for dirty data
|
||||||
|
3. Sanitization is done during encoding (memory-efficient)
|
||||||
|
4. Reloading updates shared memory with cleaned data
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from lightrag.utils import write_json, load_json, SanitizingJSONEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class TestWriteJsonOptimization:
|
||||||
|
"""Test write_json optimization with two-stage approach"""
|
||||||
|
|
||||||
|
def test_fast_path_clean_data(self):
|
||||||
|
"""Test that clean data takes the fast path without sanitization"""
|
||||||
|
clean_data = {
|
||||||
|
"name": "John Doe",
|
||||||
|
"age": 30,
|
||||||
|
"items": ["apple", "banana", "cherry"],
|
||||||
|
"nested": {"key": "value", "number": 42},
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write clean data - should return False (no sanitization)
|
||||||
|
needs_reload = write_json(clean_data, temp_file)
|
||||||
|
assert not needs_reload, "Clean data should not require sanitization"
|
||||||
|
|
||||||
|
# Verify data was written correctly
|
||||||
|
loaded_data = load_json(temp_file)
|
||||||
|
assert loaded_data == clean_data, "Loaded data should match original"
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_slow_path_dirty_data(self):
|
||||||
|
"""Test that dirty data triggers sanitization"""
|
||||||
|
# Create data with surrogate characters (U+D800 to U+DFFF)
|
||||||
|
dirty_string = "Hello\ud800World" # Contains surrogate character
|
||||||
|
dirty_data = {"text": dirty_string, "number": 123}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write dirty data - should return True (sanitization applied)
|
||||||
|
needs_reload = write_json(dirty_data, temp_file)
|
||||||
|
assert needs_reload, "Dirty data should trigger sanitization"
|
||||||
|
|
||||||
|
# Verify data was written and sanitized
|
||||||
|
loaded_data = load_json(temp_file)
|
||||||
|
assert loaded_data is not None, "Data should be written"
|
||||||
|
assert loaded_data["number"] == 123, "Clean fields should remain unchanged"
|
||||||
|
# Surrogate character should be removed
|
||||||
|
assert (
|
||||||
|
"\ud800" not in loaded_data["text"]
|
||||||
|
), "Surrogate character should be removed"
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_sanitizing_encoder_removes_surrogates(self):
|
||||||
|
"""Test that SanitizingJSONEncoder removes surrogate characters"""
|
||||||
|
data_with_surrogates = {
|
||||||
|
"text": "Hello\ud800\udc00World", # Contains surrogate pair
|
||||||
|
"clean": "Clean text",
|
||||||
|
"nested": {"dirty_key\ud801": "value", "clean_key": "clean\ud802value"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Encode using custom encoder
|
||||||
|
encoded = json.dumps(
|
||||||
|
data_with_surrogates, cls=SanitizingJSONEncoder, ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no surrogate characters in output
|
||||||
|
assert "\ud800" not in encoded, "Surrogate U+D800 should be removed"
|
||||||
|
assert "\udc00" not in encoded, "Surrogate U+DC00 should be removed"
|
||||||
|
assert "\ud801" not in encoded, "Surrogate U+D801 should be removed"
|
||||||
|
assert "\ud802" not in encoded, "Surrogate U+D802 should be removed"
|
||||||
|
|
||||||
|
# Verify clean parts remain
|
||||||
|
assert "Clean text" in encoded, "Clean text should remain"
|
||||||
|
assert "clean_key" in encoded, "Clean keys should remain"
|
||||||
|
|
||||||
|
def test_nested_structure_sanitization(self):
|
||||||
|
"""Test sanitization of deeply nested structures"""
|
||||||
|
nested_data = {
|
||||||
|
"level1": {
|
||||||
|
"level2": {
|
||||||
|
"level3": {"dirty": "text\ud800here", "clean": "normal text"},
|
||||||
|
"list": ["item1", "item\ud801dirty", "item3"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
needs_reload = write_json(nested_data, temp_file)
|
||||||
|
assert needs_reload, "Nested dirty data should trigger sanitization"
|
||||||
|
|
||||||
|
# Verify nested structure is preserved
|
||||||
|
loaded_data = load_json(temp_file)
|
||||||
|
assert "level1" in loaded_data
|
||||||
|
assert "level2" in loaded_data["level1"]
|
||||||
|
assert "level3" in loaded_data["level1"]["level2"]
|
||||||
|
|
||||||
|
# Verify surrogates are removed
|
||||||
|
dirty_text = loaded_data["level1"]["level2"]["level3"]["dirty"]
|
||||||
|
assert "\ud800" not in dirty_text, "Nested surrogate should be removed"
|
||||||
|
|
||||||
|
# Verify list items are sanitized
|
||||||
|
list_items = loaded_data["level1"]["level2"]["list"]
|
||||||
|
assert (
|
||||||
|
"\ud801" not in list_items[1]
|
||||||
|
), "List item surrogates should be removed"
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_unicode_non_characters_removed(self):
|
||||||
|
"""Test that Unicode non-characters (U+FFFE, U+FFFF) don't cause encoding errors
|
||||||
|
|
||||||
|
Note: U+FFFE and U+FFFF are valid UTF-8 characters (though discouraged),
|
||||||
|
so they don't trigger sanitization. They only get removed when explicitly
|
||||||
|
using the SanitizingJSONEncoder.
|
||||||
|
"""
|
||||||
|
data_with_nonchars = {"text1": "Hello\ufffeWorld", "text2": "Test\uffffString"}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# These characters are valid UTF-8, so they take the fast path
|
||||||
|
needs_reload = write_json(data_with_nonchars, temp_file)
|
||||||
|
assert not needs_reload, "U+FFFE/U+FFFF are valid UTF-8 characters"
|
||||||
|
|
||||||
|
loaded_data = load_json(temp_file)
|
||||||
|
# They're written as-is in the fast path
|
||||||
|
assert loaded_data == data_with_nonchars
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_mixed_clean_dirty_data(self):
|
||||||
|
"""Test data with both clean and dirty fields"""
|
||||||
|
mixed_data = {
|
||||||
|
"clean_field": "This is perfectly fine",
|
||||||
|
"dirty_field": "This has\ud800issues",
|
||||||
|
"number": 42,
|
||||||
|
"boolean": True,
|
||||||
|
"null_value": None,
|
||||||
|
"clean_list": [1, 2, 3],
|
||||||
|
"dirty_list": ["clean", "dirty\ud801item"],
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
needs_reload = write_json(mixed_data, temp_file)
|
||||||
|
assert (
|
||||||
|
needs_reload
|
||||||
|
), "Mixed data with dirty fields should trigger sanitization"
|
||||||
|
|
||||||
|
loaded_data = load_json(temp_file)
|
||||||
|
|
||||||
|
# Clean fields should remain unchanged
|
||||||
|
assert loaded_data["clean_field"] == "This is perfectly fine"
|
||||||
|
assert loaded_data["number"] == 42
|
||||||
|
assert loaded_data["boolean"]
|
||||||
|
assert loaded_data["null_value"] is None
|
||||||
|
assert loaded_data["clean_list"] == [1, 2, 3]
|
||||||
|
|
||||||
|
# Dirty fields should be sanitized
|
||||||
|
assert "\ud800" not in loaded_data["dirty_field"]
|
||||||
|
assert "\ud801" not in loaded_data["dirty_list"][1]
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_empty_and_none_strings(self):
|
||||||
|
"""Test handling of empty and None values"""
|
||||||
|
data = {
|
||||||
|
"empty": "",
|
||||||
|
"none": None,
|
||||||
|
"zero": 0,
|
||||||
|
"false": False,
|
||||||
|
"empty_list": [],
|
||||||
|
"empty_dict": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
needs_reload = write_json(data, temp_file)
|
||||||
|
assert (
|
||||||
|
not needs_reload
|
||||||
|
), "Clean empty values should not trigger sanitization"
|
||||||
|
|
||||||
|
loaded_data = load_json(temp_file)
|
||||||
|
assert loaded_data == data, "Empty/None values should be preserved"
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_specific_surrogate_udc9a(self):
|
||||||
|
"""Test specific surrogate character \\udc9a mentioned in the issue"""
|
||||||
|
# Test the exact surrogate character from the error message:
|
||||||
|
# UnicodeEncodeError: 'utf-8' codec can't encode character '\\udc9a'
|
||||||
|
data_with_udc9a = {
|
||||||
|
"text": "Some text with surrogate\udc9acharacter",
|
||||||
|
"position": 201, # As mentioned in the error
|
||||||
|
"clean_field": "Normal text",
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write data - should trigger sanitization
|
||||||
|
needs_reload = write_json(data_with_udc9a, temp_file)
|
||||||
|
assert needs_reload, "Data with \\udc9a should trigger sanitization"
|
||||||
|
|
||||||
|
# Verify surrogate was removed
|
||||||
|
loaded_data = load_json(temp_file)
|
||||||
|
assert loaded_data is not None
|
||||||
|
assert "\udc9a" not in loaded_data["text"], "\\udc9a should be removed"
|
||||||
|
assert (
|
||||||
|
loaded_data["clean_field"] == "Normal text"
|
||||||
|
), "Clean fields should remain"
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_migration_with_surrogate_sanitization(self):
|
||||||
|
"""Test that migration process handles surrogate characters correctly
|
||||||
|
|
||||||
|
This test simulates the scenario where legacy cache contains surrogate
|
||||||
|
characters and ensures they are cleaned during migration.
|
||||||
|
"""
|
||||||
|
# Simulate legacy cache data with surrogate characters
|
||||||
|
legacy_data_with_surrogates = {
|
||||||
|
"cache_entry_1": {
|
||||||
|
"return": "Result with\ud800surrogate",
|
||||||
|
"cache_type": "extract",
|
||||||
|
"original_prompt": "Some\udc9aprompt",
|
||||||
|
},
|
||||||
|
"cache_entry_2": {
|
||||||
|
"return": "Clean result",
|
||||||
|
"cache_type": "query",
|
||||||
|
"original_prompt": "Clean prompt",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First write the dirty data directly (simulating legacy cache file)
|
||||||
|
# Use custom encoder to force write even with surrogates
|
||||||
|
with open(temp_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(
|
||||||
|
legacy_data_with_surrogates,
|
||||||
|
f,
|
||||||
|
cls=SanitizingJSONEncoder,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load and verify surrogates were cleaned during initial write
|
||||||
|
loaded_data = load_json(temp_file)
|
||||||
|
assert loaded_data is not None
|
||||||
|
|
||||||
|
# The data should be sanitized
|
||||||
|
assert (
|
||||||
|
"\ud800" not in loaded_data["cache_entry_1"]["return"]
|
||||||
|
), "Surrogate in return should be removed"
|
||||||
|
assert (
|
||||||
|
"\udc9a" not in loaded_data["cache_entry_1"]["original_prompt"]
|
||||||
|
), "Surrogate in prompt should be removed"
|
||||||
|
|
||||||
|
# Clean data should remain unchanged
|
||||||
|
assert (
|
||||||
|
loaded_data["cache_entry_2"]["return"] == "Clean result"
|
||||||
|
), "Clean data should remain"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_empty_values_after_sanitization(self):
|
||||||
|
"""Test that data with empty values after sanitization is properly handled
|
||||||
|
|
||||||
|
Critical edge case: When sanitization results in data with empty string values,
|
||||||
|
we must use 'if cleaned_data is not None' instead of 'if cleaned_data' to ensure
|
||||||
|
proper reload, since truthy check on dict depends on content, not just existence.
|
||||||
|
"""
|
||||||
|
# Create data where ALL values are only surrogate characters
|
||||||
|
all_dirty_data = {
|
||||||
|
"key1": "\ud800\udc00\ud801",
|
||||||
|
"key2": "\ud802\ud803",
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write dirty data - should trigger sanitization
|
||||||
|
needs_reload = write_json(all_dirty_data, temp_file)
|
||||||
|
assert needs_reload, "All-dirty data should trigger sanitization"
|
||||||
|
|
||||||
|
# Load the sanitized data
|
||||||
|
cleaned_data = load_json(temp_file)
|
||||||
|
|
||||||
|
# Critical assertions for the edge case
|
||||||
|
assert cleaned_data is not None, "Cleaned data should not be None"
|
||||||
|
# Sanitization removes surrogates but preserves keys with empty values
|
||||||
|
assert cleaned_data == {
|
||||||
|
"key1": "",
|
||||||
|
"key2": "",
|
||||||
|
}, "Surrogates should be removed, keys preserved"
|
||||||
|
# This dict is truthy because it has keys (even with empty values)
|
||||||
|
assert cleaned_data, "Dict with keys is truthy"
|
||||||
|
|
||||||
|
# Test the actual edge case: empty dict
|
||||||
|
empty_data = {}
|
||||||
|
needs_reload2 = write_json(empty_data, temp_file)
|
||||||
|
assert not needs_reload2, "Empty dict is clean"
|
||||||
|
|
||||||
|
reloaded_empty = load_json(temp_file)
|
||||||
|
assert reloaded_empty is not None, "Empty dict should not be None"
|
||||||
|
assert reloaded_empty == {}, "Empty dict should remain empty"
|
||||||
|
assert (
|
||||||
|
not reloaded_empty
|
||||||
|
), "Empty dict evaluates to False (the critical check)"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run tests
|
||||||
|
test = TestWriteJsonOptimization()
|
||||||
|
|
||||||
|
print("Running test_fast_path_clean_data...")
|
||||||
|
test.test_fast_path_clean_data()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_slow_path_dirty_data...")
|
||||||
|
test.test_slow_path_dirty_data()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_sanitizing_encoder_removes_surrogates...")
|
||||||
|
test.test_sanitizing_encoder_removes_surrogates()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_nested_structure_sanitization...")
|
||||||
|
test.test_nested_structure_sanitization()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_unicode_non_characters_removed...")
|
||||||
|
test.test_unicode_non_characters_removed()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_mixed_clean_dirty_data...")
|
||||||
|
test.test_mixed_clean_dirty_data()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_empty_and_none_strings...")
|
||||||
|
test.test_empty_and_none_strings()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_specific_surrogate_udc9a...")
|
||||||
|
test.test_specific_surrogate_udc9a()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_migration_with_surrogate_sanitization...")
|
||||||
|
test.test_migration_with_surrogate_sanitization()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("Running test_empty_values_after_sanitization...")
|
||||||
|
test.test_empty_values_after_sanitization()
|
||||||
|
print("✓ Passed")
|
||||||
|
|
||||||
|
print("\n✅ All tests passed!")
|
||||||
Loading…
Add table
Reference in a new issue