From f26b490a8f5b3f0bb400396a4fb71771ee325ab0 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Wed, 3 Dec 2025 00:26:44 +0530 Subject: [PATCH] refactor: improve test isolation and add connect_args precedence Signed-off-by: ketanjain7981 --- .../sqlalchemy/SqlAlchemyAdapter.py | 21 +-- .../sqlalchemy/test_SqlAlchemyAdapter.py | 132 +++++++++++------- 2 files changed, 94 insertions(+), 59 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 3e800102a..c7825041e 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -30,7 +30,7 @@ class SQLAlchemyAdapter: functions. """ - def __init__(self, connection_string: str): + def __init__(self, connection_string: str, connect_args: dict = None): """ Initialize the SQLAlchemy adapter with connection settings. @@ -38,6 +38,8 @@ class SQLAlchemyAdapter: ----------- connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' or 'postgresql://user:pass@host:port/db'). + connect_args (dict, optional): Database driver arguments. These take precedence over + DATABASE_CONNECT_ARGS environment variable. Environment Variables: ---------------------- @@ -59,17 +61,20 @@ class SQLAlchemyAdapter: self.db_uri: str = connection_string # Parse optional connection arguments from environment variable - connect_args = None + env_connect_args_dict = {} env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") if env_connect_args: try: - connect_args = json.loads(env_connect_args) - if not isinstance(connect_args, dict): + parsed_args = json.loads(env_connect_args) + if isinstance(parsed_args, dict): + env_connect_args_dict = parsed_args + else: logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") - connect_args = None except json.JSONDecodeError: logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") - connect_args = None + + # Merge environment args with explicit args (explicit args take precedence) + final_connect_args = {**env_connect_args_dict, **(connect_args or {})} if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") @@ -91,7 +96,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={**{"timeout": 30}, **(connect_args or {})}, + connect_args={**{"timeout": 30}, **final_connect_args}, ) else: self.engine = create_async_engine( @@ -101,7 +106,7 @@ class SQLAlchemyAdapter: pool_recycle=280, pool_pre_ping=True, pool_timeout=280, - connect_args=connect_args or {}, + connect_args=final_connect_args, ) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py index abff77660..57c80ddcf 100644 --- a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py +++ b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py @@ -1,3 +1,4 @@ +import os from unittest.mock import patch from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import ( SQLAlchemyAdapter, @@ -5,86 +6,115 @@ from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter imp class TestSqlAlchemyAdapter: + """Test suite for SqlAlchemyAdapter environment variable handling and connection arguments.""" + @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_default_timeout(self, mock_getenv, mock_create_engine): + def test_sqlite_default_timeout(self, mock_create_engine): """Test that SQLite connection uses default timeout when no env var is set.""" - mock_getenv.return_value = None - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} + with patch.dict(os.environ, {}, clear=True): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_with_env_var_timeout(self, mock_getenv, mock_create_engine): + def test_sqlite_with_env_var_timeout(self, mock_create_engine): """Test that SQLite connection uses timeout from env var.""" - mock_getenv.return_value = '{"timeout": 60}' - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 60} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 60} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_with_other_env_var_args(self, mock_getenv, mock_create_engine): + def test_sqlite_with_other_env_var_args(self, mock_create_engine): """Test that SQLite connection merges default timeout with other args from env var.""" - mock_getenv.return_value = '{"foo": "bar"}' - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") - @patch("os.getenv") - def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_logger, mock_create_engine): + def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine): """Test that SQLite connection uses default timeout when env var has invalid JSON.""" - mock_getenv.return_value = '{"timeout": 60' # Invalid JSON - SQLAlchemyAdapter("sqlite:///test.db") + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON + SQLAlchemyAdapter("sqlite:///test.db") - mock_logger.warning.assert_called_with( - "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" - ) + mock_logger.warning.assert_called_with( + "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" + ) - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_postgresql_with_env_var(self, mock_getenv, mock_create_engine): + @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") + def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine): + """Test that SQLite connection uses default timeout when env var is valid JSON but not a dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): + SQLAlchemyAdapter("sqlite:///test.db") + + mock_logger.warning.assert_called_with( + "DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring" + ) + + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + def test_postgresql_with_env_var(self, mock_create_engine): """Test that PostgreSQL connection uses connect_args from env var.""" - mock_getenv.return_value = '{"sslmode": "require"}' - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"sslmode": "require"} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}): + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"sslmode": "require"} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_postgresql_without_env_var(self, mock_getenv, mock_create_engine): + def test_postgresql_without_env_var(self, mock_create_engine): """Test that PostgreSQL connection has empty connect_args when no env var is set.""" - mock_getenv.return_value = None - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {} + with patch.dict(os.environ, {}, clear=True): + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + def test_connect_args_precedence(self, mock_create_engine): + """Test that explicit connect_args take precedence over env var args.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + # Pass explicit connect_args that should override env var + SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120}) + + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + # timeout should be 120 (explicit), not 60 (env var) or 30 (default) + assert kwargs["connect_args"] == {"timeout": 120}