From 778e6d57c4ce4a5bb57314dc0764c6baeb8b5e1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:27 +0800 Subject: [PATCH] cherry-pick 6b2af2b5 --- lightrag/api/lightrag_server.py | 67 +++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 41a07f7f..8f3fbae1 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -452,6 +452,29 @@ def create_app(args): # Create combined auth dependency for all endpoints combined_auth = get_combined_auth_dependency(api_key) + def get_workspace_from_request(request: Request) -> str: + """ + Extract workspace from HTTP request header or use default. + + This enables multi-workspace API support by checking the custom + 'LIGHTRAG-WORKSPACE' header. If not present, falls back to the + server's default workspace configuration. + + Args: + request: FastAPI Request object + + Returns: + Workspace identifier (may be empty string for global namespace) + """ + # Check custom header first + workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() + + # Fall back to server default if header not provided + if not workspace: + workspace = args.workspace + + return workspace + # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -632,8 +655,8 @@ def create_app(args): # Step 1: Import provider function and extract default attributes provider_func = None - provider_max_token_size = None - provider_embedding_dim = None + default_max_token_size = None + default_embedding_dim = args.embedding_dim # Use config as default try: if binding == "openai": @@ -667,24 +690,18 @@ def create_app(args): # Extract attributes if provider is an EmbeddingFunc if provider_func and isinstance(provider_func, EmbeddingFunc): - provider_max_token_size = provider_func.max_token_size - provider_embedding_dim = provider_func.embedding_dim + default_max_token_size = provider_func.max_token_size + default_embedding_dim = provider_func.embedding_dim logger.debug( f"Extracted from {binding} provider: " - f"max_token_size={provider_max_token_size}, " - f"embedding_dim={provider_embedding_dim}" + f"max_token_size={default_max_token_size}, " + f"embedding_dim={default_embedding_dim}" ) except ImportError as e: logger.warning(f"Could not import provider function for {binding}: {e}") - # Step 2: Apply priority (user config > provider default) - # For max_token_size: explicit env var > provider default > None - final_max_token_size = args.embedding_token_limit or provider_max_token_size - # For embedding_dim: user config (always has value) takes priority - # Only use provider default if user config is explicitly None (which shouldn't happen) - final_embedding_dim = ( - args.embedding_dim if args.embedding_dim else provider_embedding_dim - ) + # Step 2: Apply priority (environment variable > provider default) + final_max_token_size = args.embedding_token_limit or default_max_token_size # Step 3: Create optimized embedding function (calls underlying function directly) async def optimized_embedding_function(texts, embedding_dim=None): @@ -803,18 +820,12 @@ def create_app(args): # Step 4: Wrap in EmbeddingFunc and return embedding_func_instance = EmbeddingFunc( - embedding_dim=final_embedding_dim, + embedding_dim=default_embedding_dim, func=optimized_embedding_function, max_token_size=final_max_token_size, send_dimensions=False, # Will be set later based on binding requirements ) - # Log final embedding configuration - logger.info( - f"Embedding config: binding={binding} model={model} " - f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}" - ) - return embedding_func_instance llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) @@ -1113,10 +1124,17 @@ def create_app(args): } @app.get("/health", dependencies=[Depends(combined_auth)]) - async def get_status(): + async def get_status(request: Request): """Get current system status""" try: - pipeline_status = await get_namespace_data("pipeline_status") + # Extract workspace from request header or use default + workspace = get_workspace_from_request(request) + + # Construct namespace (following GraphDB pattern) + namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" + + # Get workspace-specific pipeline status + pipeline_status = await get_namespace_data(namespace) if not auth_configured: auth_mode = "disabled" @@ -1147,7 +1165,8 @@ def create_app(args): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": args.workspace, + "workspace": workspace, + "default_workspace": args.workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None,