feat(database): add connect_args support to SqlAlchemyAdapter

- Add optional connect_args parameter to __init__ method
- Support DATABASE_CONNECT_ARGS environment variable for JSON-based configuration
- Enable custom connection parameters for all database engines (SQLite and PostgreSQL)
- Maintain backward compatibility with existing code
- Add proper error handling and validation for environment variable parsing

Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
ketanjain7981 2025-12-02 20:23:42 +05:30
parent c17f838034
commit f9b16e508d
2 changed files with 28 additions and 2 deletions

View file

@ -91,6 +91,15 @@ DB_NAME=cognee_db
#DB_USERNAME=cognee
#DB_PASSWORD=cognee
# -- Advanced: Custom database connection arguments (optional) ---------------
# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc.
# Examples:
# For PostgreSQL with SSL:
# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
# For SQLite with custom timeout:
# DATABASE_CONNECT_ARGS='{"timeout": 60}'
#DATABASE_CONNECT_ARGS='{}'
################################################################################
# 🕸️ Graph Database settings
################################################################################

View file

@ -3,6 +3,7 @@ import asyncio
from os import path
import tempfile
from uuid import UUID
import json
from typing import Optional
from typing import AsyncGenerator, List
from contextlib import asynccontextmanager
@ -29,10 +30,25 @@ class SQLAlchemyAdapter:
functions.
"""
def __init__(self, connection_string: str):
def __init__(self, connection_string: str, connect_args: Optional[dict] = None):
self.db_path: str = None
self.db_uri: str = connection_string
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
if env_connect_args:
try:
env_connect_args = json.loads(env_connect_args)
if isinstance(env_connect_args, dict):
if connect_args is None:
connect_args = {}
connect_args.update(env_connect_args)
else:
logger.warning(
f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}"
)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}")
if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///")
self.db_path = db_path
@ -53,7 +69,7 @@ class SQLAlchemyAdapter:
self.engine = create_async_engine(
connection_string,
poolclass=NullPool,
connect_args={"timeout": 30},
connect_args={**(connect_args or {}), **{"timeout": 30}},
)
else:
self.engine = create_async_engine(
@ -63,6 +79,7 @@ class SQLAlchemyAdapter:
pool_recycle=280,
pool_pre_ping=True,
pool_timeout=280,
connect_args=connect_args or {},
)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)