feat(database): add connect_args support to SqlAlchemyAdapter
- Add optional connect_args parameter to __init__ method - Support DATABASE_CONNECT_ARGS environment variable for JSON-based configuration - Enable custom connection parameters for all database engines (SQLite and PostgreSQL) - Maintain backward compatibility with existing code - Add proper error handling and validation for environment variable parsing Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
parent
c17f838034
commit
f9b16e508d
2 changed files with 28 additions and 2 deletions
|
|
@ -91,6 +91,15 @@ DB_NAME=cognee_db
|
|||
#DB_USERNAME=cognee
|
||||
#DB_PASSWORD=cognee
|
||||
|
||||
# -- Advanced: Custom database connection arguments (optional) ---------------
|
||||
# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc.
|
||||
# Examples:
|
||||
# For PostgreSQL with SSL:
|
||||
# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
|
||||
# For SQLite with custom timeout:
|
||||
# DATABASE_CONNECT_ARGS='{"timeout": 60}'
|
||||
#DATABASE_CONNECT_ARGS='{}'
|
||||
|
||||
################################################################################
|
||||
# 🕸️ Graph Database settings
|
||||
################################################################################
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
from os import path
|
||||
import tempfile
|
||||
from uuid import UUID
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -29,10 +30,25 @@ class SQLAlchemyAdapter:
|
|||
functions.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str):
|
||||
def __init__(self, connection_string: str, connect_args: Optional[dict] = None):
|
||||
self.db_path: str = None
|
||||
self.db_uri: str = connection_string
|
||||
|
||||
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
|
||||
if env_connect_args:
|
||||
try:
|
||||
env_connect_args = json.loads(env_connect_args)
|
||||
if isinstance(env_connect_args, dict):
|
||||
if connect_args is None:
|
||||
connect_args = {}
|
||||
connect_args.update(env_connect_args)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}")
|
||||
|
||||
if "sqlite" in connection_string:
|
||||
[prefix, db_path] = connection_string.split("///")
|
||||
self.db_path = db_path
|
||||
|
|
@ -53,7 +69,7 @@ class SQLAlchemyAdapter:
|
|||
self.engine = create_async_engine(
|
||||
connection_string,
|
||||
poolclass=NullPool,
|
||||
connect_args={"timeout": 30},
|
||||
connect_args={**(connect_args or {}), **{"timeout": 30}},
|
||||
)
|
||||
else:
|
||||
self.engine = create_async_engine(
|
||||
|
|
@ -63,6 +79,7 @@ class SQLAlchemyAdapter:
|
|||
pool_recycle=280,
|
||||
pool_pre_ping=True,
|
||||
pool_timeout=280,
|
||||
connect_args=connect_args or {},
|
||||
)
|
||||
|
||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue