From 99643f01dec35a8a4d498fef07bcd138d8ffad22 Mon Sep 17 00:00:00 2001 From: SJ Date: Wed, 13 Aug 2025 02:08:13 -0500 Subject: [PATCH 1/4] Enhancement: support aws bedrock as an LLm binding #1733 --- lightrag/api/config.py | 72 +++-------- lightrag/api/lightrag_server.py | 222 +++++++++++++++++--------------- lightrag/llm/bedrock.py | 197 +++++++++++++++++++++++----- 3 files changed, 299 insertions(+), 192 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 56e506cc..0a16e89b 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -77,9 +77,7 @@ def parse_args() -> argparse.Namespace: argparse.Namespace: Parsed arguments """ - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) + parser = argparse.ArgumentParser(description="LightRAG FastAPI Server with separate working and input directories") # Server configuration parser.add_argument( @@ -209,14 +207,14 @@ def parse_args() -> argparse.Namespace: "--llm-binding", type=str, default=get_env_value("LLM_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"], + choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai", "aws_bedrock"], help="LLM binding type (default: from env or ollama)", ) parser.add_argument( "--embedding-binding", type=str, default=get_env_value("EMBEDDING_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "azure_openai"], + choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], help="Embedding binding type (default: from env or ollama)", ) @@ -272,18 +270,10 @@ def parse_args() -> argparse.Namespace: args.input_dir = os.path.abspath(args.input_dir) # Inject storage configuration from environment variables - args.kv_storage = get_env_value( - "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE - ) - args.doc_status_storage = get_env_value( - "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE - ) - args.graph_storage = get_env_value( - "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE - ) - args.vector_storage = get_env_value( - "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE - ) + args.kv_storage = get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE) + args.doc_status_storage = get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE) + args.graph_storage = get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE) + args.vector_storage = get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE) # Get MAX_PARALLEL_INSERT from environment args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) @@ -299,12 +289,8 @@ def parse_args() -> argparse.Namespace: # Ollama ctx_num args.ollama_num_ctx = get_env_value("OLLAMA_NUM_CTX", 32768, int) - args.llm_binding_host = get_env_value( - "LLM_BINDING_HOST", get_default_host(args.llm_binding) - ) - args.embedding_binding_host = get_env_value( - "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding) - ) + args.llm_binding_host = get_env_value("LLM_BINDING_HOST", get_default_host(args.llm_binding)) + args.embedding_binding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)) args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None) args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") @@ -318,9 +304,7 @@ def parse_args() -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Inject LLM cache configuration - args.enable_llm_cache_for_extract = get_env_value( - "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool - ) + args.enable_llm_cache_for_extract = get_env_value("ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool) args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool) # Handle Ollama LLM temperature with priority cascade when llm-binding is ollama @@ -370,40 +354,24 @@ def parse_args() -> argparse.Namespace: args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) # Min rerank score configuration - args.min_rerank_score = get_env_value( - "MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float - ) + args.min_rerank_score = get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float) # Query configuration args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int) args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int) args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int) - args.max_entity_tokens = get_env_value( - "MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int - ) - args.max_relation_tokens = get_env_value( - "MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int - ) - args.max_total_tokens = get_env_value( - "MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int - ) - args.cosine_threshold = get_env_value( - "COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float - ) - args.related_chunk_number = get_env_value( - "RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int - ) + args.max_entity_tokens = get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int) + args.max_relation_tokens = get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int) + args.max_total_tokens = get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int) + args.cosine_threshold = get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float) + args.related_chunk_number = get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int) # Add missing environment variables for health endpoint args.force_llm_summary_on_merge = get_env_value( "FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int ) - args.embedding_func_max_async = get_env_value( - "EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int - ) - args.embedding_batch_num = get_env_value( - "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int - ) + args.embedding_func_max_async = get_env_value("EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int) + args.embedding_batch_num = get_env_value("EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int) ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag @@ -417,9 +385,7 @@ def update_uvicorn_mode_config(): original_workers = global_args.workers global_args.workers = 1 # Log warning directly here - logging.warning( - f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" - ) + logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1") global_args = parse_args() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 699f59ac..8ee23826 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -106,6 +106,7 @@ def create_app(args): "openai", "openai-ollama", "azure_openai", + "aws_bedrock", ]: raise Exception("llm binding not supported") @@ -114,6 +115,7 @@ def create_app(args): "ollama", "openai", "azure_openai", + "aws_bedrock", "jina", ]: raise Exception("embedding binding not supported") @@ -128,9 +130,7 @@ def create_app(args): # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception( - "SSL certificate and key files must be provided when SSL is enabled" - ) + raise Exception("SSL certificate and key files must be provided when SSL is enabled") if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): @@ -188,10 +188,11 @@ def create_app(args): # Initialize FastAPI app_kwargs = { "title": "LightRAG Server API", - "description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation" - + "(With authentication)" - if api_key - else "", + "description": ( + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)" + if api_key + else "" + ), "version": __api_version__, "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL "docs_url": "/docs", # Explicitly set docs URL @@ -244,9 +245,9 @@ def create_app(args): azure_openai_complete_if_cache, azure_openai_embed, ) + if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock": + from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama": - from lightrag.llm.openai import openai_complete_if_cache - from lightrag.llm.ollama import ollama_embed from lightrag.llm.binding_options import OllamaEmbeddingOptions if args.embedding_binding == "jina": from lightrag.llm.jina import jina_embed @@ -312,41 +313,80 @@ def create_app(args): **kwargs, ) + async def bedrock_model_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, + ) -> str: + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + if history_messages is None: + history_messages = [] + + # Use global temperature for Bedrock + kwargs["temperature"] = args.temperature + + return await bedrock_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + embedding_func = EmbeddingFunc( embedding_dim=args.embedding_dim, - func=lambda texts: lollms_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "lollms" - else ollama_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - options=OllamaEmbeddingOptions.options_dict(args), - ) - if args.embedding_binding == "ollama" - else azure_openai_embed( - texts, - model=args.embedding_model, # no host is used for openai, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "azure_openai" - else jina_embed( - texts, - dimensions=args.embedding_dim, - base_url=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "jina" - else openai_embed( - texts, - model=args.embedding_model, - base_url=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, + func=lambda texts: ( + lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "lollms" + else ( + ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + options=OllamaEmbeddingOptions.options_dict(args), + ) + if args.embedding_binding == "ollama" + else ( + azure_openai_embed( + texts, + model=args.embedding_model, # no host is used for openai, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "azure_openai" + else ( + bedrock_embed( + texts, + model=args.embedding_model, + ) + if args.embedding_binding == "aws_bedrock" + else ( + jina_embed( + texts, + dimensions=args.embedding_dim, + base_url=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "jina" + else openai_embed( + texts, + model=args.embedding_model, + base_url=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + ) + ) + ) + ) ), ) @@ -355,9 +395,7 @@ def create_app(args): if args.rerank_binding_api_key and args.rerank_binding_host: from lightrag.rerank import custom_rerank - async def server_rerank_func( - query: str, documents: list, top_n: int = None, **kwargs - ): + async def server_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): """Server rerank function with configuration from environment variables""" return await custom_rerank( query=query, @@ -370,9 +408,7 @@ def create_app(args): ) rerank_model_func = server_rerank_func - logger.info( - f"Rerank model configured: {args.rerank_model} (can be enabled per query)" - ) + logger.info(f"Rerank model configured: {args.rerank_model} (can be enabled per query)") else: logger.info( "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." @@ -381,41 +417,43 @@ def create_app(args): # Create ollama_server_infos from command line arguments from lightrag.api.config import OllamaServerInfos - ollama_server_infos = OllamaServerInfos( - name=args.simulated_model_name, tag=args.simulated_model_tag - ) + ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag) # Initialize RAG - if args.llm_binding in ["lollms", "ollama", "openai"]: + if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]: rag = LightRAG( working_dir=args.working_dir, workspace=args.workspace, - llm_model_func=lollms_model_complete - if args.llm_binding == "lollms" - else ollama_model_complete - if args.llm_binding == "ollama" - else openai_alike_model_complete, + llm_model_func=( + lollms_model_complete + if args.llm_binding == "lollms" + else ( + ollama_model_complete + if args.llm_binding == "ollama" + else bedrock_model_complete if args.llm_binding == "aws_bedrock" else openai_alike_model_complete + ) + ), llm_model_name=args.llm_model, llm_model_max_async=args.max_async, summary_max_tokens=args.max_tokens, chunk_token_size=int(args.chunk_size), chunk_overlap_token_size=int(args.chunk_overlap_size), - llm_model_kwargs={ - "host": args.llm_binding_host, - "timeout": args.timeout, - "options": OllamaLLMOptions.options_dict(args), - "api_key": args.llm_binding_api_key, - } - if args.llm_binding == "lollms" or args.llm_binding == "ollama" - else {}, + llm_model_kwargs=( + { + "host": args.llm_binding_host, + "timeout": args.timeout, + "options": OllamaLLMOptions.options_dict(args), + "api_key": args.llm_binding_api_key, + } + if args.llm_binding == "lollms" or args.llm_binding == "ollama" + else {} + ), embedding_func=embedding_func, kv_storage=args.kv_storage, graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": args.cosine_threshold - }, + vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -442,9 +480,7 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": args.cosine_threshold - }, + vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -480,9 +516,7 @@ def create_app(args): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token( - username="guest", role="guest", metadata={"auth_mode": "disabled"} - ) + guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) return { "auth_configured": False, "access_token": guest_token, @@ -508,9 +542,7 @@ def create_app(args): async def login(form_data: OAuth2PasswordRequestForm = Depends()): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token( - username="guest", role="guest", metadata={"auth_mode": "disabled"} - ) + guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) return { "access_token": guest_token, "token_type": "bearer", @@ -523,14 +555,10 @@ def create_app(args): } username = form_data.username if auth_handler.accounts.get(username) != form_data.password: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials" - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials") # Regular user login - user_token = auth_handler.create_token( - username=username, role="user", metadata={"auth_mode": "enabled"} - ) + user_token = auth_handler.create_token(username=username, role="user", metadata={"auth_mode": "enabled"}) return { "access_token": user_token, "token_type": "bearer", @@ -579,12 +607,8 @@ def create_app(args): "max_graph_nodes": args.max_graph_nodes, # Rerank configuration (based on whether rerank model is configured) "enable_rerank": rerank_model_func is not None, - "rerank_model": args.rerank_model - if rerank_model_func is not None - else None, - "rerank_binding_host": args.rerank_binding_host - if rerank_model_func is not None - else None, + "rerank_model": args.rerank_model if rerank_model_func is not None else None, + "rerank_binding_host": args.rerank_binding_host if rerank_model_func is not None else None, # Environment variable status (requested configuration) "summary_language": args.summary_language, "force_llm_summary_on_merge": args.force_llm_summary_on_merge, @@ -614,17 +638,11 @@ def create_app(args): response = await super().get_response(path, scope) if path.endswith(".html"): - response.headers["Cache-Control"] = ( - "no-cache, no-store, must-revalidate" - ) + response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" - elif ( - "/assets/" in path - ): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename - response.headers["Cache-Control"] = ( - "public, max-age=31536000, immutable" - ) + elif "/assets/" in path: # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename + response.headers["Cache-Control"] = "public, max-age=31536000, immutable" # Add other rules here if needed for non-HTML, non-asset files # Ensure correct Content-Type @@ -640,9 +658,7 @@ def create_app(args): static_dir.mkdir(exist_ok=True) app.mount( "/webui", - SmartStaticFiles( - directory=static_dir, html=True, check_dir=True - ), # Use SmartStaticFiles + SmartStaticFiles(directory=static_dir, html=True, check_dir=True), # Use SmartStaticFiles name="webui", ) @@ -798,9 +814,7 @@ def main(): } ) - print( - f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}" - ) + print(f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}") uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 1640abbb..51c4b895 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -15,11 +15,25 @@ from tenacity import ( retry_if_exception_type, ) +import sys + +if sys.version_info < (3, 9): + from typing import AsyncIterator +else: + from collections.abc import AsyncIterator +from typing import Union + class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" +def _set_env_if_present(key: str, value): + """Set environment variable only if a non-empty value is provided.""" + if value is not None and value != "": + os.environ[key] = value + + @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=60), @@ -34,17 +48,35 @@ async def bedrock_complete_if_cache( aws_secret_access_key=None, aws_session_token=None, **kwargs, -) -> str: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "AWS_ACCESS_KEY_ID", aws_access_key_id - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "AWS_SECRET_ACCESS_KEY", aws_secret_access_key - ) - os.environ["AWS_SESSION_TOKEN"] = os.environ.get( - "AWS_SESSION_TOKEN", aws_session_token - ) +) -> Union[str, AsyncIterator[str]]: + # Respect existing env; only set if a non-empty value is available + access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key + session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token + _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) + _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) + _set_env_if_present("AWS_SESSION_TOKEN", session_token) + # Region handling: prefer env, else kwarg (optional) + region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None) kwargs.pop("hashing_kv", None) + # Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter + # We'll use this to determine whether to call converse_stream or converse + stream = bool(kwargs.pop("stream", False)) + # Remove unsupported args for Bedrock Converse API + for k in [ + "response_format", + "tools", + "tool_choice", + "seed", + "presence_penalty", + "frequency_penalty", + "n", + "logprobs", + "top_logprobs", + "max_completion_tokens", + "response_format", + ]: + kwargs.pop(k, None) # Fix message history format messages = [] for history_message in history_messages: @@ -68,30 +100,126 @@ async def bedrock_complete_if_cache( "top_p": "topP", "stop_sequences": "stopSequences", } - if inference_params := list( - set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) - ): + if inference_params := list(set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])): args["inferenceConfig"] = {} for param in inference_params: - args["inferenceConfig"][inference_params_map.get(param, param)] = ( - kwargs.pop(param) - ) + args["inferenceConfig"][inference_params_map.get(param, param)] = kwargs.pop(param) - # Call model via Converse API + # Import logging for error handling + import logging + + # For streaming responses, we need a different approach to keep the connection open + if stream: + # Create a session that will be used throughout the streaming process + session = aioboto3.Session() + client = None + + # Define the generator function that will manage the client lifecycle + async def stream_generator(): + nonlocal client + + # Create the client outside the generator to ensure it stays open + client = await session.client("bedrock-runtime", region_name=region).__aenter__() + event_stream = None + iteration_started = False + + try: + # Make the API call + response = await client.converse_stream(**args, **kwargs) + event_stream = response.get("stream") + iteration_started = True + + # Process the stream + async for event in event_stream: + # Validate event structure + if not event or not isinstance(event, dict): + continue + + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"].get("delta", {}) + text = delta.get("text") + if text: + yield text + # Handle other event types that might indicate stream end + elif "messageStop" in event: + break + + except Exception as e: + # Log the specific error for debugging + logging.error(f"Bedrock streaming error: {e}") + + # Try to clean up resources if possible + if ( + iteration_started + and event_stream + and hasattr(event_stream, "aclose") + and callable(getattr(event_stream, "aclose", None)) + ): + try: + await event_stream.aclose() + except Exception as close_error: + logging.warning(f"Failed to close Bedrock event stream: {close_error}") + + raise BedrockError(f"Streaming error: {e}") + + finally: + # Clean up the event stream + if ( + iteration_started + and event_stream + and hasattr(event_stream, "aclose") + and callable(getattr(event_stream, "aclose", None)) + ): + try: + await event_stream.aclose() + except Exception as close_error: + logging.warning(f"Failed to close Bedrock event stream in finally block: {close_error}") + + # Clean up the client + if client: + try: + await client.__aexit__(None, None, None) + except Exception as client_close_error: + logging.warning(f"Failed to close Bedrock client: {client_close_error}") + + # Return the generator that manages its own lifecycle + return stream_generator() + + # For non-streaming responses, use the standard async context manager pattern session = aioboto3.Session() - async with session.client("bedrock-runtime") as bedrock_async_client: + async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: try: + # Use converse for non-streaming responses response = await bedrock_async_client.converse(**args, **kwargs) - except Exception as e: - raise BedrockError(e) - return response["output"]["message"]["content"][0]["text"] + # Validate response structure + if ( + not response + or "output" not in response + or "message" not in response["output"] + or "content" not in response["output"]["message"] + or not response["output"]["message"]["content"] + ): + raise BedrockError("Invalid response structure from Bedrock API") + + content = response["output"]["message"]["content"][0]["text"] + + if not content or content.strip() == "": + raise BedrockError("Received empty content from Bedrock API") + + return content + + except Exception as e: + if isinstance(e, BedrockError): + raise + else: + raise BedrockError(f"Bedrock API error: {e}") # Generic Bedrock completion function async def bedrock_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: +) -> Union[str, AsyncIterator[str]]: kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] result = await bedrock_complete_if_cache( @@ -117,18 +245,19 @@ async def bedrock_embed( aws_secret_access_key=None, aws_session_token=None, ) -> np.ndarray: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "AWS_ACCESS_KEY_ID", aws_access_key_id - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "AWS_SECRET_ACCESS_KEY", aws_secret_access_key - ) - os.environ["AWS_SESSION_TOKEN"] = os.environ.get( - "AWS_SESSION_TOKEN", aws_session_token - ) + # Respect existing env; only set if a non-empty value is available + access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key + session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token + _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) + _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) + _set_env_if_present("AWS_SESSION_TOKEN", session_token) + + # Region handling: prefer env + region = os.environ.get("AWS_REGION") session = aioboto3.Session() - async with session.client("bedrock-runtime") as bedrock_async_client: + async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: @@ -156,9 +285,7 @@ async def bedrock_embed( embed_texts.append(response_body["embedding"]) elif model_provider == "cohere": - body = json.dumps( - {"texts": texts, "input_type": "search_document", "truncate": "NONE"} - ) + body = json.dumps({"texts": texts, "input_type": "search_document", "truncate": "NONE"}) response = await bedrock_async_client.invoke_model( model=model, From f7ca9ae16a26a81d9b4b188f7a4a1e071b6419a1 Mon Sep 17 00:00:00 2001 From: SJ Date: Fri, 15 Aug 2025 22:21:34 +0000 Subject: [PATCH 2/4] Ruff formatted --- lightrag/api/config.py | 77 +++++++++++++++++++++++++-------- lightrag/api/lightrag_server.py | 75 ++++++++++++++++++++++++-------- lightrag/llm/bedrock.py | 36 +++++++++++---- 3 files changed, 142 insertions(+), 46 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 0a16e89b..01d0dd75 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -77,7 +77,9 @@ def parse_args() -> argparse.Namespace: argparse.Namespace: Parsed arguments """ - parser = argparse.ArgumentParser(description="LightRAG FastAPI Server with separate working and input directories") + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) # Server configuration parser.add_argument( @@ -207,7 +209,14 @@ def parse_args() -> argparse.Namespace: "--llm-binding", type=str, default=get_env_value("LLM_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai", "aws_bedrock"], + choices=[ + "lollms", + "ollama", + "openai", + "openai-ollama", + "azure_openai", + "aws_bedrock", + ], help="LLM binding type (default: from env or ollama)", ) parser.add_argument( @@ -270,10 +279,18 @@ def parse_args() -> argparse.Namespace: args.input_dir = os.path.abspath(args.input_dir) # Inject storage configuration from environment variables - args.kv_storage = get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE) - args.doc_status_storage = get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE) - args.graph_storage = get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE) - args.vector_storage = get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE) + args.kv_storage = get_env_value( + "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE + ) + args.doc_status_storage = get_env_value( + "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE + ) + args.graph_storage = get_env_value( + "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE + ) + args.vector_storage = get_env_value( + "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE + ) # Get MAX_PARALLEL_INSERT from environment args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) @@ -289,8 +306,12 @@ def parse_args() -> argparse.Namespace: # Ollama ctx_num args.ollama_num_ctx = get_env_value("OLLAMA_NUM_CTX", 32768, int) - args.llm_binding_host = get_env_value("LLM_BINDING_HOST", get_default_host(args.llm_binding)) - args.embedding_binding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)) + args.llm_binding_host = get_env_value( + "LLM_BINDING_HOST", get_default_host(args.llm_binding) + ) + args.embedding_binding_host = get_env_value( + "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding) + ) args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None) args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") @@ -304,7 +325,9 @@ def parse_args() -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Inject LLM cache configuration - args.enable_llm_cache_for_extract = get_env_value("ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool) + args.enable_llm_cache_for_extract = get_env_value( + "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool + ) args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool) # Handle Ollama LLM temperature with priority cascade when llm-binding is ollama @@ -354,24 +377,40 @@ def parse_args() -> argparse.Namespace: args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) # Min rerank score configuration - args.min_rerank_score = get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float) + args.min_rerank_score = get_env_value( + "MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float + ) # Query configuration args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int) args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int) args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int) - args.max_entity_tokens = get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int) - args.max_relation_tokens = get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int) - args.max_total_tokens = get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int) - args.cosine_threshold = get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float) - args.related_chunk_number = get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int) + args.max_entity_tokens = get_env_value( + "MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int + ) + args.max_relation_tokens = get_env_value( + "MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int + ) + args.max_total_tokens = get_env_value( + "MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int + ) + args.cosine_threshold = get_env_value( + "COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float + ) + args.related_chunk_number = get_env_value( + "RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int + ) # Add missing environment variables for health endpoint args.force_llm_summary_on_merge = get_env_value( "FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int ) - args.embedding_func_max_async = get_env_value("EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int) - args.embedding_batch_num = get_env_value("EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int) + args.embedding_func_max_async = get_env_value( + "EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int + ) + args.embedding_batch_num = get_env_value( + "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int + ) ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag @@ -385,7 +424,9 @@ def update_uvicorn_mode_config(): original_workers = global_args.workers global_args.workers = 1 # Log warning directly here - logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1") + logging.warning( + f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" + ) global_args = parse_args() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8ee23826..92349861 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -130,7 +130,9 @@ def create_app(args): # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception("SSL certificate and key files must be provided when SSL is enabled") + raise Exception( + "SSL certificate and key files must be provided when SSL is enabled" + ) if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): @@ -189,7 +191,8 @@ def create_app(args): app_kwargs = { "title": "LightRAG Server API", "description": ( - "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)" + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + + "(With authentication)" if api_key else "" ), @@ -395,7 +398,9 @@ def create_app(args): if args.rerank_binding_api_key and args.rerank_binding_host: from lightrag.rerank import custom_rerank - async def server_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): + async def server_rerank_func( + query: str, documents: list, top_n: int = None, **kwargs + ): """Server rerank function with configuration from environment variables""" return await custom_rerank( query=query, @@ -408,7 +413,9 @@ def create_app(args): ) rerank_model_func = server_rerank_func - logger.info(f"Rerank model configured: {args.rerank_model} (can be enabled per query)") + logger.info( + f"Rerank model configured: {args.rerank_model} (can be enabled per query)" + ) else: logger.info( "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." @@ -417,7 +424,9 @@ def create_app(args): # Create ollama_server_infos from command line arguments from lightrag.api.config import OllamaServerInfos - ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag) + ollama_server_infos = OllamaServerInfos( + name=args.simulated_model_name, tag=args.simulated_model_tag + ) # Initialize RAG if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]: @@ -430,7 +439,9 @@ def create_app(args): else ( ollama_model_complete if args.llm_binding == "ollama" - else bedrock_model_complete if args.llm_binding == "aws_bedrock" else openai_alike_model_complete + else bedrock_model_complete + if args.llm_binding == "aws_bedrock" + else openai_alike_model_complete ) ), llm_model_name=args.llm_model, @@ -453,7 +464,9 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -480,7 +493,9 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -516,7 +531,9 @@ def create_app(args): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) + guest_token = auth_handler.create_token( + username="guest", role="guest", metadata={"auth_mode": "disabled"} + ) return { "auth_configured": False, "access_token": guest_token, @@ -542,7 +559,9 @@ def create_app(args): async def login(form_data: OAuth2PasswordRequestForm = Depends()): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) + guest_token = auth_handler.create_token( + username="guest", role="guest", metadata={"auth_mode": "disabled"} + ) return { "access_token": guest_token, "token_type": "bearer", @@ -555,10 +574,14 @@ def create_app(args): } username = form_data.username if auth_handler.accounts.get(username) != form_data.password: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials" + ) # Regular user login - user_token = auth_handler.create_token(username=username, role="user", metadata={"auth_mode": "enabled"}) + user_token = auth_handler.create_token( + username=username, role="user", metadata={"auth_mode": "enabled"} + ) return { "access_token": user_token, "token_type": "bearer", @@ -607,8 +630,12 @@ def create_app(args): "max_graph_nodes": args.max_graph_nodes, # Rerank configuration (based on whether rerank model is configured) "enable_rerank": rerank_model_func is not None, - "rerank_model": args.rerank_model if rerank_model_func is not None else None, - "rerank_binding_host": args.rerank_binding_host if rerank_model_func is not None else None, + "rerank_model": args.rerank_model + if rerank_model_func is not None + else None, + "rerank_binding_host": args.rerank_binding_host + if rerank_model_func is not None + else None, # Environment variable status (requested configuration) "summary_language": args.summary_language, "force_llm_summary_on_merge": args.force_llm_summary_on_merge, @@ -638,11 +665,17 @@ def create_app(args): response = await super().get_response(path, scope) if path.endswith(".html"): - response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" + response.headers["Cache-Control"] = ( + "no-cache, no-store, must-revalidate" + ) response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" - elif "/assets/" in path: # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename - response.headers["Cache-Control"] = "public, max-age=31536000, immutable" + elif ( + "/assets/" in path + ): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename + response.headers["Cache-Control"] = ( + "public, max-age=31536000, immutable" + ) # Add other rules here if needed for non-HTML, non-asset files # Ensure correct Content-Type @@ -658,7 +691,9 @@ def create_app(args): static_dir.mkdir(exist_ok=True) app.mount( "/webui", - SmartStaticFiles(directory=static_dir, html=True, check_dir=True), # Use SmartStaticFiles + SmartStaticFiles( + directory=static_dir, html=True, check_dir=True + ), # Use SmartStaticFiles name="webui", ) @@ -814,7 +849,9 @@ def main(): } ) - print(f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}") + print( + f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}" + ) uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 51c4b895..69d00e2d 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -100,10 +100,14 @@ async def bedrock_complete_if_cache( "top_p": "topP", "stop_sequences": "stopSequences", } - if inference_params := list(set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])): + if inference_params := list( + set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) + ): args["inferenceConfig"] = {} for param in inference_params: - args["inferenceConfig"][inference_params_map.get(param, param)] = kwargs.pop(param) + args["inferenceConfig"][inference_params_map.get(param, param)] = ( + kwargs.pop(param) + ) # Import logging for error handling import logging @@ -119,7 +123,9 @@ async def bedrock_complete_if_cache( nonlocal client # Create the client outside the generator to ensure it stays open - client = await session.client("bedrock-runtime", region_name=region).__aenter__() + client = await session.client( + "bedrock-runtime", region_name=region + ).__aenter__() event_stream = None iteration_started = False @@ -158,7 +164,9 @@ async def bedrock_complete_if_cache( try: await event_stream.aclose() except Exception as close_error: - logging.warning(f"Failed to close Bedrock event stream: {close_error}") + logging.warning( + f"Failed to close Bedrock event stream: {close_error}" + ) raise BedrockError(f"Streaming error: {e}") @@ -173,21 +181,27 @@ async def bedrock_complete_if_cache( try: await event_stream.aclose() except Exception as close_error: - logging.warning(f"Failed to close Bedrock event stream in finally block: {close_error}") + logging.warning( + f"Failed to close Bedrock event stream in finally block: {close_error}" + ) # Clean up the client if client: try: await client.__aexit__(None, None, None) except Exception as client_close_error: - logging.warning(f"Failed to close Bedrock client: {client_close_error}") + logging.warning( + f"Failed to close Bedrock client: {client_close_error}" + ) # Return the generator that manages its own lifecycle return stream_generator() # For non-streaming responses, use the standard async context manager pattern session = aioboto3.Session() - async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: + async with session.client( + "bedrock-runtime", region_name=region + ) as bedrock_async_client: try: # Use converse for non-streaming responses response = await bedrock_async_client.converse(**args, **kwargs) @@ -257,7 +271,9 @@ async def bedrock_embed( region = os.environ.get("AWS_REGION") session = aioboto3.Session() - async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: + async with session.client( + "bedrock-runtime", region_name=region + ) as bedrock_async_client: if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: @@ -285,7 +301,9 @@ async def bedrock_embed( embed_texts.append(response_body["embedding"]) elif model_provider == "cohere": - body = json.dumps({"texts": texts, "input_type": "search_document", "truncate": "NONE"}) + body = json.dumps( + {"texts": texts, "input_type": "search_document", "truncate": "NONE"} + ) response = await bedrock_async_client.invoke_model( model=model, From 1ed77a2e53485d08b3b50086582f0316915682cc Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 17 Aug 2025 02:13:50 +0800 Subject: [PATCH 3/4] Remove openai-ollama binding from LightRAG level args --- lightrag/api/lightrag_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 92349861..c3384181 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -104,7 +104,6 @@ def create_app(args): "lollms", "ollama", "openai", - "openai-ollama", "azure_openai", "aws_bedrock", ]: @@ -250,7 +249,7 @@ def create_app(args): ) if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed - if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama": + if args.embedding_binding == "ollama": from lightrag.llm.binding_options import OllamaEmbeddingOptions if args.embedding_binding == "jina": from lightrag.llm.jina import jina_embed From da7e4b79e515bc4d3048c52a3c366e9a2a4e2217 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 17 Aug 2025 02:23:14 +0800 Subject: [PATCH 4/4] Update documentation in README files --- env.example | 2 +- lightrag/api/README-zh.md | 2 ++ lightrag/api/README.md | 6 ++++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/env.example b/env.example index 187a7064..e16a17fb 100644 --- a/env.example +++ b/env.example @@ -123,7 +123,7 @@ MAX_PARALLEL_INSERT=2 ########################################################### ### LLM Configuration -### LLM_BINDING type: openai, ollama, lollms, azure_openai +### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock ########################################################### ### LLM temperature setting for all llm binding (openai, azure_openai, ollama) # TEMPERATURE=1.0 diff --git a/lightrag/api/README-zh.md b/lightrag/api/README-zh.md index 6fe5f86c..f5000a29 100644 --- a/lightrag/api/README-zh.md +++ b/lightrag/api/README-zh.md @@ -40,6 +40,7 @@ LightRAG 需要同时集成 LLM(大型语言模型)和嵌入模型以有效 * lollms * openai 或 openai 兼容 * azure_openai +* aws_bedrock 建议使用环境变量来配置 LightRAG 服务器。项目根目录中有一个名为 `env.example` 的示例环境变量文件。请将此文件复制到启动目录并重命名为 `.env`。之后,您可以在 `.env` 文件中修改与 LLM 和嵌入模型相关的参数。需要注意的是,LightRAG 服务器每次启动时都会将 `.env` 中的环境变量加载到系统环境变量中。**LightRAG 服务器会优先使用系统环境变量中的设置**。 @@ -357,6 +358,7 @@ LightRAG 支持绑定到各种 LLM/嵌入后端: * openai 和 openai 兼容 * azure_openai * lollms +* aws_bedrock 使用环境变量 `LLM_BINDING` 或 CLI 参数 `--llm-binding` 选择 LLM 后端类型。使用环境变量 `EMBEDDING_BINDING` 或 CLI 参数 `--embedding-binding` 选择嵌入后端类型。 diff --git a/lightrag/api/README.md b/lightrag/api/README.md index ce27baff..18396fb6 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -40,6 +40,7 @@ LightRAG necessitates the integration of both an LLM (Large Language Model) and * lollms * openai or openai compatible * azure_openai +* aws_bedrock It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**. @@ -360,6 +361,7 @@ LightRAG supports binding to various LLM/Embedding backends: * openai & openai compatible * azure_openai * lollms +* aws_bedrock Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select the LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select the Embedding backend type. @@ -459,8 +461,8 @@ You cannot change storage implementation selection after adding documents to Lig | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | | --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | | --cosine-threshold | 0.4 | The cosine threshold for nodes and relation retrieval, works with top-k to control the retrieval of nodes and relations. | -| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) | -| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) | +| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai, aws_bedrock) | +| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai, aws_bedrock) | | --auto-scan-at-startup| - | Scan input directory for new files and start indexing | ### Additional Ollama Binding Options