Add S3 storage client and API routes for document management: - Implement s3_routes.py with file upload, download, delete endpoints - Enhance s3_client.py with improved error handling and operations - Add S3 browser UI component with file viewing and management - Implement FileViewer and PDFViewer components for storage preview - Add Resizable and Sheet UI components for layout control Update backend infrastructure: - Add bulk operations and parameterized queries to postgres_impl.py - Enhance document routes with improved type hints - Update API server registration for new S3 routes - Refine upload routes and utility functions Modernize web UI: - Integrate S3 browser into main application layout - Update localization files for storage UI strings - Add storage settings to application configuration - Sync package dependencies and lock files Remove obsolete reproduction script: - Delete reproduce_citation.py (replaced by test suite) Update configuration: - Enhance pyrightconfig.json for stricter type checking
461 lines
17 KiB
Python
461 lines
17 KiB
Python
import copy
|
|
import json
|
|
import logging
|
|
import os
|
|
|
|
import pipmaster as pm # Pipmaster for dynamic library install
|
|
|
|
if not pm.is_installed('aioboto3'):
|
|
pm.install('aioboto3')
|
|
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any, NoReturn, cast
|
|
|
|
import aioboto3
|
|
import numpy as np
|
|
from tenacity import (
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
)
|
|
|
|
from lightrag.utils import wrap_embedding_func_with_attrs
|
|
|
|
# Import botocore exceptions for proper exception handling
|
|
try:
|
|
from botocore.exceptions import (
|
|
ClientError,
|
|
ReadTimeoutError,
|
|
)
|
|
from botocore.exceptions import (
|
|
ConnectionError as BotocoreConnectionError,
|
|
)
|
|
except ImportError:
|
|
# If botocore is not installed, define placeholders
|
|
ClientError = Exception
|
|
BotocoreConnectionError = Exception
|
|
ReadTimeoutError = Exception
|
|
|
|
|
|
class BedrockError(Exception):
|
|
"""Generic error for issues related to Amazon Bedrock"""
|
|
|
|
|
|
class BedrockRateLimitError(BedrockError):
|
|
"""Error for rate limiting and throttling issues"""
|
|
|
|
|
|
class BedrockConnectionError(BedrockError):
|
|
"""Error for network and connection issues"""
|
|
|
|
|
|
class BedrockTimeoutError(BedrockError):
|
|
"""Error for timeout issues"""
|
|
|
|
|
|
def _set_env_if_present(key: str, value):
|
|
"""Set environment variable only if a non-empty value is provided."""
|
|
if value is not None and value != '':
|
|
os.environ[key] = value
|
|
|
|
|
|
def _handle_bedrock_exception(e: Exception, operation: str = 'Bedrock operation') -> NoReturn:
|
|
"""Convert AWS Bedrock exceptions to appropriate custom exceptions.
|
|
|
|
Args:
|
|
e: The exception to handle
|
|
operation: Description of the operation for error messages
|
|
|
|
Raises:
|
|
BedrockRateLimitError: For rate limiting and throttling issues (retryable)
|
|
BedrockConnectionError: For network and server issues (retryable)
|
|
BedrockTimeoutError: For timeout issues (retryable)
|
|
BedrockError: For validation and other non-retryable errors
|
|
"""
|
|
error_message = str(e)
|
|
|
|
# Handle botocore ClientError with specific error codes
|
|
if isinstance(e, ClientError):
|
|
error_code = cast(ClientError, e).response.get('Error', {}).get('Code', '')
|
|
error_msg = cast(ClientError, e).response.get('Error', {}).get('Message', error_message)
|
|
|
|
# Rate limiting and throttling errors (retryable)
|
|
if error_code in [
|
|
'ThrottlingException',
|
|
'ProvisionedThroughputExceededException',
|
|
]:
|
|
logging.error(f'{operation} rate limit error: {error_msg}')
|
|
raise BedrockRateLimitError(f'Rate limit error: {error_msg}')
|
|
|
|
# Server errors (retryable)
|
|
elif error_code in ['ServiceUnavailableException', 'InternalServerException']:
|
|
logging.error(f'{operation} connection error: {error_msg}')
|
|
raise BedrockConnectionError(f'Service error: {error_msg}')
|
|
|
|
# Check for 5xx HTTP status codes (retryable)
|
|
elif cast(ClientError, e).response.get('ResponseMetadata', {}).get('HTTPStatusCode', 0) >= 500:
|
|
logging.error(f'{operation} server error: {error_msg}')
|
|
raise BedrockConnectionError(f'Server error: {error_msg}')
|
|
|
|
# Validation and other client errors (non-retryable)
|
|
else:
|
|
logging.error(f'{operation} client error: {error_msg}')
|
|
raise BedrockError(f'Client error: {error_msg}')
|
|
|
|
# Connection errors (retryable)
|
|
elif isinstance(e, BotocoreConnectionError):
|
|
logging.error(f'{operation} connection error: {error_message}')
|
|
raise BedrockConnectionError(f'Connection error: {error_message}')
|
|
|
|
# Timeout errors (retryable)
|
|
elif isinstance(e, (ReadTimeoutError, TimeoutError)):
|
|
logging.error(f'{operation} timeout error: {error_message}')
|
|
raise BedrockTimeoutError(f'Timeout error: {error_message}')
|
|
|
|
# Custom Bedrock errors (already properly typed)
|
|
elif isinstance(
|
|
e,
|
|
(
|
|
BedrockRateLimitError,
|
|
BedrockConnectionError,
|
|
BedrockTimeoutError,
|
|
BedrockError,
|
|
),
|
|
):
|
|
raise
|
|
|
|
# Unknown errors (non-retryable)
|
|
else:
|
|
logging.error(f'{operation} unexpected error: {error_message}')
|
|
raise BedrockError(f'Unexpected error: {error_message}')
|
|
|
|
|
|
@retry(
|
|
stop=stop_after_attempt(5),
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
retry=(
|
|
retry_if_exception_type(BedrockRateLimitError)
|
|
| retry_if_exception_type(BedrockConnectionError)
|
|
| retry_if_exception_type(BedrockTimeoutError)
|
|
),
|
|
)
|
|
async def bedrock_complete_if_cache(
|
|
model,
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
enable_cot: bool = False,
|
|
aws_access_key_id=None,
|
|
aws_secret_access_key=None,
|
|
aws_session_token=None,
|
|
**kwargs,
|
|
) -> str | AsyncIterator[str]:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
if enable_cot:
|
|
from lightrag.utils import logger
|
|
|
|
logger.debug('enable_cot=True is not supported for Bedrock and will be ignored.')
|
|
# Respect existing env; only set if a non-empty value is available
|
|
access_key = os.environ.get('AWS_ACCESS_KEY_ID') or aws_access_key_id
|
|
secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') or aws_secret_access_key
|
|
session_token = os.environ.get('AWS_SESSION_TOKEN') or aws_session_token
|
|
_set_env_if_present('AWS_ACCESS_KEY_ID', access_key)
|
|
_set_env_if_present('AWS_SECRET_ACCESS_KEY', secret_key)
|
|
_set_env_if_present('AWS_SESSION_TOKEN', session_token)
|
|
# Region handling: prefer env, else kwarg (optional)
|
|
region = os.environ.get('AWS_REGION') or kwargs.pop('aws_region', None)
|
|
kwargs.pop('hashing_kv', None)
|
|
# Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter
|
|
# We'll use this to determine whether to call converse_stream or converse
|
|
stream = bool(kwargs.pop('stream', False))
|
|
# Remove unsupported args for Bedrock Converse API
|
|
for k in [
|
|
'response_format',
|
|
'tools',
|
|
'tool_choice',
|
|
'seed',
|
|
'presence_penalty',
|
|
'frequency_penalty',
|
|
'n',
|
|
'logprobs',
|
|
'top_logprobs',
|
|
'max_completion_tokens',
|
|
'response_format',
|
|
]:
|
|
kwargs.pop(k, None)
|
|
# Fix message history format
|
|
messages = []
|
|
for history_message in history_messages:
|
|
message = copy.copy(history_message)
|
|
message['content'] = [{'text': message['content']}]
|
|
messages.append(message)
|
|
|
|
# Add user prompt
|
|
messages.append({'role': 'user', 'content': [{'text': prompt}]})
|
|
|
|
# Initialize Converse API arguments
|
|
args = {'modelId': model, 'messages': messages}
|
|
|
|
# Define system prompt
|
|
if system_prompt:
|
|
args['system'] = [{'text': system_prompt}]
|
|
|
|
# Map and set up inference parameters
|
|
inference_params_map = {
|
|
'max_tokens': 'maxTokens',
|
|
'top_p': 'topP',
|
|
'stop_sequences': 'stopSequences',
|
|
}
|
|
if inference_params := list(set(kwargs) & {'max_tokens', 'temperature', 'top_p', 'stop_sequences'}):
|
|
args['inferenceConfig'] = {}
|
|
for param in inference_params:
|
|
args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
|
|
|
|
# Import logging for error handling
|
|
import logging
|
|
|
|
# For streaming responses, we need a different approach to keep the connection open
|
|
if stream:
|
|
# Create a session that will be used throughout the streaming process
|
|
session = aioboto3.Session()
|
|
client = None
|
|
|
|
# Define the generator function that will manage the client lifecycle
|
|
async def stream_generator():
|
|
nonlocal client
|
|
|
|
# Create the client outside the generator to ensure it stays open
|
|
client = await session.client('bedrock-runtime', region_name=region).__aenter__()
|
|
event_stream = None
|
|
iteration_started = False
|
|
|
|
try:
|
|
# Make the API call
|
|
response = await client.converse_stream(**args, **kwargs)
|
|
event_stream = response.get('stream')
|
|
iteration_started = True
|
|
|
|
# Process the stream
|
|
async for event in event_stream:
|
|
# Validate event structure
|
|
if not event or not isinstance(event, dict):
|
|
continue
|
|
|
|
if 'contentBlockDelta' in event:
|
|
delta = event['contentBlockDelta'].get('delta', {})
|
|
text = delta.get('text')
|
|
if text:
|
|
yield text
|
|
# Handle other event types that might indicate stream end
|
|
elif 'messageStop' in event:
|
|
break
|
|
|
|
except Exception as e:
|
|
# Try to clean up resources if possible
|
|
if (
|
|
iteration_started
|
|
and event_stream
|
|
and hasattr(event_stream, 'aclose')
|
|
and callable(getattr(event_stream, 'aclose', None))
|
|
):
|
|
try:
|
|
await event_stream.aclose()
|
|
except Exception as close_error:
|
|
logging.warning(f'Failed to close Bedrock event stream: {close_error}')
|
|
|
|
# Convert to appropriate exception type
|
|
_handle_bedrock_exception(e, 'Bedrock streaming')
|
|
|
|
finally:
|
|
# Clean up the event stream
|
|
if (
|
|
iteration_started
|
|
and event_stream
|
|
and hasattr(event_stream, 'aclose')
|
|
and callable(getattr(event_stream, 'aclose', None))
|
|
):
|
|
try:
|
|
await event_stream.aclose()
|
|
except Exception as close_error:
|
|
logging.warning(f'Failed to close Bedrock event stream in finally block: {close_error}')
|
|
|
|
# Clean up the client
|
|
if client:
|
|
try:
|
|
await client.__aexit__(None, None, None)
|
|
except Exception as client_close_error:
|
|
logging.warning(f'Failed to close Bedrock client: {client_close_error}')
|
|
|
|
# Return the generator that manages its own lifecycle
|
|
return stream_generator()
|
|
|
|
# For non-streaming responses, use the standard async context manager pattern
|
|
session = aioboto3.Session()
|
|
async with session.client('bedrock-runtime', region_name=region) as bedrock_async_client:
|
|
try:
|
|
# Use converse for non-streaming responses
|
|
response = await bedrock_async_client.converse(**args, **kwargs)
|
|
|
|
# Validate response structure
|
|
if (
|
|
not response
|
|
or 'output' not in response
|
|
or 'message' not in response['output']
|
|
or 'content' not in response['output']['message']
|
|
or not response['output']['message']['content']
|
|
):
|
|
raise BedrockError('Invalid response structure from Bedrock API')
|
|
|
|
content = response['output']['message']['content'][0]['text']
|
|
|
|
if not content or content.strip() == '':
|
|
raise BedrockError('Received empty content from Bedrock API')
|
|
|
|
return content
|
|
|
|
except Exception as e:
|
|
# Convert to appropriate exception type
|
|
_handle_bedrock_exception(e, 'Bedrock converse')
|
|
|
|
|
|
# Generic Bedrock completion function
|
|
async def bedrock_complete(
|
|
prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs
|
|
) -> str | AsyncIterator[str]:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
kwargs.pop('keyword_extraction', None)
|
|
hashing_kv = kwargs.get('hashing_kv')
|
|
if not hashing_kv:
|
|
raise ValueError("'hashing_kv' parameter is required")
|
|
model_name = hashing_kv.global_config['llm_model_name']
|
|
result = await bedrock_complete_if_cache(
|
|
model_name,
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
)
|
|
return result
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
|
@retry(
|
|
stop=stop_after_attempt(5),
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
retry=(
|
|
retry_if_exception_type(BedrockRateLimitError)
|
|
| retry_if_exception_type(BedrockConnectionError)
|
|
| retry_if_exception_type(BedrockTimeoutError)
|
|
),
|
|
)
|
|
async def bedrock_embed(
|
|
texts: list[str],
|
|
model: str = 'amazon.titan-embed-text-v2:0',
|
|
aws_access_key_id=None,
|
|
aws_secret_access_key=None,
|
|
aws_session_token=None,
|
|
) -> np.ndarray:
|
|
# Respect existing env; only set if a non-empty value is available
|
|
access_key = os.environ.get('AWS_ACCESS_KEY_ID') or aws_access_key_id
|
|
secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY') or aws_secret_access_key
|
|
session_token = os.environ.get('AWS_SESSION_TOKEN') or aws_session_token
|
|
_set_env_if_present('AWS_ACCESS_KEY_ID', access_key)
|
|
_set_env_if_present('AWS_SECRET_ACCESS_KEY', secret_key)
|
|
_set_env_if_present('AWS_SESSION_TOKEN', session_token)
|
|
|
|
# Region handling: prefer env
|
|
region = os.environ.get('AWS_REGION')
|
|
|
|
session = aioboto3.Session()
|
|
async with session.client('bedrock-runtime', region_name=region) as bedrock_async_client:
|
|
try:
|
|
if (model_provider := model.split('.')[0]) == 'amazon':
|
|
embed_texts = []
|
|
for text in texts:
|
|
try:
|
|
if 'v2' in model:
|
|
body = json.dumps(
|
|
{
|
|
'inputText': text,
|
|
# 'dimensions': embedding_dim,
|
|
'embeddingTypes': ['float'],
|
|
}
|
|
)
|
|
elif 'v1' in model:
|
|
body = json.dumps({'inputText': text})
|
|
else:
|
|
raise BedrockError(f'Model {model} is not supported!')
|
|
|
|
response = await bedrock_async_client.invoke_model(
|
|
modelId=model,
|
|
body=body,
|
|
accept='application/json',
|
|
contentType='application/json',
|
|
)
|
|
|
|
response_body = await response.get('body').json()
|
|
|
|
# Validate response structure
|
|
if not response_body or 'embedding' not in response_body:
|
|
raise BedrockError(f'Invalid embedding response structure for text: {text[:50]}...')
|
|
|
|
embedding = response_body['embedding']
|
|
if not embedding:
|
|
raise BedrockError(f'Received empty embedding for text: {text[:50]}...')
|
|
|
|
embed_texts.append(embedding)
|
|
|
|
except Exception as e:
|
|
# Convert to appropriate exception type
|
|
_handle_bedrock_exception(e, 'Bedrock embedding (amazon, text chunk)')
|
|
|
|
elif model_provider == 'cohere':
|
|
try:
|
|
body = json.dumps(
|
|
{
|
|
'texts': texts,
|
|
'input_type': 'search_document',
|
|
'truncate': 'NONE',
|
|
}
|
|
)
|
|
|
|
response = await bedrock_async_client.invoke_model(
|
|
model=model,
|
|
body=body,
|
|
accept='application/json',
|
|
contentType='application/json',
|
|
)
|
|
|
|
response_body = json.loads(response.get('body').read())
|
|
|
|
# Validate response structure
|
|
if not response_body or 'embeddings' not in response_body:
|
|
raise BedrockError('Invalid embedding response structure from Cohere')
|
|
|
|
embeddings = response_body['embeddings']
|
|
if not embeddings or len(embeddings) != len(texts):
|
|
raise BedrockError(
|
|
f'Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}'
|
|
)
|
|
|
|
embed_texts = embeddings
|
|
|
|
except Exception as e:
|
|
# Convert to appropriate exception type
|
|
_handle_bedrock_exception(e, 'Bedrock embedding (cohere)')
|
|
|
|
else:
|
|
raise BedrockError(f"Model provider '{model_provider}' is not supported!")
|
|
|
|
# Final validation
|
|
if not embed_texts:
|
|
raise BedrockError('No embeddings generated')
|
|
|
|
return np.array(embed_texts)
|
|
|
|
except Exception as e:
|
|
# Convert to appropriate exception type
|
|
_handle_bedrock_exception(e, 'Bedrock embedding')
|