From f9b16e508d3e99dd34ee9e5f3b3ca893303f6faa Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 20:23:42 +0530 Subject: [PATCH 1/7] feat(database): add connect_args support to SqlAlchemyAdapter - Add optional connect_args parameter to __init__ method - Support DATABASE_CONNECT_ARGS environment variable for JSON-based configuration - Enable custom connection parameters for all database engines (SQLite and PostgreSQL) - Maintain backward compatibility with existing code - Add proper error handling and validation for environment variable parsing Signed-off-by: ketanjain7981 --- .env.template | 9 ++++++++ .../sqlalchemy/SqlAlchemyAdapter.py | 21 +++++++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/.env.template b/.env.template index 61853b983..b93117c90 100644 --- a/.env.template +++ b/.env.template @@ -91,6 +91,15 @@ DB_NAME=cognee_db #DB_USERNAME=cognee #DB_PASSWORD=cognee +# -- Advanced: Custom database connection arguments (optional) --------------- +# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc. +# Examples: +# For PostgreSQL with SSL: +# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}' +# For SQLite with custom timeout: +# DATABASE_CONNECT_ARGS='{"timeout": 60}' +#DATABASE_CONNECT_ARGS='{}' + ################################################################################ # 🕸️ Graph Database settings ################################################################################ diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 380ce9917..a23c1b297 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -3,6 +3,7 @@ import asyncio from os import path import tempfile from uuid import UUID +import json from typing import Optional from typing import AsyncGenerator, List from contextlib import asynccontextmanager @@ -29,10 +30,25 @@ class SQLAlchemyAdapter: functions. """ - def __init__(self, connection_string: str): + def __init__(self, connection_string: str, connect_args: Optional[dict] = None): self.db_path: str = None self.db_uri: str = connection_string + env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") + if env_connect_args: + try: + env_connect_args = json.loads(env_connect_args) + if isinstance(env_connect_args, dict): + if connect_args is None: + connect_args = {} + connect_args.update(env_connect_args) + else: + logger.warning( + f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}" + ) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}") + if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") self.db_path = db_path @@ -53,7 +69,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={"timeout": 30}, + connect_args={**(connect_args or {}), **{"timeout": 30}}, ) else: self.engine = create_async_engine( @@ -63,6 +79,7 @@ class SQLAlchemyAdapter: pool_recycle=280, pool_pre_ping=True, pool_timeout=280, + connect_args=connect_args or {}, ) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) From c892265644beaf61b778c10241b4770e00119530 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 21:01:43 +0530 Subject: [PATCH 2/7] fix(database): address CodeRabbit review feedback - Add comprehensive docstring for __init__ method to meet 80% coverage requirement - Fix security issue: remove sensitive data from log messages - Fix merge precedence: programmatic args now correctly override env vars - Fix SQLite timeout order: user-specified timeout now overrides default 30s - Clarify precedence in docstring documentation Signed-off-by: ketanjain7981 --- .../sqlalchemy/SqlAlchemyAdapter.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index a23c1b297..483228bfb 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -31,23 +31,39 @@ class SQLAlchemyAdapter: """ def __init__(self, connection_string: str, connect_args: Optional[dict] = None): + """ + Initialize the SQLAlchemy adapter with connection settings. + + Parameters: + ----------- + connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' + or 'postgresql://user:pass@host:port/db'). + connect_args (Optional[dict]): Optional dictionary of connection arguments to pass to + the database engine. These are driver-specific parameters such as SSL settings, + timeouts, or connection pool options. If DATABASE_CONNECT_ARGS environment variable + is set, those values will be merged with this parameter (programmatic values take + precedence over environment variables). Defaults to None. + + Environment Variables: + ---------------------- + DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments. + Example: '{"sslmode": "require", "connect_timeout": 10}' + """ self.db_path: str = None self.db_uri: str = connection_string env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") if env_connect_args: try: - env_connect_args = json.loads(env_connect_args) - if isinstance(env_connect_args, dict): - if connect_args is None: - connect_args = {} - connect_args.update(env_connect_args) + parsed_env_args = json.loads(env_connect_args) + if isinstance(parsed_env_args, dict): + # Merge: env vars as base, programmatic args override + merged_args = {**parsed_env_args, **(connect_args or {})} + connect_args = merged_args else: - logger.warning( - f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}" - ) - except json.JSONDecodeError as e: - logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}") + logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") + except json.JSONDecodeError: + logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") @@ -69,7 +85,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={**(connect_args or {}), **{"timeout": 30}}, + connect_args={**{"timeout": 30}, **(connect_args or {})}, ) else: self.engine = create_async_engine( From 3f53534c992cff1b9893f7be08d5658f4dddd630 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 21:12:07 +0530 Subject: [PATCH 3/7] refactor(database): simplify to env var only for connect_args - Remove unused connect_args parameter from __init__ - Programmatic parameter was dead code (never called by users) - Users call get_relational_engine() which doesn't expose connect_args - Keep DATABASE_CONNECT_ARGS env var support (actually used in production) - Simplify implementation and reduce complexity - Update docstring to reflect env-var-only approach - Add production examples to docstring Signed-off-by: ketanjain7981 --- .../sqlalchemy/SqlAlchemyAdapter.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 483228bfb..3e800102a 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -30,7 +30,7 @@ class SQLAlchemyAdapter: functions. """ - def __init__(self, connection_string: str, connect_args: Optional[dict] = None): + def __init__(self, connection_string: str): """ Initialize the SQLAlchemy adapter with connection settings. @@ -38,32 +38,38 @@ class SQLAlchemyAdapter: ----------- connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' or 'postgresql://user:pass@host:port/db'). - connect_args (Optional[dict]): Optional dictionary of connection arguments to pass to - the database engine. These are driver-specific parameters such as SSL settings, - timeouts, or connection pool options. If DATABASE_CONNECT_ARGS environment variable - is set, those values will be merged with this parameter (programmatic values take - precedence over environment variables). Defaults to None. Environment Variables: ---------------------- DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments. - Example: '{"sslmode": "require", "connect_timeout": 10}' + Allows configuration of driver-specific parameters such as SSL settings, + timeouts, or connection pool options without code changes. + + Examples: + PostgreSQL with SSL: + DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}' + + SQLite with custom timeout: + DATABASE_CONNECT_ARGS='{"timeout": 60}' + + Note: This follows cognee's environment-based configuration pattern and is + the recommended approach for production deployments. """ self.db_path: str = None self.db_uri: str = connection_string + # Parse optional connection arguments from environment variable + connect_args = None env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") if env_connect_args: try: - parsed_env_args = json.loads(env_connect_args) - if isinstance(parsed_env_args, dict): - # Merge: env vars as base, programmatic args override - merged_args = {**parsed_env_args, **(connect_args or {})} - connect_args = merged_args - else: + connect_args = json.loads(env_connect_args) + if not isinstance(connect_args, dict): logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") + connect_args = None except json.JSONDecodeError: logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") + connect_args = None if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") From 4f3a1bcf012c5c380823ef7786de48a757505c37 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 23:25:47 +0530 Subject: [PATCH 4/7] test: add unit tests for SQLAlchemyAdapter connection arguments Signed-off-by: ketanjain7981 --- .../sqlalchemy/test_SqlAlchemyAdapter.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py new file mode 100644 index 000000000..bde5b9855 --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py @@ -0,0 +1,84 @@ +from unittest.mock import patch +from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import ( + SQLAlchemyAdapter, +) + + +class TestSqlAlchemyAdapter: + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_sqlite_default_timeout(self, mock_getenv, mock_create_engine): + """Test that SQLite connection uses default timeout when no env var is set.""" + mock_getenv.return_value = None + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_sqlite_with_env_var_timeout(self, mock_getenv, mock_create_engine): + """Test that SQLite connection uses timeout from env var.""" + mock_getenv.return_value = '{"timeout": 60}' + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 60} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_sqlite_with_other_env_var_args(self, mock_getenv, mock_create_engine): + """Test that SQLite connection merges default timeout with other args from env var.""" + mock_getenv.return_value = '{"foo": "bar"}' + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_create_engine): + """Test that SQLite connection uses default timeout when env var has invalid JSON.""" + mock_getenv.return_value = '{"timeout": 60' # Invalid JSON + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_postgresql_with_env_var(self, mock_getenv, mock_create_engine): + """Test that PostgreSQL connection uses connect_args from env var.""" + mock_getenv.return_value = '{"sslmode": "require"}' + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"sslmode": "require"} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_postgresql_without_env_var(self, mock_getenv, mock_create_engine): + """Test that PostgreSQL connection has empty connect_args when no env var is set.""" + mock_getenv.return_value = None + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {} From a7da9c7d655e705beceaf376a6cc6f3587be41d7 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 23:35:35 +0530 Subject: [PATCH 5/7] test: verify logger warning for invalid JSON in SQLAlchemyAdapter Signed-off-by: ketanjain7981 --- .../relational/sqlalchemy/test_SqlAlchemyAdapter.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py index bde5b9855..abff77660 100644 --- a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py +++ b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py @@ -47,11 +47,17 @@ class TestSqlAlchemyAdapter: @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) + @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") @patch("os.getenv") - def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_create_engine): + def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_logger, mock_create_engine): """Test that SQLite connection uses default timeout when env var has invalid JSON.""" mock_getenv.return_value = '{"timeout": 60' # Invalid JSON SQLAlchemyAdapter("sqlite:///test.db") + + mock_logger.warning.assert_called_with( + "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" + ) + mock_create_engine.assert_called_once() _, kwargs = mock_create_engine.call_args assert "connect_args" in kwargs From f26b490a8f5b3f0bb400396a4fb71771ee325ab0 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Wed, 3 Dec 2025 00:26:44 +0530 Subject: [PATCH 6/7] refactor: improve test isolation and add connect_args precedence Signed-off-by: ketanjain7981 --- .../sqlalchemy/SqlAlchemyAdapter.py | 21 +-- .../sqlalchemy/test_SqlAlchemyAdapter.py | 132 +++++++++++------- 2 files changed, 94 insertions(+), 59 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 3e800102a..c7825041e 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -30,7 +30,7 @@ class SQLAlchemyAdapter: functions. """ - def __init__(self, connection_string: str): + def __init__(self, connection_string: str, connect_args: dict = None): """ Initialize the SQLAlchemy adapter with connection settings. @@ -38,6 +38,8 @@ class SQLAlchemyAdapter: ----------- connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' or 'postgresql://user:pass@host:port/db'). + connect_args (dict, optional): Database driver arguments. These take precedence over + DATABASE_CONNECT_ARGS environment variable. Environment Variables: ---------------------- @@ -59,17 +61,20 @@ class SQLAlchemyAdapter: self.db_uri: str = connection_string # Parse optional connection arguments from environment variable - connect_args = None + env_connect_args_dict = {} env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") if env_connect_args: try: - connect_args = json.loads(env_connect_args) - if not isinstance(connect_args, dict): + parsed_args = json.loads(env_connect_args) + if isinstance(parsed_args, dict): + env_connect_args_dict = parsed_args + else: logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") - connect_args = None except json.JSONDecodeError: logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") - connect_args = None + + # Merge environment args with explicit args (explicit args take precedence) + final_connect_args = {**env_connect_args_dict, **(connect_args or {})} if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") @@ -91,7 +96,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={**{"timeout": 30}, **(connect_args or {})}, + connect_args={**{"timeout": 30}, **final_connect_args}, ) else: self.engine = create_async_engine( @@ -101,7 +106,7 @@ class SQLAlchemyAdapter: pool_recycle=280, pool_pre_ping=True, pool_timeout=280, - connect_args=connect_args or {}, + connect_args=final_connect_args, ) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py index abff77660..57c80ddcf 100644 --- a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py +++ b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py @@ -1,3 +1,4 @@ +import os from unittest.mock import patch from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import ( SQLAlchemyAdapter, @@ -5,86 +6,115 @@ from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter imp class TestSqlAlchemyAdapter: + """Test suite for SqlAlchemyAdapter environment variable handling and connection arguments.""" + @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_default_timeout(self, mock_getenv, mock_create_engine): + def test_sqlite_default_timeout(self, mock_create_engine): """Test that SQLite connection uses default timeout when no env var is set.""" - mock_getenv.return_value = None - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} + with patch.dict(os.environ, {}, clear=True): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_with_env_var_timeout(self, mock_getenv, mock_create_engine): + def test_sqlite_with_env_var_timeout(self, mock_create_engine): """Test that SQLite connection uses timeout from env var.""" - mock_getenv.return_value = '{"timeout": 60}' - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 60} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 60} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_with_other_env_var_args(self, mock_getenv, mock_create_engine): + def test_sqlite_with_other_env_var_args(self, mock_create_engine): """Test that SQLite connection merges default timeout with other args from env var.""" - mock_getenv.return_value = '{"foo": "bar"}' - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") - @patch("os.getenv") - def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_logger, mock_create_engine): + def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine): """Test that SQLite connection uses default timeout when env var has invalid JSON.""" - mock_getenv.return_value = '{"timeout": 60' # Invalid JSON - SQLAlchemyAdapter("sqlite:///test.db") + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON + SQLAlchemyAdapter("sqlite:///test.db") - mock_logger.warning.assert_called_with( - "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" - ) + mock_logger.warning.assert_called_with( + "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" + ) - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_postgresql_with_env_var(self, mock_getenv, mock_create_engine): + @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") + def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine): + """Test that SQLite connection uses default timeout when env var is valid JSON but not a dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): + SQLAlchemyAdapter("sqlite:///test.db") + + mock_logger.warning.assert_called_with( + "DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring" + ) + + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + def test_postgresql_with_env_var(self, mock_create_engine): """Test that PostgreSQL connection uses connect_args from env var.""" - mock_getenv.return_value = '{"sslmode": "require"}' - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"sslmode": "require"} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}): + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"sslmode": "require"} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_postgresql_without_env_var(self, mock_getenv, mock_create_engine): + def test_postgresql_without_env_var(self, mock_create_engine): """Test that PostgreSQL connection has empty connect_args when no env var is set.""" - mock_getenv.return_value = None - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {} + with patch.dict(os.environ, {}, clear=True): + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + def test_connect_args_precedence(self, mock_create_engine): + """Test that explicit connect_args take precedence over env var args.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + # Pass explicit connect_args that should override env var + SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120}) + + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + # timeout should be 120 (explicit), not 60 (env var) or 30 (default) + assert kwargs["connect_args"] == {"timeout": 120} From e1d313a46b962109644facdb51a34a3023299187 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 9 Dec 2025 10:14:46 +0530 Subject: [PATCH 7/7] move DATABASE_CONNECT_ARGS parsing to RelationalConfig Signed-off-by: ketanjain7981 --- .../databases/relational/config.py | 17 ++- .../relational/create_relational_engine.py | 4 +- .../sqlalchemy/SqlAlchemyAdapter.py | 32 +---- .../sqlalchemy/test_SqlAlchemyAdapter.py | 120 ------------------ .../relational/test_RelationalConfig.py | 69 ++++++++++ 5 files changed, 93 insertions(+), 149 deletions(-) delete mode 100644 cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py create mode 100644 cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py diff --git a/cognee/infrastructure/databases/relational/config.py b/cognee/infrastructure/databases/relational/config.py index ff7c410a1..fae7ca329 100644 --- a/cognee/infrastructure/databases/relational/config.py +++ b/cognee/infrastructure/databases/relational/config.py @@ -1,4 +1,5 @@ import os +import json import pydantic from typing import Union from functools import lru_cache @@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings): db_username: Union[str, None] = None # "cognee" db_password: Union[str, None] = None # "cognee" db_provider: str = "sqlite" + database_connect_args: Union[str, None] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings): databases_directory_path = os.path.join(base_config.system_root_directory, "databases") self.db_path = databases_directory_path + # Parse database_connect_args if provided as JSON string + if self.database_connect_args and isinstance(self.database_connect_args, str): + try: + parsed_args = json.loads(self.database_connect_args) + if isinstance(parsed_args, dict): + self.database_connect_args = parsed_args + else: + self.database_connect_args = {} + except json.JSONDecodeError: + self.database_connect_args = {} + return self def to_dict(self) -> dict: @@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings): -------- - dict: A dictionary containing database configuration settings including db_path, - db_name, db_host, db_port, db_username, db_password, and db_provider. + db_name, db_host, db_port, db_username, db_password, db_provider, and + database_connect_args. """ return { "db_path": self.db_path, @@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings): "db_username": self.db_username, "db_password": self.db_password, "db_provider": self.db_provider, + "database_connect_args": self.database_connect_args, } diff --git a/cognee/infrastructure/databases/relational/create_relational_engine.py b/cognee/infrastructure/databases/relational/create_relational_engine.py index deaeaa2da..b2a79c818 100644 --- a/cognee/infrastructure/databases/relational/create_relational_engine.py +++ b/cognee/infrastructure/databases/relational/create_relational_engine.py @@ -11,6 +11,7 @@ def create_relational_engine( db_username: str, db_password: str, db_provider: str, + database_connect_args: dict = None, ): """ Create a relational database engine based on the specified parameters. @@ -29,6 +30,7 @@ def create_relational_engine( - db_password (str): The password for database authentication, required for PostgreSQL. - db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres'). + - database_connect_args (dict, optional): Database driver connection arguments. Returns: -------- @@ -51,4 +53,4 @@ def create_relational_engine( "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality." ) - return SQLAlchemyAdapter(connection_string) + return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index c7825041e..37ceb170d 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -3,7 +3,6 @@ import asyncio from os import path import tempfile from uuid import UUID -import json from typing import Optional from typing import AsyncGenerator, List from contextlib import asynccontextmanager @@ -38,14 +37,9 @@ class SQLAlchemyAdapter: ----------- connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' or 'postgresql://user:pass@host:port/db'). - connect_args (dict, optional): Database driver arguments. These take precedence over - DATABASE_CONNECT_ARGS environment variable. - - Environment Variables: - ---------------------- - DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments. - Allows configuration of driver-specific parameters such as SSL settings, - timeouts, or connection pool options without code changes. + connect_args (dict, optional): Database driver connection arguments. + Configuration is loaded from RelationalConfig.database_connect_args, which reads + from the DATABASE_CONNECT_ARGS environment variable. Examples: PostgreSQL with SSL: @@ -53,28 +47,12 @@ class SQLAlchemyAdapter: SQLite with custom timeout: DATABASE_CONNECT_ARGS='{"timeout": 60}' - - Note: This follows cognee's environment-based configuration pattern and is - the recommended approach for production deployments. """ self.db_path: str = None self.db_uri: str = connection_string - # Parse optional connection arguments from environment variable - env_connect_args_dict = {} - env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") - if env_connect_args: - try: - parsed_args = json.loads(env_connect_args) - if isinstance(parsed_args, dict): - env_connect_args_dict = parsed_args - else: - logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") - except json.JSONDecodeError: - logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") - - # Merge environment args with explicit args (explicit args take precedence) - final_connect_args = {**env_connect_args_dict, **(connect_args or {})} + # Use provided connect_args (already parsed from config) + final_connect_args = connect_args or {} if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py deleted file mode 100644 index 57c80ddcf..000000000 --- a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -from unittest.mock import patch -from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import ( - SQLAlchemyAdapter, -) - - -class TestSqlAlchemyAdapter: - """Test suite for SqlAlchemyAdapter environment variable handling and connection arguments.""" - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_sqlite_default_timeout(self, mock_create_engine): - """Test that SQLite connection uses default timeout when no env var is set.""" - with patch.dict(os.environ, {}, clear=True): - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_sqlite_with_env_var_timeout(self, mock_create_engine): - """Test that SQLite connection uses timeout from env var.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 60} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_sqlite_with_other_env_var_args(self, mock_create_engine): - """Test that SQLite connection merges default timeout with other args from env var.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}): - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") - def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine): - """Test that SQLite connection uses default timeout when env var has invalid JSON.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON - SQLAlchemyAdapter("sqlite:///test.db") - - mock_logger.warning.assert_called_with( - "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" - ) - - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") - def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine): - """Test that SQLite connection uses default timeout when env var is valid JSON but not a dict.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): - SQLAlchemyAdapter("sqlite:///test.db") - - mock_logger.warning.assert_called_with( - "DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring" - ) - - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_postgresql_with_env_var(self, mock_create_engine): - """Test that PostgreSQL connection uses connect_args from env var.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}): - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"sslmode": "require"} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_postgresql_without_env_var(self, mock_create_engine): - """Test that PostgreSQL connection has empty connect_args when no env var is set.""" - with patch.dict(os.environ, {}, clear=True): - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_connect_args_precedence(self, mock_create_engine): - """Test that explicit connect_args take precedence over env var args.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): - # Pass explicit connect_args that should override env var - SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120}) - - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - # timeout should be 120 (explicit), not 60 (env var) or 30 (default) - assert kwargs["connect_args"] == {"timeout": 120} diff --git a/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py b/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py new file mode 100644 index 000000000..8bbfc2450 --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py @@ -0,0 +1,69 @@ +import os +from unittest.mock import patch +from cognee.infrastructure.databases.relational.config import RelationalConfig + + +class TestRelationalConfig: + """Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing.""" + + def test_database_connect_args_valid_json_dict(self): + """Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict.""" + with patch.dict( + os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'} + ): + config = RelationalConfig() + assert config.database_connect_args == {"timeout": 60, "sslmode": "require"} + + def test_database_connect_args_empty_string(self): + """Test that empty DATABASE_CONNECT_ARGS is handled correctly.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}): + config = RelationalConfig() + assert config.database_connect_args == "" + + def test_database_connect_args_not_set(self): + """Test that missing DATABASE_CONNECT_ARGS results in None.""" + with patch.dict(os.environ, {}, clear=True): + config = RelationalConfig() + assert config.database_connect_args is None + + def test_database_connect_args_invalid_json(self): + """Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON + config = RelationalConfig() + assert config.database_connect_args == {} + + def test_database_connect_args_non_dict_json(self): + """Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): + config = RelationalConfig() + assert config.database_connect_args == {} + + def test_database_connect_args_to_dict(self): + """Test that database_connect_args is included in to_dict() output.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + config = RelationalConfig() + config_dict = config.to_dict() + assert "database_connect_args" in config_dict + assert config_dict["database_connect_args"] == {"timeout": 60} + + def test_database_connect_args_integer_value(self): + """Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}): + config = RelationalConfig() + assert config.database_connect_args == {"connect_timeout": 10} + + def test_database_connect_args_mixed_types(self): + """Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly.""" + with patch.dict( + os.environ, + { + "DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}' + }, + ): + config = RelationalConfig() + assert config.database_connect_args == { + "timeout": 60, + "sslmode": "require", + "retries": 3, + "keepalive": True, + }