Format entire codebase with ruff and add type hints across all modules: - Apply ruff formatting to all Python files (121 files, 17K insertions) - Add type hints to function signatures throughout lightrag core and API - Update test suite with improved type annotations and docstrings - Add pyrightconfig.json for static type checking configuration - Create prompt_optimized.py and test_extraction_prompt_ab.py test files - Update ruff.toml and .gitignore for improved linting configuration - Standardize code style across examples, reproduce scripts, and utilities
332 lines
14 KiB
Python
332 lines
14 KiB
Python
"""
|
|
Utility functions for the LightRAG API.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
|
|
from ascii_colors import ASCIIColors
|
|
from fastapi import HTTPException, Request, Security, status
|
|
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
|
from starlette.status import HTTP_403_FORBIDDEN
|
|
|
|
from lightrag import __version__ as core_version
|
|
from lightrag.api import __api_version__ as api_version
|
|
from lightrag.constants import (
|
|
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
|
)
|
|
|
|
from .auth import auth_handler
|
|
from .config import get_env_value, global_args, ollama_server_infos
|
|
|
|
|
|
def check_env_file():
|
|
"""
|
|
Check if .env file exists and handle user confirmation if needed.
|
|
Returns True if should continue, False if should exit.
|
|
"""
|
|
if not os.path.exists('.env'):
|
|
warning_msg = 'Warning: Startup directory must contain .env file for multi-instance support.'
|
|
ASCIIColors.yellow(warning_msg)
|
|
|
|
# Check if running in interactive terminal
|
|
if sys.stdin.isatty():
|
|
response = input('Do you want to continue? (yes/no): ')
|
|
if response.lower() != 'yes':
|
|
ASCIIColors.red('Server startup cancelled')
|
|
return False
|
|
return True
|
|
|
|
|
|
# Get whitelist paths from global_args, only once during initialization
|
|
whitelist_paths = global_args.whitelist_paths.split(',')
|
|
|
|
# Pre-compile path matching patterns
|
|
whitelist_patterns: list[tuple[str, bool]] = []
|
|
for path in whitelist_paths:
|
|
path = path.strip()
|
|
if path:
|
|
# If path ends with /*, match all paths with that prefix
|
|
if path.endswith('/*'):
|
|
prefix = path[:-2]
|
|
whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match)
|
|
else:
|
|
whitelist_patterns.append((path, False)) # (exact_path, is_prefix_match)
|
|
|
|
# Global authentication configuration
|
|
auth_configured = bool(auth_handler.accounts)
|
|
|
|
|
|
def get_combined_auth_dependency(api_key: str | None = None):
|
|
"""
|
|
Create a combined authentication dependency that implements authentication logic
|
|
based on API key, OAuth2 token, and whitelist paths.
|
|
|
|
Args:
|
|
api_key (Optional[str]): API key for validation
|
|
|
|
Returns:
|
|
Callable: A dependency function that implements the authentication logic
|
|
"""
|
|
# Use global whitelist_patterns and auth_configured variables
|
|
# whitelist_patterns and auth_configured are already initialized at module level
|
|
|
|
# Only calculate api_key_configured as it depends on the function parameter
|
|
api_key_configured = bool(api_key)
|
|
|
|
# Create security dependencies with proper descriptions for Swagger UI
|
|
oauth2_scheme = OAuth2PasswordBearer(
|
|
tokenUrl='login', auto_error=False, description='OAuth2 Password Authentication'
|
|
)
|
|
|
|
# If API key is configured, create an API key header security
|
|
api_key_header = None
|
|
if api_key_configured:
|
|
api_key_header = APIKeyHeader(name='X-API-Key', auto_error=False, description='API Key Authentication')
|
|
|
|
async def combined_dependency(
|
|
request: Request,
|
|
token: str = Security(oauth2_scheme),
|
|
api_key_header_value: str | None = None if api_key_header is None else Security(api_key_header),
|
|
):
|
|
# 1. Check if path is in whitelist
|
|
path = request.url.path
|
|
for pattern, is_prefix in whitelist_patterns:
|
|
if (is_prefix and path.startswith(pattern)) or (not is_prefix and path == pattern):
|
|
return # Whitelist path, allow access
|
|
|
|
# 2. Validate token first if provided in the request (Ensure 401 error if token is invalid)
|
|
if token:
|
|
try:
|
|
token_info = auth_handler.validate_token(token)
|
|
# Accept guest token if no auth is configured
|
|
if not auth_configured and token_info.get('role') == 'guest':
|
|
return
|
|
# Accept non-guest token if auth is configured
|
|
if auth_configured and token_info.get('role') != 'guest':
|
|
return
|
|
|
|
# Token validation failed, immediately return 401 error
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail='Invalid token. Please login again.',
|
|
)
|
|
except HTTPException as e:
|
|
# If already a 401 error, re-raise it
|
|
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
|
raise
|
|
# For other exceptions, continue processing
|
|
|
|
# 3. Acept all request if no API protection needed
|
|
if not auth_configured and not api_key_configured:
|
|
return
|
|
|
|
# 4. Validate API key if provided and API-Key authentication is configured
|
|
if api_key_configured and api_key_header_value and api_key_header_value == api_key:
|
|
return # API key validation successful
|
|
|
|
### Authentication failed ####
|
|
|
|
# if password authentication is configured but not provided, ensure 401 error if auth_configured
|
|
if auth_configured and not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail='No credentials provided. Please login.',
|
|
)
|
|
|
|
# if api key is provided but validation failed
|
|
if api_key_header_value:
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN,
|
|
detail='Invalid API Key',
|
|
)
|
|
|
|
# if api_key_configured but not provided
|
|
if api_key_configured and not api_key_header_value:
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN,
|
|
detail='API Key required',
|
|
)
|
|
|
|
# Otherwise: refuse access and return 403 error
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN,
|
|
detail='API Key required or login authentication required.',
|
|
)
|
|
|
|
return combined_dependency
|
|
|
|
|
|
def display_splash_screen(args: argparse.Namespace) -> None:
|
|
"""
|
|
Display a colorful splash screen showing LightRAG server configuration
|
|
|
|
Args:
|
|
args: Parsed command line arguments
|
|
"""
|
|
# Banner
|
|
# Banner
|
|
top_border = '╔══════════════════════════════════════════════════════════════╗'
|
|
bottom_border = '╚══════════════════════════════════════════════════════════════╝'
|
|
width = len(top_border) - 4 # width inside the borders
|
|
|
|
line1_text = f'LightRAG Server v{core_version}/{api_version}'
|
|
line2_text = 'Fast, Lightweight RAG Server Implementation'
|
|
|
|
line1 = f'║ {line1_text.center(width)} ║'
|
|
line2 = f'║ {line2_text.center(width)} ║'
|
|
|
|
banner = f"""
|
|
{top_border}
|
|
{line1}
|
|
{line2}
|
|
{bottom_border}
|
|
"""
|
|
ASCIIColors.cyan(banner)
|
|
|
|
# Server Configuration
|
|
ASCIIColors.magenta('\n📡 Server Configuration:')
|
|
ASCIIColors.white(' ├─ Host: ', end='')
|
|
ASCIIColors.yellow(f'{args.host}')
|
|
ASCIIColors.white(' ├─ Port: ', end='')
|
|
ASCIIColors.yellow(f'{args.port}')
|
|
ASCIIColors.white(' ├─ Workers: ', end='')
|
|
ASCIIColors.yellow(f'{args.workers}')
|
|
ASCIIColors.white(' ├─ Timeout: ', end='')
|
|
ASCIIColors.yellow(f'{args.timeout}')
|
|
ASCIIColors.white(' ├─ CORS Origins: ', end='')
|
|
ASCIIColors.yellow(f'{args.cors_origins}')
|
|
ASCIIColors.white(' ├─ SSL Enabled: ', end='')
|
|
ASCIIColors.yellow(f'{args.ssl}')
|
|
if args.ssl:
|
|
ASCIIColors.white(' ├─ SSL Cert: ', end='')
|
|
ASCIIColors.yellow(f'{args.ssl_certfile}')
|
|
ASCIIColors.white(' ├─ SSL Key: ', end='')
|
|
ASCIIColors.yellow(f'{args.ssl_keyfile}')
|
|
ASCIIColors.white(' ├─ Ollama Emulating Model: ', end='')
|
|
ASCIIColors.yellow(f'{ollama_server_infos.LIGHTRAG_MODEL}')
|
|
ASCIIColors.white(' ├─ Log Level: ', end='')
|
|
ASCIIColors.yellow(f'{args.log_level}')
|
|
ASCIIColors.white(' ├─ Verbose Debug: ', end='')
|
|
ASCIIColors.yellow(f'{args.verbose}')
|
|
ASCIIColors.white(' ├─ API Key: ', end='')
|
|
ASCIIColors.yellow('Set' if args.key else 'Not Set')
|
|
ASCIIColors.white(' └─ JWT Auth: ', end='')
|
|
ASCIIColors.yellow('Enabled' if args.auth_accounts else 'Disabled')
|
|
|
|
# Directory Configuration
|
|
ASCIIColors.magenta('\n📂 Directory Configuration:')
|
|
ASCIIColors.white(' ├─ Working Directory: ', end='')
|
|
ASCIIColors.yellow(f'{args.working_dir}')
|
|
ASCIIColors.white(' └─ Input Directory: ', end='')
|
|
ASCIIColors.yellow(f'{args.input_dir}')
|
|
|
|
# LLM Configuration
|
|
ASCIIColors.magenta('\n🤖 LLM Configuration:')
|
|
ASCIIColors.white(' ├─ Binding: ', end='')
|
|
ASCIIColors.yellow(f'{args.llm_binding}')
|
|
ASCIIColors.white(' ├─ Host: ', end='')
|
|
ASCIIColors.yellow(f'{args.llm_binding_host}')
|
|
ASCIIColors.white(' ├─ Model: ', end='')
|
|
ASCIIColors.yellow(f'{args.llm_model}')
|
|
ASCIIColors.white(' ├─ Max Async for LLM: ', end='')
|
|
ASCIIColors.yellow(f'{args.max_async}')
|
|
ASCIIColors.white(' ├─ Summary Context Size: ', end='')
|
|
ASCIIColors.yellow(f'{args.summary_context_size}')
|
|
ASCIIColors.white(' ├─ LLM Cache Enabled: ', end='')
|
|
ASCIIColors.yellow(f'{args.enable_llm_cache}')
|
|
ASCIIColors.white(' └─ LLM Cache for Extraction Enabled: ', end='')
|
|
ASCIIColors.yellow(f'{args.enable_llm_cache_for_extract}')
|
|
|
|
# Embedding Configuration
|
|
ASCIIColors.magenta('\n📊 Embedding Configuration:')
|
|
ASCIIColors.white(' ├─ Binding: ', end='')
|
|
ASCIIColors.yellow(f'{args.embedding_binding}')
|
|
ASCIIColors.white(' ├─ Host: ', end='')
|
|
ASCIIColors.yellow(f'{args.embedding_binding_host}')
|
|
ASCIIColors.white(' ├─ Model: ', end='')
|
|
ASCIIColors.yellow(f'{args.embedding_model}')
|
|
ASCIIColors.white(' └─ Dimensions: ', end='')
|
|
ASCIIColors.yellow(f'{args.embedding_dim}')
|
|
|
|
# RAG Configuration
|
|
ASCIIColors.magenta('\n⚙️ RAG Configuration:')
|
|
ASCIIColors.white(' ├─ Summary Language: ', end='')
|
|
ASCIIColors.yellow(f'{args.summary_language}')
|
|
ASCIIColors.white(' ├─ Entity Types: ', end='')
|
|
ASCIIColors.yellow(f'{args.entity_types}')
|
|
ASCIIColors.white(' ├─ Max Parallel Insert: ', end='')
|
|
ASCIIColors.yellow(f'{args.max_parallel_insert}')
|
|
ASCIIColors.white(' ├─ Chunk Size: ', end='')
|
|
ASCIIColors.yellow(f'{args.chunk_size}')
|
|
ASCIIColors.white(' ├─ Chunk Overlap Size: ', end='')
|
|
ASCIIColors.yellow(f'{args.chunk_overlap_size}')
|
|
ASCIIColors.white(' ├─ Cosine Threshold: ', end='')
|
|
ASCIIColors.yellow(f'{args.cosine_threshold}')
|
|
ASCIIColors.white(' ├─ Top-K: ', end='')
|
|
ASCIIColors.yellow(f'{args.top_k}')
|
|
ASCIIColors.white(' └─ Force LLM Summary on Merge: ', end='')
|
|
ASCIIColors.yellow(f'{get_env_value("FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int)}')
|
|
|
|
# System Configuration
|
|
ASCIIColors.magenta('\n💾 Storage Configuration:')
|
|
ASCIIColors.white(' ├─ KV Storage: ', end='')
|
|
ASCIIColors.yellow(f'{args.kv_storage}')
|
|
ASCIIColors.white(' ├─ Vector Storage: ', end='')
|
|
ASCIIColors.yellow(f'{args.vector_storage}')
|
|
ASCIIColors.white(' ├─ Graph Storage: ', end='')
|
|
ASCIIColors.yellow(f'{args.graph_storage}')
|
|
ASCIIColors.white(' ├─ Document Status Storage: ', end='')
|
|
ASCIIColors.yellow(f'{args.doc_status_storage}')
|
|
ASCIIColors.white(' └─ Workspace: ', end='')
|
|
ASCIIColors.yellow(f'{args.workspace if args.workspace else "-"}')
|
|
|
|
# Server Status
|
|
ASCIIColors.green('\n✨ Server starting up...\n')
|
|
|
|
# Server Access Information
|
|
protocol = 'https' if args.ssl else 'http'
|
|
if args.host == '0.0.0.0':
|
|
ASCIIColors.magenta('\n🌐 Server Access Information:')
|
|
ASCIIColors.white(' ├─ WebUI (local): ', end='')
|
|
ASCIIColors.yellow(f'{protocol}://localhost:{args.port}')
|
|
ASCIIColors.white(' ├─ Remote Access: ', end='')
|
|
ASCIIColors.yellow(f'{protocol}://<your-ip-address>:{args.port}')
|
|
ASCIIColors.white(' ├─ API Documentation (local): ', end='')
|
|
ASCIIColors.yellow(f'{protocol}://localhost:{args.port}/docs')
|
|
ASCIIColors.white(' └─ Alternative Documentation (local): ', end='')
|
|
ASCIIColors.yellow(f'{protocol}://localhost:{args.port}/redoc')
|
|
|
|
ASCIIColors.magenta('\n📝 Note:')
|
|
ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
|
|
- Use 'localhost' or '127.0.0.1' for local access
|
|
- Use your machine's IP address for remote access
|
|
- To find your IP address:
|
|
• Windows: Run 'ipconfig' in terminal
|
|
• Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal
|
|
""")
|
|
else:
|
|
base_url = f'{protocol}://{args.host}:{args.port}'
|
|
ASCIIColors.magenta('\n🌐 Server Access Information:')
|
|
ASCIIColors.white(' ├─ WebUI (local): ', end='')
|
|
ASCIIColors.yellow(f'{base_url}')
|
|
ASCIIColors.white(' ├─ API Documentation: ', end='')
|
|
ASCIIColors.yellow(f'{base_url}/docs')
|
|
ASCIIColors.white(' └─ Alternative Documentation: ', end='')
|
|
ASCIIColors.yellow(f'{base_url}/redoc')
|
|
|
|
# Security Notice
|
|
if args.key:
|
|
ASCIIColors.yellow('\n⚠️ Security Notice:')
|
|
ASCIIColors.white(""" API Key authentication is enabled.
|
|
Make sure to include the X-API-Key header in all your requests.
|
|
""")
|
|
if args.auth_accounts:
|
|
ASCIIColors.yellow('\n⚠️ Security Notice:')
|
|
ASCIIColors.white(""" JWT authentication is enabled.
|
|
Make sure to login before making the request, and include the 'Authorization' in the header.
|
|
""")
|
|
|
|
# Ensure splash output flush to system log
|
|
sys.stdout.flush()
|