From f7ca9ae16a26a81d9b4b188f7a4a1e071b6419a1 Mon Sep 17 00:00:00 2001 From: SJ Date: Fri, 15 Aug 2025 22:21:34 +0000 Subject: [PATCH] Ruff formatted --- lightrag/api/config.py | 77 +++++++++++++++++++++++++-------- lightrag/api/lightrag_server.py | 75 ++++++++++++++++++++++++-------- lightrag/llm/bedrock.py | 36 +++++++++++---- 3 files changed, 142 insertions(+), 46 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 0a16e89b..01d0dd75 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -77,7 +77,9 @@ def parse_args() -> argparse.Namespace: argparse.Namespace: Parsed arguments """ - parser = argparse.ArgumentParser(description="LightRAG FastAPI Server with separate working and input directories") + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) # Server configuration parser.add_argument( @@ -207,7 +209,14 @@ def parse_args() -> argparse.Namespace: "--llm-binding", type=str, default=get_env_value("LLM_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai", "aws_bedrock"], + choices=[ + "lollms", + "ollama", + "openai", + "openai-ollama", + "azure_openai", + "aws_bedrock", + ], help="LLM binding type (default: from env or ollama)", ) parser.add_argument( @@ -270,10 +279,18 @@ def parse_args() -> argparse.Namespace: args.input_dir = os.path.abspath(args.input_dir) # Inject storage configuration from environment variables - args.kv_storage = get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE) - args.doc_status_storage = get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE) - args.graph_storage = get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE) - args.vector_storage = get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE) + args.kv_storage = get_env_value( + "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE + ) + args.doc_status_storage = get_env_value( + "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE + ) + args.graph_storage = get_env_value( + "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE + ) + args.vector_storage = get_env_value( + "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE + ) # Get MAX_PARALLEL_INSERT from environment args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) @@ -289,8 +306,12 @@ def parse_args() -> argparse.Namespace: # Ollama ctx_num args.ollama_num_ctx = get_env_value("OLLAMA_NUM_CTX", 32768, int) - args.llm_binding_host = get_env_value("LLM_BINDING_HOST", get_default_host(args.llm_binding)) - args.embedding_binding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)) + args.llm_binding_host = get_env_value( + "LLM_BINDING_HOST", get_default_host(args.llm_binding) + ) + args.embedding_binding_host = get_env_value( + "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding) + ) args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None) args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") @@ -304,7 +325,9 @@ def parse_args() -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Inject LLM cache configuration - args.enable_llm_cache_for_extract = get_env_value("ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool) + args.enable_llm_cache_for_extract = get_env_value( + "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool + ) args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool) # Handle Ollama LLM temperature with priority cascade when llm-binding is ollama @@ -354,24 +377,40 @@ def parse_args() -> argparse.Namespace: args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) # Min rerank score configuration - args.min_rerank_score = get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float) + args.min_rerank_score = get_env_value( + "MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float + ) # Query configuration args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int) args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int) args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int) - args.max_entity_tokens = get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int) - args.max_relation_tokens = get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int) - args.max_total_tokens = get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int) - args.cosine_threshold = get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float) - args.related_chunk_number = get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int) + args.max_entity_tokens = get_env_value( + "MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int + ) + args.max_relation_tokens = get_env_value( + "MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int + ) + args.max_total_tokens = get_env_value( + "MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int + ) + args.cosine_threshold = get_env_value( + "COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float + ) + args.related_chunk_number = get_env_value( + "RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int + ) # Add missing environment variables for health endpoint args.force_llm_summary_on_merge = get_env_value( "FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int ) - args.embedding_func_max_async = get_env_value("EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int) - args.embedding_batch_num = get_env_value("EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int) + args.embedding_func_max_async = get_env_value( + "EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int + ) + args.embedding_batch_num = get_env_value( + "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int + ) ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag @@ -385,7 +424,9 @@ def update_uvicorn_mode_config(): original_workers = global_args.workers global_args.workers = 1 # Log warning directly here - logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1") + logging.warning( + f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" + ) global_args = parse_args() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8ee23826..92349861 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -130,7 +130,9 @@ def create_app(args): # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception("SSL certificate and key files must be provided when SSL is enabled") + raise Exception( + "SSL certificate and key files must be provided when SSL is enabled" + ) if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): @@ -189,7 +191,8 @@ def create_app(args): app_kwargs = { "title": "LightRAG Server API", "description": ( - "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)" + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + + "(With authentication)" if api_key else "" ), @@ -395,7 +398,9 @@ def create_app(args): if args.rerank_binding_api_key and args.rerank_binding_host: from lightrag.rerank import custom_rerank - async def server_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): + async def server_rerank_func( + query: str, documents: list, top_n: int = None, **kwargs + ): """Server rerank function with configuration from environment variables""" return await custom_rerank( query=query, @@ -408,7 +413,9 @@ def create_app(args): ) rerank_model_func = server_rerank_func - logger.info(f"Rerank model configured: {args.rerank_model} (can be enabled per query)") + logger.info( + f"Rerank model configured: {args.rerank_model} (can be enabled per query)" + ) else: logger.info( "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." @@ -417,7 +424,9 @@ def create_app(args): # Create ollama_server_infos from command line arguments from lightrag.api.config import OllamaServerInfos - ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag) + ollama_server_infos = OllamaServerInfos( + name=args.simulated_model_name, tag=args.simulated_model_tag + ) # Initialize RAG if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]: @@ -430,7 +439,9 @@ def create_app(args): else ( ollama_model_complete if args.llm_binding == "ollama" - else bedrock_model_complete if args.llm_binding == "aws_bedrock" else openai_alike_model_complete + else bedrock_model_complete + if args.llm_binding == "aws_bedrock" + else openai_alike_model_complete ) ), llm_model_name=args.llm_model, @@ -453,7 +464,9 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -480,7 +493,9 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -516,7 +531,9 @@ def create_app(args): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) + guest_token = auth_handler.create_token( + username="guest", role="guest", metadata={"auth_mode": "disabled"} + ) return { "auth_configured": False, "access_token": guest_token, @@ -542,7 +559,9 @@ def create_app(args): async def login(form_data: OAuth2PasswordRequestForm = Depends()): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) + guest_token = auth_handler.create_token( + username="guest", role="guest", metadata={"auth_mode": "disabled"} + ) return { "access_token": guest_token, "token_type": "bearer", @@ -555,10 +574,14 @@ def create_app(args): } username = form_data.username if auth_handler.accounts.get(username) != form_data.password: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials" + ) # Regular user login - user_token = auth_handler.create_token(username=username, role="user", metadata={"auth_mode": "enabled"}) + user_token = auth_handler.create_token( + username=username, role="user", metadata={"auth_mode": "enabled"} + ) return { "access_token": user_token, "token_type": "bearer", @@ -607,8 +630,12 @@ def create_app(args): "max_graph_nodes": args.max_graph_nodes, # Rerank configuration (based on whether rerank model is configured) "enable_rerank": rerank_model_func is not None, - "rerank_model": args.rerank_model if rerank_model_func is not None else None, - "rerank_binding_host": args.rerank_binding_host if rerank_model_func is not None else None, + "rerank_model": args.rerank_model + if rerank_model_func is not None + else None, + "rerank_binding_host": args.rerank_binding_host + if rerank_model_func is not None + else None, # Environment variable status (requested configuration) "summary_language": args.summary_language, "force_llm_summary_on_merge": args.force_llm_summary_on_merge, @@ -638,11 +665,17 @@ def create_app(args): response = await super().get_response(path, scope) if path.endswith(".html"): - response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" + response.headers["Cache-Control"] = ( + "no-cache, no-store, must-revalidate" + ) response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" - elif "/assets/" in path: # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename - response.headers["Cache-Control"] = "public, max-age=31536000, immutable" + elif ( + "/assets/" in path + ): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename + response.headers["Cache-Control"] = ( + "public, max-age=31536000, immutable" + ) # Add other rules here if needed for non-HTML, non-asset files # Ensure correct Content-Type @@ -658,7 +691,9 @@ def create_app(args): static_dir.mkdir(exist_ok=True) app.mount( "/webui", - SmartStaticFiles(directory=static_dir, html=True, check_dir=True), # Use SmartStaticFiles + SmartStaticFiles( + directory=static_dir, html=True, check_dir=True + ), # Use SmartStaticFiles name="webui", ) @@ -814,7 +849,9 @@ def main(): } ) - print(f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}") + print( + f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}" + ) uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 51c4b895..69d00e2d 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -100,10 +100,14 @@ async def bedrock_complete_if_cache( "top_p": "topP", "stop_sequences": "stopSequences", } - if inference_params := list(set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])): + if inference_params := list( + set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) + ): args["inferenceConfig"] = {} for param in inference_params: - args["inferenceConfig"][inference_params_map.get(param, param)] = kwargs.pop(param) + args["inferenceConfig"][inference_params_map.get(param, param)] = ( + kwargs.pop(param) + ) # Import logging for error handling import logging @@ -119,7 +123,9 @@ async def bedrock_complete_if_cache( nonlocal client # Create the client outside the generator to ensure it stays open - client = await session.client("bedrock-runtime", region_name=region).__aenter__() + client = await session.client( + "bedrock-runtime", region_name=region + ).__aenter__() event_stream = None iteration_started = False @@ -158,7 +164,9 @@ async def bedrock_complete_if_cache( try: await event_stream.aclose() except Exception as close_error: - logging.warning(f"Failed to close Bedrock event stream: {close_error}") + logging.warning( + f"Failed to close Bedrock event stream: {close_error}" + ) raise BedrockError(f"Streaming error: {e}") @@ -173,21 +181,27 @@ async def bedrock_complete_if_cache( try: await event_stream.aclose() except Exception as close_error: - logging.warning(f"Failed to close Bedrock event stream in finally block: {close_error}") + logging.warning( + f"Failed to close Bedrock event stream in finally block: {close_error}" + ) # Clean up the client if client: try: await client.__aexit__(None, None, None) except Exception as client_close_error: - logging.warning(f"Failed to close Bedrock client: {client_close_error}") + logging.warning( + f"Failed to close Bedrock client: {client_close_error}" + ) # Return the generator that manages its own lifecycle return stream_generator() # For non-streaming responses, use the standard async context manager pattern session = aioboto3.Session() - async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: + async with session.client( + "bedrock-runtime", region_name=region + ) as bedrock_async_client: try: # Use converse for non-streaming responses response = await bedrock_async_client.converse(**args, **kwargs) @@ -257,7 +271,9 @@ async def bedrock_embed( region = os.environ.get("AWS_REGION") session = aioboto3.Session() - async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: + async with session.client( + "bedrock-runtime", region_name=region + ) as bedrock_async_client: if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: @@ -285,7 +301,9 @@ async def bedrock_embed( embed_texts.append(response_body["embedding"]) elif model_provider == "cohere": - body = json.dumps({"texts": texts, "input_type": "search_document", "truncate": "NONE"}) + body = json.dumps( + {"texts": texts, "input_type": "search_document", "truncate": "NONE"} + ) response = await bedrock_async_client.invoke_model( model=model,