feat(database): add connect_args support to SqlAlchemyAdapter (#1861)

- Add optional connect_args parameter to __init__ method
- Support DATABASE_CONNECT_ARGS environment variable for JSON-based
configuration
- Enable custom connection parameters for all database engines (SQLite
and PostgreSQL)
- Maintain backward compatibility with existing code
- Add proper error handling and validation for environment variable
parsing

<!-- .github/pull_request_template.md -->
## Description
The intent of this PR is to make the database initialization more
flexible and configurable. In order to do this, the system will support
a new DATABASE_CONNECT_ARGS environment variable that takes JSON-based
configuration,. This enhancement will allow custom connection parameters
to be passed to any supported database engine, including SQLite and
PostgreSQL,. To guarantee that the environment variable is parsed
securely and consistently, appropriate error handling and validation
will also be added.

## Type of Change
<!-- Please check the relevant option -->
- [x] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [x] Breaking change (fix or feature that would cause existing
functionality to change)
- [x] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Advanced database connection configuration through the optional
DATABASE_CONNECT_ARGS environment variable, supporting custom settings
such as SSL certificates and timeout configurations.
* Custom connection arguments can now be passed to relational database
adapters.

* **Tests**
* Comprehensive unit test suite for database connection argument parsing
and validation.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Vasilije 2025-12-16 14:50:27 +01:00 committed by GitHub
commit 412b6467da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 121 additions and 4 deletions

View file

@ -91,6 +91,15 @@ DB_NAME=cognee_db
#DB_USERNAME=cognee
#DB_PASSWORD=cognee
# -- Advanced: Custom database connection arguments (optional) ---------------
# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc.
# Examples:
# For PostgreSQL with SSL:
# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
# For SQLite with custom timeout:
# DATABASE_CONNECT_ARGS='{"timeout": 60}'
#DATABASE_CONNECT_ARGS='{}'
################################################################################
# 🕸️ Graph Database settings
################################################################################

View file

@ -1,4 +1,5 @@
import os
import json
import pydantic
from typing import Union
from functools import lru_cache
@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings):
db_username: Union[str, None] = None # "cognee"
db_password: Union[str, None] = None # "cognee"
db_provider: str = "sqlite"
database_connect_args: Union[str, None] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings):
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
self.db_path = databases_directory_path
# Parse database_connect_args if provided as JSON string
if self.database_connect_args and isinstance(self.database_connect_args, str):
try:
parsed_args = json.loads(self.database_connect_args)
if isinstance(parsed_args, dict):
self.database_connect_args = parsed_args
else:
self.database_connect_args = {}
except json.JSONDecodeError:
self.database_connect_args = {}
return self
def to_dict(self) -> dict:
@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings):
--------
- dict: A dictionary containing database configuration settings including db_path,
db_name, db_host, db_port, db_username, db_password, and db_provider.
db_name, db_host, db_port, db_username, db_password, db_provider, and
database_connect_args.
"""
return {
"db_path": self.db_path,
@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings):
"db_username": self.db_username,
"db_password": self.db_password,
"db_provider": self.db_provider,
"database_connect_args": self.database_connect_args,
}

View file

@ -11,6 +11,7 @@ def create_relational_engine(
db_username: str,
db_password: str,
db_provider: str,
database_connect_args: dict = None,
):
"""
Create a relational database engine based on the specified parameters.
@ -29,6 +30,7 @@ def create_relational_engine(
- db_password (str): The password for database authentication, required for
PostgreSQL.
- db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
- database_connect_args (dict, optional): Database driver connection arguments.
Returns:
--------
@ -51,4 +53,4 @@ def create_relational_engine(
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
)
return SQLAlchemyAdapter(connection_string)
return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args)

View file

@ -29,10 +29,31 @@ class SQLAlchemyAdapter:
functions.
"""
def __init__(self, connection_string: str):
def __init__(self, connection_string: str, connect_args: dict = None):
"""
Initialize the SQLAlchemy adapter with connection settings.
Parameters:
-----------
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
or 'postgresql://user:pass@host:port/db').
connect_args (dict, optional): Database driver connection arguments.
Configuration is loaded from RelationalConfig.database_connect_args, which reads
from the DATABASE_CONNECT_ARGS environment variable.
Examples:
PostgreSQL with SSL:
DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
SQLite with custom timeout:
DATABASE_CONNECT_ARGS='{"timeout": 60}'
"""
self.db_path: str = None
self.db_uri: str = connection_string
# Use provided connect_args (already parsed from config)
final_connect_args = connect_args or {}
if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///")
self.db_path = db_path
@ -53,7 +74,7 @@ class SQLAlchemyAdapter:
self.engine = create_async_engine(
connection_string,
poolclass=NullPool,
connect_args={"timeout": 30},
connect_args={**{"timeout": 30}, **final_connect_args},
)
else:
self.engine = create_async_engine(
@ -63,6 +84,7 @@ class SQLAlchemyAdapter:
pool_recycle=280,
pool_pre_ping=True,
pool_timeout=280,
connect_args=final_connect_args,
)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)

View file

@ -0,0 +1,69 @@
import os
from unittest.mock import patch
from cognee.infrastructure.databases.relational.config import RelationalConfig
class TestRelationalConfig:
"""Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""
def test_database_connect_args_valid_json_dict(self):
"""Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
with patch.dict(
os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
):
config = RelationalConfig()
assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}
def test_database_connect_args_empty_string(self):
"""Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
config = RelationalConfig()
assert config.database_connect_args == ""
def test_database_connect_args_not_set(self):
"""Test that missing DATABASE_CONNECT_ARGS results in None."""
with patch.dict(os.environ, {}, clear=True):
config = RelationalConfig()
assert config.database_connect_args is None
def test_database_connect_args_invalid_json(self):
"""Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
config = RelationalConfig()
assert config.database_connect_args == {}
def test_database_connect_args_non_dict_json(self):
"""Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
config = RelationalConfig()
assert config.database_connect_args == {}
def test_database_connect_args_to_dict(self):
"""Test that database_connect_args is included in to_dict() output."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
config = RelationalConfig()
config_dict = config.to_dict()
assert "database_connect_args" in config_dict
assert config_dict["database_connect_args"] == {"timeout": 60}
def test_database_connect_args_integer_value(self):
"""Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
config = RelationalConfig()
assert config.database_connect_args == {"connect_timeout": 10}
def test_database_connect_args_mixed_types(self):
"""Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
with patch.dict(
os.environ,
{
"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
},
):
config = RelationalConfig()
assert config.database_connect_args == {
"timeout": 60,
"sslmode": "require",
"retries": 3,
"keepalive": True,
}