Adding AWS Bedrock support as a LLM provider

Signed-off-by: xdurawa <xavierdurawa@gmail.com>
This commit is contained in:
xdurawa 2025-09-01 20:47:26 -04:00
parent bf482ef91f
commit 2a46208569
15 changed files with 394 additions and 8 deletions

View file

@ -0,0 +1,28 @@
name: test | bedrock | api key
on:
workflow_call:
jobs:
test-bedrock-api-key:
name: Run Bedrock API Key Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Bedrock API Key Test
env:
LLM_PROVIDER: "bedrock"
LLM_API_KEY: ${{ secrets.BEDROCK_API_KEY }}
LLM_MODEL: "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
AWS_REGION_NAME: "us-east-1"
EMBEDDING_PROVIDER: "bedrock"
EMBEDDING_MODEL: "amazon.titan-embed-text-v1"
EMBEDDING_DIMENSIONS: "1536"
run: poetry run python ./examples/python/simple_example.py

View file

@ -0,0 +1,29 @@
name: test | bedrock | aws credentials
on:
workflow_call:
jobs:
test-bedrock-aws-credentials:
name: Run Bedrock AWS Credentials Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Bedrock AWS Credentials Test
env:
LLM_PROVIDER: "bedrock"
LLM_MODEL: "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: "us-east-1"
EMBEDDING_PROVIDER: "cohere"
EMBEDDING_MODEL: "cohere.embed-english-v3"
EMBEDDING_DIMENSIONS: "1024"
run: poetry run python ./examples/python/simple_example.py

View file

@ -0,0 +1,37 @@
name: test | bedrock | aws profile
on:
workflow_call:
jobs:
test-bedrock-aws-profile:
name: Run Bedrock AWS Profile Test
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Configure AWS Profile
run: |
mkdir -p ~/.aws
cat > ~/.aws/credentials << EOF
[bedrock-test]
aws_access_key_id = ${{ secrets.AWS_ACCESS_KEY_ID }}
aws_secret_access_key = ${{ secrets.AWS_SECRET_ACCESS_KEY }}
EOF
- name: Run Bedrock AWS Profile Test
env:
LLM_PROVIDER: "bedrock"
LLM_MODEL: "us.anthropic.claude-3-5-haiku-20241022-v1:0"
AWS_PROFILE_NAME: "bedrock-test"
AWS_REGION_NAME: "us-east-1"
EMBEDDING_PROVIDER: "bedrock"
EMBEDDING_MODEL: "amazon.titan-embed-text-v2:0"
EMBEDDING_DIMENSIONS: "1024"
run: poetry run python ./examples/python/simple_example.py

View file

@ -110,6 +110,24 @@ jobs:
uses: ./.github/workflows/test_gemini.yml
secrets: inherit
bedrock-tests:
name: Bedrock Tests
needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_bedrock_api_key.yml
secrets: inherit
bedrock-aws-credentials-tests:
name: Bedrock AWS Credentials Tests
needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_bedrock_aws_credentials.yml
secrets: inherit
bedrock-aws-profile-tests:
name: Bedrock AWS Profile Tests
needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_bedrock_aws_profile.yml
secrets: inherit
# Ollama tests moved to the end
ollama-tests:
name: Ollama Tests
@ -143,6 +161,9 @@ jobs:
db-examples-tests,
mcp-test,
gemini-tests,
bedrock-tests,
bedrock-aws-credentials-tests,
bedrock-aws-profile-tests,
ollama-tests,
relational-db-migration-tests,
docker-compose-test,
@ -163,6 +184,9 @@ jobs:
"${{ needs.db-examples-tests.result }}" == "success" &&
"${{ needs.relational-db-migration-tests.result }}" == "success" &&
"${{ needs.gemini-tests.result }}" == "success" &&
"${{ needs.bedrock-tests.result }}" == "success" &&
"${{ needs.bedrock-aws-credentials-tests.result }}" == "success" &&
"${{ needs.bedrock-aws-profile-tests.result }}" == "success" &&
"${{ needs.docker-compose-test.result }}" == "success" &&
"${{ needs.docker-ci-test.result }}" == "success" &&
"${{ needs.ollama-tests.result }}" == "success" ]]; then

View file

@ -122,6 +122,16 @@ os.environ["LLM_API_KEY"] = "YOUR OPENAI_API_KEY"
```
You can also set the variables by creating .env file, using our <a href="https://github.com/topoteretes/cognee/blob/main/.env.template">template.</a>
**Supported LLM Providers:** OpenAI (default), Anthropic, Gemini, Ollama, AWS Bedrock
**For AWS Bedrock:** Set `LLM_PROVIDER="bedrock"` and use one of three authentication methods:
- API Key: `LLM_API_KEY="your_bedrock_api_key"`
- AWS Credentials: `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY` (+ `AWS_SESSION_TOKEN` if needed)
- AWS Profile: `AWS_PROFILE_NAME="your_profile"`
Use an [inference profile](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html#API_runtime_InvokeModel_Example_5:~:text=Use%20an%20inference%20profile%20in%20model%20invocation) for the model IDs. This usually means appending `us.*` (or other region) to the model ID (e.g., `us.anthropic.claude-3-5-sonnet-20241022-v2:0`). See [AWS Bedrock models](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html).
To use different LLM providers, for more info check out our <a href="https://docs.cognee.ai">documentation</a>

View file

@ -25,6 +25,14 @@ uv sync
## Setup LLM
Add environment variables to `.env` file.
In case you choose to use OpenAI provider, add just the model and api_key.
**Supported LLM Providers:**
- OpenAI (default)
- Anthropic
- Gemini
- Ollama
- AWS Bedrock
```
LLM_PROVIDER=""
LLM_MODEL=""
@ -39,6 +47,36 @@ EMBEDDING_API_KEY=""
EMBEDDING_API_VERSION=""
```
**For AWS Bedrock, you have three authentication options:**
1. **API Key (Bearer Token):**
```
LLM_PROVIDER="bedrock"
LLM_API_KEY="your_bedrock_api_key"
LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0"
AWS_REGION_NAME="us-east-1"
```
2. **AWS Credentials:**
```
LLM_PROVIDER="bedrock"
LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0"
AWS_ACCESS_KEY_ID="your_aws_access_key"
AWS_SECRET_ACCESS_KEY="your_aws_secret_key"
[if needed] AWS_SESSION_TOKEN="your_session_token"
AWS_REGION_NAME="us-east-1"
```
3. **AWS Profile:**
```
LLM_PROVIDER="bedrock"
LLM_MODEL="us.anthropic.claude-3-5-sonnet-20241022-v2:0"
AWS_PROFILE_NAME="your_aws_profile"
AWS_REGION_NAME="us-east-1"
```
**Note:** For Bedrock models, use an [inference profile](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html#API_runtime_InvokeModel_Example_5:~:text=Use%20an%20inference%20profile%20in%20model%20invocation) for `LLM_MODEL`. This usually means appending `us.*` (or other region) to the model ID (e.g., `us.anthropic.claude-3-5-sonnet-20241022-v2:0`). See [AWS Bedrock models](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for available models.
Activate the Python environment:
```
source .venv/bin/activate

View file

@ -127,7 +127,7 @@ async def add(
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
Optional:
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "bedrock"
- LLM_MODEL: Model name (default: "gpt-4o-mini")
- DEFAULT_USER_EMAIL: Custom default user email
- DEFAULT_USER_PASSWORD: Custom default user password

View file

@ -154,7 +154,7 @@ async def search(
- LLM_API_KEY: API key for your LLM provider
Optional:
- LLM_PROVIDER, LLM_MODEL: Configure LLM for search responses
- LLM_PROVIDER, LLM_MODEL: Configure LLM for search responses (supports: openai, anthropic, gemini, ollama, bedrock)
- VECTOR_DB_PROVIDER: Must match what was used during cognify
- GRAPH_DATABASE_PROVIDER: Must match what was used during cognify

View file

@ -27,6 +27,12 @@ class LLMConfig(BaseSettings):
- embedding_rate_limit_enabled
- embedding_rate_limit_requests
- embedding_rate_limit_interval
- aws_access_key_id (Bedrock)
- aws_secret_access_key (Bedrock)
- aws_session_token (Bedrock)
- aws_region_name (Bedrock)
- aws_profile_name (Bedrock)
- aws_bedrock_runtime_endpoint (Bedrock)
Public methods include:
- ensure_env_vars_for_ollama
@ -63,6 +69,14 @@ class LLMConfig(BaseSettings):
fallback_endpoint: str = ""
fallback_model: str = ""
# AWS Bedrock configuration
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_session_token: Optional[str] = None
aws_region_name: str = "us-east-1"
aws_profile_name: Optional[str] = None
aws_bedrock_runtime_endpoint: Optional[str] = None
baml_registry: ClassVar[ClientRegistry] = ClientRegistry()
model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -0,0 +1,6 @@
"""Bedrock LLM adapter module."""
from .adapter import BedrockAdapter
__all__ = ["BedrockAdapter"]

View file

@ -0,0 +1,161 @@
import litellm
import instructor
from typing import Type, Optional
from pydantic import BaseModel
from litellm.exceptions import ContentPolicyViolationError
from instructor.exceptions import InstructorRetryException
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
sleep_and_retry_sync,
)
from cognee.modules.observability.get_observe import get_observe
observe = get_observe()
class BedrockAdapter(LLMInterface):
"""
Adapter for AWS Bedrock API with support for three authentication methods:
1. API Key (Bearer Token)
2. AWS Credentials (access key + secret key)
3. AWS Profile (boto3 credential chain)
"""
name = "Bedrock"
model: str
api_key: str
aws_access_key_id: str
aws_secret_access_key: str
aws_region_name: str
aws_profile_name: str
MAX_RETRIES = 5
def __init__(
self,
model: str,
api_key: str = None,
aws_access_key_id: str = None,
aws_secret_access_key: str = None,
aws_session_token: str = None,
aws_region_name: str = "us-east-1",
aws_profile_name: str = None,
aws_bedrock_runtime_endpoint: str = None,
max_tokens: int = 16384,
streaming: bool = False,
):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.model = model
self.api_key = api_key
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.aws_bedrock_runtime_endpoint = aws_bedrock_runtime_endpoint
self.max_tokens = max_tokens
self.streaming = streaming
def _create_bedrock_request(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> dict:
"""Create Bedrock request with authentication and enhanced JSON formatting."""
enhanced_system_prompt = f"""{system_prompt}
IMPORTANT: You must respond with valid JSON only. Do not include any text before or after the JSON. The response must be a valid JSON object that can be parsed directly."""
request_params = {
"model": self.model,
"custom_llm_provider": "bedrock",
"drop_params": True,
"messages": [
{"role": "user", "content": text_input},
{"role": "system", "content": enhanced_system_prompt},
],
"response_model": response_model,
"max_retries": self.MAX_RETRIES,
"max_tokens": self.max_tokens,
"stream": self.streaming,
}
# Add authentication parameters
if self.api_key:
request_params["api_key"] = self.api_key
elif self.aws_access_key_id and self.aws_secret_access_key:
request_params["aws_access_key_id"] = self.aws_access_key_id
request_params["aws_secret_access_key"] = self.aws_secret_access_key
if self.aws_session_token:
request_params["aws_session_token"] = self.aws_session_token
elif self.aws_profile_name:
request_params["aws_profile_name"] = self.aws_profile_name
# Add optional parameters
if self.aws_region_name:
request_params["aws_region_name"] = self.aws_region_name
if self.aws_bedrock_runtime_endpoint:
request_params["aws_bedrock_runtime_endpoint"] = self.aws_bedrock_runtime_endpoint
return request_params
@observe(as_type="generation")
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate structured output from AWS Bedrock API."""
try:
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
return await self.aclient.chat.completions.create(**request_params)
except (
ContentPolicyViolationError,
InstructorRetryException,
) as error:
if (
isinstance(error, InstructorRetryException)
and "content management policy" not in str(error).lower()
):
raise error
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
)
@observe
@sleep_and_retry_sync()
@rate_limit_sync
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""Generate structured output from AWS Bedrock API (synchronous)."""
request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
return self.client.chat.completions.create(**request_params)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise InvalidValueError(message="No system prompt path provided.")
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -20,6 +20,7 @@ class LLMProvider(Enum):
- ANTHROPIC: Represents the Anthropic provider.
- CUSTOM: Represents a custom provider option.
- GEMINI: Represents the Gemini provider.
- BEDROCK: Represents the AWS Bedrock provider.
"""
OPENAI = "openai"
@ -27,6 +28,7 @@ class LLMProvider(Enum):
ANTHROPIC = "anthropic"
CUSTOM = "custom"
GEMINI = "gemini"
BEDROCK = "bedrock"
def get_llm_client():
@ -137,5 +139,23 @@ def get_llm_client():
streaming=llm_config.llm_streaming,
)
elif provider == LLMProvider.BEDROCK:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import (
BedrockAdapter,
)
return BedrockAdapter(
model=llm_config.llm_model,
api_key=llm_config.llm_api_key,
aws_access_key_id=llm_config.aws_access_key_id,
aws_secret_access_key=llm_config.aws_secret_access_key,
aws_session_token=llm_config.aws_session_token,
aws_region_name=llm_config.aws_region_name,
aws_profile_name=llm_config.aws_profile_name,
aws_bedrock_runtime_endpoint=llm_config.aws_bedrock_runtime_endpoint,
max_tokens=max_tokens,
streaming=llm_config.llm_streaming,
)
else:
raise InvalidValueError(message=f"Unsupported LLM provider: {provider}")

View file

@ -15,10 +15,11 @@ class ModelName(Enum):
ollama = "ollama"
anthropic = "anthropic"
gemini = "gemini"
bedrock = "bedrock"
class LLMConfig(BaseModel):
api_key: str
api_key: Optional[str]
model: str
provider: str
endpoint: Optional[str]
@ -72,6 +73,10 @@ def get_settings() -> SettingsDict:
"value": "gemini",
"label": "Gemini",
},
{
"value": "bedrock",
"label": "AWS Bedrock",
},
]
return SettingsDict.model_validate(
@ -134,6 +139,20 @@ def get_settings() -> SettingsDict:
"label": "Gemini 2.0 Flash",
},
],
"bedrock": [
{
"value": "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"label": "Claude 3.5 Sonnet",
},
{
"value": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
"label": "Claude 3.5 Haiku",
},
{
"value": "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"label": "Claude 3.5 Sonnet (June)",
},
],
},
},
vector_db={

View file

@ -34,7 +34,7 @@ dependencies = [
"sqlalchemy>=2.0.39,<3.0.0",
"aiosqlite>=0.20.0,<1.0.0",
"tiktoken>=0.8.0,<1.0.0",
"litellm>=1.71.0, <2.0.0",
"litellm>=1.76.0, <2.0.0",
"instructor>=1.9.1,<2.0.0",
"langfuse>=2.32.0,<3",
"filetype>=1.2.0,<2.0.0",

8
uv.lock generated
View file

@ -1060,7 +1060,7 @@ requires-dist = [
{ name = "langfuse", specifier = ">=2.32.0,<3" },
{ name = "langsmith", marker = "extra == 'langchain'", specifier = ">=0.2.3,<1.0.0" },
{ name = "limits", specifier = ">=4.4.1,<5" },
{ name = "litellm", specifier = ">=1.71.0,<2.0.0" },
{ name = "litellm", specifier = ">=1.76.0,<2.0.0" },
{ name = "llama-index-core", marker = "extra == 'llama-index'", specifier = ">=0.12.11,<0.13" },
{ name = "matplotlib", specifier = ">=3.8.3,<4" },
{ name = "mistral-common", marker = "extra == 'mistral'", specifier = ">=1.5.2,<2" },
@ -3552,7 +3552,7 @@ wheels = [
[[package]]
name = "litellm"
version = "1.75.8"
version = "1.76.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
@ -3567,9 +3567,9 @@ dependencies = [
{ name = "tiktoken" },
{ name = "tokenizers" },
]
sdist = { url = "https://files.pythonhosted.org/packages/8d/4e/48e3d6de19afe713223e3bc7009a2003501420de2a5d823c569cefbd9731/litellm-1.75.8.tar.gz", hash = "sha256:92061bd263ff8c33c8fff70ba92cd046adb7ea041a605826a915d108742fe59e", size = 10140384, upload-time = "2025-08-16T21:42:24.23Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c9/e8/47c791c3d2cb4397ddece90840aecfc6cdc4a003f039dde42d7c861f4709/litellm-1.76.0.tar.gz", hash = "sha256:d26d12333135edd72af60e0e310284dac3b079f4d7c47c79dfbb2430b9b4b421", size = 10170569, upload-time = "2025-08-24T05:14:01.176Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5e/82/c4d00fbeafd93c00dab6ea03f33cadd6a97adeb720ba1d89fc319e5cb10b/litellm-1.75.8-py3-none-any.whl", hash = "sha256:0bf004488df8506381ec6e35e1486e2870e8d578a7c3f2427cd497558ce07a2e", size = 8916305, upload-time = "2025-08-16T21:42:21.387Z" },
{ url = "https://files.pythonhosted.org/packages/86/f2/891b4b6c09021046d7f5bcff57178b18f352a67b032d35cc693d79b38620/litellm-1.76.0-py3-none-any.whl", hash = "sha256:357464242fc1eeda384810c9e334e48ad67a50ecd30cf61e86c15f89e2f2e0b4", size = 8953112, upload-time = "2025-08-24T05:13:58.642Z" },
]
[[package]]