diff --git a/env.example b/env.example index 187a7064..e16a17fb 100644 --- a/env.example +++ b/env.example @@ -123,7 +123,7 @@ MAX_PARALLEL_INSERT=2 ########################################################### ### LLM Configuration -### LLM_BINDING type: openai, ollama, lollms, azure_openai +### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock ########################################################### ### LLM temperature setting for all llm binding (openai, azure_openai, ollama) # TEMPERATURE=1.0 diff --git a/lightrag/api/README-zh.md b/lightrag/api/README-zh.md index 174b4538..b74e4d12 100644 --- a/lightrag/api/README-zh.md +++ b/lightrag/api/README-zh.md @@ -40,6 +40,7 @@ LightRAG 需要同时集成 LLM(大型语言模型)和嵌入模型以有效 * lollms * openai 或 openai 兼容 * azure_openai +* aws_bedrock 建议使用环境变量来配置 LightRAG 服务器。项目根目录中有一个名为 `env.example` 的示例环境变量文件。请将此文件复制到启动目录并重命名为 `.env`。之后,您可以在 `.env` 文件中修改与 LLM 和嵌入模型相关的参数。需要注意的是,LightRAG 服务器每次启动时都会将 `.env` 中的环境变量加载到系统环境变量中。**LightRAG 服务器会优先使用系统环境变量中的设置**。 @@ -359,6 +360,7 @@ LightRAG 支持绑定到各种 LLM/嵌入后端: * openai 和 openai 兼容 * azure_openai * lollms +* aws_bedrock 使用环境变量 `LLM_BINDING` 或 CLI 参数 `--llm-binding` 选择 LLM 后端类型。使用环境变量 `EMBEDDING_BINDING` 或 CLI 参数 `--embedding-binding` 选择嵌入后端类型。 diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 48d2c011..da59b38f 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -40,6 +40,7 @@ LightRAG necessitates the integration of both an LLM (Large Language Model) and * lollms * openai or openai compatible * azure_openai +* aws_bedrock It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**. @@ -362,6 +363,7 @@ LightRAG supports binding to various LLM/Embedding backends: * openai & openai compatible * azure_openai * lollms +* aws_bedrock Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select the LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select the Embedding backend type. @@ -461,8 +463,8 @@ You cannot change storage implementation selection after adding documents to Lig | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | | --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | | --cosine-threshold | 0.4 | The cosine threshold for nodes and relation retrieval, works with top-k to control the retrieval of nodes and relations. | -| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) | -| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) | +| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai, aws_bedrock) | +| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai, aws_bedrock) | | --auto-scan-at-startup| - | Scan input directory for new files and start indexing | ### Additional Ollama Binding Options diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 56e506cc..01d0dd75 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -209,14 +209,21 @@ def parse_args() -> argparse.Namespace: "--llm-binding", type=str, default=get_env_value("LLM_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"], + choices=[ + "lollms", + "ollama", + "openai", + "openai-ollama", + "azure_openai", + "aws_bedrock", + ], help="LLM binding type (default: from env or ollama)", ) parser.add_argument( "--embedding-binding", type=str, default=get_env_value("EMBEDDING_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "azure_openai"], + choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], help="Embedding binding type (default: from env or ollama)", ) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 699f59ac..c3384181 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -104,8 +104,8 @@ def create_app(args): "lollms", "ollama", "openai", - "openai-ollama", "azure_openai", + "aws_bedrock", ]: raise Exception("llm binding not supported") @@ -114,6 +114,7 @@ def create_app(args): "ollama", "openai", "azure_openai", + "aws_bedrock", "jina", ]: raise Exception("embedding binding not supported") @@ -188,10 +189,12 @@ def create_app(args): # Initialize FastAPI app_kwargs = { "title": "LightRAG Server API", - "description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation" - + "(With authentication)" - if api_key - else "", + "description": ( + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + + "(With authentication)" + if api_key + else "" + ), "version": __api_version__, "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL "docs_url": "/docs", # Explicitly set docs URL @@ -244,9 +247,9 @@ def create_app(args): azure_openai_complete_if_cache, azure_openai_embed, ) - if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama": - from lightrag.llm.openai import openai_complete_if_cache - from lightrag.llm.ollama import ollama_embed + if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock": + from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed + if args.embedding_binding == "ollama": from lightrag.llm.binding_options import OllamaEmbeddingOptions if args.embedding_binding == "jina": from lightrag.llm.jina import jina_embed @@ -312,41 +315,80 @@ def create_app(args): **kwargs, ) + async def bedrock_model_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, + ) -> str: + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + if history_messages is None: + history_messages = [] + + # Use global temperature for Bedrock + kwargs["temperature"] = args.temperature + + return await bedrock_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + embedding_func = EmbeddingFunc( embedding_dim=args.embedding_dim, - func=lambda texts: lollms_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "lollms" - else ollama_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - options=OllamaEmbeddingOptions.options_dict(args), - ) - if args.embedding_binding == "ollama" - else azure_openai_embed( - texts, - model=args.embedding_model, # no host is used for openai, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "azure_openai" - else jina_embed( - texts, - dimensions=args.embedding_dim, - base_url=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "jina" - else openai_embed( - texts, - model=args.embedding_model, - base_url=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, + func=lambda texts: ( + lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "lollms" + else ( + ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + options=OllamaEmbeddingOptions.options_dict(args), + ) + if args.embedding_binding == "ollama" + else ( + azure_openai_embed( + texts, + model=args.embedding_model, # no host is used for openai, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "azure_openai" + else ( + bedrock_embed( + texts, + model=args.embedding_model, + ) + if args.embedding_binding == "aws_bedrock" + else ( + jina_embed( + texts, + dimensions=args.embedding_dim, + base_url=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "jina" + else openai_embed( + texts, + model=args.embedding_model, + base_url=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + ) + ) + ) + ) ), ) @@ -386,28 +428,36 @@ def create_app(args): ) # Initialize RAG - if args.llm_binding in ["lollms", "ollama", "openai"]: + if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]: rag = LightRAG( working_dir=args.working_dir, workspace=args.workspace, - llm_model_func=lollms_model_complete - if args.llm_binding == "lollms" - else ollama_model_complete - if args.llm_binding == "ollama" - else openai_alike_model_complete, + llm_model_func=( + lollms_model_complete + if args.llm_binding == "lollms" + else ( + ollama_model_complete + if args.llm_binding == "ollama" + else bedrock_model_complete + if args.llm_binding == "aws_bedrock" + else openai_alike_model_complete + ) + ), llm_model_name=args.llm_model, llm_model_max_async=args.max_async, summary_max_tokens=args.max_tokens, chunk_token_size=int(args.chunk_size), chunk_overlap_token_size=int(args.chunk_overlap_size), - llm_model_kwargs={ - "host": args.llm_binding_host, - "timeout": args.timeout, - "options": OllamaLLMOptions.options_dict(args), - "api_key": args.llm_binding_api_key, - } - if args.llm_binding == "lollms" or args.llm_binding == "ollama" - else {}, + llm_model_kwargs=( + { + "host": args.llm_binding_host, + "timeout": args.timeout, + "options": OllamaLLMOptions.options_dict(args), + "api_key": args.llm_binding_api_key, + } + if args.llm_binding == "lollms" or args.llm_binding == "ollama" + else {} + ), embedding_func=embedding_func, kv_storage=args.kv_storage, graph_storage=args.graph_storage, diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 1640abbb..69d00e2d 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -15,11 +15,25 @@ from tenacity import ( retry_if_exception_type, ) +import sys + +if sys.version_info < (3, 9): + from typing import AsyncIterator +else: + from collections.abc import AsyncIterator +from typing import Union + class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" +def _set_env_if_present(key: str, value): + """Set environment variable only if a non-empty value is provided.""" + if value is not None and value != "": + os.environ[key] = value + + @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=60), @@ -34,17 +48,35 @@ async def bedrock_complete_if_cache( aws_secret_access_key=None, aws_session_token=None, **kwargs, -) -> str: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "AWS_ACCESS_KEY_ID", aws_access_key_id - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "AWS_SECRET_ACCESS_KEY", aws_secret_access_key - ) - os.environ["AWS_SESSION_TOKEN"] = os.environ.get( - "AWS_SESSION_TOKEN", aws_session_token - ) +) -> Union[str, AsyncIterator[str]]: + # Respect existing env; only set if a non-empty value is available + access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key + session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token + _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) + _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) + _set_env_if_present("AWS_SESSION_TOKEN", session_token) + # Region handling: prefer env, else kwarg (optional) + region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None) kwargs.pop("hashing_kv", None) + # Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter + # We'll use this to determine whether to call converse_stream or converse + stream = bool(kwargs.pop("stream", False)) + # Remove unsupported args for Bedrock Converse API + for k in [ + "response_format", + "tools", + "tool_choice", + "seed", + "presence_penalty", + "frequency_penalty", + "n", + "logprobs", + "top_logprobs", + "max_completion_tokens", + "response_format", + ]: + kwargs.pop(k, None) # Fix message history format messages = [] for history_message in history_messages: @@ -77,21 +109,131 @@ async def bedrock_complete_if_cache( kwargs.pop(param) ) - # Call model via Converse API - session = aioboto3.Session() - async with session.client("bedrock-runtime") as bedrock_async_client: - try: - response = await bedrock_async_client.converse(**args, **kwargs) - except Exception as e: - raise BedrockError(e) + # Import logging for error handling + import logging - return response["output"]["message"]["content"][0]["text"] + # For streaming responses, we need a different approach to keep the connection open + if stream: + # Create a session that will be used throughout the streaming process + session = aioboto3.Session() + client = None + + # Define the generator function that will manage the client lifecycle + async def stream_generator(): + nonlocal client + + # Create the client outside the generator to ensure it stays open + client = await session.client( + "bedrock-runtime", region_name=region + ).__aenter__() + event_stream = None + iteration_started = False + + try: + # Make the API call + response = await client.converse_stream(**args, **kwargs) + event_stream = response.get("stream") + iteration_started = True + + # Process the stream + async for event in event_stream: + # Validate event structure + if not event or not isinstance(event, dict): + continue + + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"].get("delta", {}) + text = delta.get("text") + if text: + yield text + # Handle other event types that might indicate stream end + elif "messageStop" in event: + break + + except Exception as e: + # Log the specific error for debugging + logging.error(f"Bedrock streaming error: {e}") + + # Try to clean up resources if possible + if ( + iteration_started + and event_stream + and hasattr(event_stream, "aclose") + and callable(getattr(event_stream, "aclose", None)) + ): + try: + await event_stream.aclose() + except Exception as close_error: + logging.warning( + f"Failed to close Bedrock event stream: {close_error}" + ) + + raise BedrockError(f"Streaming error: {e}") + + finally: + # Clean up the event stream + if ( + iteration_started + and event_stream + and hasattr(event_stream, "aclose") + and callable(getattr(event_stream, "aclose", None)) + ): + try: + await event_stream.aclose() + except Exception as close_error: + logging.warning( + f"Failed to close Bedrock event stream in finally block: {close_error}" + ) + + # Clean up the client + if client: + try: + await client.__aexit__(None, None, None) + except Exception as client_close_error: + logging.warning( + f"Failed to close Bedrock client: {client_close_error}" + ) + + # Return the generator that manages its own lifecycle + return stream_generator() + + # For non-streaming responses, use the standard async context manager pattern + session = aioboto3.Session() + async with session.client( + "bedrock-runtime", region_name=region + ) as bedrock_async_client: + try: + # Use converse for non-streaming responses + response = await bedrock_async_client.converse(**args, **kwargs) + + # Validate response structure + if ( + not response + or "output" not in response + or "message" not in response["output"] + or "content" not in response["output"]["message"] + or not response["output"]["message"]["content"] + ): + raise BedrockError("Invalid response structure from Bedrock API") + + content = response["output"]["message"]["content"][0]["text"] + + if not content or content.strip() == "": + raise BedrockError("Received empty content from Bedrock API") + + return content + + except Exception as e: + if isinstance(e, BedrockError): + raise + else: + raise BedrockError(f"Bedrock API error: {e}") # Generic Bedrock completion function async def bedrock_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: +) -> Union[str, AsyncIterator[str]]: kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] result = await bedrock_complete_if_cache( @@ -117,18 +259,21 @@ async def bedrock_embed( aws_secret_access_key=None, aws_session_token=None, ) -> np.ndarray: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "AWS_ACCESS_KEY_ID", aws_access_key_id - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "AWS_SECRET_ACCESS_KEY", aws_secret_access_key - ) - os.environ["AWS_SESSION_TOKEN"] = os.environ.get( - "AWS_SESSION_TOKEN", aws_session_token - ) + # Respect existing env; only set if a non-empty value is available + access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key + session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token + _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) + _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) + _set_env_if_present("AWS_SESSION_TOKEN", session_token) + + # Region handling: prefer env + region = os.environ.get("AWS_REGION") session = aioboto3.Session() - async with session.client("bedrock-runtime") as bedrock_async_client: + async with session.client( + "bedrock-runtime", region_name=region + ) as bedrock_async_client: if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: