""" Utility functions for the LightRAG API. """ import argparse import os import sys from ascii_colors import ASCIIColors from fastapi import HTTPException, Request, Security, status from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from starlette.status import HTTP_403_FORBIDDEN from lightrag import __version__ as core_version from lightrag.api import __api_version__ as api_version from lightrag.constants import ( DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, ) from .auth import auth_handler from .config import get_env_value, global_args, ollama_server_infos def check_env_file(): """ Check if .env file exists and handle user confirmation if needed. Returns True if should continue, False if should exit. """ if not os.path.exists('.env'): warning_msg = 'Warning: Startup directory must contain .env file for multi-instance support.' ASCIIColors.yellow(warning_msg) # Check if running in interactive terminal if sys.stdin.isatty(): response = input('Do you want to continue? (yes/no): ') if response.lower() != 'yes': ASCIIColors.red('Server startup cancelled') return False return True # Get whitelist paths from global_args, only once during initialization whitelist_paths = global_args.whitelist_paths.split(',') # Pre-compile path matching patterns whitelist_patterns: list[tuple[str, bool]] = [] for path in whitelist_paths: path = path.strip() if path: # If path ends with /*, match all paths with that prefix if path.endswith('/*'): prefix = path[:-2] whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match) else: whitelist_patterns.append((path, False)) # (exact_path, is_prefix_match) # Global authentication configuration auth_configured = bool(auth_handler.accounts) def get_combined_auth_dependency(api_key: str | None = None): """ Create a combined authentication dependency that implements authentication logic based on API key, OAuth2 token, and whitelist paths. Args: api_key (Optional[str]): API key for validation Returns: Callable: A dependency function that implements the authentication logic """ # Use global whitelist_patterns and auth_configured variables # whitelist_patterns and auth_configured are already initialized at module level # Only calculate api_key_configured as it depends on the function parameter api_key_configured = bool(api_key) # Create security dependencies with proper descriptions for Swagger UI oauth2_scheme = OAuth2PasswordBearer( tokenUrl='login', auto_error=False, description='OAuth2 Password Authentication' ) # If API key is configured, create an API key header security api_key_header = None if api_key_configured: api_key_header = APIKeyHeader(name='X-API-Key', auto_error=False, description='API Key Authentication') async def combined_dependency( request: Request, token: str = Security(oauth2_scheme), api_key_header_value: str | None = None if api_key_header is None else Security(api_key_header), ): # 1. Check if path is in whitelist path = request.url.path for pattern, is_prefix in whitelist_patterns: if (is_prefix and path.startswith(pattern)) or (not is_prefix and path == pattern): return # Whitelist path, allow access # 2. Validate token first if provided in the request (Ensure 401 error if token is invalid) if token: try: token_info = auth_handler.validate_token(token) # Accept guest token if no auth is configured if not auth_configured and token_info.get('role') == 'guest': return # Accept non-guest token if auth is configured if auth_configured and token_info.get('role') != 'guest': return # Token validation failed, immediately return 401 error raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail='Invalid token. Please login again.', ) except HTTPException as e: # If already a 401 error, re-raise it if e.status_code == status.HTTP_401_UNAUTHORIZED: raise # For other exceptions, continue processing # 3. Acept all request if no API protection needed if not auth_configured and not api_key_configured: return # 4. Validate API key if provided and API-Key authentication is configured if api_key_configured and api_key_header_value and api_key_header_value == api_key: return # API key validation successful ### Authentication failed #### # if password authentication is configured but not provided, ensure 401 error if auth_configured if auth_configured and not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail='No credentials provided. Please login.', ) # if api key is provided but validation failed if api_key_header_value: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail='Invalid API Key', ) # if api_key_configured but not provided if api_key_configured and not api_key_header_value: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail='API Key required', ) # Otherwise: refuse access and return 403 error raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail='API Key required or login authentication required.', ) return combined_dependency def display_splash_screen(args: argparse.Namespace) -> None: """ Display a colorful splash screen showing LightRAG server configuration Args: args: Parsed command line arguments """ # Banner # Banner top_border = '╔══════════════════════════════════════════════════════════════╗' bottom_border = '╚══════════════════════════════════════════════════════════════╝' width = len(top_border) - 4 # width inside the borders line1_text = f'LightRAG Server v{core_version}/{api_version}' line2_text = 'Fast, Lightweight RAG Server Implementation' line1 = f'║ {line1_text.center(width)} ║' line2 = f'║ {line2_text.center(width)} ║' banner = f""" {top_border} {line1} {line2} {bottom_border} """ ASCIIColors.cyan(banner) # Server Configuration ASCIIColors.magenta('\n📡 Server Configuration:') ASCIIColors.white(' ├─ Host: ', end='') ASCIIColors.yellow(f'{args.host}') ASCIIColors.white(' ├─ Port: ', end='') ASCIIColors.yellow(f'{args.port}') ASCIIColors.white(' ├─ Workers: ', end='') ASCIIColors.yellow(f'{args.workers}') ASCIIColors.white(' ├─ Timeout: ', end='') ASCIIColors.yellow(f'{args.timeout}') ASCIIColors.white(' ├─ CORS Origins: ', end='') ASCIIColors.yellow(f'{args.cors_origins}') ASCIIColors.white(' ├─ SSL Enabled: ', end='') ASCIIColors.yellow(f'{args.ssl}') if args.ssl: ASCIIColors.white(' ├─ SSL Cert: ', end='') ASCIIColors.yellow(f'{args.ssl_certfile}') ASCIIColors.white(' ├─ SSL Key: ', end='') ASCIIColors.yellow(f'{args.ssl_keyfile}') ASCIIColors.white(' ├─ Ollama Emulating Model: ', end='') ASCIIColors.yellow(f'{ollama_server_infos.LIGHTRAG_MODEL}') ASCIIColors.white(' ├─ Log Level: ', end='') ASCIIColors.yellow(f'{args.log_level}') ASCIIColors.white(' ├─ Verbose Debug: ', end='') ASCIIColors.yellow(f'{args.verbose}') ASCIIColors.white(' ├─ API Key: ', end='') ASCIIColors.yellow('Set' if args.key else 'Not Set') ASCIIColors.white(' └─ JWT Auth: ', end='') ASCIIColors.yellow('Enabled' if args.auth_accounts else 'Disabled') # Directory Configuration ASCIIColors.magenta('\n📂 Directory Configuration:') ASCIIColors.white(' ├─ Working Directory: ', end='') ASCIIColors.yellow(f'{args.working_dir}') ASCIIColors.white(' └─ Input Directory: ', end='') ASCIIColors.yellow(f'{args.input_dir}') # LLM Configuration ASCIIColors.magenta('\n🤖 LLM Configuration:') ASCIIColors.white(' ├─ Binding: ', end='') ASCIIColors.yellow(f'{args.llm_binding}') ASCIIColors.white(' ├─ Host: ', end='') ASCIIColors.yellow(f'{args.llm_binding_host}') ASCIIColors.white(' ├─ Model: ', end='') ASCIIColors.yellow(f'{args.llm_model}') ASCIIColors.white(' ├─ Max Async for LLM: ', end='') ASCIIColors.yellow(f'{args.max_async}') ASCIIColors.white(' ├─ Summary Context Size: ', end='') ASCIIColors.yellow(f'{args.summary_context_size}') ASCIIColors.white(' ├─ LLM Cache Enabled: ', end='') ASCIIColors.yellow(f'{args.enable_llm_cache}') ASCIIColors.white(' └─ LLM Cache for Extraction Enabled: ', end='') ASCIIColors.yellow(f'{args.enable_llm_cache_for_extract}') # Embedding Configuration ASCIIColors.magenta('\n📊 Embedding Configuration:') ASCIIColors.white(' ├─ Binding: ', end='') ASCIIColors.yellow(f'{args.embedding_binding}') ASCIIColors.white(' ├─ Host: ', end='') ASCIIColors.yellow(f'{args.embedding_binding_host}') ASCIIColors.white(' ├─ Model: ', end='') ASCIIColors.yellow(f'{args.embedding_model}') ASCIIColors.white(' └─ Dimensions: ', end='') ASCIIColors.yellow(f'{args.embedding_dim}') # RAG Configuration ASCIIColors.magenta('\n⚙️ RAG Configuration:') ASCIIColors.white(' ├─ Summary Language: ', end='') ASCIIColors.yellow(f'{args.summary_language}') ASCIIColors.white(' ├─ Entity Types: ', end='') ASCIIColors.yellow(f'{args.entity_types}') ASCIIColors.white(' ├─ Max Parallel Insert: ', end='') ASCIIColors.yellow(f'{args.max_parallel_insert}') ASCIIColors.white(' ├─ Chunk Size: ', end='') ASCIIColors.yellow(f'{args.chunk_size}') ASCIIColors.white(' ├─ Chunk Overlap Size: ', end='') ASCIIColors.yellow(f'{args.chunk_overlap_size}') ASCIIColors.white(' ├─ Cosine Threshold: ', end='') ASCIIColors.yellow(f'{args.cosine_threshold}') ASCIIColors.white(' ├─ Top-K: ', end='') ASCIIColors.yellow(f'{args.top_k}') ASCIIColors.white(' └─ Force LLM Summary on Merge: ', end='') ASCIIColors.yellow(f'{get_env_value("FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int)}') # System Configuration ASCIIColors.magenta('\n💾 Storage Configuration:') ASCIIColors.white(' ├─ KV Storage: ', end='') ASCIIColors.yellow(f'{args.kv_storage}') ASCIIColors.white(' ├─ Vector Storage: ', end='') ASCIIColors.yellow(f'{args.vector_storage}') ASCIIColors.white(' ├─ Graph Storage: ', end='') ASCIIColors.yellow(f'{args.graph_storage}') ASCIIColors.white(' ├─ Document Status Storage: ', end='') ASCIIColors.yellow(f'{args.doc_status_storage}') ASCIIColors.white(' └─ Workspace: ', end='') ASCIIColors.yellow(f'{args.workspace if args.workspace else "-"}') # Server Status ASCIIColors.green('\n✨ Server starting up...\n') # Server Access Information protocol = 'https' if args.ssl else 'http' if args.host == '0.0.0.0': ASCIIColors.magenta('\n🌐 Server Access Information:') ASCIIColors.white(' ├─ WebUI (local): ', end='') ASCIIColors.yellow(f'{protocol}://localhost:{args.port}') ASCIIColors.white(' ├─ Remote Access: ', end='') ASCIIColors.yellow(f'{protocol}://:{args.port}') ASCIIColors.white(' ├─ API Documentation (local): ', end='') ASCIIColors.yellow(f'{protocol}://localhost:{args.port}/docs') ASCIIColors.white(' └─ Alternative Documentation (local): ', end='') ASCIIColors.yellow(f'{protocol}://localhost:{args.port}/redoc') ASCIIColors.magenta('\n📝 Note:') ASCIIColors.cyan(""" Since the server is running on 0.0.0.0: - Use 'localhost' or '127.0.0.1' for local access - Use your machine's IP address for remote access - To find your IP address: • Windows: Run 'ipconfig' in terminal • Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal """) else: base_url = f'{protocol}://{args.host}:{args.port}' ASCIIColors.magenta('\n🌐 Server Access Information:') ASCIIColors.white(' ├─ WebUI (local): ', end='') ASCIIColors.yellow(f'{base_url}') ASCIIColors.white(' ├─ API Documentation: ', end='') ASCIIColors.yellow(f'{base_url}/docs') ASCIIColors.white(' └─ Alternative Documentation: ', end='') ASCIIColors.yellow(f'{base_url}/redoc') # Security Notice if args.key: ASCIIColors.yellow('\n⚠️ Security Notice:') ASCIIColors.white(""" API Key authentication is enabled. Make sure to include the X-API-Key header in all your requests. """) if args.auth_accounts: ASCIIColors.yellow('\n⚠️ Security Notice:') ASCIIColors.white(""" JWT authentication is enabled. Make sure to login before making the request, and include the 'Authorization' in the header. """) # Ensure splash output flush to system log sys.stdout.flush()