move DATABASE_CONNECT_ARGS parsing to RelationalConfig
Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
parent
654a573454
commit
e1d313a46b
5 changed files with 93 additions and 149 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import json
|
||||
import pydantic
|
||||
from typing import Union
|
||||
from functools import lru_cache
|
||||
|
|
@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings):
|
|||
db_username: Union[str, None] = None # "cognee"
|
||||
db_password: Union[str, None] = None # "cognee"
|
||||
db_provider: str = "sqlite"
|
||||
database_connect_args: Union[str, None] = None
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings):
|
|||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||
self.db_path = databases_directory_path
|
||||
|
||||
# Parse database_connect_args if provided as JSON string
|
||||
if self.database_connect_args and isinstance(self.database_connect_args, str):
|
||||
try:
|
||||
parsed_args = json.loads(self.database_connect_args)
|
||||
if isinstance(parsed_args, dict):
|
||||
self.database_connect_args = parsed_args
|
||||
else:
|
||||
self.database_connect_args = {}
|
||||
except json.JSONDecodeError:
|
||||
self.database_connect_args = {}
|
||||
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
|
@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings):
|
|||
--------
|
||||
|
||||
- dict: A dictionary containing database configuration settings including db_path,
|
||||
db_name, db_host, db_port, db_username, db_password, and db_provider.
|
||||
db_name, db_host, db_port, db_username, db_password, db_provider, and
|
||||
database_connect_args.
|
||||
"""
|
||||
return {
|
||||
"db_path": self.db_path,
|
||||
|
|
@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings):
|
|||
"db_username": self.db_username,
|
||||
"db_password": self.db_password,
|
||||
"db_provider": self.db_provider,
|
||||
"database_connect_args": self.database_connect_args,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ def create_relational_engine(
|
|||
db_username: str,
|
||||
db_password: str,
|
||||
db_provider: str,
|
||||
database_connect_args: dict = None,
|
||||
):
|
||||
"""
|
||||
Create a relational database engine based on the specified parameters.
|
||||
|
|
@ -29,6 +30,7 @@ def create_relational_engine(
|
|||
- db_password (str): The password for database authentication, required for
|
||||
PostgreSQL.
|
||||
- db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
|
||||
- database_connect_args (dict, optional): Database driver connection arguments.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
|
@ -51,4 +53,4 @@ def create_relational_engine(
|
|||
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
|
||||
)
|
||||
|
||||
return SQLAlchemyAdapter(connection_string)
|
||||
return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import asyncio
|
|||
from os import path
|
||||
import tempfile
|
||||
from uuid import UUID
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -38,14 +37,9 @@ class SQLAlchemyAdapter:
|
|||
-----------
|
||||
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
|
||||
or 'postgresql://user:pass@host:port/db').
|
||||
connect_args (dict, optional): Database driver arguments. These take precedence over
|
||||
DATABASE_CONNECT_ARGS environment variable.
|
||||
|
||||
Environment Variables:
|
||||
----------------------
|
||||
DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments.
|
||||
Allows configuration of driver-specific parameters such as SSL settings,
|
||||
timeouts, or connection pool options without code changes.
|
||||
connect_args (dict, optional): Database driver connection arguments.
|
||||
Configuration is loaded from RelationalConfig.database_connect_args, which reads
|
||||
from the DATABASE_CONNECT_ARGS environment variable.
|
||||
|
||||
Examples:
|
||||
PostgreSQL with SSL:
|
||||
|
|
@ -53,28 +47,12 @@ class SQLAlchemyAdapter:
|
|||
|
||||
SQLite with custom timeout:
|
||||
DATABASE_CONNECT_ARGS='{"timeout": 60}'
|
||||
|
||||
Note: This follows cognee's environment-based configuration pattern and is
|
||||
the recommended approach for production deployments.
|
||||
"""
|
||||
self.db_path: str = None
|
||||
self.db_uri: str = connection_string
|
||||
|
||||
# Parse optional connection arguments from environment variable
|
||||
env_connect_args_dict = {}
|
||||
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
|
||||
if env_connect_args:
|
||||
try:
|
||||
parsed_args = json.loads(env_connect_args)
|
||||
if isinstance(parsed_args, dict):
|
||||
env_connect_args_dict = parsed_args
|
||||
else:
|
||||
logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring")
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring")
|
||||
|
||||
# Merge environment args with explicit args (explicit args take precedence)
|
||||
final_connect_args = {**env_connect_args_dict, **(connect_args or {})}
|
||||
# Use provided connect_args (already parsed from config)
|
||||
final_connect_args = connect_args or {}
|
||||
|
||||
if "sqlite" in connection_string:
|
||||
[prefix, db_path] = connection_string.split("///")
|
||||
|
|
|
|||
|
|
@ -1,120 +0,0 @@
|
|||
import os
|
||||
from unittest.mock import patch
|
||||
from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import (
|
||||
SQLAlchemyAdapter,
|
||||
)
|
||||
|
||||
|
||||
class TestSqlAlchemyAdapter:
|
||||
"""Test suite for SqlAlchemyAdapter environment variable handling and connection arguments."""
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
def test_sqlite_default_timeout(self, mock_create_engine):
|
||||
"""Test that SQLite connection uses default timeout when no env var is set."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
def test_sqlite_with_env_var_timeout(self, mock_create_engine):
|
||||
"""Test that SQLite connection uses timeout from env var."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 60}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
def test_sqlite_with_other_env_var_args(self, mock_create_engine):
|
||||
"""Test that SQLite connection merges default timeout with other args from env var."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}):
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
|
||||
def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine):
|
||||
"""Test that SQLite connection uses default timeout when env var has invalid JSON."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring"
|
||||
)
|
||||
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
|
||||
def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine):
|
||||
"""Test that SQLite connection uses default timeout when env var is valid JSON but not a dict."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
|
||||
SQLAlchemyAdapter("sqlite:///test.db")
|
||||
|
||||
mock_logger.warning.assert_called_with(
|
||||
"DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring"
|
||||
)
|
||||
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"timeout": 30}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
def test_postgresql_with_env_var(self, mock_create_engine):
|
||||
"""Test that PostgreSQL connection uses connect_args from env var."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}):
|
||||
SQLAlchemyAdapter("postgresql://user:pass@host/db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {"sslmode": "require"}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
def test_postgresql_without_env_var(self, mock_create_engine):
|
||||
"""Test that PostgreSQL connection has empty connect_args when no env var is set."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
SQLAlchemyAdapter("postgresql://user:pass@host/db")
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
assert kwargs["connect_args"] == {}
|
||||
|
||||
@patch(
|
||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
||||
)
|
||||
def test_connect_args_precedence(self, mock_create_engine):
|
||||
"""Test that explicit connect_args take precedence over env var args."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
||||
# Pass explicit connect_args that should override env var
|
||||
SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120})
|
||||
|
||||
mock_create_engine.assert_called_once()
|
||||
_, kwargs = mock_create_engine.call_args
|
||||
assert "connect_args" in kwargs
|
||||
# timeout should be 120 (explicit), not 60 (env var) or 30 (default)
|
||||
assert kwargs["connect_args"] == {"timeout": 120}
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
import os
|
||||
from unittest.mock import patch
|
||||
from cognee.infrastructure.databases.relational.config import RelationalConfig
|
||||
|
||||
|
||||
class TestRelationalConfig:
|
||||
"""Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""
|
||||
|
||||
def test_database_connect_args_valid_json_dict(self):
|
||||
"""Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
|
||||
with patch.dict(
|
||||
os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
|
||||
):
|
||||
config = RelationalConfig()
|
||||
assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}
|
||||
|
||||
def test_database_connect_args_empty_string(self):
|
||||
"""Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
|
||||
config = RelationalConfig()
|
||||
assert config.database_connect_args == ""
|
||||
|
||||
def test_database_connect_args_not_set(self):
|
||||
"""Test that missing DATABASE_CONNECT_ARGS results in None."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
config = RelationalConfig()
|
||||
assert config.database_connect_args is None
|
||||
|
||||
def test_database_connect_args_invalid_json(self):
|
||||
"""Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
|
||||
config = RelationalConfig()
|
||||
assert config.database_connect_args == {}
|
||||
|
||||
def test_database_connect_args_non_dict_json(self):
|
||||
"""Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
|
||||
config = RelationalConfig()
|
||||
assert config.database_connect_args == {}
|
||||
|
||||
def test_database_connect_args_to_dict(self):
|
||||
"""Test that database_connect_args is included in to_dict() output."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
||||
config = RelationalConfig()
|
||||
config_dict = config.to_dict()
|
||||
assert "database_connect_args" in config_dict
|
||||
assert config_dict["database_connect_args"] == {"timeout": 60}
|
||||
|
||||
def test_database_connect_args_integer_value(self):
|
||||
"""Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
|
||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
|
||||
config = RelationalConfig()
|
||||
assert config.database_connect_args == {"connect_timeout": 10}
|
||||
|
||||
def test_database_connect_args_mixed_types(self):
|
||||
"""Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
|
||||
},
|
||||
):
|
||||
config = RelationalConfig()
|
||||
assert config.database_connect_args == {
|
||||
"timeout": 60,
|
||||
"sslmode": "require",
|
||||
"retries": 3,
|
||||
"keepalive": True,
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue