diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 56e506cc..0a16e89b 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -77,9 +77,7 @@ def parse_args() -> argparse.Namespace: argparse.Namespace: Parsed arguments """ - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) + parser = argparse.ArgumentParser(description="LightRAG FastAPI Server with separate working and input directories") # Server configuration parser.add_argument( @@ -209,14 +207,14 @@ def parse_args() -> argparse.Namespace: "--llm-binding", type=str, default=get_env_value("LLM_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"], + choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai", "aws_bedrock"], help="LLM binding type (default: from env or ollama)", ) parser.add_argument( "--embedding-binding", type=str, default=get_env_value("EMBEDDING_BINDING", "ollama"), - choices=["lollms", "ollama", "openai", "azure_openai"], + choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], help="Embedding binding type (default: from env or ollama)", ) @@ -272,18 +270,10 @@ def parse_args() -> argparse.Namespace: args.input_dir = os.path.abspath(args.input_dir) # Inject storage configuration from environment variables - args.kv_storage = get_env_value( - "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE - ) - args.doc_status_storage = get_env_value( - "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE - ) - args.graph_storage = get_env_value( - "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE - ) - args.vector_storage = get_env_value( - "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE - ) + args.kv_storage = get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE) + args.doc_status_storage = get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE) + args.graph_storage = get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE) + args.vector_storage = get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE) # Get MAX_PARALLEL_INSERT from environment args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) @@ -299,12 +289,8 @@ def parse_args() -> argparse.Namespace: # Ollama ctx_num args.ollama_num_ctx = get_env_value("OLLAMA_NUM_CTX", 32768, int) - args.llm_binding_host = get_env_value( - "LLM_BINDING_HOST", get_default_host(args.llm_binding) - ) - args.embedding_binding_host = get_env_value( - "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding) - ) + args.llm_binding_host = get_env_value("LLM_BINDING_HOST", get_default_host(args.llm_binding)) + args.embedding_binding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)) args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None) args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") @@ -318,9 +304,7 @@ def parse_args() -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Inject LLM cache configuration - args.enable_llm_cache_for_extract = get_env_value( - "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool - ) + args.enable_llm_cache_for_extract = get_env_value("ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool) args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool) # Handle Ollama LLM temperature with priority cascade when llm-binding is ollama @@ -370,40 +354,24 @@ def parse_args() -> argparse.Namespace: args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) # Min rerank score configuration - args.min_rerank_score = get_env_value( - "MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float - ) + args.min_rerank_score = get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float) # Query configuration args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int) args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int) args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int) - args.max_entity_tokens = get_env_value( - "MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int - ) - args.max_relation_tokens = get_env_value( - "MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int - ) - args.max_total_tokens = get_env_value( - "MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int - ) - args.cosine_threshold = get_env_value( - "COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float - ) - args.related_chunk_number = get_env_value( - "RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int - ) + args.max_entity_tokens = get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int) + args.max_relation_tokens = get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int) + args.max_total_tokens = get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int) + args.cosine_threshold = get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float) + args.related_chunk_number = get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int) # Add missing environment variables for health endpoint args.force_llm_summary_on_merge = get_env_value( "FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int ) - args.embedding_func_max_async = get_env_value( - "EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int - ) - args.embedding_batch_num = get_env_value( - "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int - ) + args.embedding_func_max_async = get_env_value("EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int) + args.embedding_batch_num = get_env_value("EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int) ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag @@ -417,9 +385,7 @@ def update_uvicorn_mode_config(): original_workers = global_args.workers global_args.workers = 1 # Log warning directly here - logging.warning( - f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" - ) + logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1") global_args = parse_args() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 699f59ac..8ee23826 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -106,6 +106,7 @@ def create_app(args): "openai", "openai-ollama", "azure_openai", + "aws_bedrock", ]: raise Exception("llm binding not supported") @@ -114,6 +115,7 @@ def create_app(args): "ollama", "openai", "azure_openai", + "aws_bedrock", "jina", ]: raise Exception("embedding binding not supported") @@ -128,9 +130,7 @@ def create_app(args): # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception( - "SSL certificate and key files must be provided when SSL is enabled" - ) + raise Exception("SSL certificate and key files must be provided when SSL is enabled") if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): @@ -188,10 +188,11 @@ def create_app(args): # Initialize FastAPI app_kwargs = { "title": "LightRAG Server API", - "description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation" - + "(With authentication)" - if api_key - else "", + "description": ( + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)" + if api_key + else "" + ), "version": __api_version__, "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL "docs_url": "/docs", # Explicitly set docs URL @@ -244,9 +245,9 @@ def create_app(args): azure_openai_complete_if_cache, azure_openai_embed, ) + if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock": + from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama": - from lightrag.llm.openai import openai_complete_if_cache - from lightrag.llm.ollama import ollama_embed from lightrag.llm.binding_options import OllamaEmbeddingOptions if args.embedding_binding == "jina": from lightrag.llm.jina import jina_embed @@ -312,41 +313,80 @@ def create_app(args): **kwargs, ) + async def bedrock_model_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, + ) -> str: + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + if history_messages is None: + history_messages = [] + + # Use global temperature for Bedrock + kwargs["temperature"] = args.temperature + + return await bedrock_complete_if_cache( + args.llm_model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + embedding_func = EmbeddingFunc( embedding_dim=args.embedding_dim, - func=lambda texts: lollms_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "lollms" - else ollama_embed( - texts, - embed_model=args.embedding_model, - host=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - options=OllamaEmbeddingOptions.options_dict(args), - ) - if args.embedding_binding == "ollama" - else azure_openai_embed( - texts, - model=args.embedding_model, # no host is used for openai, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "azure_openai" - else jina_embed( - texts, - dimensions=args.embedding_dim, - base_url=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, - ) - if args.embedding_binding == "jina" - else openai_embed( - texts, - model=args.embedding_model, - base_url=args.embedding_binding_host, - api_key=args.embedding_binding_api_key, + func=lambda texts: ( + lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "lollms" + else ( + ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + options=OllamaEmbeddingOptions.options_dict(args), + ) + if args.embedding_binding == "ollama" + else ( + azure_openai_embed( + texts, + model=args.embedding_model, # no host is used for openai, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "azure_openai" + else ( + bedrock_embed( + texts, + model=args.embedding_model, + ) + if args.embedding_binding == "aws_bedrock" + else ( + jina_embed( + texts, + dimensions=args.embedding_dim, + base_url=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + if args.embedding_binding == "jina" + else openai_embed( + texts, + model=args.embedding_model, + base_url=args.embedding_binding_host, + api_key=args.embedding_binding_api_key, + ) + ) + ) + ) + ) ), ) @@ -355,9 +395,7 @@ def create_app(args): if args.rerank_binding_api_key and args.rerank_binding_host: from lightrag.rerank import custom_rerank - async def server_rerank_func( - query: str, documents: list, top_n: int = None, **kwargs - ): + async def server_rerank_func(query: str, documents: list, top_n: int = None, **kwargs): """Server rerank function with configuration from environment variables""" return await custom_rerank( query=query, @@ -370,9 +408,7 @@ def create_app(args): ) rerank_model_func = server_rerank_func - logger.info( - f"Rerank model configured: {args.rerank_model} (can be enabled per query)" - ) + logger.info(f"Rerank model configured: {args.rerank_model} (can be enabled per query)") else: logger.info( "Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking." @@ -381,41 +417,43 @@ def create_app(args): # Create ollama_server_infos from command line arguments from lightrag.api.config import OllamaServerInfos - ollama_server_infos = OllamaServerInfos( - name=args.simulated_model_name, tag=args.simulated_model_tag - ) + ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag) # Initialize RAG - if args.llm_binding in ["lollms", "ollama", "openai"]: + if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]: rag = LightRAG( working_dir=args.working_dir, workspace=args.workspace, - llm_model_func=lollms_model_complete - if args.llm_binding == "lollms" - else ollama_model_complete - if args.llm_binding == "ollama" - else openai_alike_model_complete, + llm_model_func=( + lollms_model_complete + if args.llm_binding == "lollms" + else ( + ollama_model_complete + if args.llm_binding == "ollama" + else bedrock_model_complete if args.llm_binding == "aws_bedrock" else openai_alike_model_complete + ) + ), llm_model_name=args.llm_model, llm_model_max_async=args.max_async, summary_max_tokens=args.max_tokens, chunk_token_size=int(args.chunk_size), chunk_overlap_token_size=int(args.chunk_overlap_size), - llm_model_kwargs={ - "host": args.llm_binding_host, - "timeout": args.timeout, - "options": OllamaLLMOptions.options_dict(args), - "api_key": args.llm_binding_api_key, - } - if args.llm_binding == "lollms" or args.llm_binding == "ollama" - else {}, + llm_model_kwargs=( + { + "host": args.llm_binding_host, + "timeout": args.timeout, + "options": OllamaLLMOptions.options_dict(args), + "api_key": args.llm_binding_api_key, + } + if args.llm_binding == "lollms" or args.llm_binding == "ollama" + else {} + ), embedding_func=embedding_func, kv_storage=args.kv_storage, graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": args.cosine_threshold - }, + vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -442,9 +480,7 @@ def create_app(args): graph_storage=args.graph_storage, vector_storage=args.vector_storage, doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": args.cosine_threshold - }, + vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold}, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, rerank_model_func=rerank_model_func, @@ -480,9 +516,7 @@ def create_app(args): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token( - username="guest", role="guest", metadata={"auth_mode": "disabled"} - ) + guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) return { "auth_configured": False, "access_token": guest_token, @@ -508,9 +542,7 @@ def create_app(args): async def login(form_data: OAuth2PasswordRequestForm = Depends()): if not auth_handler.accounts: # Authentication not configured, return guest token - guest_token = auth_handler.create_token( - username="guest", role="guest", metadata={"auth_mode": "disabled"} - ) + guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"}) return { "access_token": guest_token, "token_type": "bearer", @@ -523,14 +555,10 @@ def create_app(args): } username = form_data.username if auth_handler.accounts.get(username) != form_data.password: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials" - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials") # Regular user login - user_token = auth_handler.create_token( - username=username, role="user", metadata={"auth_mode": "enabled"} - ) + user_token = auth_handler.create_token(username=username, role="user", metadata={"auth_mode": "enabled"}) return { "access_token": user_token, "token_type": "bearer", @@ -579,12 +607,8 @@ def create_app(args): "max_graph_nodes": args.max_graph_nodes, # Rerank configuration (based on whether rerank model is configured) "enable_rerank": rerank_model_func is not None, - "rerank_model": args.rerank_model - if rerank_model_func is not None - else None, - "rerank_binding_host": args.rerank_binding_host - if rerank_model_func is not None - else None, + "rerank_model": args.rerank_model if rerank_model_func is not None else None, + "rerank_binding_host": args.rerank_binding_host if rerank_model_func is not None else None, # Environment variable status (requested configuration) "summary_language": args.summary_language, "force_llm_summary_on_merge": args.force_llm_summary_on_merge, @@ -614,17 +638,11 @@ def create_app(args): response = await super().get_response(path, scope) if path.endswith(".html"): - response.headers["Cache-Control"] = ( - "no-cache, no-store, must-revalidate" - ) + response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" - elif ( - "/assets/" in path - ): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename - response.headers["Cache-Control"] = ( - "public, max-age=31536000, immutable" - ) + elif "/assets/" in path: # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename + response.headers["Cache-Control"] = "public, max-age=31536000, immutable" # Add other rules here if needed for non-HTML, non-asset files # Ensure correct Content-Type @@ -640,9 +658,7 @@ def create_app(args): static_dir.mkdir(exist_ok=True) app.mount( "/webui", - SmartStaticFiles( - directory=static_dir, html=True, check_dir=True - ), # Use SmartStaticFiles + SmartStaticFiles(directory=static_dir, html=True, check_dir=True), # Use SmartStaticFiles name="webui", ) @@ -798,9 +814,7 @@ def main(): } ) - print( - f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}" - ) + print(f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}") uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 1640abbb..51c4b895 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -15,11 +15,25 @@ from tenacity import ( retry_if_exception_type, ) +import sys + +if sys.version_info < (3, 9): + from typing import AsyncIterator +else: + from collections.abc import AsyncIterator +from typing import Union + class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" +def _set_env_if_present(key: str, value): + """Set environment variable only if a non-empty value is provided.""" + if value is not None and value != "": + os.environ[key] = value + + @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=60), @@ -34,17 +48,35 @@ async def bedrock_complete_if_cache( aws_secret_access_key=None, aws_session_token=None, **kwargs, -) -> str: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "AWS_ACCESS_KEY_ID", aws_access_key_id - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "AWS_SECRET_ACCESS_KEY", aws_secret_access_key - ) - os.environ["AWS_SESSION_TOKEN"] = os.environ.get( - "AWS_SESSION_TOKEN", aws_session_token - ) +) -> Union[str, AsyncIterator[str]]: + # Respect existing env; only set if a non-empty value is available + access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key + session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token + _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) + _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) + _set_env_if_present("AWS_SESSION_TOKEN", session_token) + # Region handling: prefer env, else kwarg (optional) + region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None) kwargs.pop("hashing_kv", None) + # Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter + # We'll use this to determine whether to call converse_stream or converse + stream = bool(kwargs.pop("stream", False)) + # Remove unsupported args for Bedrock Converse API + for k in [ + "response_format", + "tools", + "tool_choice", + "seed", + "presence_penalty", + "frequency_penalty", + "n", + "logprobs", + "top_logprobs", + "max_completion_tokens", + "response_format", + ]: + kwargs.pop(k, None) # Fix message history format messages = [] for history_message in history_messages: @@ -68,30 +100,126 @@ async def bedrock_complete_if_cache( "top_p": "topP", "stop_sequences": "stopSequences", } - if inference_params := list( - set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) - ): + if inference_params := list(set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])): args["inferenceConfig"] = {} for param in inference_params: - args["inferenceConfig"][inference_params_map.get(param, param)] = ( - kwargs.pop(param) - ) + args["inferenceConfig"][inference_params_map.get(param, param)] = kwargs.pop(param) - # Call model via Converse API + # Import logging for error handling + import logging + + # For streaming responses, we need a different approach to keep the connection open + if stream: + # Create a session that will be used throughout the streaming process + session = aioboto3.Session() + client = None + + # Define the generator function that will manage the client lifecycle + async def stream_generator(): + nonlocal client + + # Create the client outside the generator to ensure it stays open + client = await session.client("bedrock-runtime", region_name=region).__aenter__() + event_stream = None + iteration_started = False + + try: + # Make the API call + response = await client.converse_stream(**args, **kwargs) + event_stream = response.get("stream") + iteration_started = True + + # Process the stream + async for event in event_stream: + # Validate event structure + if not event or not isinstance(event, dict): + continue + + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"].get("delta", {}) + text = delta.get("text") + if text: + yield text + # Handle other event types that might indicate stream end + elif "messageStop" in event: + break + + except Exception as e: + # Log the specific error for debugging + logging.error(f"Bedrock streaming error: {e}") + + # Try to clean up resources if possible + if ( + iteration_started + and event_stream + and hasattr(event_stream, "aclose") + and callable(getattr(event_stream, "aclose", None)) + ): + try: + await event_stream.aclose() + except Exception as close_error: + logging.warning(f"Failed to close Bedrock event stream: {close_error}") + + raise BedrockError(f"Streaming error: {e}") + + finally: + # Clean up the event stream + if ( + iteration_started + and event_stream + and hasattr(event_stream, "aclose") + and callable(getattr(event_stream, "aclose", None)) + ): + try: + await event_stream.aclose() + except Exception as close_error: + logging.warning(f"Failed to close Bedrock event stream in finally block: {close_error}") + + # Clean up the client + if client: + try: + await client.__aexit__(None, None, None) + except Exception as client_close_error: + logging.warning(f"Failed to close Bedrock client: {client_close_error}") + + # Return the generator that manages its own lifecycle + return stream_generator() + + # For non-streaming responses, use the standard async context manager pattern session = aioboto3.Session() - async with session.client("bedrock-runtime") as bedrock_async_client: + async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: try: + # Use converse for non-streaming responses response = await bedrock_async_client.converse(**args, **kwargs) - except Exception as e: - raise BedrockError(e) - return response["output"]["message"]["content"][0]["text"] + # Validate response structure + if ( + not response + or "output" not in response + or "message" not in response["output"] + or "content" not in response["output"]["message"] + or not response["output"]["message"]["content"] + ): + raise BedrockError("Invalid response structure from Bedrock API") + + content = response["output"]["message"]["content"][0]["text"] + + if not content or content.strip() == "": + raise BedrockError("Received empty content from Bedrock API") + + return content + + except Exception as e: + if isinstance(e, BedrockError): + raise + else: + raise BedrockError(f"Bedrock API error: {e}") # Generic Bedrock completion function async def bedrock_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: +) -> Union[str, AsyncIterator[str]]: kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] result = await bedrock_complete_if_cache( @@ -117,18 +245,19 @@ async def bedrock_embed( aws_secret_access_key=None, aws_session_token=None, ) -> np.ndarray: - os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( - "AWS_ACCESS_KEY_ID", aws_access_key_id - ) - os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( - "AWS_SECRET_ACCESS_KEY", aws_secret_access_key - ) - os.environ["AWS_SESSION_TOKEN"] = os.environ.get( - "AWS_SESSION_TOKEN", aws_session_token - ) + # Respect existing env; only set if a non-empty value is available + access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key + session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token + _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) + _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) + _set_env_if_present("AWS_SESSION_TOKEN", session_token) + + # Region handling: prefer env + region = os.environ.get("AWS_REGION") session = aioboto3.Session() - async with session.client("bedrock-runtime") as bedrock_async_client: + async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client: if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: @@ -156,9 +285,7 @@ async def bedrock_embed( embed_texts.append(response_body["embedding"]) elif model_provider == "cohere": - body = json.dumps( - {"texts": texts, "input_type": "search_document", "truncate": "NONE"} - ) + body = json.dumps({"texts": texts, "input_type": "search_document", "truncate": "NONE"}) response = await bedrock_async_client.invoke_model( model=model,