Sqlite works, made fixes in config so it becomes a basis, added a few mods on top
This commit is contained in:
parent
653fe049b4
commit
b0b9c31102
3 changed files with 79 additions and 124 deletions
|
|
@ -1,8 +1,8 @@
|
|||
"""Configuration for cognee - cognitive architecture framework."""
|
||||
import os
|
||||
import json
|
||||
import configparser
|
||||
import uuid
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -16,8 +16,8 @@ load_dotenv(dotenv_path=dotenv_path)
|
|||
|
||||
@dataclass
|
||||
class Config:
|
||||
# Paths and Directories
|
||||
memgpt_dir: str = field(
|
||||
""" Configuration for cognee - cognitive architecture framework. """
|
||||
cognee_dir: str = field(
|
||||
default_factory=lambda: os.getenv("COG_ARCH_DIR", "cognitive_achitecture")
|
||||
)
|
||||
config_path: str = field(
|
||||
|
|
@ -26,7 +26,15 @@ class Config:
|
|||
)
|
||||
)
|
||||
|
||||
vectordb: str = "lancedb"
|
||||
db_path = Path(__file__).resolve().parent / "database/data"
|
||||
|
||||
vectordb: str = "weaviate"
|
||||
db_type: str = os.getenv("DB_TYPE", "postgres")
|
||||
db_name: str = os.getenv("DB_NAME", "cognee.db")
|
||||
db_host: str = os.getenv("DB_HOST", "localhost")
|
||||
db_port: str = os.getenv("DB_PORT", "5432")
|
||||
db_user: str = os.getenv("DB_USER", "cognee")
|
||||
db_password: str = os.getenv("DB_PASSWORD", "cognee")
|
||||
|
||||
# Model parameters
|
||||
model: str = "gpt-4-1106-preview"
|
||||
|
|
|
|||
|
|
@ -1,104 +1,91 @@
|
|||
"""This module provides functionalities for creating and managing databases."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import psycopg2
|
||||
from dotenv import load_dotenv
|
||||
from relationaldb.database import Base
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from relationaldb.models import memory, metadatas, operation, sessions, user, docs
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from relationaldb.models import memory
|
||||
from relationaldb.models import metadatas
|
||||
from relationaldb.models import operation
|
||||
from relationaldb.models import sessions
|
||||
from relationaldb.models import user
|
||||
from relationaldb.models import docs
|
||||
from dotenv import load_dotenv
|
||||
from relationaldb.database import (
|
||||
Base,DatabaseConfig)
|
||||
from cognitive_architecture.config import Config
|
||||
config = Config()
|
||||
config.load()
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
import os
|
||||
import logging
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from contextlib import contextmanager
|
||||
from dotenv import load_dotenv
|
||||
from relationaldb.database import (
|
||||
Base,
|
||||
) # Assuming all models are imported within this module
|
||||
from relationaldb.database import (
|
||||
DatabaseConfig,
|
||||
) # Assuming DatabaseConfig is defined as before
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages database creation, deletion, and table initialization."""
|
||||
def __init__(self, config: DatabaseConfig):
|
||||
"""Initialize the DatabaseManager with a given configuration."""
|
||||
self.config = config
|
||||
self.engine = create_engine(config.get_sqlalchemy_database_url())
|
||||
self.engine = create_async_engine(config.get_sqlalchemy_database_url(), echo=True)
|
||||
self.db_type = config.db_type
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
@asynccontextmanager
|
||||
async def get_connection(self):
|
||||
"""Initialize the DatabaseManager with a given configuration."""
|
||||
if self.db_type in ["sqlite", "duckdb"]:
|
||||
# For SQLite and DuckDB, the engine itself manages connections
|
||||
yield self.engine
|
||||
else:
|
||||
connection = self.engine.connect()
|
||||
try:
|
||||
async with self.engine.connect() as connection:
|
||||
yield connection
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
def database_exists(self, db_name):
|
||||
async def database_exists(self, db_name):
|
||||
"""Check if a database exists."""
|
||||
if self.db_type in ["sqlite", "duckdb"]:
|
||||
# For SQLite and DuckDB, check if the database file exists
|
||||
return os.path.exists(db_name)
|
||||
else:
|
||||
query = text(f"SELECT 1 FROM pg_database WHERE datname='{db_name}'")
|
||||
with self.get_connection() as connection:
|
||||
result = connection.execute(query).fetchone()
|
||||
return result is not None
|
||||
async with self.get_connection() as connection:
|
||||
result = await connection.execute(query)
|
||||
return await result.fetchone() is not None
|
||||
|
||||
def create_database(self, db_name):
|
||||
async def create_database(self, db_name):
|
||||
"""Create a new database."""
|
||||
if self.db_type not in ["sqlite", "duckdb"]:
|
||||
# For databases like PostgreSQL, create the database explicitly
|
||||
with self.get_connection() as connection:
|
||||
connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
connection.execute(f"CREATE DATABASE {db_name}")
|
||||
async with self.get_connection() as connection:
|
||||
await connection.execute(text(f"CREATE DATABASE {db_name}"))
|
||||
|
||||
def drop_database(self, db_name):
|
||||
async def drop_database(self, db_name):
|
||||
"""Drop an existing database."""
|
||||
if self.db_type in ["sqlite", "duckdb"]:
|
||||
# For SQLite and DuckDB, simply remove the database file
|
||||
os.remove(db_name)
|
||||
else:
|
||||
with self.get_connection() as connection:
|
||||
connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
connection.execute(f"DROP DATABASE IF EXISTS {db_name}")
|
||||
|
||||
def create_tables(self):
|
||||
Base.metadata.create_all(bind=self.engine)
|
||||
async with self.get_connection() as connection:
|
||||
await connection.execute(text(f"DROP DATABASE IF EXISTS {db_name}"))
|
||||
|
||||
async def create_tables(self):
|
||||
"""Create tables based on the SQLAlchemy Base metadata."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage with SQLite
|
||||
config = DatabaseConfig(db_type="sqlite", db_name="mydatabase.db")
|
||||
async def main():
|
||||
"""Runs as a part of startup docker scripts to create the database and tables."""
|
||||
|
||||
# For DuckDB, you would set db_type to 'duckdb' and provide the database file name
|
||||
# config = DatabaseConfig(db_type='duckdb', db_name='mydatabase.duckdb')
|
||||
dbconfig = DatabaseConfig(db_type=config.db_type, db_name=config.db_name)
|
||||
db_manager = DatabaseManager(config=dbconfig)
|
||||
database_name = dbconfig.db_name
|
||||
|
||||
db_manager = DatabaseManager(config=config)
|
||||
if not await db_manager.database_exists(database_name):
|
||||
print(f"Database {database_name} does not exist. Creating...")
|
||||
await db_manager.create_database(database_name)
|
||||
print(f"Database {database_name} created successfully.")
|
||||
|
||||
database_name = config.db_name
|
||||
|
||||
if not db_manager.database_exists(database_name):
|
||||
logger.info(f"Database {database_name} does not exist. Creating...")
|
||||
db_manager.create_database(database_name)
|
||||
logger.info(f"Database {database_name} created successfully.")
|
||||
|
||||
db_manager.create_tables()
|
||||
await db_manager.create_tables()
|
||||
|
||||
asyncio.run(main())
|
||||
#
|
||||
# def create_admin_engine(username, password, host, database_name):
|
||||
# admin_url = f"postgresql://{username}:{password}@{host}:5432/{database_name}"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
import json
|
||||
import os
|
||||
"""Database configuration and connection."""
|
||||
from pathlib import Path
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.exc import OperationalError
|
||||
import asyncio
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
from cognitive_architecture.config import Config
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
|
@ -19,6 +16,7 @@ MAX_RETRIES = 3
|
|||
RETRY_DELAY = 5
|
||||
|
||||
class DatabaseConfig:
|
||||
"""Configuration for the database connection."""
|
||||
def __init__(
|
||||
self,
|
||||
db_type=None,
|
||||
|
|
@ -27,37 +25,26 @@ class DatabaseConfig:
|
|||
user=None,
|
||||
password=None,
|
||||
port=None,
|
||||
config_file=None,
|
||||
):
|
||||
if config_file:
|
||||
self.load_from_file(config_file)
|
||||
else:
|
||||
# Load default values from environment variables or use provided values
|
||||
self.db_type = db_type or os.getenv("DB_TYPE", "sqlite")
|
||||
self.db_name = db_name or os.getenv("DB_NAME", "database.db")
|
||||
self.host = host or os.getenv("DB_HOST", "localhost")
|
||||
self.user = user or os.getenv("DB_USER", "user")
|
||||
self.password = password or os.getenv("DB_PASSWORD", "password")
|
||||
self.port = port or os.getenv("DB_PORT", "5432")
|
||||
self.config = Config()
|
||||
self.config.load()
|
||||
self.base_path = Path(self.config.db_path)
|
||||
# Load default values from environment variables or use provided values
|
||||
self.db_type = db_type or self.config.db_type
|
||||
self.db_name = db_name or self.config.db_name
|
||||
self.host = host or self.config.db_host
|
||||
self.user = user or self.config.db_user
|
||||
self.password = password or self.config.db_password
|
||||
self.port = port or self.config.db_port
|
||||
|
||||
|
||||
def load_from_file(self, file_path):
|
||||
with open(file_path, "r") as file:
|
||||
config = json.load(file)
|
||||
self.db_type = config.get("db_type", "sqlite")
|
||||
self.db_name = config.get("db_name", "database.db")
|
||||
self.host = config.get("host", "localhost")
|
||||
self.user = config.get("user", "user")
|
||||
self.password = config.get("password", "password")
|
||||
self.port = config.get("port", "5432")
|
||||
|
||||
def get_sqlalchemy_database_url(self):
|
||||
"""Get the SQLAlchemy database URL based on the configuration."""
|
||||
db_path = (self.base_path / self.db_name).absolute()
|
||||
if self.db_type == "sqlite":
|
||||
db_path = Path(self.db_name).absolute() # Ensure the path is absolute
|
||||
return f"sqlite+aiosqlite:///{db_path}" # SQLite uses file path
|
||||
elif self.db_type == "duckdb":
|
||||
db_path = Path(
|
||||
self.db_name
|
||||
).absolute() # Ensure the path is absolute for DuckDB as well
|
||||
return f"duckdb+aiosqlite:///{db_path}"
|
||||
elif self.db_type == "postgresql":
|
||||
# Ensure optional parameters are handled gracefully
|
||||
|
|
@ -95,6 +82,7 @@ Base = declarative_base()
|
|||
# Use asynccontextmanager to define an async context manager
|
||||
@asynccontextmanager
|
||||
async def get_db():
|
||||
"""Provide a database session to the context."""
|
||||
db = AsyncSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
|
|
@ -102,31 +90,3 @@ async def get_db():
|
|||
await db.close()
|
||||
|
||||
|
||||
#
|
||||
# if os.environ.get('AWS_ENV') == 'prd' or os.environ.get('AWS_ENV') == 'dev':
|
||||
# host = os.environ.get('POSTGRES_HOST')
|
||||
# username = os.environ.get('POSTGRES_USER')
|
||||
# password = os.environ.get('POSTGRES_PASSWORD')
|
||||
# database_name = os.environ.get('POSTGRES_DB')
|
||||
# elif os.environ.get('AWS_ENV') == 'local':
|
||||
# host = os.environ.get('POSTGRES_HOST')
|
||||
# username = os.environ.get('POSTGRES_USER')
|
||||
# password = os.environ.get('POSTGRES_PASSWORD')
|
||||
# database_name = os.environ.get('POSTGRES_DB')
|
||||
# else:
|
||||
# host = os.environ.get('POSTGRES_HOST')
|
||||
# username = os.environ.get('POSTGRES_USER')
|
||||
# password = os.environ.get('POSTGRES_PASSWORD')
|
||||
# database_name = os.environ.get('POSTGRES_DB')
|
||||
#
|
||||
# # host = config.postgres_host
|
||||
# # username = config.postgres_user
|
||||
# # password = config.postgres_password
|
||||
# # database_name = config.postgres_db
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# # Use the asyncpg driver for async operation
|
||||
# SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue