move DATABASE_CONNECT_ARGS parsing to RelationalConfig
Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
parent
654a573454
commit
e1d313a46b
5 changed files with 93 additions and 149 deletions
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import pydantic
|
import pydantic
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings):
|
||||||
db_username: Union[str, None] = None # "cognee"
|
db_username: Union[str, None] = None # "cognee"
|
||||||
db_password: Union[str, None] = None # "cognee"
|
db_password: Union[str, None] = None # "cognee"
|
||||||
db_provider: str = "sqlite"
|
db_provider: str = "sqlite"
|
||||||
|
database_connect_args: Union[str, None] = None
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings):
|
||||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||||
self.db_path = databases_directory_path
|
self.db_path = databases_directory_path
|
||||||
|
|
||||||
|
# Parse database_connect_args if provided as JSON string
|
||||||
|
if self.database_connect_args and isinstance(self.database_connect_args, str):
|
||||||
|
try:
|
||||||
|
parsed_args = json.loads(self.database_connect_args)
|
||||||
|
if isinstance(parsed_args, dict):
|
||||||
|
self.database_connect_args = parsed_args
|
||||||
|
else:
|
||||||
|
self.database_connect_args = {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.database_connect_args = {}
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
|
|
@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings):
|
||||||
--------
|
--------
|
||||||
|
|
||||||
- dict: A dictionary containing database configuration settings including db_path,
|
- dict: A dictionary containing database configuration settings including db_path,
|
||||||
db_name, db_host, db_port, db_username, db_password, and db_provider.
|
db_name, db_host, db_port, db_username, db_password, db_provider, and
|
||||||
|
database_connect_args.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"db_path": self.db_path,
|
"db_path": self.db_path,
|
||||||
|
|
@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings):
|
||||||
"db_username": self.db_username,
|
"db_username": self.db_username,
|
||||||
"db_password": self.db_password,
|
"db_password": self.db_password,
|
||||||
"db_provider": self.db_provider,
|
"db_provider": self.db_provider,
|
||||||
|
"database_connect_args": self.database_connect_args,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ def create_relational_engine(
|
||||||
db_username: str,
|
db_username: str,
|
||||||
db_password: str,
|
db_password: str,
|
||||||
db_provider: str,
|
db_provider: str,
|
||||||
|
database_connect_args: dict = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a relational database engine based on the specified parameters.
|
Create a relational database engine based on the specified parameters.
|
||||||
|
|
@ -29,6 +30,7 @@ def create_relational_engine(
|
||||||
- db_password (str): The password for database authentication, required for
|
- db_password (str): The password for database authentication, required for
|
||||||
PostgreSQL.
|
PostgreSQL.
|
||||||
- db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
|
- db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
|
||||||
|
- database_connect_args (dict, optional): Database driver connection arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -51,4 +53,4 @@ def create_relational_engine(
|
||||||
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
|
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
|
||||||
)
|
)
|
||||||
|
|
||||||
return SQLAlchemyAdapter(connection_string)
|
return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import asyncio
|
||||||
from os import path
|
from os import path
|
||||||
import tempfile
|
import tempfile
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import json
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import AsyncGenerator, List
|
from typing import AsyncGenerator, List
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
@ -38,14 +37,9 @@ class SQLAlchemyAdapter:
|
||||||
-----------
|
-----------
|
||||||
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
|
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
|
||||||
or 'postgresql://user:pass@host:port/db').
|
or 'postgresql://user:pass@host:port/db').
|
||||||
connect_args (dict, optional): Database driver arguments. These take precedence over
|
connect_args (dict, optional): Database driver connection arguments.
|
||||||
DATABASE_CONNECT_ARGS environment variable.
|
Configuration is loaded from RelationalConfig.database_connect_args, which reads
|
||||||
|
from the DATABASE_CONNECT_ARGS environment variable.
|
||||||
Environment Variables:
|
|
||||||
----------------------
|
|
||||||
DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments.
|
|
||||||
Allows configuration of driver-specific parameters such as SSL settings,
|
|
||||||
timeouts, or connection pool options without code changes.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
PostgreSQL with SSL:
|
PostgreSQL with SSL:
|
||||||
|
|
@ -53,28 +47,12 @@ class SQLAlchemyAdapter:
|
||||||
|
|
||||||
SQLite with custom timeout:
|
SQLite with custom timeout:
|
||||||
DATABASE_CONNECT_ARGS='{"timeout": 60}'
|
DATABASE_CONNECT_ARGS='{"timeout": 60}'
|
||||||
|
|
||||||
Note: This follows cognee's environment-based configuration pattern and is
|
|
||||||
the recommended approach for production deployments.
|
|
||||||
"""
|
"""
|
||||||
self.db_path: str = None
|
self.db_path: str = None
|
||||||
self.db_uri: str = connection_string
|
self.db_uri: str = connection_string
|
||||||
|
|
||||||
# Parse optional connection arguments from environment variable
|
# Use provided connect_args (already parsed from config)
|
||||||
env_connect_args_dict = {}
|
final_connect_args = connect_args or {}
|
||||||
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
|
|
||||||
if env_connect_args:
|
|
||||||
try:
|
|
||||||
parsed_args = json.loads(env_connect_args)
|
|
||||||
if isinstance(parsed_args, dict):
|
|
||||||
env_connect_args_dict = parsed_args
|
|
||||||
else:
|
|
||||||
logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring")
|
|
||||||
|
|
||||||
# Merge environment args with explicit args (explicit args take precedence)
|
|
||||||
final_connect_args = {**env_connect_args_dict, **(connect_args or {})}
|
|
||||||
|
|
||||||
if "sqlite" in connection_string:
|
if "sqlite" in connection_string:
|
||||||
[prefix, db_path] = connection_string.split("///")
|
[prefix, db_path] = connection_string.split("///")
|
||||||
|
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
import os
|
|
||||||
from unittest.mock import patch
|
|
||||||
from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import (
|
|
||||||
SQLAlchemyAdapter,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSqlAlchemyAdapter:
|
|
||||||
"""Test suite for SqlAlchemyAdapter environment variable handling and connection arguments."""
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
|
||||||
)
|
|
||||||
def test_sqlite_default_timeout(self, mock_create_engine):
|
|
||||||
"""Test that SQLite connection uses default timeout when no env var is set."""
|
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
|
||||||
SQLAlchemyAdapter("sqlite:///test.db")
|
|
||||||
mock_create_engine.assert_called_once()
|
|
||||||
_, kwargs = mock_create_engine.call_args
|
|
||||||
assert "connect_args" in kwargs
|
|
||||||
assert kwargs["connect_args"] == {"timeout": 30}
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
|
||||||
)
|
|
||||||
def test_sqlite_with_env_var_timeout(self, mock_create_engine):
|
|
||||||
"""Test that SQLite connection uses timeout from env var."""
|
|
||||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
|
||||||
SQLAlchemyAdapter("sqlite:///test.db")
|
|
||||||
mock_create_engine.assert_called_once()
|
|
||||||
_, kwargs = mock_create_engine.call_args
|
|
||||||
assert "connect_args" in kwargs
|
|
||||||
assert kwargs["connect_args"] == {"timeout": 60}
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
|
||||||
)
|
|
||||||
def test_sqlite_with_other_env_var_args(self, mock_create_engine):
|
|
||||||
"""Test that SQLite connection merges default timeout with other args from env var."""
|
|
||||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}):
|
|
||||||
SQLAlchemyAdapter("sqlite:///test.db")
|
|
||||||
mock_create_engine.assert_called_once()
|
|
||||||
_, kwargs = mock_create_engine.call_args
|
|
||||||
assert "connect_args" in kwargs
|
|
||||||
assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"}
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
|
||||||
)
|
|
||||||
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
|
|
||||||
def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine):
|
|
||||||
"""Test that SQLite connection uses default timeout when env var has invalid JSON."""
|
|
||||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
|
|
||||||
SQLAlchemyAdapter("sqlite:///test.db")
|
|
||||||
|
|
||||||
mock_logger.warning.assert_called_with(
|
|
||||||
"Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring"
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_create_engine.assert_called_once()
|
|
||||||
_, kwargs = mock_create_engine.call_args
|
|
||||||
assert "connect_args" in kwargs
|
|
||||||
assert kwargs["connect_args"] == {"timeout": 30}
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
|
||||||
)
|
|
||||||
@patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger")
|
|
||||||
def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine):
|
|
||||||
"""Test that SQLite connection uses default timeout when env var is valid JSON but not a dict."""
|
|
||||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
|
|
||||||
SQLAlchemyAdapter("sqlite:///test.db")
|
|
||||||
|
|
||||||
mock_logger.warning.assert_called_with(
|
|
||||||
"DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring"
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_create_engine.assert_called_once()
|
|
||||||
_, kwargs = mock_create_engine.call_args
|
|
||||||
assert "connect_args" in kwargs
|
|
||||||
assert kwargs["connect_args"] == {"timeout": 30}
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
|
||||||
)
|
|
||||||
def test_postgresql_with_env_var(self, mock_create_engine):
|
|
||||||
"""Test that PostgreSQL connection uses connect_args from env var."""
|
|
||||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}):
|
|
||||||
SQLAlchemyAdapter("postgresql://user:pass@host/db")
|
|
||||||
mock_create_engine.assert_called_once()
|
|
||||||
_, kwargs = mock_create_engine.call_args
|
|
||||||
assert "connect_args" in kwargs
|
|
||||||
assert kwargs["connect_args"] == {"sslmode": "require"}
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
|
||||||
)
|
|
||||||
def test_postgresql_without_env_var(self, mock_create_engine):
|
|
||||||
"""Test that PostgreSQL connection has empty connect_args when no env var is set."""
|
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
|
||||||
SQLAlchemyAdapter("postgresql://user:pass@host/db")
|
|
||||||
mock_create_engine.assert_called_once()
|
|
||||||
_, kwargs = mock_create_engine.call_args
|
|
||||||
assert "connect_args" in kwargs
|
|
||||||
assert kwargs["connect_args"] == {}
|
|
||||||
|
|
||||||
@patch(
|
|
||||||
"cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine"
|
|
||||||
)
|
|
||||||
def test_connect_args_precedence(self, mock_create_engine):
|
|
||||||
"""Test that explicit connect_args take precedence over env var args."""
|
|
||||||
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
|
||||||
# Pass explicit connect_args that should override env var
|
|
||||||
SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120})
|
|
||||||
|
|
||||||
mock_create_engine.assert_called_once()
|
|
||||||
_, kwargs = mock_create_engine.call_args
|
|
||||||
assert "connect_args" in kwargs
|
|
||||||
# timeout should be 120 (explicit), not 60 (env var) or 30 (default)
|
|
||||||
assert kwargs["connect_args"] == {"timeout": 120}
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
from cognee.infrastructure.databases.relational.config import RelationalConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestRelationalConfig:
|
||||||
|
"""Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing."""
|
||||||
|
|
||||||
|
def test_database_connect_args_valid_json_dict(self):
|
||||||
|
"""Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict."""
|
||||||
|
with patch.dict(
|
||||||
|
os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'}
|
||||||
|
):
|
||||||
|
config = RelationalConfig()
|
||||||
|
assert config.database_connect_args == {"timeout": 60, "sslmode": "require"}
|
||||||
|
|
||||||
|
def test_database_connect_args_empty_string(self):
|
||||||
|
"""Test that empty DATABASE_CONNECT_ARGS is handled correctly."""
|
||||||
|
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}):
|
||||||
|
config = RelationalConfig()
|
||||||
|
assert config.database_connect_args == ""
|
||||||
|
|
||||||
|
def test_database_connect_args_not_set(self):
|
||||||
|
"""Test that missing DATABASE_CONNECT_ARGS results in None."""
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
config = RelationalConfig()
|
||||||
|
assert config.database_connect_args is None
|
||||||
|
|
||||||
|
def test_database_connect_args_invalid_json(self):
|
||||||
|
"""Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict."""
|
||||||
|
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON
|
||||||
|
config = RelationalConfig()
|
||||||
|
assert config.database_connect_args == {}
|
||||||
|
|
||||||
|
def test_database_connect_args_non_dict_json(self):
|
||||||
|
"""Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict."""
|
||||||
|
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}):
|
||||||
|
config = RelationalConfig()
|
||||||
|
assert config.database_connect_args == {}
|
||||||
|
|
||||||
|
def test_database_connect_args_to_dict(self):
|
||||||
|
"""Test that database_connect_args is included in to_dict() output."""
|
||||||
|
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}):
|
||||||
|
config = RelationalConfig()
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
assert "database_connect_args" in config_dict
|
||||||
|
assert config_dict["database_connect_args"] == {"timeout": 60}
|
||||||
|
|
||||||
|
def test_database_connect_args_integer_value(self):
|
||||||
|
"""Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly."""
|
||||||
|
with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}):
|
||||||
|
config = RelationalConfig()
|
||||||
|
assert config.database_connect_args == {"connect_timeout": 10}
|
||||||
|
|
||||||
|
def test_database_connect_args_mixed_types(self):
|
||||||
|
"""Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly."""
|
||||||
|
with patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}'
|
||||||
|
},
|
||||||
|
):
|
||||||
|
config = RelationalConfig()
|
||||||
|
assert config.database_connect_args == {
|
||||||
|
"timeout": 60,
|
||||||
|
"sslmode": "require",
|
||||||
|
"retries": 3,
|
||||||
|
"keepalive": True,
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue