move DATABASE_CONNECT_ARGS parsing to RelationalConfig

Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
ketanjain7981 2025-12-09 10:14:46 +05:30
parent 654a573454
commit e1d313a46b
5 changed files with 93 additions and 149 deletions

View file

@ -1,4 +1,5 @@
import os import os
import json
import pydantic import pydantic
from typing import Union from typing import Union
from functools import lru_cache from functools import lru_cache
@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings):
db_username: Union[str, None] = None # "cognee" db_username: Union[str, None] = None # "cognee"
db_password: Union[str, None] = None # "cognee" db_password: Union[str, None] = None # "cognee"
db_provider: str = "sqlite" db_provider: str = "sqlite"
database_connect_args: Union[str, None] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings):
databases_directory_path = os.path.join(base_config.system_root_directory, "databases") databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
self.db_path = databases_directory_path self.db_path = databases_directory_path
# Parse database_connect_args if provided as JSON string
if self.database_connect_args and isinstance(self.database_connect_args, str):
try:
parsed_args = json.loads(self.database_connect_args)
if isinstance(parsed_args, dict):
self.database_connect_args = parsed_args
else:
self.database_connect_args = {}
except json.JSONDecodeError:
self.database_connect_args = {}
return self return self
def to_dict(self) -> dict: def to_dict(self) -> dict:
@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings):
-------- --------
- dict: A dictionary containing database configuration settings including db_path, - dict: A dictionary containing database configuration settings including db_path,
db_name, db_host, db_port, db_username, db_password, and db_provider. db_name, db_host, db_port, db_username, db_password, db_provider, and
database_connect_args.
""" """
return { return {
"db_path": self.db_path, "db_path": self.db_path,
@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings):
"db_username": self.db_username, "db_username": self.db_username,
"db_password": self.db_password, "db_password": self.db_password,
"db_provider": self.db_provider, "db_provider": self.db_provider,
"database_connect_args": self.database_connect_args,
} }

View file

@ -11,6 +11,7 @@ def create_relational_engine(
db_username: str, db_username: str,
db_password: str, db_password: str,
db_provider: str, db_provider: str,
database_connect_args: dict = None,
): ):
""" """
Create a relational database engine based on the specified parameters. Create a relational database engine based on the specified parameters.
@ -29,6 +30,7 @@ def create_relational_engine(
- db_password (str): The password for database authentication, required for - db_password (str): The password for database authentication, required for
PostgreSQL. PostgreSQL.
- db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres'). - db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
- database_connect_args (dict, optional): Database driver connection arguments.
Returns: Returns:
-------- --------
@ -51,4 +53,4 @@ def create_relational_engine(
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality." "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
) )
return SQLAlchemyAdapter(connection_string) return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args)

View file

@ -3,7 +3,6 @@ import asyncio
from os import path from os import path
import tempfile import tempfile
from uuid import UUID from uuid import UUID
import json
from typing import Optional from typing import Optional
from typing import AsyncGenerator, List from typing import AsyncGenerator, List
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -38,14 +37,9 @@ class SQLAlchemyAdapter:
----------- -----------
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
or 'postgresql://user:pass@host:port/db'). or 'postgresql://user:pass@host:port/db').
connect_args (dict, optional): Database driver arguments. These take precedence over connect_args (dict, optional): Database driver connection arguments.
DATABASE_CONNECT_ARGS environment variable. Configuration is loaded from RelationalConfig.database_connect_args, which reads
from the DATABASE_CONNECT_ARGS environment variable.
Environment Variables:
----------------------
DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments.
Allows configuration of driver-specific parameters such as SSL settings,
timeouts, or connection pool options without code changes.
Examples: Examples:
PostgreSQL with SSL: PostgreSQL with SSL:
@ -53,28 +47,12 @@ class SQLAlchemyAdapter:
SQLite with custom timeout: SQLite with custom timeout:
DATABASE_CONNECT_ARGS='{"timeout": 60}' DATABASE_CONNECT_ARGS='{"timeout": 60}'
Note: This follows cognee's environment-based configuration pattern and is
the recommended approach for production deployments.
""" """
self.db_path: str = None self.db_path: str = None
self.db_uri: str = connection_string self.db_uri: str = connection_string
# Parse optional connection arguments from environment variable # Use provided connect_args (already parsed from config)
env_connect_args_dict = {} final_connect_args = connect_args or {}
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
if env_connect_args:
try:
parsed_args = json.loads(env_connect_args)
if isinstance(parsed_args, dict):
env_connect_args_dict = parsed_args
else:
logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring")
except json.JSONDecodeError:
logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring")
# Merge environment args with explicit args (explicit args take precedence)
final_connect_args = {**env_connect_args_dict, **(connect_args or {})}
if "sqlite" in connection_string: if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///") [prefix, db_path] = connection_string.split("///")

View file

@ -1,120 +0,0 @@
import os
from unittest.mock import patch
from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import (
SQLAlchemyAdapter,
)
class TestSqlAlchemyAdapter:
"""Test suite for SqlAlchemyAdapter environment variable handling and connection arguments."""
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_sqlite_default_timeout(self, mock_create_engine):
"""Test that SQLite connection uses default timeout when no env var is set."""
with patch.dict(os.environ, {}, clear=True):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_sqlite_with_env_var_timeout(self, mock_create_engine):
"""Test that SQLite connection uses timeout from env var."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 60}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_sqlite_with_other_env_var_args(self, mock_create_engine):
"""Test that SQLite connection merges default timeout with other args from env var."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine):
"""Test that SQLite connection uses default timeout when env var has invalid JSON."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
SQLAlchemyAdapter("sqlite:///test.db")
mock_logger.warning.assert_called_with(
"Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring"
)
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine):
"""Test that SQLite connection uses default timeout when env var is valid JSON but not a dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_logger.warning.assert_called_with(
"DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring"
)
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_postgresql_with_env_var(self, mock_create_engine):
"""Test that PostgreSQL connection uses connect_args from env var."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}):
SQLAlchemyAdapter("postgresql://user:pass@host/db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"sslmode": "require"}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_postgresql_without_env_var(self, mock_create_engine):
"""Test that PostgreSQL connection has empty connect_args when no env var is set."""
with patch.dict(os.environ, {}, clear=True):
SQLAlchemyAdapter("postgresql://user:pass@host/db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_connect_args_precedence(self, mock_create_engine):
"""Test that explicit connect_args take precedence over env var args."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
# Pass explicit connect_args that should override env var
SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120})
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
# timeout should be 120 (explicit), not 60 (env var) or 30 (default)
assert kwargs["connect_args"] == {"timeout": 120}

View file

@ -0,0 +1,69 @@
import os
from unittest.mock import patch
from cognee.infrastructure.databases.relational.config import RelationalConfig
class TestRelationalConfig:
"""Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""
def test_database_connect_args_valid_json_dict(self):
"""Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
with patch.dict(
os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
):
config = RelationalConfig()
assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}
def test_database_connect_args_empty_string(self):
"""Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
config = RelationalConfig()
assert config.database_connect_args == ""
def test_database_connect_args_not_set(self):
"""Test that missing DATABASE_CONNECT_ARGS results in None."""
with patch.dict(os.environ, {}, clear=True):
config = RelationalConfig()
assert config.database_connect_args is None
def test_database_connect_args_invalid_json(self):
"""Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
config = RelationalConfig()
assert config.database_connect_args == {}
def test_database_connect_args_non_dict_json(self):
"""Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
config = RelationalConfig()
assert config.database_connect_args == {}
def test_database_connect_args_to_dict(self):
"""Test that database_connect_args is included in to_dict() output."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
config = RelationalConfig()
config_dict = config.to_dict()
assert "database_connect_args" in config_dict
assert config_dict["database_connect_args"] == {"timeout": 60}
def test_database_connect_args_integer_value(self):
"""Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
config = RelationalConfig()
assert config.database_connect_args == {"connect_timeout": 10}
def test_database_connect_args_mixed_types(self):
"""Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
with patch.dict(
os.environ,
{
"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
},
):
config = RelationalConfig()
assert config.database_connect_args == {
"timeout": 60,
"sslmode": "require",
"retries": 3,
"keepalive": True,
}