Merge branch 'bedrock-support'
This commit is contained in:
commit
3a7310873c
6 changed files with 295 additions and 89 deletions
|
|
@ -123,7 +123,7 @@ MAX_PARALLEL_INSERT=2
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
### LLM Configuration
|
### LLM Configuration
|
||||||
### LLM_BINDING type: openai, ollama, lollms, azure_openai
|
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
||||||
###########################################################
|
###########################################################
|
||||||
### LLM temperature setting for all llm binding (openai, azure_openai, ollama)
|
### LLM temperature setting for all llm binding (openai, azure_openai, ollama)
|
||||||
# TEMPERATURE=1.0
|
# TEMPERATURE=1.0
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ LightRAG 需要同时集成 LLM(大型语言模型)和嵌入模型以有效
|
||||||
* lollms
|
* lollms
|
||||||
* openai 或 openai 兼容
|
* openai 或 openai 兼容
|
||||||
* azure_openai
|
* azure_openai
|
||||||
|
* aws_bedrock
|
||||||
|
|
||||||
建议使用环境变量来配置 LightRAG 服务器。项目根目录中有一个名为 `env.example` 的示例环境变量文件。请将此文件复制到启动目录并重命名为 `.env`。之后,您可以在 `.env` 文件中修改与 LLM 和嵌入模型相关的参数。需要注意的是,LightRAG 服务器每次启动时都会将 `.env` 中的环境变量加载到系统环境变量中。**LightRAG 服务器会优先使用系统环境变量中的设置**。
|
建议使用环境变量来配置 LightRAG 服务器。项目根目录中有一个名为 `env.example` 的示例环境变量文件。请将此文件复制到启动目录并重命名为 `.env`。之后,您可以在 `.env` 文件中修改与 LLM 和嵌入模型相关的参数。需要注意的是,LightRAG 服务器每次启动时都会将 `.env` 中的环境变量加载到系统环境变量中。**LightRAG 服务器会优先使用系统环境变量中的设置**。
|
||||||
|
|
||||||
|
|
@ -359,6 +360,7 @@ LightRAG 支持绑定到各种 LLM/嵌入后端:
|
||||||
* openai 和 openai 兼容
|
* openai 和 openai 兼容
|
||||||
* azure_openai
|
* azure_openai
|
||||||
* lollms
|
* lollms
|
||||||
|
* aws_bedrock
|
||||||
|
|
||||||
使用环境变量 `LLM_BINDING` 或 CLI 参数 `--llm-binding` 选择 LLM 后端类型。使用环境变量 `EMBEDDING_BINDING` 或 CLI 参数 `--embedding-binding` 选择嵌入后端类型。
|
使用环境变量 `LLM_BINDING` 或 CLI 参数 `--llm-binding` 选择 LLM 后端类型。使用环境变量 `EMBEDDING_BINDING` 或 CLI 参数 `--embedding-binding` 选择嵌入后端类型。
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ LightRAG necessitates the integration of both an LLM (Large Language Model) and
|
||||||
* lollms
|
* lollms
|
||||||
* openai or openai compatible
|
* openai or openai compatible
|
||||||
* azure_openai
|
* azure_openai
|
||||||
|
* aws_bedrock
|
||||||
|
|
||||||
It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**.
|
It is recommended to use environment variables to configure the LightRAG Server. There is an example environment variable file named `env.example` in the root directory of the project. Please copy this file to the startup directory and rename it to `.env`. After that, you can modify the parameters related to the LLM and Embedding models in the `.env` file. It is important to note that the LightRAG Server will load the environment variables from `.env` into the system environment variables each time it starts. **LightRAG Server will prioritize the settings in the system environment variables to .env file**.
|
||||||
|
|
||||||
|
|
@ -362,6 +363,7 @@ LightRAG supports binding to various LLM/Embedding backends:
|
||||||
* openai & openai compatible
|
* openai & openai compatible
|
||||||
* azure_openai
|
* azure_openai
|
||||||
* lollms
|
* lollms
|
||||||
|
* aws_bedrock
|
||||||
|
|
||||||
Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select the LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select the Embedding backend type.
|
Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select the LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select the Embedding backend type.
|
||||||
|
|
||||||
|
|
@ -461,8 +463,8 @@ You cannot change storage implementation selection after adding documents to Lig
|
||||||
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
|
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
|
||||||
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
|
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
|
||||||
| --cosine-threshold | 0.4 | The cosine threshold for nodes and relation retrieval, works with top-k to control the retrieval of nodes and relations. |
|
| --cosine-threshold | 0.4 | The cosine threshold for nodes and relation retrieval, works with top-k to control the retrieval of nodes and relations. |
|
||||||
| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai) |
|
| --llm-binding | ollama | LLM binding type (lollms, ollama, openai, openai-ollama, azure_openai, aws_bedrock) |
|
||||||
| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai) |
|
| --embedding-binding | ollama | Embedding binding type (lollms, ollama, openai, azure_openai, aws_bedrock) |
|
||||||
| --auto-scan-at-startup| - | Scan input directory for new files and start indexing |
|
| --auto-scan-at-startup| - | Scan input directory for new files and start indexing |
|
||||||
|
|
||||||
### Additional Ollama Binding Options
|
### Additional Ollama Binding Options
|
||||||
|
|
|
||||||
|
|
@ -209,14 +209,21 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--llm-binding",
|
"--llm-binding",
|
||||||
type=str,
|
type=str,
|
||||||
default=get_env_value("LLM_BINDING", "ollama"),
|
default=get_env_value("LLM_BINDING", "ollama"),
|
||||||
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
choices=[
|
||||||
|
"lollms",
|
||||||
|
"ollama",
|
||||||
|
"openai",
|
||||||
|
"openai-ollama",
|
||||||
|
"azure_openai",
|
||||||
|
"aws_bedrock",
|
||||||
|
],
|
||||||
help="LLM binding type (default: from env or ollama)",
|
help="LLM binding type (default: from env or ollama)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--embedding-binding",
|
"--embedding-binding",
|
||||||
type=str,
|
type=str,
|
||||||
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
||||||
choices=["lollms", "ollama", "openai", "azure_openai"],
|
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
|
||||||
help="Embedding binding type (default: from env or ollama)",
|
help="Embedding binding type (default: from env or ollama)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,8 @@ def create_app(args):
|
||||||
"lollms",
|
"lollms",
|
||||||
"ollama",
|
"ollama",
|
||||||
"openai",
|
"openai",
|
||||||
"openai-ollama",
|
|
||||||
"azure_openai",
|
"azure_openai",
|
||||||
|
"aws_bedrock",
|
||||||
]:
|
]:
|
||||||
raise Exception("llm binding not supported")
|
raise Exception("llm binding not supported")
|
||||||
|
|
||||||
|
|
@ -114,6 +114,7 @@ def create_app(args):
|
||||||
"ollama",
|
"ollama",
|
||||||
"openai",
|
"openai",
|
||||||
"azure_openai",
|
"azure_openai",
|
||||||
|
"aws_bedrock",
|
||||||
"jina",
|
"jina",
|
||||||
]:
|
]:
|
||||||
raise Exception("embedding binding not supported")
|
raise Exception("embedding binding not supported")
|
||||||
|
|
@ -188,10 +189,12 @@ def create_app(args):
|
||||||
# Initialize FastAPI
|
# Initialize FastAPI
|
||||||
app_kwargs = {
|
app_kwargs = {
|
||||||
"title": "LightRAG Server API",
|
"title": "LightRAG Server API",
|
||||||
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
|
"description": (
|
||||||
+ "(With authentication)"
|
"Providing API for LightRAG core, Web UI and Ollama Model Emulation"
|
||||||
if api_key
|
+ "(With authentication)"
|
||||||
else "",
|
if api_key
|
||||||
|
else ""
|
||||||
|
),
|
||||||
"version": __api_version__,
|
"version": __api_version__,
|
||||||
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
|
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
|
||||||
"docs_url": "/docs", # Explicitly set docs URL
|
"docs_url": "/docs", # Explicitly set docs URL
|
||||||
|
|
@ -244,9 +247,9 @@ def create_app(args):
|
||||||
azure_openai_complete_if_cache,
|
azure_openai_complete_if_cache,
|
||||||
azure_openai_embed,
|
azure_openai_embed,
|
||||||
)
|
)
|
||||||
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
|
if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock":
|
||||||
from lightrag.llm.openai import openai_complete_if_cache
|
from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed
|
||||||
from lightrag.llm.ollama import ollama_embed
|
if args.embedding_binding == "ollama":
|
||||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||||
if args.embedding_binding == "jina":
|
if args.embedding_binding == "jina":
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
@ -312,41 +315,80 @@ def create_app(args):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def bedrock_model_complete(
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=None,
|
||||||
|
keyword_extraction=False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||||
|
if keyword_extraction:
|
||||||
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
|
if history_messages is None:
|
||||||
|
history_messages = []
|
||||||
|
|
||||||
|
# Use global temperature for Bedrock
|
||||||
|
kwargs["temperature"] = args.temperature
|
||||||
|
|
||||||
|
return await bedrock_complete_if_cache(
|
||||||
|
args.llm_model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
embedding_func = EmbeddingFunc(
|
embedding_func = EmbeddingFunc(
|
||||||
embedding_dim=args.embedding_dim,
|
embedding_dim=args.embedding_dim,
|
||||||
func=lambda texts: lollms_embed(
|
func=lambda texts: (
|
||||||
texts,
|
lollms_embed(
|
||||||
embed_model=args.embedding_model,
|
texts,
|
||||||
host=args.embedding_binding_host,
|
embed_model=args.embedding_model,
|
||||||
api_key=args.embedding_binding_api_key,
|
host=args.embedding_binding_host,
|
||||||
)
|
api_key=args.embedding_binding_api_key,
|
||||||
if args.embedding_binding == "lollms"
|
)
|
||||||
else ollama_embed(
|
if args.embedding_binding == "lollms"
|
||||||
texts,
|
else (
|
||||||
embed_model=args.embedding_model,
|
ollama_embed(
|
||||||
host=args.embedding_binding_host,
|
texts,
|
||||||
api_key=args.embedding_binding_api_key,
|
embed_model=args.embedding_model,
|
||||||
options=OllamaEmbeddingOptions.options_dict(args),
|
host=args.embedding_binding_host,
|
||||||
)
|
api_key=args.embedding_binding_api_key,
|
||||||
if args.embedding_binding == "ollama"
|
options=OllamaEmbeddingOptions.options_dict(args),
|
||||||
else azure_openai_embed(
|
)
|
||||||
texts,
|
if args.embedding_binding == "ollama"
|
||||||
model=args.embedding_model, # no host is used for openai,
|
else (
|
||||||
api_key=args.embedding_binding_api_key,
|
azure_openai_embed(
|
||||||
)
|
texts,
|
||||||
if args.embedding_binding == "azure_openai"
|
model=args.embedding_model, # no host is used for openai,
|
||||||
else jina_embed(
|
api_key=args.embedding_binding_api_key,
|
||||||
texts,
|
)
|
||||||
dimensions=args.embedding_dim,
|
if args.embedding_binding == "azure_openai"
|
||||||
base_url=args.embedding_binding_host,
|
else (
|
||||||
api_key=args.embedding_binding_api_key,
|
bedrock_embed(
|
||||||
)
|
texts,
|
||||||
if args.embedding_binding == "jina"
|
model=args.embedding_model,
|
||||||
else openai_embed(
|
)
|
||||||
texts,
|
if args.embedding_binding == "aws_bedrock"
|
||||||
model=args.embedding_model,
|
else (
|
||||||
base_url=args.embedding_binding_host,
|
jina_embed(
|
||||||
api_key=args.embedding_binding_api_key,
|
texts,
|
||||||
|
dimensions=args.embedding_dim,
|
||||||
|
base_url=args.embedding_binding_host,
|
||||||
|
api_key=args.embedding_binding_api_key,
|
||||||
|
)
|
||||||
|
if args.embedding_binding == "jina"
|
||||||
|
else openai_embed(
|
||||||
|
texts,
|
||||||
|
model=args.embedding_model,
|
||||||
|
base_url=args.embedding_binding_host,
|
||||||
|
api_key=args.embedding_binding_api_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -386,28 +428,36 @@ def create_app(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize RAG
|
# Initialize RAG
|
||||||
if args.llm_binding in ["lollms", "ollama", "openai"]:
|
if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]:
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=args.working_dir,
|
working_dir=args.working_dir,
|
||||||
workspace=args.workspace,
|
workspace=args.workspace,
|
||||||
llm_model_func=lollms_model_complete
|
llm_model_func=(
|
||||||
if args.llm_binding == "lollms"
|
lollms_model_complete
|
||||||
else ollama_model_complete
|
if args.llm_binding == "lollms"
|
||||||
if args.llm_binding == "ollama"
|
else (
|
||||||
else openai_alike_model_complete,
|
ollama_model_complete
|
||||||
|
if args.llm_binding == "ollama"
|
||||||
|
else bedrock_model_complete
|
||||||
|
if args.llm_binding == "aws_bedrock"
|
||||||
|
else openai_alike_model_complete
|
||||||
|
)
|
||||||
|
),
|
||||||
llm_model_name=args.llm_model,
|
llm_model_name=args.llm_model,
|
||||||
llm_model_max_async=args.max_async,
|
llm_model_max_async=args.max_async,
|
||||||
summary_max_tokens=args.max_tokens,
|
summary_max_tokens=args.max_tokens,
|
||||||
chunk_token_size=int(args.chunk_size),
|
chunk_token_size=int(args.chunk_size),
|
||||||
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
||||||
llm_model_kwargs={
|
llm_model_kwargs=(
|
||||||
"host": args.llm_binding_host,
|
{
|
||||||
"timeout": args.timeout,
|
"host": args.llm_binding_host,
|
||||||
"options": OllamaLLMOptions.options_dict(args),
|
"timeout": args.timeout,
|
||||||
"api_key": args.llm_binding_api_key,
|
"options": OllamaLLMOptions.options_dict(args),
|
||||||
}
|
"api_key": args.llm_binding_api_key,
|
||||||
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
}
|
||||||
else {},
|
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
||||||
|
else {}
|
||||||
|
),
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
kv_storage=args.kv_storage,
|
kv_storage=args.kv_storage,
|
||||||
graph_storage=args.graph_storage,
|
graph_storage=args.graph_storage,
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,25 @@ from tenacity import (
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.version_info < (3, 9):
|
||||||
|
from typing import AsyncIterator
|
||||||
|
else:
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
class BedrockError(Exception):
|
class BedrockError(Exception):
|
||||||
"""Generic error for issues related to Amazon Bedrock"""
|
"""Generic error for issues related to Amazon Bedrock"""
|
||||||
|
|
||||||
|
|
||||||
|
def _set_env_if_present(key: str, value):
|
||||||
|
"""Set environment variable only if a non-empty value is provided."""
|
||||||
|
if value is not None and value != "":
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(5),
|
stop=stop_after_attempt(5),
|
||||||
wait=wait_exponential(multiplier=1, max=60),
|
wait=wait_exponential(multiplier=1, max=60),
|
||||||
|
|
@ -34,17 +48,35 @@ async def bedrock_complete_if_cache(
|
||||||
aws_secret_access_key=None,
|
aws_secret_access_key=None,
|
||||||
aws_session_token=None,
|
aws_session_token=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
# Respect existing env; only set if a non-empty value is available
|
||||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||||
)
|
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
|
||||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
|
||||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
|
||||||
)
|
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
|
||||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
|
||||||
"AWS_SESSION_TOKEN", aws_session_token
|
# Region handling: prefer env, else kwarg (optional)
|
||||||
)
|
region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None)
|
||||||
kwargs.pop("hashing_kv", None)
|
kwargs.pop("hashing_kv", None)
|
||||||
|
# Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter
|
||||||
|
# We'll use this to determine whether to call converse_stream or converse
|
||||||
|
stream = bool(kwargs.pop("stream", False))
|
||||||
|
# Remove unsupported args for Bedrock Converse API
|
||||||
|
for k in [
|
||||||
|
"response_format",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"seed",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"n",
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"response_format",
|
||||||
|
]:
|
||||||
|
kwargs.pop(k, None)
|
||||||
# Fix message history format
|
# Fix message history format
|
||||||
messages = []
|
messages = []
|
||||||
for history_message in history_messages:
|
for history_message in history_messages:
|
||||||
|
|
@ -77,21 +109,131 @@ async def bedrock_complete_if_cache(
|
||||||
kwargs.pop(param)
|
kwargs.pop(param)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call model via Converse API
|
# Import logging for error handling
|
||||||
session = aioboto3.Session()
|
import logging
|
||||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
|
||||||
try:
|
|
||||||
response = await bedrock_async_client.converse(**args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
raise BedrockError(e)
|
|
||||||
|
|
||||||
return response["output"]["message"]["content"][0]["text"]
|
# For streaming responses, we need a different approach to keep the connection open
|
||||||
|
if stream:
|
||||||
|
# Create a session that will be used throughout the streaming process
|
||||||
|
session = aioboto3.Session()
|
||||||
|
client = None
|
||||||
|
|
||||||
|
# Define the generator function that will manage the client lifecycle
|
||||||
|
async def stream_generator():
|
||||||
|
nonlocal client
|
||||||
|
|
||||||
|
# Create the client outside the generator to ensure it stays open
|
||||||
|
client = await session.client(
|
||||||
|
"bedrock-runtime", region_name=region
|
||||||
|
).__aenter__()
|
||||||
|
event_stream = None
|
||||||
|
iteration_started = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make the API call
|
||||||
|
response = await client.converse_stream(**args, **kwargs)
|
||||||
|
event_stream = response.get("stream")
|
||||||
|
iteration_started = True
|
||||||
|
|
||||||
|
# Process the stream
|
||||||
|
async for event in event_stream:
|
||||||
|
# Validate event structure
|
||||||
|
if not event or not isinstance(event, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "contentBlockDelta" in event:
|
||||||
|
delta = event["contentBlockDelta"].get("delta", {})
|
||||||
|
text = delta.get("text")
|
||||||
|
if text:
|
||||||
|
yield text
|
||||||
|
# Handle other event types that might indicate stream end
|
||||||
|
elif "messageStop" in event:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log the specific error for debugging
|
||||||
|
logging.error(f"Bedrock streaming error: {e}")
|
||||||
|
|
||||||
|
# Try to clean up resources if possible
|
||||||
|
if (
|
||||||
|
iteration_started
|
||||||
|
and event_stream
|
||||||
|
and hasattr(event_stream, "aclose")
|
||||||
|
and callable(getattr(event_stream, "aclose", None))
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await event_stream.aclose()
|
||||||
|
except Exception as close_error:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to close Bedrock event stream: {close_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise BedrockError(f"Streaming error: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up the event stream
|
||||||
|
if (
|
||||||
|
iteration_started
|
||||||
|
and event_stream
|
||||||
|
and hasattr(event_stream, "aclose")
|
||||||
|
and callable(getattr(event_stream, "aclose", None))
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await event_stream.aclose()
|
||||||
|
except Exception as close_error:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to close Bedrock event stream in finally block: {close_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up the client
|
||||||
|
if client:
|
||||||
|
try:
|
||||||
|
await client.__aexit__(None, None, None)
|
||||||
|
except Exception as client_close_error:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to close Bedrock client: {client_close_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the generator that manages its own lifecycle
|
||||||
|
return stream_generator()
|
||||||
|
|
||||||
|
# For non-streaming responses, use the standard async context manager pattern
|
||||||
|
session = aioboto3.Session()
|
||||||
|
async with session.client(
|
||||||
|
"bedrock-runtime", region_name=region
|
||||||
|
) as bedrock_async_client:
|
||||||
|
try:
|
||||||
|
# Use converse for non-streaming responses
|
||||||
|
response = await bedrock_async_client.converse(**args, **kwargs)
|
||||||
|
|
||||||
|
# Validate response structure
|
||||||
|
if (
|
||||||
|
not response
|
||||||
|
or "output" not in response
|
||||||
|
or "message" not in response["output"]
|
||||||
|
or "content" not in response["output"]["message"]
|
||||||
|
or not response["output"]["message"]["content"]
|
||||||
|
):
|
||||||
|
raise BedrockError("Invalid response structure from Bedrock API")
|
||||||
|
|
||||||
|
content = response["output"]["message"]["content"][0]["text"]
|
||||||
|
|
||||||
|
if not content or content.strip() == "":
|
||||||
|
raise BedrockError("Received empty content from Bedrock API")
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, BedrockError):
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise BedrockError(f"Bedrock API error: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Generic Bedrock completion function
|
# Generic Bedrock completion function
|
||||||
async def bedrock_complete(
|
async def bedrock_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
result = await bedrock_complete_if_cache(
|
result = await bedrock_complete_if_cache(
|
||||||
|
|
@ -117,18 +259,21 @@ async def bedrock_embed(
|
||||||
aws_secret_access_key=None,
|
aws_secret_access_key=None,
|
||||||
aws_session_token=None,
|
aws_session_token=None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
# Respect existing env; only set if a non-empty value is available
|
||||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||||
)
|
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
|
||||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
|
||||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
|
||||||
)
|
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
|
||||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
|
||||||
"AWS_SESSION_TOKEN", aws_session_token
|
|
||||||
)
|
# Region handling: prefer env
|
||||||
|
region = os.environ.get("AWS_REGION")
|
||||||
|
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
async with session.client(
|
||||||
|
"bedrock-runtime", region_name=region
|
||||||
|
) as bedrock_async_client:
|
||||||
if (model_provider := model.split(".")[0]) == "amazon":
|
if (model_provider := model.split(".")[0]) == "amazon":
|
||||||
embed_texts = []
|
embed_texts = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue