move DATABASE_CONNECT_ARGS parsing to RelationalConfig

Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
ketanjain7981 2025-12-09 10:14:46 +05:30
parent 654a573454
commit e1d313a46b
5 changed files with 93 additions and 149 deletions

View file

@ -1,4 +1,5 @@
import os
import json
import pydantic
from typing import Union
from functools import lru_cache
@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings):
db_username: Union[str, None] = None # "cognee"
db_password: Union[str, None] = None # "cognee"
db_provider: str = "sqlite"
database_connect_args: Union[str, None] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings):
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
self.db_path = databases_directory_path
# Parse database_connect_args if provided as JSON string
if self.database_connect_args and isinstance(self.database_connect_args, str):
try:
parsed_args = json.loads(self.database_connect_args)
if isinstance(parsed_args, dict):
self.database_connect_args = parsed_args
else:
self.database_connect_args = {}
except json.JSONDecodeError:
self.database_connect_args = {}
return self
def to_dict(self) -> dict:
@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings):
--------
- dict: A dictionary containing database configuration settings including db_path,
db_name, db_host, db_port, db_username, db_password, and db_provider.
db_name, db_host, db_port, db_username, db_password, db_provider, and
database_connect_args.
"""
return {
"db_path": self.db_path,
@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings):
"db_username": self.db_username,
"db_password": self.db_password,
"db_provider": self.db_provider,
"database_connect_args": self.database_connect_args,
}

View file

@ -11,6 +11,7 @@ def create_relational_engine(
db_username: str,
db_password: str,
db_provider: str,
database_connect_args: dict = None,
):
"""
Create a relational database engine based on the specified parameters.
@ -29,6 +30,7 @@ def create_relational_engine(
- db_password (str): The password for database authentication, required for
PostgreSQL.
- db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
- database_connect_args (dict, optional): Database driver connection arguments.
Returns:
--------
@ -51,4 +53,4 @@ def create_relational_engine(
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
)
return SQLAlchemyAdapter(connection_string)
return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args)

View file

@ -3,7 +3,6 @@ import asyncio
from os import path
import tempfile
from uuid import UUID
import json
from typing import Optional
from typing import AsyncGenerator, List
from contextlib import asynccontextmanager
@ -38,14 +37,9 @@ class SQLAlchemyAdapter:
-----------
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
or 'postgresql://user:pass@host:port/db').
connect_args (dict, optional): Database driver arguments. These take precedence over
DATABASE_CONNECT_ARGS environment variable.
Environment Variables:
----------------------
DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments.
Allows configuration of driver-specific parameters such as SSL settings,
timeouts, or connection pool options without code changes.
connect_args (dict, optional): Database driver connection arguments.
Configuration is loaded from RelationalConfig.database_connect_args, which reads
from the DATABASE_CONNECT_ARGS environment variable.
Examples:
PostgreSQL with SSL:
@ -53,28 +47,12 @@ class SQLAlchemyAdapter:
SQLite with custom timeout:
DATABASE_CONNECT_ARGS='{"timeout": 60}'
Note: This follows cognee's environment-based configuration pattern and is
the recommended approach for production deployments.
"""
self.db_path: str = None
self.db_uri: str = connection_string
# Parse optional connection arguments from environment variable
env_connect_args_dict = {}
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
if env_connect_args:
try:
parsed_args = json.loads(env_connect_args)
if isinstance(parsed_args, dict):
env_connect_args_dict = parsed_args
else:
logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring")
except json.JSONDecodeError:
logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring")
# Merge environment args with explicit args (explicit args take precedence)
final_connect_args = {**env_connect_args_dict, **(connect_args or {})}
# Use provided connect_args (already parsed from config)
final_connect_args = connect_args or {}
if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///")

View file

@ -1,120 +0,0 @@
import os
from unittest.mock import patch
from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import (
SQLAlchemyAdapter,
)
class TestSqlAlchemyAdapter:
"""Test suite for SqlAlchemyAdapter environment variable handling and connection arguments."""
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_sqlite_default_timeout(self, mock_create_engine):
"""Test that SQLite connection uses default timeout when no env var is set."""
with patch.dict(os.environ, {}, clear=True):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_sqlite_with_env_var_timeout(self, mock_create_engine):
"""Test that SQLite connection uses timeout from env var."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 60}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_sqlite_with_other_env_var_args(self, mock_create_engine):
"""Test that SQLite connection merges default timeout with other args from env var."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine):
"""Test that SQLite connection uses default timeout when env var has invalid JSON."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
SQLAlchemyAdapter("sqlite:///test.db")
mock_logger.warning.assert_called_with(
"Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring"
)
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine):
"""Test that SQLite connection uses default timeout when env var is valid JSON but not a dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
SQLAlchemyAdapter("sqlite:///test.db")
mock_logger.warning.assert_called_with(
"DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring"
)
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"timeout": 30}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_postgresql_with_env_var(self, mock_create_engine):
"""Test that PostgreSQL connection uses connect_args from env var."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}):
SQLAlchemyAdapter("postgresql://user:pass@host/db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {"sslmode": "require"}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_postgresql_without_env_var(self, mock_create_engine):
"""Test that PostgreSQL connection has empty connect_args when no env var is set."""
with patch.dict(os.environ, {}, clear=True):
SQLAlchemyAdapter("postgresql://user:pass@host/db")
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
assert kwargs["connect_args"] == {}
@patch(
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
)
def test_connect_args_precedence(self, mock_create_engine):
"""Test that explicit connect_args take precedence over env var args."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
# Pass explicit connect_args that should override env var
SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120})
mock_create_engine.assert_called_once()
_, kwargs = mock_create_engine.call_args
assert "connect_args" in kwargs
# timeout should be 120 (explicit), not 60 (env var) or 30 (default)
assert kwargs["connect_args"] == {"timeout": 120}

View file

@ -0,0 +1,69 @@
import os
from unittest.mock import patch
from cognee.infrastructure.databases.relational.config import RelationalConfig
class TestRelationalConfig:
"""Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""
def test_database_connect_args_valid_json_dict(self):
"""Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
with patch.dict(
os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
):
config = RelationalConfig()
assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}
def test_database_connect_args_empty_string(self):
"""Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
config = RelationalConfig()
assert config.database_connect_args == ""
def test_database_connect_args_not_set(self):
"""Test that missing DATABASE_CONNECT_ARGS results in None."""
with patch.dict(os.environ, {}, clear=True):
config = RelationalConfig()
assert config.database_connect_args is None
def test_database_connect_args_invalid_json(self):
"""Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
config = RelationalConfig()
assert config.database_connect_args == {}
def test_database_connect_args_non_dict_json(self):
"""Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
config = RelationalConfig()
assert config.database_connect_args == {}
def test_database_connect_args_to_dict(self):
"""Test that database_connect_args is included in to_dict() output."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
config = RelationalConfig()
config_dict = config.to_dict()
assert "database_connect_args" in config_dict
assert config_dict["database_connect_args"] == {"timeout": 60}
def test_database_connect_args_integer_value(self):
"""Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
config = RelationalConfig()
assert config.database_connect_args == {"connect_timeout": 10}
def test_database_connect_args_mixed_types(self):
"""Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
with patch.dict(
os.environ,
{
"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
},
):
config = RelationalConfig()
assert config.database_connect_args == {
"timeout": 60,
"sslmode": "require",
"retries": 3,
"keepalive": True,
}