Enhancement: support aws bedrock as an LLm binding #1733
This commit is contained in:
parent
5b0e26d9da
commit
99643f01de
3 changed files with 299 additions and 192 deletions
|
|
@ -77,9 +77,7 @@ def parse_args() -> argparse.Namespace:
|
|||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LightRAG FastAPI Server with separate working and input directories"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="LightRAG FastAPI Server with separate working and input directories")
|
||||
|
||||
# Server configuration
|
||||
parser.add_argument(
|
||||
|
|
@ -209,14 +207,14 @@ def parse_args() -> argparse.Namespace:
|
|||
"--llm-binding",
|
||||
type=str,
|
||||
default=get_env_value("LLM_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
||||
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai", "aws_bedrock"],
|
||||
help="LLM binding type (default: from env or ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-binding",
|
||||
type=str,
|
||||
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
||||
choices=["lollms", "ollama", "openai", "azure_openai"],
|
||||
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
|
||||
help="Embedding binding type (default: from env or ollama)",
|
||||
)
|
||||
|
||||
|
|
@ -272,18 +270,10 @@ def parse_args() -> argparse.Namespace:
|
|||
args.input_dir = os.path.abspath(args.input_dir)
|
||||
|
||||
# Inject storage configuration from environment variables
|
||||
args.kv_storage = get_env_value(
|
||||
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
||||
)
|
||||
args.doc_status_storage = get_env_value(
|
||||
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
||||
)
|
||||
args.graph_storage = get_env_value(
|
||||
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
||||
)
|
||||
args.vector_storage = get_env_value(
|
||||
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
||||
)
|
||||
args.kv_storage = get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE)
|
||||
args.doc_status_storage = get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE)
|
||||
args.graph_storage = get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE)
|
||||
args.vector_storage = get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE)
|
||||
|
||||
# Get MAX_PARALLEL_INSERT from environment
|
||||
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
||||
|
|
@ -299,12 +289,8 @@ def parse_args() -> argparse.Namespace:
|
|||
# Ollama ctx_num
|
||||
args.ollama_num_ctx = get_env_value("OLLAMA_NUM_CTX", 32768, int)
|
||||
|
||||
args.llm_binding_host = get_env_value(
|
||||
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
||||
)
|
||||
args.embedding_binding_host = get_env_value(
|
||||
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
||||
)
|
||||
args.llm_binding_host = get_env_value("LLM_BINDING_HOST", get_default_host(args.llm_binding))
|
||||
args.embedding_binding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding))
|
||||
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
||||
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
||||
|
||||
|
|
@ -318,9 +304,7 @@ def parse_args() -> argparse.Namespace:
|
|||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||
|
||||
# Inject LLM cache configuration
|
||||
args.enable_llm_cache_for_extract = get_env_value(
|
||||
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
||||
)
|
||||
args.enable_llm_cache_for_extract = get_env_value("ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool)
|
||||
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
|
||||
|
||||
# Handle Ollama LLM temperature with priority cascade when llm-binding is ollama
|
||||
|
|
@ -370,40 +354,24 @@ def parse_args() -> argparse.Namespace:
|
|||
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
|
||||
|
||||
# Min rerank score configuration
|
||||
args.min_rerank_score = get_env_value(
|
||||
"MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float
|
||||
)
|
||||
args.min_rerank_score = get_env_value("MIN_RERANK_SCORE", DEFAULT_MIN_RERANK_SCORE, float)
|
||||
|
||||
# Query configuration
|
||||
args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int)
|
||||
args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int)
|
||||
args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int)
|
||||
args.max_entity_tokens = get_env_value(
|
||||
"MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int
|
||||
)
|
||||
args.max_relation_tokens = get_env_value(
|
||||
"MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int
|
||||
)
|
||||
args.max_total_tokens = get_env_value(
|
||||
"MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int
|
||||
)
|
||||
args.cosine_threshold = get_env_value(
|
||||
"COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float
|
||||
)
|
||||
args.related_chunk_number = get_env_value(
|
||||
"RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int
|
||||
)
|
||||
args.max_entity_tokens = get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int)
|
||||
args.max_relation_tokens = get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int)
|
||||
args.max_total_tokens = get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int)
|
||||
args.cosine_threshold = get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float)
|
||||
args.related_chunk_number = get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int)
|
||||
|
||||
# Add missing environment variables for health endpoint
|
||||
args.force_llm_summary_on_merge = get_env_value(
|
||||
"FORCE_LLM_SUMMARY_ON_MERGE", DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int
|
||||
)
|
||||
args.embedding_func_max_async = get_env_value(
|
||||
"EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int
|
||||
)
|
||||
args.embedding_batch_num = get_env_value(
|
||||
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
||||
)
|
||||
args.embedding_func_max_async = get_env_value("EMBEDDING_FUNC_MAX_ASYNC", DEFAULT_EMBEDDING_FUNC_MAX_ASYNC, int)
|
||||
args.embedding_batch_num = get_env_value("EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int)
|
||||
|
||||
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
||||
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
||||
|
|
@ -417,9 +385,7 @@ def update_uvicorn_mode_config():
|
|||
original_workers = global_args.workers
|
||||
global_args.workers = 1
|
||||
# Log warning directly here
|
||||
logging.warning(
|
||||
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
||||
)
|
||||
logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1")
|
||||
|
||||
|
||||
global_args = parse_args()
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ def create_app(args):
|
|||
"openai",
|
||||
"openai-ollama",
|
||||
"azure_openai",
|
||||
"aws_bedrock",
|
||||
]:
|
||||
raise Exception("llm binding not supported")
|
||||
|
||||
|
|
@ -114,6 +115,7 @@ def create_app(args):
|
|||
"ollama",
|
||||
"openai",
|
||||
"azure_openai",
|
||||
"aws_bedrock",
|
||||
"jina",
|
||||
]:
|
||||
raise Exception("embedding binding not supported")
|
||||
|
|
@ -128,9 +130,7 @@ def create_app(args):
|
|||
# Add SSL validation
|
||||
if args.ssl:
|
||||
if not args.ssl_certfile or not args.ssl_keyfile:
|
||||
raise Exception(
|
||||
"SSL certificate and key files must be provided when SSL is enabled"
|
||||
)
|
||||
raise Exception("SSL certificate and key files must be provided when SSL is enabled")
|
||||
if not os.path.exists(args.ssl_certfile):
|
||||
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
|
||||
if not os.path.exists(args.ssl_keyfile):
|
||||
|
|
@ -188,10 +188,11 @@ def create_app(args):
|
|||
# Initialize FastAPI
|
||||
app_kwargs = {
|
||||
"title": "LightRAG Server API",
|
||||
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
|
||||
+ "(With authentication)"
|
||||
if api_key
|
||||
else "",
|
||||
"description": (
|
||||
"Providing API for LightRAG core, Web UI and Ollama Model Emulation" + "(With authentication)"
|
||||
if api_key
|
||||
else ""
|
||||
),
|
||||
"version": __api_version__,
|
||||
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
|
||||
"docs_url": "/docs", # Explicitly set docs URL
|
||||
|
|
@ -244,9 +245,9 @@ def create_app(args):
|
|||
azure_openai_complete_if_cache,
|
||||
azure_openai_embed,
|
||||
)
|
||||
if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed
|
||||
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
|
||||
from lightrag.llm.openai import openai_complete_if_cache
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||
if args.embedding_binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
|
@ -312,41 +313,80 @@ def create_app(args):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
async def bedrock_model_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
# Use global temperature for Bedrock
|
||||
kwargs["temperature"] = args.temperature
|
||||
|
||||
return await bedrock_complete_if_cache(
|
||||
args.llm_model,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
func=lambda texts: lollms_embed(
|
||||
texts,
|
||||
embed_model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
)
|
||||
if args.embedding_binding == "lollms"
|
||||
else ollama_embed(
|
||||
texts,
|
||||
embed_model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
options=OllamaEmbeddingOptions.options_dict(args),
|
||||
)
|
||||
if args.embedding_binding == "ollama"
|
||||
else azure_openai_embed(
|
||||
texts,
|
||||
model=args.embedding_model, # no host is used for openai,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
)
|
||||
if args.embedding_binding == "azure_openai"
|
||||
else jina_embed(
|
||||
texts,
|
||||
dimensions=args.embedding_dim,
|
||||
base_url=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
)
|
||||
if args.embedding_binding == "jina"
|
||||
else openai_embed(
|
||||
texts,
|
||||
model=args.embedding_model,
|
||||
base_url=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
func=lambda texts: (
|
||||
lollms_embed(
|
||||
texts,
|
||||
embed_model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
)
|
||||
if args.embedding_binding == "lollms"
|
||||
else (
|
||||
ollama_embed(
|
||||
texts,
|
||||
embed_model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
options=OllamaEmbeddingOptions.options_dict(args),
|
||||
)
|
||||
if args.embedding_binding == "ollama"
|
||||
else (
|
||||
azure_openai_embed(
|
||||
texts,
|
||||
model=args.embedding_model, # no host is used for openai,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
)
|
||||
if args.embedding_binding == "azure_openai"
|
||||
else (
|
||||
bedrock_embed(
|
||||
texts,
|
||||
model=args.embedding_model,
|
||||
)
|
||||
if args.embedding_binding == "aws_bedrock"
|
||||
else (
|
||||
jina_embed(
|
||||
texts,
|
||||
dimensions=args.embedding_dim,
|
||||
base_url=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
)
|
||||
if args.embedding_binding == "jina"
|
||||
else openai_embed(
|
||||
texts,
|
||||
model=args.embedding_model,
|
||||
base_url=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -355,9 +395,7 @@ def create_app(args):
|
|||
if args.rerank_binding_api_key and args.rerank_binding_host:
|
||||
from lightrag.rerank import custom_rerank
|
||||
|
||||
async def server_rerank_func(
|
||||
query: str, documents: list, top_n: int = None, **kwargs
|
||||
):
|
||||
async def server_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
||||
"""Server rerank function with configuration from environment variables"""
|
||||
return await custom_rerank(
|
||||
query=query,
|
||||
|
|
@ -370,9 +408,7 @@ def create_app(args):
|
|||
)
|
||||
|
||||
rerank_model_func = server_rerank_func
|
||||
logger.info(
|
||||
f"Rerank model configured: {args.rerank_model} (can be enabled per query)"
|
||||
)
|
||||
logger.info(f"Rerank model configured: {args.rerank_model} (can be enabled per query)")
|
||||
else:
|
||||
logger.info(
|
||||
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
|
||||
|
|
@ -381,41 +417,43 @@ def create_app(args):
|
|||
# Create ollama_server_infos from command line arguments
|
||||
from lightrag.api.config import OllamaServerInfos
|
||||
|
||||
ollama_server_infos = OllamaServerInfos(
|
||||
name=args.simulated_model_name, tag=args.simulated_model_tag
|
||||
)
|
||||
ollama_server_infos = OllamaServerInfos(name=args.simulated_model_name, tag=args.simulated_model_tag)
|
||||
|
||||
# Initialize RAG
|
||||
if args.llm_binding in ["lollms", "ollama", "openai"]:
|
||||
if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]:
|
||||
rag = LightRAG(
|
||||
working_dir=args.working_dir,
|
||||
workspace=args.workspace,
|
||||
llm_model_func=lollms_model_complete
|
||||
if args.llm_binding == "lollms"
|
||||
else ollama_model_complete
|
||||
if args.llm_binding == "ollama"
|
||||
else openai_alike_model_complete,
|
||||
llm_model_func=(
|
||||
lollms_model_complete
|
||||
if args.llm_binding == "lollms"
|
||||
else (
|
||||
ollama_model_complete
|
||||
if args.llm_binding == "ollama"
|
||||
else bedrock_model_complete if args.llm_binding == "aws_bedrock" else openai_alike_model_complete
|
||||
)
|
||||
),
|
||||
llm_model_name=args.llm_model,
|
||||
llm_model_max_async=args.max_async,
|
||||
summary_max_tokens=args.max_tokens,
|
||||
chunk_token_size=int(args.chunk_size),
|
||||
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
||||
llm_model_kwargs={
|
||||
"host": args.llm_binding_host,
|
||||
"timeout": args.timeout,
|
||||
"options": OllamaLLMOptions.options_dict(args),
|
||||
"api_key": args.llm_binding_api_key,
|
||||
}
|
||||
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
||||
else {},
|
||||
llm_model_kwargs=(
|
||||
{
|
||||
"host": args.llm_binding_host,
|
||||
"timeout": args.timeout,
|
||||
"options": OllamaLLMOptions.options_dict(args),
|
||||
"api_key": args.llm_binding_api_key,
|
||||
}
|
||||
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
||||
else {}
|
||||
),
|
||||
embedding_func=embedding_func,
|
||||
kv_storage=args.kv_storage,
|
||||
graph_storage=args.graph_storage,
|
||||
vector_storage=args.vector_storage,
|
||||
doc_status_storage=args.doc_status_storage,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold},
|
||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||
enable_llm_cache=args.enable_llm_cache,
|
||||
rerank_model_func=rerank_model_func,
|
||||
|
|
@ -442,9 +480,7 @@ def create_app(args):
|
|||
graph_storage=args.graph_storage,
|
||||
vector_storage=args.vector_storage,
|
||||
doc_status_storage=args.doc_status_storage,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
vector_db_storage_cls_kwargs={"cosine_better_than_threshold": args.cosine_threshold},
|
||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||||
enable_llm_cache=args.enable_llm_cache,
|
||||
rerank_model_func=rerank_model_func,
|
||||
|
|
@ -480,9 +516,7 @@ def create_app(args):
|
|||
|
||||
if not auth_handler.accounts:
|
||||
# Authentication not configured, return guest token
|
||||
guest_token = auth_handler.create_token(
|
||||
username="guest", role="guest", metadata={"auth_mode": "disabled"}
|
||||
)
|
||||
guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"})
|
||||
return {
|
||||
"auth_configured": False,
|
||||
"access_token": guest_token,
|
||||
|
|
@ -508,9 +542,7 @@ def create_app(args):
|
|||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
if not auth_handler.accounts:
|
||||
# Authentication not configured, return guest token
|
||||
guest_token = auth_handler.create_token(
|
||||
username="guest", role="guest", metadata={"auth_mode": "disabled"}
|
||||
)
|
||||
guest_token = auth_handler.create_token(username="guest", role="guest", metadata={"auth_mode": "disabled"})
|
||||
return {
|
||||
"access_token": guest_token,
|
||||
"token_type": "bearer",
|
||||
|
|
@ -523,14 +555,10 @@ def create_app(args):
|
|||
}
|
||||
username = form_data.username
|
||||
if auth_handler.accounts.get(username) != form_data.password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials")
|
||||
|
||||
# Regular user login
|
||||
user_token = auth_handler.create_token(
|
||||
username=username, role="user", metadata={"auth_mode": "enabled"}
|
||||
)
|
||||
user_token = auth_handler.create_token(username=username, role="user", metadata={"auth_mode": "enabled"})
|
||||
return {
|
||||
"access_token": user_token,
|
||||
"token_type": "bearer",
|
||||
|
|
@ -579,12 +607,8 @@ def create_app(args):
|
|||
"max_graph_nodes": args.max_graph_nodes,
|
||||
# Rerank configuration (based on whether rerank model is configured)
|
||||
"enable_rerank": rerank_model_func is not None,
|
||||
"rerank_model": args.rerank_model
|
||||
if rerank_model_func is not None
|
||||
else None,
|
||||
"rerank_binding_host": args.rerank_binding_host
|
||||
if rerank_model_func is not None
|
||||
else None,
|
||||
"rerank_model": args.rerank_model if rerank_model_func is not None else None,
|
||||
"rerank_binding_host": args.rerank_binding_host if rerank_model_func is not None else None,
|
||||
# Environment variable status (requested configuration)
|
||||
"summary_language": args.summary_language,
|
||||
"force_llm_summary_on_merge": args.force_llm_summary_on_merge,
|
||||
|
|
@ -614,17 +638,11 @@ def create_app(args):
|
|||
response = await super().get_response(path, scope)
|
||||
|
||||
if path.endswith(".html"):
|
||||
response.headers["Cache-Control"] = (
|
||||
"no-cache, no-store, must-revalidate"
|
||||
)
|
||||
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
elif (
|
||||
"/assets/" in path
|
||||
): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename
|
||||
response.headers["Cache-Control"] = (
|
||||
"public, max-age=31536000, immutable"
|
||||
)
|
||||
elif "/assets/" in path: # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename
|
||||
response.headers["Cache-Control"] = "public, max-age=31536000, immutable"
|
||||
# Add other rules here if needed for non-HTML, non-asset files
|
||||
|
||||
# Ensure correct Content-Type
|
||||
|
|
@ -640,9 +658,7 @@ def create_app(args):
|
|||
static_dir.mkdir(exist_ok=True)
|
||||
app.mount(
|
||||
"/webui",
|
||||
SmartStaticFiles(
|
||||
directory=static_dir, html=True, check_dir=True
|
||||
), # Use SmartStaticFiles
|
||||
SmartStaticFiles(directory=static_dir, html=True, check_dir=True), # Use SmartStaticFiles
|
||||
name="webui",
|
||||
)
|
||||
|
||||
|
|
@ -798,9 +814,7 @@ def main():
|
|||
}
|
||||
)
|
||||
|
||||
print(
|
||||
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
|
||||
)
|
||||
print(f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}")
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,25 @@ from tenacity import (
|
|||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Union
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
||||
|
||||
def _set_env_if_present(key: str, value):
|
||||
"""Set environment variable only if a non-empty value is provided."""
|
||||
if value is not None and value != "":
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, max=60),
|
||||
|
|
@ -34,17 +48,35 @@ async def bedrock_complete_if_cache(
|
|||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||
)
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
# Respect existing env; only set if a non-empty value is available
|
||||
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
|
||||
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
|
||||
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
|
||||
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
|
||||
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
|
||||
# Region handling: prefer env, else kwarg (optional)
|
||||
region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
# Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter
|
||||
# We'll use this to determine whether to call converse_stream or converse
|
||||
stream = bool(kwargs.pop("stream", False))
|
||||
# Remove unsupported args for Bedrock Converse API
|
||||
for k in [
|
||||
"response_format",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_completion_tokens",
|
||||
"response_format",
|
||||
]:
|
||||
kwargs.pop(k, None)
|
||||
# Fix message history format
|
||||
messages = []
|
||||
for history_message in history_messages:
|
||||
|
|
@ -68,30 +100,126 @@ async def bedrock_complete_if_cache(
|
|||
"top_p": "topP",
|
||||
"stop_sequences": "stopSequences",
|
||||
}
|
||||
if inference_params := list(
|
||||
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
||||
):
|
||||
if inference_params := list(set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])):
|
||||
args["inferenceConfig"] = {}
|
||||
for param in inference_params:
|
||||
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||
kwargs.pop(param)
|
||||
)
|
||||
args["inferenceConfig"][inference_params_map.get(param, param)] = kwargs.pop(param)
|
||||
|
||||
# Call model via Converse API
|
||||
# Import logging for error handling
|
||||
import logging
|
||||
|
||||
# For streaming responses, we need a different approach to keep the connection open
|
||||
if stream:
|
||||
# Create a session that will be used throughout the streaming process
|
||||
session = aioboto3.Session()
|
||||
client = None
|
||||
|
||||
# Define the generator function that will manage the client lifecycle
|
||||
async def stream_generator():
|
||||
nonlocal client
|
||||
|
||||
# Create the client outside the generator to ensure it stays open
|
||||
client = await session.client("bedrock-runtime", region_name=region).__aenter__()
|
||||
event_stream = None
|
||||
iteration_started = False
|
||||
|
||||
try:
|
||||
# Make the API call
|
||||
response = await client.converse_stream(**args, **kwargs)
|
||||
event_stream = response.get("stream")
|
||||
iteration_started = True
|
||||
|
||||
# Process the stream
|
||||
async for event in event_stream:
|
||||
# Validate event structure
|
||||
if not event or not isinstance(event, dict):
|
||||
continue
|
||||
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"].get("delta", {})
|
||||
text = delta.get("text")
|
||||
if text:
|
||||
yield text
|
||||
# Handle other event types that might indicate stream end
|
||||
elif "messageStop" in event:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
# Log the specific error for debugging
|
||||
logging.error(f"Bedrock streaming error: {e}")
|
||||
|
||||
# Try to clean up resources if possible
|
||||
if (
|
||||
iteration_started
|
||||
and event_stream
|
||||
and hasattr(event_stream, "aclose")
|
||||
and callable(getattr(event_stream, "aclose", None))
|
||||
):
|
||||
try:
|
||||
await event_stream.aclose()
|
||||
except Exception as close_error:
|
||||
logging.warning(f"Failed to close Bedrock event stream: {close_error}")
|
||||
|
||||
raise BedrockError(f"Streaming error: {e}")
|
||||
|
||||
finally:
|
||||
# Clean up the event stream
|
||||
if (
|
||||
iteration_started
|
||||
and event_stream
|
||||
and hasattr(event_stream, "aclose")
|
||||
and callable(getattr(event_stream, "aclose", None))
|
||||
):
|
||||
try:
|
||||
await event_stream.aclose()
|
||||
except Exception as close_error:
|
||||
logging.warning(f"Failed to close Bedrock event stream in finally block: {close_error}")
|
||||
|
||||
# Clean up the client
|
||||
if client:
|
||||
try:
|
||||
await client.__aexit__(None, None, None)
|
||||
except Exception as client_close_error:
|
||||
logging.warning(f"Failed to close Bedrock client: {client_close_error}")
|
||||
|
||||
# Return the generator that manages its own lifecycle
|
||||
return stream_generator()
|
||||
|
||||
# For non-streaming responses, use the standard async context manager pattern
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client:
|
||||
try:
|
||||
# Use converse for non-streaming responses
|
||||
response = await bedrock_async_client.converse(**args, **kwargs)
|
||||
except Exception as e:
|
||||
raise BedrockError(e)
|
||||
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
# Validate response structure
|
||||
if (
|
||||
not response
|
||||
or "output" not in response
|
||||
or "message" not in response["output"]
|
||||
or "content" not in response["output"]["message"]
|
||||
or not response["output"]["message"]["content"]
|
||||
):
|
||||
raise BedrockError("Invalid response structure from Bedrock API")
|
||||
|
||||
content = response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
if not content or content.strip() == "":
|
||||
raise BedrockError("Received empty content from Bedrock API")
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, BedrockError):
|
||||
raise
|
||||
else:
|
||||
raise BedrockError(f"Bedrock API error: {e}")
|
||||
|
||||
|
||||
# Generic Bedrock completion function
|
||||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
result = await bedrock_complete_if_cache(
|
||||
|
|
@ -117,18 +245,19 @@ async def bedrock_embed(
|
|||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
) -> np.ndarray:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||
)
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
# Respect existing env; only set if a non-empty value is available
|
||||
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
|
||||
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
|
||||
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
|
||||
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
|
||||
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
|
||||
|
||||
# Region handling: prefer env
|
||||
region = os.environ.get("AWS_REGION")
|
||||
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
async with session.client("bedrock-runtime", region_name=region) as bedrock_async_client:
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
|
|
@ -156,9 +285,7 @@ async def bedrock_embed(
|
|||
|
||||
embed_texts.append(response_body["embedding"])
|
||||
elif model_provider == "cohere":
|
||||
body = json.dumps(
|
||||
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
||||
)
|
||||
body = json.dumps({"texts": texts, "input_type": "search_document", "truncate": "NONE"})
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue