cherry-pick 6b2af2b5
This commit is contained in:
parent
7f5afd0a4d
commit
778e6d57c4
1 changed files with 43 additions and 24 deletions
|
|
@ -452,6 +452,29 @@ def create_app(args):
|
|||
# Create combined auth dependency for all endpoints
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
def get_workspace_from_request(request: Request) -> str:
|
||||
"""
|
||||
Extract workspace from HTTP request header or use default.
|
||||
|
||||
This enables multi-workspace API support by checking the custom
|
||||
'LIGHTRAG-WORKSPACE' header. If not present, falls back to the
|
||||
server's default workspace configuration.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
Workspace identifier (may be empty string for global namespace)
|
||||
"""
|
||||
# Check custom header first
|
||||
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
|
||||
|
||||
# Fall back to server default if header not provided
|
||||
if not workspace:
|
||||
workspace = args.workspace
|
||||
|
||||
return workspace
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -632,8 +655,8 @@ def create_app(args):
|
|||
|
||||
# Step 1: Import provider function and extract default attributes
|
||||
provider_func = None
|
||||
provider_max_token_size = None
|
||||
provider_embedding_dim = None
|
||||
default_max_token_size = None
|
||||
default_embedding_dim = args.embedding_dim # Use config as default
|
||||
|
||||
try:
|
||||
if binding == "openai":
|
||||
|
|
@ -667,24 +690,18 @@ def create_app(args):
|
|||
|
||||
# Extract attributes if provider is an EmbeddingFunc
|
||||
if provider_func and isinstance(provider_func, EmbeddingFunc):
|
||||
provider_max_token_size = provider_func.max_token_size
|
||||
provider_embedding_dim = provider_func.embedding_dim
|
||||
default_max_token_size = provider_func.max_token_size
|
||||
default_embedding_dim = provider_func.embedding_dim
|
||||
logger.debug(
|
||||
f"Extracted from {binding} provider: "
|
||||
f"max_token_size={provider_max_token_size}, "
|
||||
f"embedding_dim={provider_embedding_dim}"
|
||||
f"max_token_size={default_max_token_size}, "
|
||||
f"embedding_dim={default_embedding_dim}"
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import provider function for {binding}: {e}")
|
||||
|
||||
# Step 2: Apply priority (user config > provider default)
|
||||
# For max_token_size: explicit env var > provider default > None
|
||||
final_max_token_size = args.embedding_token_limit or provider_max_token_size
|
||||
# For embedding_dim: user config (always has value) takes priority
|
||||
# Only use provider default if user config is explicitly None (which shouldn't happen)
|
||||
final_embedding_dim = (
|
||||
args.embedding_dim if args.embedding_dim else provider_embedding_dim
|
||||
)
|
||||
# Step 2: Apply priority (environment variable > provider default)
|
||||
final_max_token_size = args.embedding_token_limit or default_max_token_size
|
||||
|
||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||
|
|
@ -803,18 +820,12 @@ def create_app(args):
|
|||
|
||||
# Step 4: Wrap in EmbeddingFunc and return
|
||||
embedding_func_instance = EmbeddingFunc(
|
||||
embedding_dim=final_embedding_dim,
|
||||
embedding_dim=default_embedding_dim,
|
||||
func=optimized_embedding_function,
|
||||
max_token_size=final_max_token_size,
|
||||
send_dimensions=False, # Will be set later based on binding requirements
|
||||
)
|
||||
|
||||
# Log final embedding configuration
|
||||
logger.info(
|
||||
f"Embedding config: binding={binding} model={model} "
|
||||
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
|
||||
)
|
||||
|
||||
return embedding_func_instance
|
||||
|
||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||
|
|
@ -1113,10 +1124,17 @@ def create_app(args):
|
|||
}
|
||||
|
||||
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||||
async def get_status():
|
||||
async def get_status(request: Request):
|
||||
"""Get current system status"""
|
||||
try:
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
# Extract workspace from request header or use default
|
||||
workspace = get_workspace_from_request(request)
|
||||
|
||||
# Construct namespace (following GraphDB pattern)
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
|
||||
# Get workspace-specific pipeline status
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
|
||||
if not auth_configured:
|
||||
auth_mode = "disabled"
|
||||
|
|
@ -1147,7 +1165,8 @@ def create_app(args):
|
|||
"vector_storage": args.vector_storage,
|
||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
"enable_llm_cache": args.enable_llm_cache,
|
||||
"workspace": args.workspace,
|
||||
"workspace": workspace,
|
||||
"default_workspace": args.workspace,
|
||||
"max_graph_nodes": args.max_graph_nodes,
|
||||
# Rerank configuration
|
||||
"enable_rerank": rerank_model_func is not None,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue