LightRAG/lightrag/api/middleware/tenant.py
2025-12-05 14:31:13 +08:00

82 lines
3.1 KiB
Python

from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from lightrag.api.auth import auth_handler
from lightrag.utils_context import set_current_tenant_id
import logging
logger = logging.getLogger(__name__)
class TenantMiddleware(BaseHTTPMiddleware):
"""
Middleware to extract and set tenant context from request.
Priority of tenant identification:
1. Subdomain (e.g., tenant.app.com)
2. JWT Token (metadata.tenant_id)
Sets request.state.tenant_id if found.
"""
async def dispatch(self, request: Request, call_next):
# Skip for public endpoints
if request.url.path in [
"/health",
"/docs",
"/openapi.json",
"/redoc",
"/auth-status",
"/login",
"/",
"/webui",
] or request.url.path.startswith("/assets"):
return await call_next(request)
tenant_id = None
# 1. Attempt Subdomain Extraction
host = request.headers.get("host", "")
# Simple logic: if 3 parts, first is subdomain.
# Adjust based on actual domain config (e.g. if using localhost)
if "." in host:
parts = host.split(".")
# e.g. tenant.localhost:8000 or tenant.example.com
if len(parts) >= 2 and not parts[0].isdigit():
# Avoid IP addresses.
# For localhost (localhost:8000), it's just "localhost", no subdomain usually unless configured
# For tenant.localhost, parts=['tenant', 'localhost:8000']
if parts[0] != "www":
# Potential subdomain
# In a real app, check against a list of allowed domains or Redis
pass
# 2. Attempt JWT Extraction
# We peek at the token to get tenant_id.
# Full validation happens in dependencies, but we want to set state early.
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
try:
token = auth_header.split(" ")[1]
# We use validate_token which verifies signature too
token_data = auth_handler.validate_token(token)
jwt_tenant_id = token_data.get("metadata", {}).get("tenant_id")
if jwt_tenant_id:
tenant_id = jwt_tenant_id
except Exception as e:
# Token invalid or expired.
# We don't block here, we let dependencies.py handle 401 if auth is required.
logger.debug(f"TenantMiddleware: Token validation failed: {e}")
pass
# 3. Set State
if tenant_id:
request.state.tenant_id = tenant_id
# Set ContextVar for deep integration (DB layer)
token = set_current_tenant_id(tenant_id)
logger.debug(f"TenantMiddleware: Set tenant_id={tenant_id}")
else:
logger.debug("TenantMiddleware: No tenant_id found")
response = await call_next(request)
return response