refactor: improve test isolation and add connect_args precedence
Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
parent
1f98d50870
commit
f26b490a8f
2 changed files with 94 additions and 59 deletions
|
|
@ -30,7 +30,7 @@ class SQLAlchemyAdapter:
|
|||
functions.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str):
|
||||
def __init__(self, connection_string: str, connect_args: dict = None):
|
||||
"""
|
||||
Initialize the SQLAlchemy adapter with connection settings.
|
||||
|
||||
|
|
@ -38,6 +38,8 @@ class SQLAlchemyAdapter:
|
|||
-----------
|
||||
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
|
||||
or 'postgresql://user:pass@host:port/db').
|
||||
connect_args (dict, optional): Database driver arguments. These take precedence over
|
||||
DATABASE_CONNECT_ARGS environment variable.
|
||||
|
||||
Environment Variables:
|
||||
----------------------
|
||||
|
|
@ -59,17 +61,20 @@ class SQLAlchemyAdapter:
|
|||
self.db_uri: str = connection_string
|
||||
|
||||
# Parse optional connection arguments from environment variable
|
||||
connect_args = None
|
||||
env_connect_args_dict = {}
|
||||
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
|
||||
if env_connect_args:
|
||||
try:
|
||||
connect_args = json.loads(env_connect_args)
|
||||
if not isinstance(connect_args, dict):
|
||||
parsed_args = json.loads(env_connect_args)
|
||||
if isinstance(parsed_args, dict):
|
||||
env_connect_args_dict = parsed_args
|
||||
else:
|
||||
logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring")
|
||||
connect_args = None
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring")
|
||||
connect_args = None
|
||||
|
||||
# Merge environment args with explicit args (explicit args take precedence)
|
||||
final_connect_args = {**env_connect_args_dict, **(connect_args or {})}
|
||||
|
||||
if "sqlite" in connection_string:
|
||||
[prefix, db_path] = connection_string.split("///")
|
||||
|
|
@ -91,7 +96,7 @@ class SQLAlchemyAdapter:
|
|||
self.engine = create_async_engine(
|
||||
connection_string,
|
||||
poolclass=NullPool,
|
||||
connect_args={**{"timeout": 30}, **(connect_args or {})},
|
||||
connect_args={**{"timeout": 30}, **final_connect_args},
|
||||
)
|
||||
else:
|
||||
self.engine = create_async_engine(
|
||||
|
|
@ -101,7 +106,7 @@ class SQLAlchemyAdapter:
|
|||
pool_recycle=280,
|
||||
pool_pre_ping=True,
|
||||
pool_timeout=280,
|
||||
connect_args=connect_args or {},
|
||||
connect_args=final_connect_args,
|
||||
)
|
||||
|
||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from unittest.mock import patch
|
||||
from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import (
|
||||
SQLAlchemyAdapter,
|
||||
|
|
@ -5,86 +6,115 @@ from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter imp
|
|||
|
||||
|
||||
class TestSqlAlchemyAdapter:
|
||||
"""Test suite for SqlAlchemyAdapter environment variable handling and connection arguments."""
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
@patch("os.getenv")
|
||||
def test_sqlite_default_timeout(self, mock_getenv, mock_create_engine):
|
||||
def test_sqlite_default_timeout(self, mock_create_engine):
|
||||
"""Test that SQLite connection uses default timeout when no env var is set."""
|
||||
mock_getenv.return_value = None
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30}
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
@patch("os.getenv")
|
||||
def test_sqlite_with_env_var_timeout(self, mock_getenv, mock_create_engine):
|
||||
def test_sqlite_with_env_var_timeout(self, mock_create_engine):
|
||||
"""Test that SQLite connection uses timeout from env var."""
|
||||
mock_getenv.return_value = '{"timeout": 60}'
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 60}
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 60}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
@patch("os.getenv")
|
||||
def test_sqlite_with_other_env_var_args(self, mock_getenv, mock_create_engine):
|
||||
def test_sqlite_with_other_env_var_args(self, mock_create_engine):
|
||||
"""Test that SQLite connection merges default timeout with other args from env var."""
|
||||
mock_getenv.return_value = '{"foo": "bar"}'
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"}
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}):
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
|
||||
@patch("os.getenv")
|
||||
def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_logger, mock_create_engine):
|
||||
def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine):
|
||||
"""Test that SQLite connection uses default timeout when env var has invalid JSON."""
|
||||
mock_getenv.return_value = '{"timeout": 60' # Invalid JSON
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring"
|
||||
)
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring"
|
||||
)
|
||||
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30}
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
@patch("os.getenv")
|
||||
def test_postgresql_with_env_var(self, mock_getenv, mock_create_engine):
|
||||
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
|
||||
def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine):
|
||||
"""Test that SQLite connection uses default timeout when env var is valid JSON but not a dict."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
|
||||
mock_logger.warning.assert_called_with(
|
||||
"DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring"
|
||||
)
|
||||
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
def test_postgresql_with_env_var(self, mock_create_engine):
|
||||
"""Test that PostgreSQL connection uses connect_args from env var."""
|
||||
mock_getenv.return_value = '{"sslmode": "require"}'
|
||||
SQLAlchemyAdapter("postgresql://user:pass@host/db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"sslmode": "require"}
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}):
|
||||
SQLAlchemyAdapter("postgresql://user:pass@host/db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"sslmode": "require"}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
@patch("os.getenv")
|
||||
def test_postgresql_without_env_var(self, mock_getenv, mock_create_engine):
|
||||
def test_postgresql_without_env_var(self, mock_create_engine):
|
||||
"""Test that PostgreSQL connection has empty connect_args when no env var is set."""
|
||||
mock_getenv.return_value = None
|
||||
SQLAlchemyAdapter("postgresql://user:pass@host/db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {}
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
SQLAlchemyAdapter("postgresql://user:pass@host/db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
def test_connect_args_precedence(self, mock_create_engine):
|
||||
"""Test that explicit connect_args take precedence over env var args."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
||||
# Pass explicit connect_args that should override env var
|
||||
SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120})
|
||||
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
# timeout should be 120 (explicit), not 60 (env var) or 30 (default)
|
||||
assert kwargs["connect_args"] == {"timeout": 120}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue