feat: add bedrock as supported llm provider (#1830)
<!-- .github/pull_request_template.md -->
## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->
Added support for AWS Bedrock, and the models that are available there.
This was a contributor PR that was never finished, so now I polished it
up and made it work.
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added AWS Bedrock as a new LLM provider with support for multiple
authentication methods.
* Integrated three new AI models: Claude 4.5 Sonnet, Claude 4.5 Haiku,
and Amazon Nova Lite.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
69e36cc834
7 changed files with 289 additions and 2 deletions
90
.github/workflows/test_llms.yml
vendored
90
.github/workflows/test_llms.yml
vendored
|
|
@ -84,3 +84,93 @@ jobs:
|
||||||
EMBEDDING_DIMENSIONS: "3072"
|
EMBEDDING_DIMENSIONS: "3072"
|
||||||
EMBEDDING_MAX_TOKENS: "8191"
|
EMBEDDING_MAX_TOKENS: "8191"
|
||||||
run: uv run python ./examples/python/simple_example.py
|
run: uv run python ./examples/python/simple_example.py
|
||||||
|
|
||||||
|
test-bedrock-api-key:
|
||||||
|
name: Run Bedrock API Key Test
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
extra-dependencies: "aws"
|
||||||
|
|
||||||
|
- name: Run Bedrock API Key Simple Example
|
||||||
|
env:
|
||||||
|
LLM_PROVIDER: "bedrock"
|
||||||
|
LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||||
|
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||||
|
LLM_MAX_TOKENS: "16384"
|
||||||
|
AWS_REGION_NAME: "eu-west-1"
|
||||||
|
EMBEDDING_PROVIDER: "bedrock"
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||||
|
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||||
|
EMBEDDING_DIMENSIONS: "1024"
|
||||||
|
EMBEDDING_MAX_TOKENS: "8191"
|
||||||
|
run: uv run python ./examples/python/simple_example.py
|
||||||
|
|
||||||
|
test-bedrock-aws-credentials:
|
||||||
|
name: Run Bedrock AWS Credentials Test
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
extra-dependencies: "aws"
|
||||||
|
|
||||||
|
- name: Run Bedrock AWS Credentials Simple Example
|
||||||
|
env:
|
||||||
|
LLM_PROVIDER: "bedrock"
|
||||||
|
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||||
|
LLM_MAX_TOKENS: "16384"
|
||||||
|
AWS_REGION_NAME: "eu-west-1"
|
||||||
|
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
EMBEDDING_PROVIDER: "bedrock"
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
|
||||||
|
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||||
|
EMBEDDING_DIMENSIONS: "1024"
|
||||||
|
EMBEDDING_MAX_TOKENS: "8191"
|
||||||
|
run: uv run python ./examples/python/simple_example.py
|
||||||
|
|
||||||
|
test-bedrock-aws-profile:
|
||||||
|
name: Run Bedrock AWS Profile Test
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
extra-dependencies: "aws"
|
||||||
|
|
||||||
|
- name: Configure AWS Profile
|
||||||
|
run: |
|
||||||
|
mkdir -p ~/.aws
|
||||||
|
cat > ~/.aws/credentials << EOF
|
||||||
|
[bedrock-test]
|
||||||
|
aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
- name: Run Bedrock AWS Profile Simple Example
|
||||||
|
env:
|
||||||
|
LLM_PROVIDER: "bedrock"
|
||||||
|
LLM_MODEL: "eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||||
|
LLM_MAX_TOKENS: "16384"
|
||||||
|
AWS_PROFILE_NAME: "bedrock-test"
|
||||||
|
AWS_REGION_NAME: "eu-west-1"
|
||||||
|
EMBEDDING_PROVIDER: "bedrock"
|
||||||
|
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
|
||||||
|
EMBEDDING_DIMENSIONS: "1024"
|
||||||
|
EMBEDDING_MAX_TOKENS: "8191"
|
||||||
|
run: uv run python ./examples/python/simple_example.py
|
||||||
|
|
@ -155,7 +155,7 @@ async def add(
|
||||||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||||
|
|
||||||
Optional:
|
Optional:
|
||||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral", "bedrock"
|
||||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||||
- DEFAULT_USER_EMAIL: Custom default user email
|
- DEFAULT_USER_EMAIL: Custom default user email
|
||||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ class S3Config(BaseSettings):
|
||||||
aws_access_key_id: Optional[str] = None
|
aws_access_key_id: Optional[str] = None
|
||||||
aws_secret_access_key: Optional[str] = None
|
aws_secret_access_key: Optional[str] = None
|
||||||
aws_session_token: Optional[str] = None
|
aws_session_token: Optional[str] = None
|
||||||
|
aws_profile_name: Optional[str] = None
|
||||||
|
aws_bedrock_runtime_endpoint: Optional[str] = None
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Bedrock LLM adapter module."""
|
||||||
|
|
||||||
|
from .adapter import BedrockAdapter
|
||||||
|
|
||||||
|
__all__ = ["BedrockAdapter"]
|
||||||
|
|
@ -0,0 +1,153 @@
|
||||||
|
import litellm
|
||||||
|
import instructor
|
||||||
|
from typing import Type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from litellm.exceptions import ContentPolicyViolationError
|
||||||
|
from instructor.exceptions import InstructorRetryException
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
|
LLMInterface,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.llm.exceptions import (
|
||||||
|
ContentPolicyFilterError,
|
||||||
|
MissingSystemPromptPathError,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.files.storage.s3_config import get_s3_config
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||||
|
rate_limit_async,
|
||||||
|
rate_limit_sync,
|
||||||
|
sleep_and_retry_async,
|
||||||
|
sleep_and_retry_sync,
|
||||||
|
)
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
|
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockAdapter(LLMInterface):
|
||||||
|
"""
|
||||||
|
Adapter for AWS Bedrock API with support for three authentication methods:
|
||||||
|
1. API Key (Bearer Token)
|
||||||
|
2. AWS Credentials (access key + secret key)
|
||||||
|
3. AWS Profile (boto3 credential chain)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "Bedrock"
|
||||||
|
model: str
|
||||||
|
api_key: str
|
||||||
|
default_instructor_mode = "json_schema_mode"
|
||||||
|
|
||||||
|
MAX_RETRIES = 5
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: str = None,
|
||||||
|
max_completion_tokens: int = 16384,
|
||||||
|
streaming: bool = False,
|
||||||
|
instructor_mode: str = None,
|
||||||
|
):
|
||||||
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
|
self.aclient = instructor.from_litellm(
|
||||||
|
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||||
|
)
|
||||||
|
self.client = instructor.from_litellm(litellm.completion)
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.streaming = streaming
|
||||||
|
|
||||||
|
def _create_bedrock_request(
|
||||||
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
) -> dict:
|
||||||
|
"""Create Bedrock request with authentication."""
|
||||||
|
|
||||||
|
request_params = {
|
||||||
|
"model": self.model,
|
||||||
|
"custom_llm_provider": "bedrock",
|
||||||
|
"drop_params": True,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": text_input},
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
],
|
||||||
|
"response_model": response_model,
|
||||||
|
"max_retries": self.MAX_RETRIES,
|
||||||
|
"max_completion_tokens": self.max_completion_tokens,
|
||||||
|
"stream": self.streaming,
|
||||||
|
}
|
||||||
|
|
||||||
|
s3_config = get_s3_config()
|
||||||
|
|
||||||
|
# Add authentication parameters
|
||||||
|
if self.api_key:
|
||||||
|
request_params["api_key"] = self.api_key
|
||||||
|
elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key:
|
||||||
|
request_params["aws_access_key_id"] = s3_config.aws_access_key_id
|
||||||
|
request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key
|
||||||
|
if s3_config.aws_session_token:
|
||||||
|
request_params["aws_session_token"] = s3_config.aws_session_token
|
||||||
|
elif s3_config.aws_profile_name:
|
||||||
|
request_params["aws_profile_name"] = s3_config.aws_profile_name
|
||||||
|
|
||||||
|
if s3_config.aws_region:
|
||||||
|
request_params["aws_region_name"] = s3_config.aws_region
|
||||||
|
|
||||||
|
# Add optional parameters
|
||||||
|
if s3_config.aws_bedrock_runtime_endpoint:
|
||||||
|
request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint
|
||||||
|
|
||||||
|
return request_params
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
|
@sleep_and_retry_async()
|
||||||
|
@rate_limit_async
|
||||||
|
async def acreate_structured_output(
|
||||||
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
) -> BaseModel:
|
||||||
|
"""Generate structured output from AWS Bedrock API."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
|
||||||
|
return await self.aclient.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
except (
|
||||||
|
ContentPolicyViolationError,
|
||||||
|
InstructorRetryException,
|
||||||
|
) as error:
|
||||||
|
if (
|
||||||
|
isinstance(error, InstructorRetryException)
|
||||||
|
and "content management policy" not in str(error).lower()
|
||||||
|
):
|
||||||
|
raise error
|
||||||
|
|
||||||
|
raise ContentPolicyFilterError(
|
||||||
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@observe
|
||||||
|
@sleep_and_retry_sync()
|
||||||
|
@rate_limit_sync
|
||||||
|
def create_structured_output(
|
||||||
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
) -> BaseModel:
|
||||||
|
"""Generate structured output from AWS Bedrock API (synchronous)."""
|
||||||
|
|
||||||
|
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
|
||||||
|
return self.client.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
|
"""Format and display the prompt for a user query."""
|
||||||
|
if not text_input:
|
||||||
|
text_input = "No user input provided."
|
||||||
|
if not system_prompt:
|
||||||
|
raise MissingSystemPromptPathError()
|
||||||
|
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||||
|
|
||||||
|
formatted_prompt = (
|
||||||
|
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||||
|
if system_prompt
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return formatted_prompt
|
||||||
|
|
@ -24,6 +24,7 @@ class LLMProvider(Enum):
|
||||||
- CUSTOM: Represents a custom provider option.
|
- CUSTOM: Represents a custom provider option.
|
||||||
- GEMINI: Represents the Gemini provider.
|
- GEMINI: Represents the Gemini provider.
|
||||||
- MISTRAL: Represents the Mistral AI provider.
|
- MISTRAL: Represents the Mistral AI provider.
|
||||||
|
- BEDROCK: Represents the AWS Bedrock provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
|
@ -32,6 +33,7 @@ class LLMProvider(Enum):
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
GEMINI = "gemini"
|
GEMINI = "gemini"
|
||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
|
BEDROCK = "bedrock"
|
||||||
|
|
||||||
|
|
||||||
def get_llm_client(raise_api_key_error: bool = True):
|
def get_llm_client(raise_api_key_error: bool = True):
|
||||||
|
|
@ -154,7 +156,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
elif provider == LLMProvider.MISTRAL:
|
elif provider == LLMProvider.MISTRAL:
|
||||||
if llm_config.llm_api_key is None:
|
if llm_config.llm_api_key is None and raise_api_key_error:
|
||||||
raise LLMAPIKeyNotSetError()
|
raise LLMAPIKeyNotSetError()
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||||
|
|
@ -169,5 +171,21 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif provider == LLMProvider.BEDROCK:
|
||||||
|
# if llm_config.llm_api_key is None and raise_api_key_error:
|
||||||
|
# raise LLMAPIKeyNotSetError()
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import (
|
||||||
|
BedrockAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BedrockAdapter(
|
||||||
|
model=llm_config.llm_model,
|
||||||
|
api_key=llm_config.llm_api_key,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
streaming=llm_config.llm_streaming,
|
||||||
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise UnsupportedLLMProviderError(provider)
|
raise UnsupportedLLMProviderError(provider)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ class ModelName(Enum):
|
||||||
anthropic = "anthropic"
|
anthropic = "anthropic"
|
||||||
gemini = "gemini"
|
gemini = "gemini"
|
||||||
mistral = "mistral"
|
mistral = "mistral"
|
||||||
|
bedrock = "bedrock"
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
|
|
@ -77,6 +78,10 @@ def get_settings() -> SettingsDict:
|
||||||
"value": "mistral",
|
"value": "mistral",
|
||||||
"label": "Mistral",
|
"label": "Mistral",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"value": "bedrock",
|
||||||
|
"label": "Bedrock",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
return SettingsDict.model_validate(
|
return SettingsDict.model_validate(
|
||||||
|
|
@ -157,6 +162,20 @@ def get_settings() -> SettingsDict:
|
||||||
"label": "Mistral Large 2.1",
|
"label": "Mistral Large 2.1",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"bedrock": [
|
||||||
|
{
|
||||||
|
"value": "eu.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||||
|
"label": "Claude 4.5 Sonnet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "eu.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||||
|
"label": "Claude 4.5 Haiku",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "eu.amazon.nova-lite-v1:0",
|
||||||
|
"label": "Amazon Nova Lite",
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
vector_db={
|
vector_db={
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue