LightRAG/lightrag/api/lightrag_server.py

1231 lines
46 KiB
Python

"""
LightRAG FastAPI Server
"""
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.openapi.docs import (
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
import os
import logging
import logging.config
import sys
import uvicorn
import pipmaster as pm
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path
import configparser
from ascii_colors import ASCIIColors
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from lightrag.api.utils_api import (
get_combined_auth_dependency,
display_splash_screen,
check_env_file,
)
from .config import (
global_args,
update_uvicorn_mode_config,
get_default_host,
)
from lightrag.utils import get_env_value
from lightrag import LightRAG, __version__ as core_version
from lightrag.api import __api_version__
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.utils import EmbeddingFunc
from lightrag.constants import (
DEFAULT_LOG_MAX_BYTES,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_FILENAME,
DEFAULT_LLM_TIMEOUT,
DEFAULT_EMBEDDING_TIMEOUT,
)
from lightrag.api.routers.document_routes import (
DocumentManager,
create_document_routes,
)
from lightrag.api.routers.query_routes import create_query_routes
from lightrag.api.routers.graph_routes import create_graph_routes
from lightrag.api.routers.ollama_api import OllamaAPI
from lightrag.utils import logger, set_verbose_debug
from lightrag.kg.shared_storage import (
get_namespace_data,
initialize_pipeline_status,
cleanup_keyed_lock,
finalize_share_data,
)
from fastapi.security import OAuth2PasswordRequestForm
from lightrag.api.auth import auth_handler
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
webui_title = os.getenv("WEBUI_TITLE")
webui_description = os.getenv("WEBUI_DESCRIPTION")
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
# Global authentication configuration
auth_configured = bool(auth_handler.accounts)
class LLMConfigCache:
"""Smart LLM and Embedding configuration cache class"""
def __init__(self, args):
self.args = args
# Initialize configurations based on binding conditions
self.openai_llm_options = None
self.gemini_llm_options = None
self.ollama_llm_options = None
self.ollama_embedding_options = None
# Only initialize and log OpenAI options when using OpenAI-related bindings
if args.llm_binding in ["openai", "azure_openai"]:
from lightrag.llm.binding_options import OpenAILLMOptions
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
if args.llm_binding == "gemini":
from lightrag.llm.binding_options import GeminiLLMOptions
self.gemini_llm_options = GeminiLLMOptions.options_dict(args)
logger.info(f"Gemini LLM Options: {self.gemini_llm_options}")
# Only initialize and log Ollama LLM options when using Ollama LLM binding
if args.llm_binding == "ollama":
try:
from lightrag.llm.binding_options import OllamaLLMOptions
self.ollama_llm_options = OllamaLLMOptions.options_dict(args)
logger.info(f"Ollama LLM Options: {self.ollama_llm_options}")
except ImportError:
logger.warning(
"OllamaLLMOptions not available, using default configuration"
)
self.ollama_llm_options = {}
# Only initialize and log Ollama Embedding options when using Ollama Embedding binding
if args.embedding_binding == "ollama":
try:
from lightrag.llm.binding_options import OllamaEmbeddingOptions
self.ollama_embedding_options = OllamaEmbeddingOptions.options_dict(
args
)
logger.info(
f"Ollama Embedding Options: {self.ollama_embedding_options}"
)
except ImportError:
logger.warning(
"OllamaEmbeddingOptions not available, using default configuration"
)
self.ollama_embedding_options = {}
def check_frontend_build():
"""Check if frontend is built and optionally check if source is up-to-date
Returns:
bool: True if frontend is outdated, False if up-to-date or production environment
"""
webui_dir = Path(__file__).parent / "webui"
index_html = webui_dir / "index.html"
# 1. Check if build files exist (required)
if not index_html.exists():
ASCIIColors.red("\n" + "=" * 80)
ASCIIColors.red("ERROR: Frontend Not Built")
ASCIIColors.red("=" * 80)
ASCIIColors.yellow("The WebUI frontend has not been built yet.")
ASCIIColors.yellow(
"Please build the frontend code first using the following commands:\n"
)
ASCIIColors.cyan(" cd lightrag_webui")
ASCIIColors.cyan(" bun install --frozen-lockfile")
ASCIIColors.cyan(" bun run build")
ASCIIColors.cyan(" cd ..")
ASCIIColors.yellow("\nThen restart the service.\n")
ASCIIColors.cyan(
"Note: Make sure you have Bun installed. Visit https://bun.sh for installation."
)
ASCIIColors.red("=" * 80 + "\n")
sys.exit(1) # Exit immediately
# 2. Check if this is a development environment (source directory exists)
try:
source_dir = Path(__file__).parent.parent.parent / "lightrag_webui"
src_dir = source_dir / "src"
# Determine if this is a development environment: source directory exists and contains src directory
if not source_dir.exists() or not src_dir.exists():
# Production environment, skip source code check
logger.debug(
"Production environment detected, skipping source freshness check"
)
return False
# Development environment, perform source code timestamp check
logger.debug("Development environment detected, checking source freshness")
# Source code file extensions (files to check)
source_extensions = {
".ts",
".tsx",
".js",
".jsx",
".mjs",
".cjs", # TypeScript/JavaScript
".css",
".scss",
".sass",
".less", # Style files
".json",
".jsonc", # Configuration/data files
".html",
".htm", # Template files
".md",
".mdx", # Markdown
}
# Key configuration files (in lightrag_webui root directory)
key_files = [
source_dir / "package.json",
source_dir / "bun.lock",
source_dir / "vite.config.ts",
source_dir / "tsconfig.json",
source_dir / "tailraid.config.js",
source_dir / "index.html",
]
# Get the latest modification time of source code
latest_source_time = 0
# Check source code files in src directory
for file_path in src_dir.rglob("*"):
if file_path.is_file():
# Only check source code files, ignore temporary files and logs
if file_path.suffix.lower() in source_extensions:
mtime = file_path.stat().st_mtime
latest_source_time = max(latest_source_time, mtime)
# Check key configuration files
for key_file in key_files:
if key_file.exists():
mtime = key_file.stat().st_mtime
latest_source_time = max(latest_source_time, mtime)
# Get build time
build_time = index_html.stat().st_mtime
# Compare timestamps (5 second tolerance to avoid file system time precision issues)
if latest_source_time > build_time + 5:
ASCIIColors.yellow("\n" + "=" * 80)
ASCIIColors.yellow("WARNING: Frontend Source Code Has Been Updated")
ASCIIColors.yellow("=" * 80)
ASCIIColors.yellow(
"The frontend source code is newer than the current build."
)
ASCIIColors.yellow(
"This might happen after 'git pull' or manual code changes.\n"
)
ASCIIColors.cyan(
"Recommended: Rebuild the frontend to use the latest changes:"
)
ASCIIColors.cyan(" cd lightrag_webui")
ASCIIColors.cyan(" bun install --frozen-lockfile")
ASCIIColors.cyan(" bun run build")
ASCIIColors.cyan(" cd ..")
ASCIIColors.yellow("\nThe server will continue with the current build.")
ASCIIColors.yellow("=" * 80 + "\n")
return True # Frontend is outdated
else:
logger.info("Frontend build is up-to-date")
return False # Frontend is up-to-date
except Exception as e:
# If check fails, log warning but don't affect startup
logger.warning(f"Failed to check frontend source freshness: {e}")
return False # Assume up-to-date on error
def create_app(args):
# Check frontend build first and get outdated status
is_frontend_outdated = check_frontend_build()
# Create unified API version display with warning symbol if frontend is outdated
api_version_display = (
f"{__api_version__}⚠️" if is_frontend_outdated else __api_version__
)
# Setup logging
logger.setLevel(args.log_level)
set_verbose_debug(args.verbose)
# Create configuration cache (this will output configuration logs)
config_cache = LLMConfigCache(args)
# Verify that bindings are correctly setup
if args.llm_binding not in [
"lollms",
"ollama",
"openai",
"azure_openai",
"aws_bedrock",
"gemini",
]:
raise Exception("llm binding not supported")
if args.embedding_binding not in [
"lollms",
"ollama",
"openai",
"azure_openai",
"aws_bedrock",
"jina",
]:
raise Exception("embedding binding not supported")
# Set default hosts if not provided
if args.llm_binding_host is None:
args.llm_binding_host = get_default_host(args.llm_binding)
if args.embedding_binding_host is None:
args.embedding_binding_host = get_default_host(args.embedding_binding)
# Add SSL validation
if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile:
raise Exception(
"SSL certificate and key files must be provided when SSL is enabled"
)
if not os.path.exists(args.ssl_certfile):
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize document manager with workspace support for data isolation
doc_manager = DocumentManager(args.input_dir, workspace=args.workspace)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Store background tasks
app.state.background_tasks = set()
try:
# Initialize database connections
await rag.initialize_storages()
await initialize_pipeline_status()
# Data migration regardless of storage implementation
await rag.check_and_migrate_data()
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
yield
finally:
# Clean up database connections
await rag.finalize_storages()
if "LIGHTRAG_GUNICORN_MODE" not in os.environ:
# Only perform cleanup in Uvicorn single-process mode
logger.debug("Unvicorn Mode: finalizing shared storage...")
finalize_share_data()
else:
# In Gunicorn mode with preload_app=True, cleanup is handled by on_exit hooks
logger.debug(
"Gunicorn Mode: postpone shared storage finalization to master process"
)
# Initialize FastAPI
base_description = (
"Providing API for LightRAG core, Web UI and Ollama Model Emulation"
)
swagger_description = (
base_description
+ (" (API-Key Enabled)" if api_key else "")
+ "\n\n[View ReDoc documentation](/redoc)"
)
app_kwargs = {
"title": "LightRAG Server API",
"description": swagger_description,
"version": __api_version__,
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
"docs_url": None, # Disable default docs, we'll create custom endpoint
"redoc_url": "/redoc", # Explicitly set redoc URL
"lifespan": lifespan,
}
# Configure Swagger UI parameters
# Enable persistAuthorization and tryItOutEnabled for better user experience
app_kwargs["swagger_ui_parameters"] = {
"persistAuthorization": True,
"tryItOutEnabled": True,
}
app = FastAPI(**app_kwargs)
# Add custom validation error handler for /query/data endpoint
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
# Check if this is a request to /query/data endpoint
if request.url.path.endswith("/query/data"):
# Extract error details
error_details = []
for error in exc.errors():
field_path = " -> ".join(str(loc) for loc in error["loc"])
error_details.append(f"{field_path}: {error['msg']}")
error_message = "; ".join(error_details)
# Return in the expected format for /query/data
return JSONResponse(
status_code=400,
content={
"status": "failure",
"message": f"Validation error: {error_message}",
"data": {},
"metadata": {},
},
)
else:
# For other endpoints, return the default FastAPI validation error
return JSONResponse(status_code=422, content={"detail": exc.errors()})
def get_cors_origins():
"""Get allowed origins from global_args
Returns a list of allowed origins, defaults to ["*"] if not set
"""
origins_str = global_args.cors_origins
if origins_str == "*":
return ["*"]
return [origin.strip() for origin in origins_str.split(",")]
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=get_cors_origins(),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
def create_optimized_openai_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int
):
"""Create optimized OpenAI LLM function with pre-processed configuration"""
async def optimized_openai_alike_model_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
from lightrag.llm.openai import openai_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
# Use pre-processed configuration to avoid repeated parsing
kwargs["timeout"] = llm_timeout
if config_cache.openai_llm_options:
kwargs.update(config_cache.openai_llm_options)
return await openai_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=args.llm_binding_api_key,
**kwargs,
)
return optimized_openai_alike_model_complete
def create_optimized_azure_openai_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int
):
"""Create optimized Azure OpenAI LLM function with pre-processed configuration"""
async def optimized_azure_openai_model_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
from lightrag.llm.azure_openai import azure_openai_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
# Use pre-processed configuration to avoid repeated parsing
kwargs["timeout"] = llm_timeout
if config_cache.openai_llm_options:
kwargs.update(config_cache.openai_llm_options)
return await azure_openai_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key),
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
**kwargs,
)
return optimized_azure_openai_model_complete
def create_optimized_gemini_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int
):
"""Create optimized Gemini LLM function with cached configuration"""
async def optimized_gemini_model_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
from lightrag.llm.gemini import gemini_complete_if_cache
if history_messages is None:
history_messages = []
# Use pre-processed configuration to avoid repeated parsing
kwargs["timeout"] = llm_timeout
if (
config_cache.gemini_llm_options is not None
and "generation_config" not in kwargs
):
kwargs["generation_config"] = dict(config_cache.gemini_llm_options)
return await gemini_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=args.llm_binding_api_key,
base_url=args.llm_binding_host,
keyword_extraction=keyword_extraction,
**kwargs,
)
return optimized_gemini_model_complete
def create_llm_model_func(binding: str):
"""
Create LLM model function based on binding type.
Uses optimized functions for OpenAI bindings and lazy import for others.
"""
try:
if binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete
return lollms_model_complete
elif binding == "ollama":
from lightrag.llm.ollama import ollama_model_complete
return ollama_model_complete
elif binding == "aws_bedrock":
return bedrock_model_complete # Already defined locally
elif binding == "azure_openai":
# Use optimized function with pre-processed configuration
return create_optimized_azure_openai_llm_func(
config_cache, args, llm_timeout
)
elif binding == "gemini":
return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
else: # openai and compatible
# Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
except ImportError as e:
raise Exception(f"Failed to import {binding} LLM binding: {e}")
def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict:
"""
Create LLM model kwargs based on binding type.
Uses lazy import for binding-specific options.
"""
if binding in ["lollms", "ollama"]:
try:
from lightrag.llm.binding_options import OllamaLLMOptions
return {
"host": args.llm_binding_host,
"timeout": llm_timeout,
"options": OllamaLLMOptions.options_dict(args),
"api_key": args.llm_binding_api_key,
}
except ImportError as e:
raise Exception(f"Failed to import {binding} options: {e}")
return {}
def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, args
):
"""
Create optimized embedding function with pre-processed configuration for applicable bindings.
Uses lazy imports for all bindings and avoids repeated configuration parsing.
"""
async def optimized_embedding_function(texts, embedding_dim=None):
try:
if binding == "lollms":
from lightrag.llm.lollms import lollms_embed
return await lollms_embed(
texts, embed_model=model, host=host, api_key=api_key
)
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
if config_cache.ollama_embedding_options is not None:
ollama_options = config_cache.ollama_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import OllamaEmbeddingOptions
ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await ollama_embed(
texts,
embed_model=model,
host=host,
api_key=api_key,
options=ollama_options,
)
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
return await azure_openai_embed(texts, model=model, api_key=api_key)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
return await bedrock_embed(texts, model=model)
elif binding == "jina":
from lightrag.llm.jina import jina_embed
return await jina_embed(
texts,
embedding_dim=embedding_dim,
base_url=host,
api_key=api_key,
)
else: # openai and compatible
from lightrag.llm.openai import openai_embed
return await openai_embed(
texts,
model=model,
base_url=host,
api_key=api_key,
embedding_dim=embedding_dim,
)
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")
return optimized_embedding_function
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value(
"EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int
)
async def bedrock_model_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
# Lazy import
from lightrag.llm.bedrock import bedrock_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
# Use global temperature for Bedrock
kwargs["temperature"] = get_env_value("BEDROCK_LLM_TEMPERATURE", 1.0, float)
return await bedrock_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
# Create embedding function with optimized configuration
import inspect
# Create the optimized embedding function
optimized_embedding_func = create_optimized_embedding_function(
config_cache=config_cache,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
args=args, # Pass args object for fallback option generation
)
# Get embedding_send_dim from centralized configuration
embedding_send_dim = args.embedding_send_dim
# Check if the function signature has embedding_dim parameter
# Note: Since optimized_embedding_func is an async function, inspect its signature
sig = inspect.signature(optimized_embedding_func)
has_embedding_dim_param = "embedding_dim" in sig.parameters
# Determine send_dimensions value based on binding type
# Jina REQUIRES dimension parameter (forced to True)
# OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable
if args.embedding_binding == "jina":
# Jina API requires dimension parameter - always send it
send_dimensions = has_embedding_dim_param
dimension_control = "forced by Jina API"
else:
# For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting
send_dimensions = embedding_send_dim and has_embedding_dim_param
if send_dimensions or not embedding_send_dim:
dimension_control = "by env var"
else:
dimension_control = "by not hasparam"
logger.info(
f"Send embedding dimension: {send_dimensions} {dimension_control} "
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
f"binding={args.embedding_binding})"
)
# Create EmbeddingFunc with send_dimensions attribute
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
func=optimized_embedding_func,
send_dimensions=send_dimensions,
)
# Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None
if args.rerank_binding != "null":
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
# Map rerank binding to corresponding function
rerank_functions = {
"cohere": cohere_rerank,
"jina": jina_rerank,
"aliyun": ali_rerank,
}
# Select the appropriate rerank function based on binding
selected_rerank_func = rerank_functions.get(args.rerank_binding)
if not selected_rerank_func:
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
# Get default values from selected_rerank_func if args values are None
if args.rerank_model is None or args.rerank_binding_host is None:
sig = inspect.signature(selected_rerank_func)
# Set default model if args.rerank_model is None
if args.rerank_model is None and "model" in sig.parameters:
default_model = sig.parameters["model"].default
if default_model != inspect.Parameter.empty:
args.rerank_model = default_model
# Set default base_url if args.rerank_binding_host is None
if args.rerank_binding_host is None and "base_url" in sig.parameters:
default_base_url = sig.parameters["base_url"].default
if default_base_url != inspect.Parameter.empty:
args.rerank_binding_host = default_base_url
async def server_rerank_func(
query: str, documents: list, top_n: int = None, extra_body: dict = None
):
"""Server rerank function with configuration from environment variables"""
return await selected_rerank_func(
query=query,
documents=documents,
top_n=top_n,
api_key=args.rerank_binding_api_key,
model=args.rerank_model,
base_url=args.rerank_binding_host,
extra_body=extra_body,
)
rerank_model_func = server_rerank_func
logger.info(
f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider"
)
else:
logger.info("Reranking is disabled")
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
ollama_server_infos = OllamaServerInfos(
name=args.simulated_model_name, tag=args.simulated_model_tag
)
# Initialize RAG with unified configuration
try:
rag = LightRAG(
working_dir=args.working_dir,
workspace=args.workspace,
llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs=create_llm_model_kwargs(
args.llm_binding, args, llm_timeout
),
embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
"entity_types": args.entity_types,
},
ollama_server_infos=ollama_server_infos,
)
except Exception as e:
logger.error(f"Failed to initialize LightRAG: {e}")
raise
# Add routes
app.include_router(
create_document_routes(
rag,
doc_manager,
api_key,
)
)
app.include_router(create_query_routes(rag, api_key, args.top_k))
app.include_router(create_graph_routes(rag, api_key))
# Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
app.include_router(ollama_api.router, prefix="/api")
# Custom Swagger UI endpoint for offline support
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
"""Custom Swagger UI HTML with local static files"""
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title + " - Swagger UI",
oauth2_redirect_url="/docs/oauth2-redirect",
swagger_js_url="/static/swagger-ui/swagger-ui-bundle.js",
swagger_css_url="/static/swagger-ui/swagger-ui.css",
swagger_favicon_url="/static/swagger-ui/favicon-32x32.png",
swagger_ui_parameters=app.swagger_ui_parameters,
)
@app.get("/docs/oauth2-redirect", include_in_schema=False)
async def swagger_ui_redirect():
"""OAuth2 redirect for Swagger UI"""
return get_swagger_ui_oauth2_redirect_html()
@app.get("/")
async def redirect_to_webui():
"""Redirect root path to /webui"""
return RedirectResponse(url="/webui")
@app.get("/auth-status")
async def get_auth_status():
"""Get authentication status and guest token if auth is not configured"""
if not auth_handler.accounts:
# Authentication not configured, return guest token
guest_token = auth_handler.create_token(
username="guest", role="guest", metadata={"auth_mode": "disabled"}
)
return {
"auth_configured": False,
"access_token": guest_token,
"token_type": "bearer",
"auth_mode": "disabled",
"message": "Authentication is disabled. Using guest access.",
"core_version": core_version,
"api_version": api_version_display,
"webui_title": webui_title,
"webui_description": webui_description,
}
return {
"auth_configured": True,
"auth_mode": "enabled",
"core_version": core_version,
"api_version": api_version_display,
"webui_title": webui_title,
"webui_description": webui_description,
}
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
if not auth_handler.accounts:
# Authentication not configured, return guest token
guest_token = auth_handler.create_token(
username="guest", role="guest", metadata={"auth_mode": "disabled"}
)
return {
"access_token": guest_token,
"token_type": "bearer",
"auth_mode": "disabled",
"message": "Authentication is disabled. Using guest access.",
"core_version": core_version,
"api_version": api_version_display,
"webui_title": webui_title,
"webui_description": webui_description,
}
username = form_data.username
if auth_handler.accounts.get(username) != form_data.password:
raise HTTPException(status_code=401, detail="Incorrect credentials")
# Regular user login
user_token = auth_handler.create_token(
username=username, role="user", metadata={"auth_mode": "enabled"}
)
return {
"access_token": user_token,
"token_type": "bearer",
"auth_mode": "enabled",
"core_version": core_version,
"api_version": api_version_display,
"webui_title": webui_title,
"webui_description": webui_description,
}
@app.get("/health", dependencies=[Depends(combined_auth)])
async def get_status():
"""Get current system status"""
try:
pipeline_status = await get_namespace_data("pipeline_status")
if not auth_configured:
auth_mode = "disabled"
else:
auth_mode = "enabled"
# Cleanup expired keyed locks and get status
keyed_lock_info = cleanup_keyed_lock()
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"configuration": {
# LLM configuration binding/host address (if applicable)/model (if applicable)
"llm_binding": args.llm_binding,
"llm_binding_host": args.llm_binding_host,
"llm_model": args.llm_model,
# embedding model configuration binding/host address (if applicable)/model (if applicable)
"embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
"summary_max_tokens": args.summary_max_tokens,
"summary_context_size": args.summary_context_size,
"kv_storage": args.kv_storage,
"doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage,
"vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
"enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration
"enable_rerank": rerank_model_func is not None,
"rerank_binding": args.rerank_binding,
"rerank_model": args.rerank_model if rerank_model_func else None,
"rerank_binding_host": args.rerank_binding_host
if rerank_model_func
else None,
# Environment variable status (requested configuration)
"summary_language": args.summary_language,
"force_llm_summary_on_merge": args.force_llm_summary_on_merge,
"max_parallel_insert": args.max_parallel_insert,
"cosine_threshold": args.cosine_threshold,
"min_rerank_score": args.min_rerank_score,
"related_chunk_number": args.related_chunk_number,
"max_async": args.max_async,
"embedding_func_max_async": args.embedding_func_max_async,
"embedding_batch_num": args.embedding_batch_num,
},
"auth_mode": auth_mode,
"pipeline_busy": pipeline_status.get("busy", False),
"keyed_locks": keyed_lock_info,
"core_version": core_version,
"api_version": api_version_display,
"webui_title": webui_title,
"webui_description": webui_description,
}
except Exception as e:
logger.error(f"Error getting health status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Custom StaticFiles class for smart caching
class SmartStaticFiles(StaticFiles): # Renamed from NoCacheStaticFiles
async def get_response(self, path: str, scope):
response = await super().get_response(path, scope)
is_html = path.endswith(".html") or response.media_type == "text/html"
if is_html:
response.headers["Cache-Control"] = (
"no-cache, no-store, must-revalidate"
)
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
elif (
"/assets/" in path
): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename
response.headers["Cache-Control"] = (
"public, max-age=31536000, immutable"
)
# Add other rules here if needed for non-HTML, non-asset files
# Ensure correct Content-Type
if path.endswith(".js"):
response.headers["Content-Type"] = "application/javascript"
elif path.endswith(".css"):
response.headers["Content-Type"] = "text/css"
return response
# Mount Swagger UI static files for offline support
swagger_static_dir = Path(__file__).parent / "static" / "swagger-ui"
if swagger_static_dir.exists():
app.mount(
"/static/swagger-ui",
StaticFiles(directory=swagger_static_dir),
name="swagger-ui-static",
)
# Webui mount webui/index.html
static_dir = Path(__file__).parent / "webui"
static_dir.mkdir(exist_ok=True)
app.mount(
"/webui",
SmartStaticFiles(
directory=static_dir, html=True, check_dir=True
), # Use SmartStaticFiles
name="webui",
)
return app
def get_application(args=None):
"""Factory function for creating the FastAPI application"""
if args is None:
args = global_args
return create_app(args)
def configure_logging():
"""Configure logging for uvicorn startup"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger = logging.getLogger(logger_name)
logger.handlers = []
logger.filters = []
# Get log directory path from environment variable
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, DEFAULT_LOG_FILENAME))
print(f"\nLightRAG log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = get_env_value("LOG_MAX_BYTES", DEFAULT_LOG_MAX_BYTES, int)
log_backup_count = get_env_value("LOG_BACKUP_COUNT", DEFAULT_LOG_BACKUP_COUNT, int)
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
# Configure all uvicorn related loggers
"uvicorn": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
"uvicorn.access": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
"uvicorn.error": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
},
"filters": {
"path_filter": {
"()": "lightrag.utils.LightragPathFilter",
},
},
}
)
def check_and_install_dependencies():
"""Check and install required dependencies"""
required_packages = [
"uvicorn",
"tiktoken",
"fastapi",
# Add other required packages here
]
for package in required_packages:
if not pm.is_installed(package):
print(f"Installing {package}...")
pm.install(package)
print(f"{package} installed successfully")
def main():
# Check if running under Gunicorn
if "GUNICORN_CMD_ARGS" in os.environ:
# If started with Gunicorn, return directly as Gunicorn will call get_application
print("Running under Gunicorn - worker management handled by Gunicorn")
return
# Check .env file
if not check_env_file():
sys.exit(1)
# Check and install dependencies
check_and_install_dependencies()
from multiprocessing import freeze_support
freeze_support()
# Configure logging before parsing args
configure_logging()
update_uvicorn_mode_config()
display_splash_screen(global_args)
# Note: Signal handlers are NOT registered here because:
# - Uvicorn has built-in signal handling that properly calls lifespan shutdown
# - Custom signal handlers can interfere with uvicorn's graceful shutdown
# - Cleanup is handled by the lifespan context manager's finally block
# Create application instance directly instead of using factory function
app = create_app(global_args)
# Start Uvicorn in single process mode
uvicorn_config = {
"app": app, # Pass application instance directly instead of string path
"host": global_args.host,
"port": global_args.port,
"log_config": None, # Disable default config
}
if global_args.ssl:
uvicorn_config.update(
{
"ssl_certfile": global_args.ssl_certfile,
"ssl_keyfile": global_args.ssl_keyfile,
}
)
print(
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
)
uvicorn.run(**uvicorn_config)
if __name__ == "__main__":
main()