import copy import json import logging import os import pipmaster as pm # Pipmaster for dynamic library install if not pm.is_installed('aioboto3'): pm.install('aioboto3') from collections.abc import AsyncIterator from typing import Any, NoReturn, cast import aioboto3 import numpy as np from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) from lightrag.utils import wrap_embedding_func_with_attrs # Import botocore exceptions for proper exception handling try: from botocore.exceptions import ( ClientError, ReadTimeoutError, ) from botocore.exceptions import ( ConnectionError as BotocoreConnectionError, ) except ImportError: # If botocore is not installed, define placeholders ClientError = Exception BotocoreConnectionError = Exception ReadTimeoutError = Exception class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" class BedrockRateLimitError(BedrockError): """Error for rate limiting and throttling issues""" class BedrockConnectionError(BedrockError): """Error for network and connection issues""" class BedrockTimeoutError(BedrockError): """Error for timeout issues""" def _set_env_if_present(key: str, value): """Set environment variable only if a non-empty value is provided.""" if value is not None and value != '': os.environ[key] = value def _handle_bedrock_exception(e: Exception, operation: str = 'Bedrock operation') -> NoReturn: """Convert AWS Bedrock exceptions to appropriate custom exceptions. Args: e: The exception to handle operation: Description of the operation for error messages Raises: BedrockRateLimitError: For rate limiting and throttling issues (retryable) BedrockConnectionError: For network and server issues (retryable) BedrockTimeoutError: For timeout issues (retryable) BedrockError: For validation and other non-retryable errors """ error_message = str(e) # Handle botocore ClientError with specific error codes if isinstance(e, ClientError): error_code = cast(ClientError, e).response.get('Error', {}).get('Code', '') error_msg = cast(ClientError, e).response.get('Error', {}).get('Message', error_message) # Rate limiting and throttling errors (retryable) if error_code in [ 'ThrottlingException', 'ProvisionedThroughputExceededException', ]: logging.error(f'{operation} rate limit error: {error_msg}') raise BedrockRateLimitError(f'Rate limit error: {error_msg}') # Server errors (retryable) elif error_code in ['ServiceUnavailableException', 'InternalServerException']: logging.error(f'{operation} connection error: {error_msg}') raise BedrockConnectionError(f'Service error: {error_msg}') # Check for 5xx HTTP status codes (retryable) elif cast(ClientError, e).response.get('ResponseMetadata', {}).get('HTTPStatusCode', 0) >= 500: logging.error(f'{operation} server error: {error_msg}') raise BedrockConnectionError(f'Server error: {error_msg}') # Validation and other client errors (non-retryable) else: logging.error(f'{operation} client error: {error_msg}') raise BedrockError(f'Client error: {error_msg}') # Connection errors (retryable) elif isinstance(e, BotocoreConnectionError): logging.error(f'{operation} connection error: {error_message}') raise BedrockConnectionError(f'Connection error: {error_message}') # Timeout errors (retryable) elif isinstance(e, (ReadTimeoutError, TimeoutError)): logging.error(f'{operation} timeout error: {error_message}') raise BedrockTimeoutError(f'Timeout error: {error_message}') # Custom Bedrock errors (already properly typed) elif isinstance( e, ( BedrockRateLimitError, BedrockConnectionError, BedrockTimeoutError, BedrockError, ), ): raise # Unknown errors (non-retryable) else: logging.error(f'{operation} unexpected error: {error_message}') raise BedrockError(f'Unexpected error: {error_message}') @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=( retry_if_exception_type(BedrockRateLimitError) | retry_if_exception_type(BedrockConnectionError) | retry_if_exception_type(BedrockTimeoutError) ), ) async def bedrock_complete_if_cache( model, prompt, system_prompt=None, history_messages=None, enable_cot: bool = False, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs, ) -> str | AsyncIterator[str]: if history_messages is None: history_messages = [] if enable_cot: from lightrag.utils import logger logger.debug('enable_cot=True is not supported for Bedrock and will be ignored.') # Respect existing env; only set if a non-empty value is available access_key = os.environ.get('AWS_ACCESS_KEY_ID') or aws_access_key_id secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') or aws_secret_access_key session_token = os.environ.get('AWS_SESSION_TOKEN') or aws_session_token _set_env_if_present('AWS_ACCESS_KEY_ID', access_key) _set_env_if_present('AWS_SECRET_ACCESS_KEY', secret_key) _set_env_if_present('AWS_SESSION_TOKEN', session_token) # Region handling: prefer env, else kwarg (optional) region = os.environ.get('AWS_REGION') or kwargs.pop('aws_region', None) kwargs.pop('hashing_kv', None) # Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter # We'll use this to determine whether to call converse_stream or converse stream = bool(kwargs.pop('stream', False)) # Remove unsupported args for Bedrock Converse API for k in [ 'response_format', 'tools', 'tool_choice', 'seed', 'presence_penalty', 'frequency_penalty', 'n', 'logprobs', 'top_logprobs', 'max_completion_tokens', 'response_format', ]: kwargs.pop(k, None) # Fix message history format messages = [] for history_message in history_messages: message = copy.copy(history_message) message['content'] = [{'text': message['content']}] messages.append(message) # Add user prompt messages.append({'role': 'user', 'content': [{'text': prompt}]}) # Initialize Converse API arguments args = {'modelId': model, 'messages': messages} # Define system prompt if system_prompt: args['system'] = [{'text': system_prompt}] # Map and set up inference parameters inference_params_map = { 'max_tokens': 'maxTokens', 'top_p': 'topP', 'stop_sequences': 'stopSequences', } if inference_params := list(set(kwargs) & {'max_tokens', 'temperature', 'top_p', 'stop_sequences'}): args['inferenceConfig'] = {} for param in inference_params: args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param) # Import logging for error handling import logging # For streaming responses, we need a different approach to keep the connection open if stream: # Create a session that will be used throughout the streaming process session = aioboto3.Session() client = None # Define the generator function that will manage the client lifecycle async def stream_generator(): nonlocal client # Create the client outside the generator to ensure it stays open client = await session.client('bedrock-runtime', region_name=region).__aenter__() event_stream = None iteration_started = False try: # Make the API call response = await client.converse_stream(**args, **kwargs) event_stream = response.get('stream') iteration_started = True # Process the stream async for event in event_stream: # Validate event structure if not event or not isinstance(event, dict): continue if 'contentBlockDelta' in event: delta = event['contentBlockDelta'].get('delta', {}) text = delta.get('text') if text: yield text # Handle other event types that might indicate stream end elif 'messageStop' in event: break except Exception as e: # Try to clean up resources if possible if ( iteration_started and event_stream and hasattr(event_stream, 'aclose') and callable(getattr(event_stream, 'aclose', None)) ): try: await event_stream.aclose() except Exception as close_error: logging.warning(f'Failed to close Bedrock event stream: {close_error}') # Convert to appropriate exception type _handle_bedrock_exception(e, 'Bedrock streaming') finally: # Clean up the event stream if ( iteration_started and event_stream and hasattr(event_stream, 'aclose') and callable(getattr(event_stream, 'aclose', None)) ): try: await event_stream.aclose() except Exception as close_error: logging.warning(f'Failed to close Bedrock event stream in finally block: {close_error}') # Clean up the client if client: try: await client.__aexit__(None, None, None) except Exception as client_close_error: logging.warning(f'Failed to close Bedrock client: {client_close_error}') # Return the generator that manages its own lifecycle return stream_generator() # For non-streaming responses, use the standard async context manager pattern session = aioboto3.Session() async with session.client('bedrock-runtime', region_name=region) as bedrock_async_client: try: # Use converse for non-streaming responses response = await bedrock_async_client.converse(**args, **kwargs) # Validate response structure if ( not response or 'output' not in response or 'message' not in response['output'] or 'content' not in response['output']['message'] or not response['output']['message']['content'] ): raise BedrockError('Invalid response structure from Bedrock API') content = response['output']['message']['content'][0]['text'] if not content or content.strip() == '': raise BedrockError('Received empty content from Bedrock API') return content except Exception as e: # Convert to appropriate exception type _handle_bedrock_exception(e, 'Bedrock converse') # Generic Bedrock completion function async def bedrock_complete( prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs ) -> str | AsyncIterator[str]: if history_messages is None: history_messages = [] kwargs.pop('keyword_extraction', None) hashing_kv = kwargs.get('hashing_kv') if not hashing_kv: raise ValueError("'hashing_kv' parameter is required") model_name = hashing_kv.global_config['llm_model_name'] result = await bedrock_complete_if_cache( model_name, prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) return result @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=( retry_if_exception_type(BedrockRateLimitError) | retry_if_exception_type(BedrockConnectionError) | retry_if_exception_type(BedrockTimeoutError) ), ) async def bedrock_embed( texts: list[str], model: str = 'amazon.titan-embed-text-v2:0', aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, ) -> np.ndarray: # Respect existing env; only set if a non-empty value is available access_key = os.environ.get('AWS_ACCESS_KEY_ID') or aws_access_key_id secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') or aws_secret_access_key session_token = os.environ.get('AWS_SESSION_TOKEN') or aws_session_token _set_env_if_present('AWS_ACCESS_KEY_ID', access_key) _set_env_if_present('AWS_SECRET_ACCESS_KEY', secret_key) _set_env_if_present('AWS_SESSION_TOKEN', session_token) # Region handling: prefer env region = os.environ.get('AWS_REGION') session = aioboto3.Session() async with session.client('bedrock-runtime', region_name=region) as bedrock_async_client: try: if (model_provider := model.split('.')[0]) == 'amazon': embed_texts = [] for text in texts: try: if 'v2' in model: body = json.dumps( { 'inputText': text, # 'dimensions': embedding_dim, 'embeddingTypes': ['float'], } ) elif 'v1' in model: body = json.dumps({'inputText': text}) else: raise BedrockError(f'Model {model} is not supported!') response = await bedrock_async_client.invoke_model( modelId=model, body=body, accept='application/json', contentType='application/json', ) response_body = await response.get('body').json() # Validate response structure if not response_body or 'embedding' not in response_body: raise BedrockError(f'Invalid embedding response structure for text: {text[:50]}...') embedding = response_body['embedding'] if not embedding: raise BedrockError(f'Received empty embedding for text: {text[:50]}...') embed_texts.append(embedding) except Exception as e: # Convert to appropriate exception type _handle_bedrock_exception(e, 'Bedrock embedding (amazon, text chunk)') elif model_provider == 'cohere': try: body = json.dumps( { 'texts': texts, 'input_type': 'search_document', 'truncate': 'NONE', } ) response = await bedrock_async_client.invoke_model( model=model, body=body, accept='application/json', contentType='application/json', ) response_body = json.loads(response.get('body').read()) # Validate response structure if not response_body or 'embeddings' not in response_body: raise BedrockError('Invalid embedding response structure from Cohere') embeddings = response_body['embeddings'] if not embeddings or len(embeddings) != len(texts): raise BedrockError( f'Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}' ) embed_texts = embeddings except Exception as e: # Convert to appropriate exception type _handle_bedrock_exception(e, 'Bedrock embedding (cohere)') else: raise BedrockError(f"Model provider '{model_provider}' is not supported!") # Final validation if not embed_texts: raise BedrockError('No embeddings generated') return np.array(embed_texts) except Exception as e: # Convert to appropriate exception type _handle_bedrock_exception(e, 'Bedrock embedding')