Ruff formatted

This commit is contained in:
SJ 2025-08-15 22:21:34 +00:00
parent 3aa3332505
commit f7ca9ae16a
3 changed files with 142 additions and 46 deletions

View file

@ -77,7 +77,9 @@ def parse_args() -> argparse.Namespace:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="LightRAG FastAPI Server with separate working and input directories")
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Server configuration
parser.add_argument(
@ -207,7 +209,14 @@ def parse_args() -> argparse.Namespace:
"--llm-binding",
type=str,
default=get_env_value("LLM_BINDING", "ollama"),
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai", "aws_bedrock"],
choices=[
"lollms",
"ollama",
"openai",
"openai-ollama",
"azure_openai",
"aws_bedrock",
],
help="LLM binding type (default: from env or ollama)",
)
parser.add_argument(
@ -270,10 +279,18 @@ def parse_args() -> argparse.Namespace:
args.input_dir = os.path.abspath(args.input_dir)
# Inject storage configuration from environment variables
args.kv_storage = get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE)
args.doc_status_storage = get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE)
args.graph_storage = get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE)
args.vector_storage = get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE)
args.kv_storage = get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
)
args.doc_status_storage = get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
)
args.graph_storage = get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
)
args.vector_storage = get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
)
# Get MAX_PARALLEL_INSERT from environment
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
@ -289,8 +306,12 @@ def parse_args() -> argparse.Namespace:
# Ollama ctx_num
args.ollama_num_ctx = get_env_value("OLLAMA_NUM_CTX", 32768, int)
args.llm_binding_host = get_env_value("LLM_BINDING_HOST", get_default_host(args.llm_binding))
args.embedding_binding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding))
args.llm_binding_host = get_env_value(
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
)
args.embedding_binding_host = get_env_value(
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
)
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
@ -304,7 +325,9 @@ def parse_args() -> argparse.Namespace:
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration
args.enable_llm_cache_for_extract = get_env_value("ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool)
args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
)
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
# Handle Ollama LLM temperature with priority cascade when llm-binding is ollama
@ -354,24 +377,40 @@ def parse_args() -> argparse.Namespace:
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
# Min rerank score configuration
args.min_rerank_score = get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float)
args.min_rerank_score = get_env_value(
"MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float
)
# Query configuration
args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int)
args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int)
args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int)
args.max_entity_tokens = get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int)
args.max_relation_tokens = get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int)
args.max_total_tokens = get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int)
args.cosine_threshold = get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float)
args.related_chunk_number = get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int)
args.max_entity_tokens = get_env_value(
"MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int
)
args.max_relation_tokens = get_env_value(
"MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int
)
args.max_total_tokens = get_env_value(
"MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int
)
args.cosine_threshold = get_env_value(
"COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float
)
args.related_chunk_number = get_env_value(
"RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int
)
# Add missing environment variables for health endpoint
args.force_llm_summary_on_merge = get_env_value(
"FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int
)
args.embedding_func_max_async = get_env_value("EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int)
args.embedding_batch_num = get_env_value("EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int)
args.embedding_func_max_async = get_env_value(
"EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int
)
args.embedding_batch_num = get_env_value(
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
)
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
@ -385,7 +424,9 @@ def update_uvicorn_mode_config():
original_workers = global_args.workers
global_args.workers = 1
# Log warning directly here
logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1")
logging.warning(
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
)
global_args = parse_args()

View file

@ -130,7 +130,9 @@ def create_app(args):
# Add SSL validation
if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile:
raise Exception("SSL certificate and key files must be provided when SSL is enabled")
raise Exception(
"SSL certificate and key files must be provided when SSL is enabled"
)
if not os.path.exists(args.ssl_certfile):
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
if not os.path.exists(args.ssl_keyfile):
@ -189,7 +191,8 @@ def create_app(args):
app_kwargs = {
"title": "LightRAG Server API",
"description": (
"Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)"
"Providing API for LightRAG core, Web UI and Ollama Model Emulation"
+ "(With authentication)"
if api_key
else ""
),
@ -395,7 +398,9 @@ def create_app(args):
if args.rerank_binding_api_key and args.rerank_binding_host:
from lightrag.rerank import custom_rerank
async def server_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
async def server_rerank_func(
query: str, documents: list, top_n: int = None, **kwargs
):
"""Server rerank function with configuration from environment variables"""
return await custom_rerank(
query=query,
@ -408,7 +413,9 @@ def create_app(args):
)
rerank_model_func = server_rerank_func
logger.info(f"Rerank model configured: {args.rerank_model} (can be enabled per query)")
logger.info(
f"Rerank model configured: {args.rerank_model} (can be enabled per query)"
)
else:
logger.info(
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
@ -417,7 +424,9 @@ def create_app(args):
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag)
ollama_server_infos = OllamaServerInfos(
name=args.simulated_model_name, tag=args.simulated_model_tag
)
# Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]:
@ -430,7 +439,9 @@ def create_app(args):
else (
ollama_model_complete
if args.llm_binding == "ollama"
else bedrock_model_complete if args.llm_binding == "aws_bedrock" else openai_alike_model_complete
else bedrock_model_complete
if args.llm_binding == "aws_bedrock"
else openai_alike_model_complete
)
),
llm_model_name=args.llm_model,
@ -453,7 +464,9 @@ def create_app(args):
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold},
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
@ -480,7 +493,9 @@ def create_app(args):
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold},
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
@ -516,7 +531,9 @@ def create_app(args):
if not auth_handler.accounts:
# Authentication not configured, return guest token
guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"})
guest_token = auth_handler.create_token(
username="guest", role="guest", metadata={"auth_mode": "disabled"}
)
return {
"auth_configured": False,
"access_token": guest_token,
@ -542,7 +559,9 @@ def create_app(args):
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
if not auth_handler.accounts:
# Authentication not configured, return guest token
guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"})
guest_token = auth_handler.create_token(
username="guest", role="guest", metadata={"auth_mode": "disabled"}
)
return {
"access_token": guest_token,
"token_type": "bearer",
@ -555,10 +574,14 @@ def create_app(args):
}
username = form_data.username
if auth_handler.accounts.get(username) != form_data.password:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
# Regular user login
user_token = auth_handler.create_token(username=username, role="user", metadata={"auth_mode": "enabled"})
user_token = auth_handler.create_token(
username=username, role="user", metadata={"auth_mode": "enabled"}
)
return {
"access_token": user_token,
"token_type": "bearer",
@ -607,8 +630,12 @@ def create_app(args):
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration (based on whether rerank model is configured)
"enable_rerank": rerank_model_func is not None,
"rerank_model": args.rerank_model if rerank_model_func is not None else None,
"rerank_binding_host": args.rerank_binding_host if rerank_model_func is not None else None,
"rerank_model": args.rerank_model
if rerank_model_func is not None
else None,
"rerank_binding_host": args.rerank_binding_host
if rerank_model_func is not None
else None,
# Environment variable status (requested configuration)
"summary_language": args.summary_language,
"force_llm_summary_on_merge": args.force_llm_summary_on_merge,
@ -638,11 +665,17 @@ def create_app(args):
response = await super().get_response(path, scope)
if path.endswith(".html"):
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
response.headers["Cache-Control"] = (
"no-cache, no-store, must-revalidate"
)
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
elif "/assets/" in path: # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename
response.headers["Cache-Control"] = "public, max-age=31536000, immutable"
elif (
"/assets/" in path
): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename
response.headers["Cache-Control"] = (
"public, max-age=31536000, immutable"
)
# Add other rules here if needed for non-HTML, non-asset files
# Ensure correct Content-Type
@ -658,7 +691,9 @@ def create_app(args):
static_dir.mkdir(exist_ok=True)
app.mount(
"/webui",
SmartStaticFiles(directory=static_dir, html=True, check_dir=True), # Use SmartStaticFiles
SmartStaticFiles(
directory=static_dir, html=True, check_dir=True
), # Use SmartStaticFiles
name="webui",
)
@ -814,7 +849,9 @@ def main():
}
)
print(f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}")
print(
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
)
uvicorn.run(**uvicorn_config)

View file

@ -100,10 +100,14 @@ async def bedrock_complete_if_cache(
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])):
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = kwargs.pop(param)
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
# Import logging for error handling
import logging
@ -119,7 +123,9 @@ async def bedrock_complete_if_cache(
nonlocal client
# Create the client outside the generator to ensure it stays open
client = await session.client("bedrock-runtime", region_name=region).__aenter__()
client = await session.client(
"bedrock-runtime", region_name=region
).__aenter__()
event_stream = None
iteration_started = False
@ -158,7 +164,9 @@ async def bedrock_complete_if_cache(
try:
await event_stream.aclose()
except Exception as close_error:
logging.warning(f"Failed to close Bedrock event stream: {close_error}")
logging.warning(
f"Failed to close Bedrock event stream: {close_error}"
)
raise BedrockError(f"Streaming error: {e}")
@ -173,21 +181,27 @@ async def bedrock_complete_if_cache(
try:
await event_stream.aclose()
except Exception as close_error:
logging.warning(f"Failed to close Bedrock event stream in finally block: {close_error}")
logging.warning(
f"Failed to close Bedrock event stream in finally block: {close_error}"
)
# Clean up the client
if client:
try:
await client.__aexit__(None, None, None)
except Exception as client_close_error:
logging.warning(f"Failed to close Bedrock client: {client_close_error}")
logging.warning(
f"Failed to close Bedrock client: {client_close_error}"
)
# Return the generator that manages its own lifecycle
return stream_generator()
# For non-streaming responses, use the standard async context manager pattern
session = aioboto3.Session()
async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client:
async with session.client(
"bedrock-runtime", region_name=region
) as bedrock_async_client:
try:
# Use converse for non-streaming responses
response = await bedrock_async_client.converse(**args, **kwargs)
@ -257,7 +271,9 @@ async def bedrock_embed(
region = os.environ.get("AWS_REGION")
session = aioboto3.Session()
async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client:
async with session.client(
"bedrock-runtime", region_name=region
) as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
@ -285,7 +301,9 @@ async def bedrock_embed(
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps({"texts": texts, "input_type": "search_document", "truncate": "NONE"})
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,