diff --git a/.env.template b/.env.template index 7defaee09..fe168cf91 100644 --- a/.env.template +++ b/.env.template @@ -91,6 +91,15 @@ DB_NAME=cognee_db #DB_USERNAME=cognee #DB_PASSWORD=cognee +# -- Advanced: Custom database connection arguments (optional) --------------- +# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc. +# Examples: +# For PostgreSQL with SSL: +# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}' +# For SQLite with custom timeout: +# DATABASE_CONNECT_ARGS='{"timeout": 60}' +#DATABASE_CONNECT_ARGS='{}' + ################################################################################ # 🕸️ Graph Database settings ################################################################################ diff --git a/cognee/infrastructure/databases/relational/config.py b/cognee/infrastructure/databases/relational/config.py index ff7c410a1..fae7ca329 100644 --- a/cognee/infrastructure/databases/relational/config.py +++ b/cognee/infrastructure/databases/relational/config.py @@ -1,4 +1,5 @@ import os +import json import pydantic from typing import Union from functools import lru_cache @@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings): db_username: Union[str, None] = None # "cognee" db_password: Union[str, None] = None # "cognee" db_provider: str = "sqlite" + database_connect_args: Union[str, None] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings): databases_directory_path = os.path.join(base_config.system_root_directory, "databases") self.db_path = databases_directory_path + # Parse database_connect_args if provided as JSON string + if self.database_connect_args and isinstance(self.database_connect_args, str): + try: + parsed_args = json.loads(self.database_connect_args) + if isinstance(parsed_args, dict): + self.database_connect_args = parsed_args + else: + self.database_connect_args = {} + except json.JSONDecodeError: + self.database_connect_args = {} + return self def to_dict(self) -> dict: @@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings): -------- - dict: A dictionary containing database configuration settings including db_path, - db_name, db_host, db_port, db_username, db_password, and db_provider. + db_name, db_host, db_port, db_username, db_password, db_provider, and + database_connect_args. """ return { "db_path": self.db_path, @@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings): "db_username": self.db_username, "db_password": self.db_password, "db_provider": self.db_provider, + "database_connect_args": self.database_connect_args, } diff --git a/cognee/infrastructure/databases/relational/create_relational_engine.py b/cognee/infrastructure/databases/relational/create_relational_engine.py index deaeaa2da..b2a79c818 100644 --- a/cognee/infrastructure/databases/relational/create_relational_engine.py +++ b/cognee/infrastructure/databases/relational/create_relational_engine.py @@ -11,6 +11,7 @@ def create_relational_engine( db_username: str, db_password: str, db_provider: str, + database_connect_args: dict = None, ): """ Create a relational database engine based on the specified parameters. @@ -29,6 +30,7 @@ def create_relational_engine( - db_password (str): The password for database authentication, required for PostgreSQL. - db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres'). + - database_connect_args (dict, optional): Database driver connection arguments. Returns: -------- @@ -51,4 +53,4 @@ def create_relational_engine( "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality." ) - return SQLAlchemyAdapter(connection_string) + return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 380ce9917..37ceb170d 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -29,10 +29,31 @@ class SQLAlchemyAdapter: functions. """ - def __init__(self, connection_string: str): + def __init__(self, connection_string: str, connect_args: dict = None): + """ + Initialize the SQLAlchemy adapter with connection settings. + + Parameters: + ----------- + connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' + or 'postgresql://user:pass@host:port/db'). + connect_args (dict, optional): Database driver connection arguments. + Configuration is loaded from RelationalConfig.database_connect_args, which reads + from the DATABASE_CONNECT_ARGS environment variable. + + Examples: + PostgreSQL with SSL: + DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}' + + SQLite with custom timeout: + DATABASE_CONNECT_ARGS='{"timeout": 60}' + """ self.db_path: str = None self.db_uri: str = connection_string + # Use provided connect_args (already parsed from config) + final_connect_args = connect_args or {} + if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") self.db_path = db_path @@ -53,7 +74,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={"timeout": 30}, + connect_args={**{"timeout": 30}, **final_connect_args}, ) else: self.engine = create_async_engine( @@ -63,6 +84,7 @@ class SQLAlchemyAdapter: pool_recycle=280, pool_pre_ping=True, pool_timeout=280, + connect_args=final_connect_args, ) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) diff --git a/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py b/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py new file mode 100644 index 000000000..8bbfc2450 --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py @@ -0,0 +1,69 @@ +import os +from unittest.mock import patch +from cognee.infrastructure.databases.relational.config import RelationalConfig + + +class TestRelationalConfig: + """Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing.""" + + def test_database_connect_args_valid_json_dict(self): + """Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict.""" + with patch.dict( + os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'} + ): + config = RelationalConfig() + assert config.database_connect_args == {"timeout": 60, "sslmode": "require"} + + def test_database_connect_args_empty_string(self): + """Test that empty DATABASE_CONNECT_ARGS is handled correctly.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}): + config = RelationalConfig() + assert config.database_connect_args == "" + + def test_database_connect_args_not_set(self): + """Test that missing DATABASE_CONNECT_ARGS results in None.""" + with patch.dict(os.environ, {}, clear=True): + config = RelationalConfig() + assert config.database_connect_args is None + + def test_database_connect_args_invalid_json(self): + """Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON + config = RelationalConfig() + assert config.database_connect_args == {} + + def test_database_connect_args_non_dict_json(self): + """Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): + config = RelationalConfig() + assert config.database_connect_args == {} + + def test_database_connect_args_to_dict(self): + """Test that database_connect_args is included in to_dict() output.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + config = RelationalConfig() + config_dict = config.to_dict() + assert "database_connect_args" in config_dict + assert config_dict["database_connect_args"] == {"timeout": 60} + + def test_database_connect_args_integer_value(self): + """Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}): + config = RelationalConfig() + assert config.database_connect_args == {"connect_timeout": 10} + + def test_database_connect_args_mixed_types(self): + """Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly.""" + with patch.dict( + os.environ, + { + "DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}' + }, + ): + config = RelationalConfig() + assert config.database_connect_args == { + "timeout": 60, + "sslmode": "require", + "retries": 3, + "keepalive": True, + }