LightRAG/lightrag/api/utils_api.py
clssck 69358d830d test(lightrag,examples,api): comprehensive ruff formatting and type hints
Format entire codebase with ruff and add type hints across all modules:
- Apply ruff formatting to all Python files (121 files, 17K insertions)
- Add type hints to function signatures throughout lightrag core and API
- Update test suite with improved type annotations and docstrings
- Add pyrightconfig.json for static type checking configuration
- Create prompt_optimized.py and test_extraction_prompt_ab.py test files
- Update ruff.toml and .gitignore for improved linting configuration
- Standardize code style across examples, reproduce scripts, and utilities
2025-12-05 15:17:06 +01:00

332 lines
14 KiB
Python

"""
Utility functions for the LightRAG API.
"""
import argparse
import os
import sys
from ascii_colors import ASCIIColors
from fastapi import HTTPException, Request, Security, status
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN
from lightrag import __version__ as core_version
from lightrag.api import __api_version__ as api_version
from lightrag.constants import (
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
)
from .auth import auth_handler
from .config import get_env_value, global_args, ollama_server_infos
def check_env_file():
"""
Check if .env file exists and handle user confirmation if needed.
Returns True if should continue, False if should exit.
"""
if not os.path.exists('.env'):
warning_msg = 'Warning: Startup directory must contain .env file for multi-instance support.'
ASCIIColors.yellow(warning_msg)
# Check if running in interactive terminal
if sys.stdin.isatty():
response = input('Do you want to continue? (yes/no): ')
if response.lower() != 'yes':
ASCIIColors.red('Server startup cancelled')
return False
return True
# Get whitelist paths from global_args, only once during initialization
whitelist_paths = global_args.whitelist_paths.split(',')
# Pre-compile path matching patterns
whitelist_patterns: list[tuple[str, bool]] = []
for path in whitelist_paths:
path = path.strip()
if path:
# If path ends with /*, match all paths with that prefix
if path.endswith('/*'):
prefix = path[:-2]
whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match)
else:
whitelist_patterns.append((path, False)) # (exact_path, is_prefix_match)
# Global authentication configuration
auth_configured = bool(auth_handler.accounts)
def get_combined_auth_dependency(api_key: str | None = None):
"""
Create a combined authentication dependency that implements authentication logic
based on API key, OAuth2 token, and whitelist paths.
Args:
api_key (Optional[str]): API key for validation
Returns:
Callable: A dependency function that implements the authentication logic
"""
# Use global whitelist_patterns and auth_configured variables
# whitelist_patterns and auth_configured are already initialized at module level
# Only calculate api_key_configured as it depends on the function parameter
api_key_configured = bool(api_key)
# Create security dependencies with proper descriptions for Swagger UI
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl='login', auto_error=False, description='OAuth2 Password Authentication'
)
# If API key is configured, create an API key header security
api_key_header = None
if api_key_configured:
api_key_header = APIKeyHeader(name='X-API-Key', auto_error=False, description='API Key Authentication')
async def combined_dependency(
request: Request,
token: str = Security(oauth2_scheme),
api_key_header_value: str | None = None if api_key_header is None else Security(api_key_header),
):
# 1. Check if path is in whitelist
path = request.url.path
for pattern, is_prefix in whitelist_patterns:
if (is_prefix and path.startswith(pattern)) or (not is_prefix and path == pattern):
return # Whitelist path, allow access
# 2. Validate token first if provided in the request (Ensure 401 error if token is invalid)
if token:
try:
token_info = auth_handler.validate_token(token)
# Accept guest token if no auth is configured
if not auth_configured and token_info.get('role') == 'guest':
return
# Accept non-guest token if auth is configured
if auth_configured and token_info.get('role') != 'guest':
return
# Token validation failed, immediately return 401 error
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Invalid token. Please login again.',
)
except HTTPException as e:
# If already a 401 error, re-raise it
if e.status_code == status.HTTP_401_UNAUTHORIZED:
raise
# For other exceptions, continue processing
# 3. Acept all request if no API protection needed
if not auth_configured and not api_key_configured:
return
# 4. Validate API key if provided and API-Key authentication is configured
if api_key_configured and api_key_header_value and api_key_header_value == api_key:
return # API key validation successful
### Authentication failed ####
# if password authentication is configured but not provided, ensure 401 error if auth_configured
if auth_configured and not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='No credentials provided. Please login.',
)
# if api key is provided but validation failed
if api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail='Invalid API Key',
)
# if api_key_configured but not provided
if api_key_configured and not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail='API Key required',
)
# Otherwise: refuse access and return 403 error
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail='API Key required or login authentication required.',
)
return combined_dependency
def display_splash_screen(args: argparse.Namespace) -> None:
"""
Display a colorful splash screen showing LightRAG server configuration
Args:
args: Parsed command line arguments
"""
# Banner
# Banner
top_border = '╔══════════════════════════════════════════════════════════════╗'
bottom_border = '╚══════════════════════════════════════════════════════════════╝'
width = len(top_border) - 4 # width inside the borders
line1_text = f'LightRAG Server v{core_version}/{api_version}'
line2_text = 'Fast, Lightweight RAG Server Implementation'
line1 = f'{line1_text.center(width)}'
line2 = f'{line2_text.center(width)}'
banner = f"""
{top_border}
{line1}
{line2}
{bottom_border}
"""
ASCIIColors.cyan(banner)
# Server Configuration
ASCIIColors.magenta('\n📡 Server Configuration:')
ASCIIColors.white(' ├─ Host: ', end='')
ASCIIColors.yellow(f'{args.host}')
ASCIIColors.white(' ├─ Port: ', end='')
ASCIIColors.yellow(f'{args.port}')
ASCIIColors.white(' ├─ Workers: ', end='')
ASCIIColors.yellow(f'{args.workers}')
ASCIIColors.white(' ├─ Timeout: ', end='')
ASCIIColors.yellow(f'{args.timeout}')
ASCIIColors.white(' ├─ CORS Origins: ', end='')
ASCIIColors.yellow(f'{args.cors_origins}')
ASCIIColors.white(' ├─ SSL Enabled: ', end='')
ASCIIColors.yellow(f'{args.ssl}')
if args.ssl:
ASCIIColors.white(' ├─ SSL Cert: ', end='')
ASCIIColors.yellow(f'{args.ssl_certfile}')
ASCIIColors.white(' ├─ SSL Key: ', end='')
ASCIIColors.yellow(f'{args.ssl_keyfile}')
ASCIIColors.white(' ├─ Ollama Emulating Model: ', end='')
ASCIIColors.yellow(f'{ollama_server_infos.LIGHTRAG_MODEL}')
ASCIIColors.white(' ├─ Log Level: ', end='')
ASCIIColors.yellow(f'{args.log_level}')
ASCIIColors.white(' ├─ Verbose Debug: ', end='')
ASCIIColors.yellow(f'{args.verbose}')
ASCIIColors.white(' ├─ API Key: ', end='')
ASCIIColors.yellow('Set' if args.key else 'Not Set')
ASCIIColors.white(' └─ JWT Auth: ', end='')
ASCIIColors.yellow('Enabled' if args.auth_accounts else 'Disabled')
# Directory Configuration
ASCIIColors.magenta('\n📂 Directory Configuration:')
ASCIIColors.white(' ├─ Working Directory: ', end='')
ASCIIColors.yellow(f'{args.working_dir}')
ASCIIColors.white(' └─ Input Directory: ', end='')
ASCIIColors.yellow(f'{args.input_dir}')
# LLM Configuration
ASCIIColors.magenta('\n🤖 LLM Configuration:')
ASCIIColors.white(' ├─ Binding: ', end='')
ASCIIColors.yellow(f'{args.llm_binding}')
ASCIIColors.white(' ├─ Host: ', end='')
ASCIIColors.yellow(f'{args.llm_binding_host}')
ASCIIColors.white(' ├─ Model: ', end='')
ASCIIColors.yellow(f'{args.llm_model}')
ASCIIColors.white(' ├─ Max Async for LLM: ', end='')
ASCIIColors.yellow(f'{args.max_async}')
ASCIIColors.white(' ├─ Summary Context Size: ', end='')
ASCIIColors.yellow(f'{args.summary_context_size}')
ASCIIColors.white(' ├─ LLM Cache Enabled: ', end='')
ASCIIColors.yellow(f'{args.enable_llm_cache}')
ASCIIColors.white(' └─ LLM Cache for Extraction Enabled: ', end='')
ASCIIColors.yellow(f'{args.enable_llm_cache_for_extract}')
# Embedding Configuration
ASCIIColors.magenta('\n📊 Embedding Configuration:')
ASCIIColors.white(' ├─ Binding: ', end='')
ASCIIColors.yellow(f'{args.embedding_binding}')
ASCIIColors.white(' ├─ Host: ', end='')
ASCIIColors.yellow(f'{args.embedding_binding_host}')
ASCIIColors.white(' ├─ Model: ', end='')
ASCIIColors.yellow(f'{args.embedding_model}')
ASCIIColors.white(' └─ Dimensions: ', end='')
ASCIIColors.yellow(f'{args.embedding_dim}')
# RAG Configuration
ASCIIColors.magenta('\n⚙️ RAG Configuration:')
ASCIIColors.white(' ├─ Summary Language: ', end='')
ASCIIColors.yellow(f'{args.summary_language}')
ASCIIColors.white(' ├─ Entity Types: ', end='')
ASCIIColors.yellow(f'{args.entity_types}')
ASCIIColors.white(' ├─ Max Parallel Insert: ', end='')
ASCIIColors.yellow(f'{args.max_parallel_insert}')
ASCIIColors.white(' ├─ Chunk Size: ', end='')
ASCIIColors.yellow(f'{args.chunk_size}')
ASCIIColors.white(' ├─ Chunk Overlap Size: ', end='')
ASCIIColors.yellow(f'{args.chunk_overlap_size}')
ASCIIColors.white(' ├─ Cosine Threshold: ', end='')
ASCIIColors.yellow(f'{args.cosine_threshold}')
ASCIIColors.white(' ├─ Top-K: ', end='')
ASCIIColors.yellow(f'{args.top_k}')
ASCIIColors.white(' └─ Force LLM Summary on Merge: ', end='')
ASCIIColors.yellow(f'{get_env_value("FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int)}')
# System Configuration
ASCIIColors.magenta('\n💾 Storage Configuration:')
ASCIIColors.white(' ├─ KV Storage: ', end='')
ASCIIColors.yellow(f'{args.kv_storage}')
ASCIIColors.white(' ├─ Vector Storage: ', end='')
ASCIIColors.yellow(f'{args.vector_storage}')
ASCIIColors.white(' ├─ Graph Storage: ', end='')
ASCIIColors.yellow(f'{args.graph_storage}')
ASCIIColors.white(' ├─ Document Status Storage: ', end='')
ASCIIColors.yellow(f'{args.doc_status_storage}')
ASCIIColors.white(' └─ Workspace: ', end='')
ASCIIColors.yellow(f'{args.workspace if args.workspace else "-"}')
# Server Status
ASCIIColors.green('\n✨ Server starting up...\n')
# Server Access Information
protocol = 'https' if args.ssl else 'http'
if args.host == '0.0.0.0':
ASCIIColors.magenta('\n🌐 Server Access Information:')
ASCIIColors.white(' ├─ WebUI (local): ', end='')
ASCIIColors.yellow(f'{protocol}://localhost:{args.port}')
ASCIIColors.white(' ├─ Remote Access: ', end='')
ASCIIColors.yellow(f'{protocol}://<your-ip-address>:{args.port}')
ASCIIColors.white(' ├─ API Documentation (local): ', end='')
ASCIIColors.yellow(f'{protocol}://localhost:{args.port}/docs')
ASCIIColors.white(' └─ Alternative Documentation (local): ', end='')
ASCIIColors.yellow(f'{protocol}://localhost:{args.port}/redoc')
ASCIIColors.magenta('\n📝 Note:')
ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
- Use 'localhost' or '127.0.0.1' for local access
- Use your machine's IP address for remote access
- To find your IP address:
• Windows: Run 'ipconfig' in terminal
• Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal
""")
else:
base_url = f'{protocol}://{args.host}:{args.port}'
ASCIIColors.magenta('\n🌐 Server Access Information:')
ASCIIColors.white(' ├─ WebUI (local): ', end='')
ASCIIColors.yellow(f'{base_url}')
ASCIIColors.white(' ├─ API Documentation: ', end='')
ASCIIColors.yellow(f'{base_url}/docs')
ASCIIColors.white(' └─ Alternative Documentation: ', end='')
ASCIIColors.yellow(f'{base_url}/redoc')
# Security Notice
if args.key:
ASCIIColors.yellow('\n⚠️ Security Notice:')
ASCIIColors.white(""" API Key authentication is enabled.
Make sure to include the X-API-Key header in all your requests.
""")
if args.auth_accounts:
ASCIIColors.yellow('\n⚠️ Security Notice:')
ASCIIColors.white(""" JWT authentication is enabled.
Make sure to login before making the request, and include the 'Authorization' in the header.
""")
# Ensure splash output flush to system log
sys.stdout.flush()