diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index ccfbb4f7..f6871422 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -1,6 +1,7 @@ import copy import os import json +import logging import pipmaster as pm # Pipmaster for dynamic library install @@ -24,21 +25,121 @@ else: from collections.abc import AsyncIterator from typing import Union +# Import botocore exceptions for proper exception handling +try: + from botocore.exceptions import ( + ClientError, + ConnectionError as BotocoreConnectionError, + ReadTimeoutError, + ) +except ImportError: + # If botocore is not installed, define placeholders + ClientError = Exception + BotocoreConnectionError = Exception + ReadTimeoutError = Exception + class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" +class BedrockRateLimitError(BedrockError): + """Error for rate limiting and throttling issues""" + + +class BedrockConnectionError(BedrockError): + """Error for network and connection issues""" + + +class BedrockTimeoutError(BedrockError): + """Error for timeout issues""" + + def _set_env_if_present(key: str, value): """Set environment variable only if a non-empty value is provided.""" if value is not None and value != "": os.environ[key] = value +def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None: + """Convert AWS Bedrock exceptions to appropriate custom exceptions. + + Args: + e: The exception to handle + operation: Description of the operation for error messages + + Raises: + BedrockRateLimitError: For rate limiting and throttling issues (retryable) + BedrockConnectionError: For network and server issues (retryable) + BedrockTimeoutError: For timeout issues (retryable) + BedrockError: For validation and other non-retryable errors + """ + error_message = str(e) + + # Handle botocore ClientError with specific error codes + if isinstance(e, ClientError): + error_code = e.response.get("Error", {}).get("Code", "") + error_msg = e.response.get("Error", {}).get("Message", error_message) + + # Rate limiting and throttling errors (retryable) + if error_code in [ + "ThrottlingException", + "ProvisionedThroughputExceededException", + ]: + logging.error(f"{operation} rate limit error: {error_msg}") + raise BedrockRateLimitError(f"Rate limit error: {error_msg}") + + # Server errors (retryable) + elif error_code in ["ServiceUnavailableException", "InternalServerException"]: + logging.error(f"{operation} connection error: {error_msg}") + raise BedrockConnectionError(f"Service error: {error_msg}") + + # Check for 5xx HTTP status codes (retryable) + elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500: + logging.error(f"{operation} server error: {error_msg}") + raise BedrockConnectionError(f"Server error: {error_msg}") + + # Validation and other client errors (non-retryable) + else: + logging.error(f"{operation} client error: {error_msg}") + raise BedrockError(f"Client error: {error_msg}") + + # Connection errors (retryable) + elif isinstance(e, BotocoreConnectionError): + logging.error(f"{operation} connection error: {error_message}") + raise BedrockConnectionError(f"Connection error: {error_message}") + + # Timeout errors (retryable) + elif isinstance(e, (ReadTimeoutError, TimeoutError)): + logging.error(f"{operation} timeout error: {error_message}") + raise BedrockTimeoutError(f"Timeout error: {error_message}") + + # Custom Bedrock errors (already properly typed) + elif isinstance( + e, + ( + BedrockRateLimitError, + BedrockConnectionError, + BedrockTimeoutError, + BedrockError, + ), + ): + raise + + # Unknown errors (non-retryable) + else: + logging.error(f"{operation} unexpected error: {error_message}") + raise BedrockError(f"Unexpected error: {error_message}") + + @retry( stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, max=60), - retry=retry_if_exception_type((BedrockError)), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=( + retry_if_exception_type(BedrockRateLimitError) + | retry_if_exception_type(BedrockConnectionError) + | retry_if_exception_type(BedrockTimeoutError) + ), ) async def bedrock_complete_if_cache( model, @@ -159,9 +260,6 @@ async def bedrock_complete_if_cache( break except Exception as e: - # Log the specific error for debugging - logging.error(f"Bedrock streaming error: {e}") - # Try to clean up resources if possible if ( iteration_started @@ -176,7 +274,8 @@ async def bedrock_complete_if_cache( f"Failed to close Bedrock event stream: {close_error}" ) - raise BedrockError(f"Streaming error: {e}") + # Convert to appropriate exception type + _handle_bedrock_exception(e, "Bedrock streaming") finally: # Clean up the event stream @@ -232,10 +331,8 @@ async def bedrock_complete_if_cache( return content except Exception as e: - if isinstance(e, BedrockError): - raise - else: - raise BedrockError(f"Bedrock API error: {e}") + # Convert to appropriate exception type + _handle_bedrock_exception(e, "Bedrock converse") # Generic Bedrock completion function @@ -255,11 +352,15 @@ async def bedrock_complete( @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) -# @retry( -# stop=stop_after_attempt(3), -# wait=wait_exponential(multiplier=1, min=4, max=10), -# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions -# ) +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=( + retry_if_exception_type(BedrockRateLimitError) + | retry_if_exception_type(BedrockConnectionError) + | retry_if_exception_type(BedrockTimeoutError) + ), +) async def bedrock_embed( texts: list[str], model: str = "amazon.titan-embed-text-v2:0", @@ -282,48 +383,101 @@ async def bedrock_embed( async with session.client( "bedrock-runtime", region_name=region ) as bedrock_async_client: - if (model_provider := model.split(".")[0]) == "amazon": - embed_texts = [] - for text in texts: - if "v2" in model: + try: + if (model_provider := model.split(".")[0]) == "amazon": + embed_texts = [] + for text in texts: + try: + if "v2" in model: + body = json.dumps( + { + "inputText": text, + # 'dimensions': embedding_dim, + "embeddingTypes": ["float"], + } + ) + elif "v1" in model: + body = json.dumps({"inputText": text}) + else: + raise BedrockError(f"Model {model} is not supported!") + + response = await bedrock_async_client.invoke_model( + modelId=model, + body=body, + accept="application/json", + contentType="application/json", + ) + + response_body = await response.get("body").json() + + # Validate response structure + if not response_body or "embedding" not in response_body: + raise BedrockError( + f"Invalid embedding response structure for text: {text[:50]}..." + ) + + embedding = response_body["embedding"] + if not embedding: + raise BedrockError( + f"Received empty embedding for text: {text[:50]}..." + ) + + embed_texts.append(embedding) + + except Exception as e: + # Convert to appropriate exception type + _handle_bedrock_exception( + e, "Bedrock embedding (amazon, text chunk)" + ) + + elif model_provider == "cohere": + try: body = json.dumps( { - "inputText": text, - # 'dimensions': embedding_dim, - "embeddingTypes": ["float"], + "texts": texts, + "input_type": "search_document", + "truncate": "NONE", } ) - elif "v1" in model: - body = json.dumps({"inputText": text}) - else: - raise ValueError(f"Model {model} is not supported!") - response = await bedrock_async_client.invoke_model( - modelId=model, - body=body, - accept="application/json", - contentType="application/json", + response = await bedrock_async_client.invoke_model( + model=model, + body=body, + accept="application/json", + contentType="application/json", + ) + + response_body = json.loads(response.get("body").read()) + + # Validate response structure + if not response_body or "embeddings" not in response_body: + raise BedrockError( + "Invalid embedding response structure from Cohere" + ) + + embeddings = response_body["embeddings"] + if not embeddings or len(embeddings) != len(texts): + raise BedrockError( + f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}" + ) + + embed_texts = embeddings + + except Exception as e: + # Convert to appropriate exception type + _handle_bedrock_exception(e, "Bedrock embedding (cohere)") + + else: + raise BedrockError( + f"Model provider '{model_provider}' is not supported!" ) - response_body = await response.get("body").json() + # Final validation + if not embed_texts: + raise BedrockError("No embeddings generated") - embed_texts.append(response_body["embedding"]) - elif model_provider == "cohere": - body = json.dumps( - {"texts": texts, "input_type": "search_document", "truncate": "NONE"} - ) + return np.array(embed_texts) - response = await bedrock_async_client.invoke_model( - model=model, - body=body, - accept="application/json", - contentType="application/json", - ) - - response_body = json.loads(response.get("body").read()) - - embed_texts = response_body["embeddings"] - else: - raise ValueError(f"Model provider '{model_provider}' is not supported!") - - return np.array(embed_texts) + except Exception as e: + # Convert to appropriate exception type + _handle_bedrock_exception(e, "Bedrock embedding")