diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index a23c1b297..483228bfb 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -31,23 +31,39 @@ class SQLAlchemyAdapter: """ def __init__(self, connection_string: str, connect_args: Optional[dict] = None): + """ + Initialize the SQLAlchemy adapter with connection settings. + + Parameters: + ----------- + connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' + or 'postgresql://user:pass@host:port/db'). + connect_args (Optional[dict]): Optional dictionary of connection arguments to pass to + the database engine. These are driver-specific parameters such as SSL settings, + timeouts, or connection pool options. If DATABASE_CONNECT_ARGS environment variable + is set, those values will be merged with this parameter (programmatic values take + precedence over environment variables). Defaults to None. + + Environment Variables: + ---------------------- + DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments. + Example: '{"sslmode": "require", "connect_timeout": 10}' + """ self.db_path: str = None self.db_uri: str = connection_string env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") if env_connect_args: try: - env_connect_args = json.loads(env_connect_args) - if isinstance(env_connect_args, dict): - if connect_args is None: - connect_args = {} - connect_args.update(env_connect_args) + parsed_env_args = json.loads(env_connect_args) + if isinstance(parsed_env_args, dict): + # Merge: env vars as base, programmatic args override + merged_args = {**parsed_env_args, **(connect_args or {})} + connect_args = merged_args else: - logger.warning( - f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}" - ) - except json.JSONDecodeError as e: - logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}") + logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") + except json.JSONDecodeError: + logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") @@ -69,7 +85,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={**(connect_args or {}), **{"timeout": 30}}, + connect_args={**{"timeout": 30}, **(connect_args or {})}, ) else: self.engine = create_async_engine(