import copy import os import json import logging import pipmaster as pm # Pipmaster for dynamic library install if not pm.is_installed("aioboto3"): pm.install("aioboto3") import aioboto3 import numpy as np from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, ) import sys from lightrag.utils import wrap_embedding_func_with_attrs if sys.version_info < (3, 9): from typing import AsyncIterator else: from collections.abc import AsyncIterator from typing import Union # Import botocore exceptions for proper exception handling try: from botocore.exceptions import ( ClientError, ConnectionError as BotocoreConnectionError, ReadTimeoutError, ) except ImportError: # If botocore is not installed, define placeholders ClientError = Exception BotocoreConnectionError = Exception ReadTimeoutError = Exception class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" class BedrockRateLimitError(BedrockError): """Error for rate limiting and throttling issues""" class BedrockConnectionError(BedrockError): """Error for network and connection issues""" class BedrockTimeoutError(BedrockError): """Error for timeout issues""" def _set_env_if_present(key: str, value): """Set environment variable only if a non-empty value is provided.""" if value is not None and value != "": os.environ[key] = value def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None: """Convert AWS Bedrock exceptions to appropriate custom exceptions. Args: e: The exception to handle operation: Description of the operation for error messages Raises: BedrockRateLimitError: For rate limiting and throttling issues (retryable) BedrockConnectionError: For network and server issues (retryable) BedrockTimeoutError: For timeout issues (retryable) BedrockError: For validation and other non-retryable errors """ error_message = str(e) # Handle botocore ClientError with specific error codes if isinstance(e, ClientError): error_code = e.response.get("Error", {}).get("Code", "") error_msg = e.response.get("Error", {}).get("Message", error_message) # Rate limiting and throttling errors (retryable) if error_code in [ "ThrottlingException", "ProvisionedThroughputExceededException", ]: logging.error(f"{operation} rate limit error: {error_msg}") raise BedrockRateLimitError(f"Rate limit error: {error_msg}") # Server errors (retryable) elif error_code in ["ServiceUnavailableException", "InternalServerException"]: logging.error(f"{operation} connection error: {error_msg}") raise BedrockConnectionError(f"Service error: {error_msg}") # Check for 5xx HTTP status codes (retryable) elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500: logging.error(f"{operation} server error: {error_msg}") raise BedrockConnectionError(f"Server error: {error_msg}") # Validation and other client errors (non-retryable) else: logging.error(f"{operation} client error: {error_msg}") raise BedrockError(f"Client error: {error_msg}") # Connection errors (retryable) elif isinstance(e, BotocoreConnectionError): logging.error(f"{operation} connection error: {error_message}") raise BedrockConnectionError(f"Connection error: {error_message}") # Timeout errors (retryable) elif isinstance(e, (ReadTimeoutError, TimeoutError)): logging.error(f"{operation} timeout error: {error_message}") raise BedrockTimeoutError(f"Timeout error: {error_message}") # Custom Bedrock errors (already properly typed) elif isinstance( e, ( BedrockRateLimitError, BedrockConnectionError, BedrockTimeoutError, BedrockError, ), ): raise # Unknown errors (non-retryable) else: logging.error(f"{operation} unexpected error: {error_message}") raise BedrockError(f"Unexpected error: {error_message}") @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=( retry_if_exception_type(BedrockRateLimitError) | retry_if_exception_type(BedrockConnectionError) | retry_if_exception_type(BedrockTimeoutError) ), ) async def bedrock_complete_if_cache( model, prompt, system_prompt=None, history_messages=[], enable_cot: bool = False, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: if enable_cot: import logging logging.debug( "enable_cot=True is not supported for Bedrock and will be ignored." ) # Respect existing env; only set if a non-empty value is available access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) _set_env_if_present("AWS_SESSION_TOKEN", session_token) # Region handling: prefer env, else kwarg (optional) region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None) kwargs.pop("hashing_kv", None) # Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter # We'll use this to determine whether to call converse_stream or converse stream = bool(kwargs.pop("stream", False)) # Remove unsupported args for Bedrock Converse API for k in [ "response_format", "tools", "tool_choice", "seed", "presence_penalty", "frequency_penalty", "n", "logprobs", "top_logprobs", "max_completion_tokens", "response_format", ]: kwargs.pop(k, None) # Fix message history format messages = [] for history_message in history_messages: message = copy.copy(history_message) message["content"] = [{"text": message["content"]}] messages.append(message) # Add user prompt messages.append({"role": "user", "content": [{"text": prompt}]}) # Initialize Converse API arguments args = {"modelId": model, "messages": messages} # Define system prompt if system_prompt: args["system"] = [{"text": system_prompt}] # Map and set up inference parameters inference_params_map = { "max_tokens": "maxTokens", "top_p": "topP", "stop_sequences": "stopSequences", } if inference_params := list( set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) ): args["inferenceConfig"] = {} for param in inference_params: args["inferenceConfig"][inference_params_map.get(param, param)] = ( kwargs.pop(param) ) # Import logging for error handling import logging # For streaming responses, we need a different approach to keep the connection open if stream: # Create a session that will be used throughout the streaming process session = aioboto3.Session() client = None # Define the generator function that will manage the client lifecycle async def stream_generator(): nonlocal client # Create the client outside the generator to ensure it stays open client = await session.client( "bedrock-runtime", region_name=region ).__aenter__() event_stream = None iteration_started = False try: # Make the API call response = await client.converse_stream(**args, **kwargs) event_stream = response.get("stream") iteration_started = True # Process the stream async for event in event_stream: # Validate event structure if not event or not isinstance(event, dict): continue if "contentBlockDelta" in event: delta = event["contentBlockDelta"].get("delta", {}) text = delta.get("text") if text: yield text # Handle other event types that might indicate stream end elif "messageStop" in event: break except Exception as e: # Try to clean up resources if possible if ( iteration_started and event_stream and hasattr(event_stream, "aclose") and callable(getattr(event_stream, "aclose", None)) ): try: await event_stream.aclose() except Exception as close_error: logging.warning( f"Failed to close Bedrock event stream: {close_error}" ) # Convert to appropriate exception type _handle_bedrock_exception(e, "Bedrock streaming") finally: # Clean up the event stream if ( iteration_started and event_stream and hasattr(event_stream, "aclose") and callable(getattr(event_stream, "aclose", None)) ): try: await event_stream.aclose() except Exception as close_error: logging.warning( f"Failed to close Bedrock event stream in finally block: {close_error}" ) # Clean up the client if client: try: await client.__aexit__(None, None, None) except Exception as client_close_error: logging.warning( f"Failed to close Bedrock client: {client_close_error}" ) # Return the generator that manages its own lifecycle return stream_generator() # For non-streaming responses, use the standard async context manager pattern session = aioboto3.Session() async with session.client( "bedrock-runtime", region_name=region ) as bedrock_async_client: try: # Use converse for non-streaming responses response = await bedrock_async_client.converse(**args, **kwargs) # Validate response structure if ( not response or "output" not in response or "message" not in response["output"] or "content" not in response["output"]["message"] or not response["output"]["message"]["content"] ): raise BedrockError("Invalid response structure from Bedrock API") content = response["output"]["message"]["content"][0]["text"] if not content or content.strip() == "": raise BedrockError("Received empty content from Bedrock API") return content except Exception as e: # Convert to appropriate exception type _handle_bedrock_exception(e, "Bedrock converse") # Generic Bedrock completion function async def bedrock_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> Union[str, AsyncIterator[str]]: kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] result = await bedrock_complete_if_cache( model_name, prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) return result @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=( retry_if_exception_type(BedrockRateLimitError) | retry_if_exception_type(BedrockConnectionError) | retry_if_exception_type(BedrockTimeoutError) ), ) async def bedrock_embed( texts: list[str], model: str = "amazon.titan-embed-text-v2:0", aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, ) -> np.ndarray: # Respect existing env; only set if a non-empty value is available access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token _set_env_if_present("AWS_ACCESS_KEY_ID", access_key) _set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key) _set_env_if_present("AWS_SESSION_TOKEN", session_token) # Region handling: prefer env region = os.environ.get("AWS_REGION") session = aioboto3.Session() async with session.client( "bedrock-runtime", region_name=region ) as bedrock_async_client: try: if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: try: if "v2" in model: body = json.dumps( { "inputText": text, # 'dimensions': embedding_dim, "embeddingTypes": ["float"], } ) elif "v1" in model: body = json.dumps({"inputText": text}) else: raise BedrockError(f"Model {model} is not supported!") response = await bedrock_async_client.invoke_model( modelId=model, body=body, accept="application/json", contentType="application/json", ) response_body = await response.get("body").json() # Validate response structure if not response_body or "embedding" not in response_body: raise BedrockError( f"Invalid embedding response structure for text: {text[:50]}..." ) embedding = response_body["embedding"] if not embedding: raise BedrockError( f"Received empty embedding for text: {text[:50]}..." ) embed_texts.append(embedding) except Exception as e: # Convert to appropriate exception type _handle_bedrock_exception( e, "Bedrock embedding (amazon, text chunk)" ) elif model_provider == "cohere": try: body = json.dumps( { "texts": texts, "input_type": "search_document", "truncate": "NONE", } ) response = await bedrock_async_client.invoke_model( model=model, body=body, accept="application/json", contentType="application/json", ) response_body = json.loads(response.get("body").read()) # Validate response structure if not response_body or "embeddings" not in response_body: raise BedrockError( "Invalid embedding response structure from Cohere" ) embeddings = response_body["embeddings"] if not embeddings or len(embeddings) != len(texts): raise BedrockError( f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}" ) embed_texts = embeddings except Exception as e: # Convert to appropriate exception type _handle_bedrock_exception(e, "Bedrock embedding (cohere)") else: raise BedrockError( f"Model provider '{model_provider}' is not supported!" ) # Final validation if not embed_texts: raise BedrockError("No embeddings generated") return np.array(embed_texts) except Exception as e: # Convert to appropriate exception type _handle_bedrock_exception(e, "Bedrock embedding")