refactor: improve test isolation and add connect_args precedence

Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
ketanjain7981 2025-12-03 00:26:44 +05:30
parent 1f98d50870
commit f26b490a8f
2 changed files with 94 additions and 59 deletions

View file

@ -30,7 +30,7 @@ class SQLAlchemyAdapter:
functions.
"""
def __init__(self, connection_string: str):
def __init__(self, connection_string: str, connect_args: dict = None):
"""
Initialize the SQLAlchemy adapter with connection settings.
@ -38,6 +38,8 @@ class SQLAlchemyAdapter:
-----------
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
or 'postgresql://user:pass@host:port/db').
connect_args (dict, optional): Database driver arguments. These take precedence over
DATABASE_CONNECT_ARGS environment variable.
Environment Variables:
----------------------
@ -59,17 +61,20 @@ class SQLAlchemyAdapter:
self.db_uri: str = connection_string
# Parse optional connection arguments from environment variable
connect_args = None
env_connect_args_dict = {}
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
if env_connect_args:
try:
connect_args = json.loads(env_connect_args)
if not isinstance(connect_args, dict):
parsed_args = json.loads(env_connect_args)
if isinstance(parsed_args, dict):
env_connect_args_dict = parsed_args
else:
logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring")
connect_args = None
except json.JSONDecodeError:
logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring")
connect_args = None
# Merge environment args with explicit args (explicit args take precedence)
final_connect_args = {**env_connect_args_dict, **(connect_args or {})}
if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///")
@ -91,7 +96,7 @@ class SQLAlchemyAdapter:
self.engine = create_async_engine(
connection_string,
poolclass=NullPool,
connect_args={**{"timeout": 30}, **(connect_args or {})},
connect_args={**{"timeout": 30}, **final_connect_args},
)
else:
self.engine = create_async_engine(
@ -101,7 +106,7 @@ class SQLAlchemyAdapter:
pool_recycle=280,
pool_pre_ping=True,
pool_timeout=280,
connect_args=connect_args or {},
connect_args=final_connect_args,
)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)

View file

@ -1,3 +1,4 @@
import os
from unittest.mock import patch
from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import (
SQLAlchemyAdapter,
@ -5,86 +6,115 @@ from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter imp
class TestSqlAlchemyAdapter:
"""Test suite for SqlAlchemyAdapter environment variable handling and connection arguments."""
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("os.getenv")
def test_sqlite_default_timeout(self, mock_getenv, mock_create_engine):
def test_sqlite_default_timeout(self, mock_create_engine):
"""Test that SQLite connection uses default timeout when no env var is set."""
mock_getenv.return_value = None
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
with patch.dict(os.environ, {}, clear=True):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("os.getenv")
def test_sqlite_with_env_var_timeout(self, mock_getenv, mock_create_engine):
def test_sqlite_with_env_var_timeout(self, mock_create_engine):
"""Test that SQLite connection uses timeout from env var."""
mock_getenv.return_value = '{"timeout": 60}'
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 60}
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 60}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("os.getenv")
def test_sqlite_with_other_env_var_args(self, mock_getenv, mock_create_engine):
def test_sqlite_with_other_env_var_args(self, mock_create_engine):
"""Test that SQLite connection merges default timeout with other args from env var."""
mock_getenv.return_value = '{"foo": "bar"}'
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"}
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
@patch("os.getenv")
def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_logger, mock_create_engine):
def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine):
"""Test that SQLite connection uses default timeout when env var has invalid JSON."""
mock_getenv.return_value = '{"timeout": 60' # Invalid JSON
SQLAlchemyAdapter("sqlite:///test.db")
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
SQLAlchemyAdapter("sqlite:///test.db")
mock_logger.warning.assert_called_with(
"Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring"
)
mock_logger.warning.assert_called_with(
"Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring"
)
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("os.getenv")
def test_postgresql_with_env_var(self, mock_getenv, mock_create_engine):
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine):
"""Test that SQLite connection uses default timeout when env var is valid JSON but not a dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_logger.warning.assert_called_with(
"DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring"
)
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_postgresql_with_env_var(self, mock_create_engine):
"""Test that PostgreSQL connection uses connect_args from env var."""
mock_getenv.return_value = '{"sslmode": "require"}'
SQLAlchemyAdapter("postgresql://user:pass@host/db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"sslmode": "require"}
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}):
SQLAlchemyAdapter("postgresql://user:pass@host/db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"sslmode": "require"}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("os.getenv")
def test_postgresql_without_env_var(self, mock_getenv, mock_create_engine):
def test_postgresql_without_env_var(self, mock_create_engine):
"""Test that PostgreSQL connection has empty connect_args when no env var is set."""
mock_getenv.return_value = None
SQLAlchemyAdapter("postgresql://user:pass@host/db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {}
with patch.dict(os.environ, {}, clear=True):
SQLAlchemyAdapter("postgresql://user:pass@host/db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_connect_args_precedence(self, mock_create_engine):
"""Test that explicit connect_args take precedence over env var args."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
# Pass explicit connect_args that should override env var
SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120})
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
# timeout should be 120 (explicit), not 60 (env var) or 30 (default)
assert kwargs["connect_args"] == {"timeout": 120}