feat(database): add connect_args support to SqlAlchemyAdapter
- Add optional connect_args parameter to __init__ method - Support DATABASE_CONNECT_ARGS environment variable for JSON-based configuration - Enable custom connection parameters for all database engines (SQLite and PostgreSQL) - Maintain backward compatibility with existing code - Add proper error handling and validation for environment variable parsing Signed-off-by: ketanjain7981 <ketan.jain@think41.com>
This commit is contained in:
parent
c17f838034
commit
f9b16e508d
2 changed files with 28 additions and 2 deletions
|
|
@ -91,6 +91,15 @@ DB_NAME=cognee_db
|
||||||
#DB_USERNAME=cognee
|
#DB_USERNAME=cognee
|
||||||
#DB_PASSWORD=cognee
|
#DB_PASSWORD=cognee
|
||||||
|
|
||||||
|
# -- Advanced: Custom database connection arguments (optional) ---------------
|
||||||
|
# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc.
|
||||||
|
# Examples:
|
||||||
|
# For PostgreSQL with SSL:
|
||||||
|
# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
|
||||||
|
# For SQLite with custom timeout:
|
||||||
|
# DATABASE_CONNECT_ARGS='{"timeout": 60}'
|
||||||
|
#DATABASE_CONNECT_ARGS='{}'
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# 🕸️ Graph Database settings
|
# 🕸️ Graph Database settings
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
from os import path
|
from os import path
|
||||||
import tempfile
|
import tempfile
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import AsyncGenerator, List
|
from typing import AsyncGenerator, List
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
@ -29,10 +30,25 @@ class SQLAlchemyAdapter:
|
||||||
functions.
|
functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, connection_string: str):
|
def __init__(self, connection_string: str, connect_args: Optional[dict] = None):
|
||||||
self.db_path: str = None
|
self.db_path: str = None
|
||||||
self.db_uri: str = connection_string
|
self.db_uri: str = connection_string
|
||||||
|
|
||||||
|
env_connect_args = os.getenv("DATABASE_CONNECT_ARGS")
|
||||||
|
if env_connect_args:
|
||||||
|
try:
|
||||||
|
env_connect_args = json.loads(env_connect_args)
|
||||||
|
if isinstance(env_connect_args, dict):
|
||||||
|
if connect_args is None:
|
||||||
|
connect_args = {}
|
||||||
|
connect_args.update(env_connect_args)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}"
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}")
|
||||||
|
|
||||||
if "sqlite" in connection_string:
|
if "sqlite" in connection_string:
|
||||||
[prefix, db_path] = connection_string.split("///")
|
[prefix, db_path] = connection_string.split("///")
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
|
@ -53,7 +69,7 @@ class SQLAlchemyAdapter:
|
||||||
self.engine = create_async_engine(
|
self.engine = create_async_engine(
|
||||||
connection_string,
|
connection_string,
|
||||||
poolclass=NullPool,
|
poolclass=NullPool,
|
||||||
connect_args={"timeout": 30},
|
connect_args={**(connect_args or {}), **{"timeout": 30}},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.engine = create_async_engine(
|
self.engine = create_async_engine(
|
||||||
|
|
@ -63,6 +79,7 @@ class SQLAlchemyAdapter:
|
||||||
pool_recycle=280,
|
pool_recycle=280,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
pool_timeout=280,
|
pool_timeout=280,
|
||||||
|
connect_args=connect_args or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue