diff --git a/cognitive_architecture/config.py b/cognitive_architecture/config.py index 600626a2e..322a60455 100644 --- a/cognitive_architecture/config.py +++ b/cognitive_architecture/config.py @@ -1,8 +1,8 @@ +"""Configuration for cognee - cognitive architecture framework.""" import os -import json import configparser import uuid -from typing import Optional, List, Dict, Any +from typing import Optional, Dict, Any from dataclasses import dataclass, field from pathlib import Path from dotenv import load_dotenv @@ -16,8 +16,8 @@ load_dotenv(dotenv_path=dotenv_path) @dataclass class Config: - # Paths and Directories - memgpt_dir: str = field( + """ Configuration for cognee - cognitive architecture framework. """ + cognee_dir: str = field( default_factory=lambda: os.getenv("COG_ARCH_DIR", "cognitive_achitecture") ) config_path: str = field( @@ -26,7 +26,15 @@ class Config: ) ) - vectordb: str = "lancedb" + db_path = Path(__file__).resolve().parent / "database/data" + + vectordb: str = "weaviate" + db_type: str = os.getenv("DB_TYPE", "postgres") + db_name: str = os.getenv("DB_NAME", "cognee.db") + db_host: str = os.getenv("DB_HOST", "localhost") + db_port: str = os.getenv("DB_PORT", "5432") + db_user: str = os.getenv("DB_USER", "cognee") + db_password: str = os.getenv("DB_PASSWORD", "cognee") # Model parameters model: str = "gpt-4-1106-preview" diff --git a/cognitive_architecture/database/create_database.py b/cognitive_architecture/database/create_database.py index 55978958b..29ac69a6e 100644 --- a/cognitive_architecture/database/create_database.py +++ b/cognitive_architecture/database/create_database.py @@ -1,104 +1,91 @@ +"""This module provides functionalities for creating and managing databases.""" + +import asyncio import os import logging -import psycopg2 -from dotenv import load_dotenv -from relationaldb.database import Base +from contextlib import asynccontextmanager +from sqlalchemy.ext.asyncio import create_async_engine +from relationaldb.models import memory, metadatas, operation, sessions, user, docs from sqlalchemy import create_engine, text - -from relationaldb.models import memory -from relationaldb.models import metadatas -from relationaldb.models import operation -from relationaldb.models import sessions -from relationaldb.models import user -from relationaldb.models import docs +from dotenv import load_dotenv +from relationaldb.database import ( + Base,DatabaseConfig) +from cognitive_architecture.config import Config +config = Config() +config.load() load_dotenv() logger = logging.getLogger(__name__) -import os -import logging -from sqlalchemy import create_engine, text -from sqlalchemy.exc import SQLAlchemyError -from contextlib import contextmanager -from dotenv import load_dotenv -from relationaldb.database import ( - Base, -) # Assuming all models are imported within this module -from relationaldb.database import ( - DatabaseConfig, -) # Assuming DatabaseConfig is defined as before - -load_dotenv() -logger = logging.getLogger(__name__) class DatabaseManager: + """Manages database creation, deletion, and table initialization.""" def __init__(self, config: DatabaseConfig): + """Initialize the DatabaseManager with a given configuration.""" self.config = config - self.engine = create_engine(config.get_sqlalchemy_database_url()) + self.engine = create_async_engine(config.get_sqlalchemy_database_url(), echo=True) self.db_type = config.db_type - @contextmanager - def get_connection(self): + @asynccontextmanager + async def get_connection(self): + """Initialize the DatabaseManager with a given configuration.""" if self.db_type in ["sqlite", "duckdb"]: # For SQLite and DuckDB, the engine itself manages connections yield self.engine else: - connection = self.engine.connect() - try: + async with self.engine.connect() as connection: yield connection - finally: - connection.close() - def database_exists(self, db_name): + async def database_exists(self, db_name): + """Check if a database exists.""" if self.db_type in ["sqlite", "duckdb"]: # For SQLite and DuckDB, check if the database file exists return os.path.exists(db_name) else: query = text(f"SELECT 1 FROM pg_database WHERE datname='{db_name}'") - with self.get_connection() as connection: - result = connection.execute(query).fetchone() - return result is not None + async with self.get_connection() as connection: + result = await connection.execute(query) + return await result.fetchone() is not None - def create_database(self, db_name): + async def create_database(self, db_name): + """Create a new database.""" if self.db_type not in ["sqlite", "duckdb"]: # For databases like PostgreSQL, create the database explicitly - with self.get_connection() as connection: - connection.execution_options(isolation_level="AUTOCOMMIT") - connection.execute(f"CREATE DATABASE {db_name}") + async with self.get_connection() as connection: + await connection.execute(text(f"CREATE DATABASE {db_name}")) - def drop_database(self, db_name): + async def drop_database(self, db_name): + """Drop an existing database.""" if self.db_type in ["sqlite", "duckdb"]: # For SQLite and DuckDB, simply remove the database file os.remove(db_name) else: - with self.get_connection() as connection: - connection.execution_options(isolation_level="AUTOCOMMIT") - connection.execute(f"DROP DATABASE IF EXISTS {db_name}") - - def create_tables(self): - Base.metadata.create_all(bind=self.engine) + async with self.get_connection() as connection: + await connection.execute(text(f"DROP DATABASE IF EXISTS {db_name}")) + async def create_tables(self): + """Create tables based on the SQLAlchemy Base metadata.""" + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) if __name__ == "__main__": - # Example usage with SQLite - config = DatabaseConfig(db_type="sqlite", db_name="mydatabase.db") + async def main(): + """Runs as a part of startup docker scripts to create the database and tables.""" - # For DuckDB, you would set db_type to 'duckdb' and provide the database file name - # config = DatabaseConfig(db_type='duckdb', db_name='mydatabase.duckdb') + dbconfig = DatabaseConfig(db_type=config.db_type, db_name=config.db_name) + db_manager = DatabaseManager(config=dbconfig) + database_name = dbconfig.db_name - db_manager = DatabaseManager(config=config) + if not await db_manager.database_exists(database_name): + print(f"Database {database_name} does not exist. Creating...") + await db_manager.create_database(database_name) + print(f"Database {database_name} created successfully.") - database_name = config.db_name - - if not db_manager.database_exists(database_name): - logger.info(f"Database {database_name} does not exist. Creating...") - db_manager.create_database(database_name) - logger.info(f"Database {database_name} created successfully.") - - db_manager.create_tables() + await db_manager.create_tables() + asyncio.run(main()) # # def create_admin_engine(username, password, host, database_name): # admin_url = f"postgresql://{username}:{password}@{host}:5432/{database_name}" diff --git a/cognitive_architecture/database/relationaldb/database.py b/cognitive_architecture/database/relationaldb/database.py index 1533d9f9b..ba34bfd26 100644 --- a/cognitive_architecture/database/relationaldb/database.py +++ b/cognitive_architecture/database/relationaldb/database.py @@ -1,14 +1,11 @@ -import json -import os +"""Database configuration and connection.""" from pathlib import Path - +from contextlib import asynccontextmanager from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import declarative_base, sessionmaker -from contextlib import asynccontextmanager -from sqlalchemy.exc import OperationalError -import asyncio -import sys from dotenv import load_dotenv +from cognitive_architecture.config import Config + load_dotenv() @@ -19,6 +16,7 @@ MAX_RETRIES = 3 RETRY_DELAY = 5 class DatabaseConfig: + """Configuration for the database connection.""" def __init__( self, db_type=None, @@ -27,37 +25,26 @@ class DatabaseConfig: user=None, password=None, port=None, - config_file=None, ): - if config_file: - self.load_from_file(config_file) - else: - # Load default values from environment variables or use provided values - self.db_type = db_type or os.getenv("DB_TYPE", "sqlite") - self.db_name = db_name or os.getenv("DB_NAME", "database.db") - self.host = host or os.getenv("DB_HOST", "localhost") - self.user = user or os.getenv("DB_USER", "user") - self.password = password or os.getenv("DB_PASSWORD", "password") - self.port = port or os.getenv("DB_PORT", "5432") + self.config = Config() + self.config.load() + self.base_path = Path(self.config.db_path) + # Load default values from environment variables or use provided values + self.db_type = db_type or self.config.db_type + self.db_name = db_name or self.config.db_name + self.host = host or self.config.db_host + self.user = user or self.config.db_user + self.password = password or self.config.db_password + self.port = port or self.config.db_port + - def load_from_file(self, file_path): - with open(file_path, "r") as file: - config = json.load(file) - self.db_type = config.get("db_type", "sqlite") - self.db_name = config.get("db_name", "database.db") - self.host = config.get("host", "localhost") - self.user = config.get("user", "user") - self.password = config.get("password", "password") - self.port = config.get("port", "5432") def get_sqlalchemy_database_url(self): + """Get the SQLAlchemy database URL based on the configuration.""" + db_path = (self.base_path / self.db_name).absolute() if self.db_type == "sqlite": - db_path = Path(self.db_name).absolute() # Ensure the path is absolute return f"sqlite+aiosqlite:///{db_path}" # SQLite uses file path elif self.db_type == "duckdb": - db_path = Path( - self.db_name - ).absolute() # Ensure the path is absolute for DuckDB as well return f"duckdb+aiosqlite:///{db_path}" elif self.db_type == "postgresql": # Ensure optional parameters are handled gracefully @@ -95,6 +82,7 @@ Base = declarative_base() # Use asynccontextmanager to define an async context manager @asynccontextmanager async def get_db(): + """Provide a database session to the context.""" db = AsyncSessionLocal() try: yield db @@ -102,31 +90,3 @@ async def get_db(): await db.close() -# -# if os.environ.get('AWS_ENV') == 'prd' or os.environ.get('AWS_ENV') == 'dev': -# host = os.environ.get('POSTGRES_HOST') -# username = os.environ.get('POSTGRES_USER') -# password = os.environ.get('POSTGRES_PASSWORD') -# database_name = os.environ.get('POSTGRES_DB') -# elif os.environ.get('AWS_ENV') == 'local': -# host = os.environ.get('POSTGRES_HOST') -# username = os.environ.get('POSTGRES_USER') -# password = os.environ.get('POSTGRES_PASSWORD') -# database_name = os.environ.get('POSTGRES_DB') -# else: -# host = os.environ.get('POSTGRES_HOST') -# username = os.environ.get('POSTGRES_USER') -# password = os.environ.get('POSTGRES_PASSWORD') -# database_name = os.environ.get('POSTGRES_DB') -# -# # host = config.postgres_host -# # username = config.postgres_user -# # password = config.postgres_password -# # database_name = config.postgres_db -# -# -# -# -# -# # Use the asyncpg driver for async operation -# SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}"