feat(database): add connect_args support to SqlAlchemyAdapter

- Add optional connect_args parameter to __init__ method
- Support DATABASE_CONNECT_ARGS environment variable for JSON-based configuration
- Enable custom connection parameters for all database engines (SQLite and PostgreSQL)
- Maintain backward compatibility with existing code
- Add proper error handling and validation for environment variable parsing

Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
ketanjain7981 2025-12-02 20:23:42 +05:30
parent c17f838034
commit f9b16e508d
2 changed files with 28 additions and 2 deletions

View file

@ -91,6 +91,15 @@ DB_NAME=cognee_db
#DB_USERNAME=cognee #DB_USERNAME=cognee
#DB_PASSWORD=cognee #DB_PASSWORD=cognee
# -- Advanced: Custom database connection arguments (optional) ---------------
# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc.
# Examples:
# For PostgreSQL with SSL:
# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
# For SQLite with custom timeout:
# DATABASE_CONNECT_ARGS='{"timeout": 60}'
#DATABASE_CONNECT_ARGS='{}'
################################################################################ ################################################################################
# 🕸️ Graph Database settings # 🕸️ Graph Database settings
################################################################################ ################################################################################

View file

@ -3,6 +3,7 @@ import asyncio
from os import path from os import path
import tempfile import tempfile
from uuid import UUID from uuid import UUID
import json
from typing import Optional from typing import Optional
from typing import AsyncGenerator, List from typing import AsyncGenerator, List
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -29,10 +30,25 @@ class SQLAlchemyAdapter:
functions. functions.
""" """
def __init__(self, connection_string: str): def __init__(self, connection_string: str, connect_args: Optional[dict] = None):
self.db_path: str = None self.db_path: str = None
self.db_uri: str = connection_string self.db_uri: str = connection_string
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
if env_connect_args:
try:
env_connect_args = json.loads(env_connect_args)
if isinstance(env_connect_args, dict):
if connect_args is None:
connect_args = {}
connect_args.update(env_connect_args)
else:
logger.warning(
f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}"
)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}")
if "sqlite" in connection_string: if "sqlite" in connection_string:
[prefix, db_path] = connection_string.split("///") [prefix, db_path] = connection_string.split("///")
self.db_path = db_path self.db_path = db_path
@ -53,7 +69,7 @@ class SQLAlchemyAdapter:
self.engine = create_async_engine( self.engine = create_async_engine(
connection_string, connection_string,
poolclass=NullPool, poolclass=NullPool,
connect_args={"timeout": 30}, connect_args={**(connect_args or {}), **{"timeout": 30}},
) )
else: else:
self.engine = create_async_engine( self.engine = create_async_engine(
@ -63,6 +79,7 @@ class SQLAlchemyAdapter:
pool_recycle=280, pool_recycle=280,
pool_pre_ping=True, pool_pre_ping=True,
pool_timeout=280, pool_timeout=280,
connect_args=connect_args or {},
) )
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)