Sqlite works, made fixes in config so it becomes a basis, added a few mods on top
This commit is contained in:
parent
653fe049b4
commit
b0b9c31102
3 changed files with 79 additions and 124 deletions
|
|
@ -1,8 +1,8 @@
|
||||||
|
"""Configuration for cognee - cognitive architecture framework."""
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import configparser
|
import configparser
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -16,8 +16,8 @@ load_dotenv(dotenv_path=dotenv_path)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
# Paths and Directories
|
""" Configuration for cognee - cognitive architecture framework. """
|
||||||
memgpt_dir: str = field(
|
cognee_dir: str = field(
|
||||||
default_factory=lambda: os.getenv("COG_ARCH_DIR", "cognitive_achitecture")
|
default_factory=lambda: os.getenv("COG_ARCH_DIR", "cognitive_achitecture")
|
||||||
)
|
)
|
||||||
config_path: str = field(
|
config_path: str = field(
|
||||||
|
|
@ -26,7 +26,15 @@ class Config:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
vectordb: str = "lancedb"
|
db_path = Path(__file__).resolve().parent / "database/data"
|
||||||
|
|
||||||
|
vectordb: str = "weaviate"
|
||||||
|
db_type: str = os.getenv("DB_TYPE", "postgres")
|
||||||
|
db_name: str = os.getenv("DB_NAME", "cognee.db")
|
||||||
|
db_host: str = os.getenv("DB_HOST", "localhost")
|
||||||
|
db_port: str = os.getenv("DB_PORT", "5432")
|
||||||
|
db_user: str = os.getenv("DB_USER", "cognee")
|
||||||
|
db_password: str = os.getenv("DB_PASSWORD", "cognee")
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
model: str = "gpt-4-1106-preview"
|
model: str = "gpt-4-1106-preview"
|
||||||
|
|
|
||||||
|
|
@ -1,104 +1,91 @@
|
||||||
|
"""This module provides functionalities for creating and managing databases."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import psycopg2
|
from contextlib import asynccontextmanager
|
||||||
from dotenv import load_dotenv
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
from relationaldb.database import Base
|
from relationaldb.models import memory, metadatas, operation, sessions, user, docs
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
|
from dotenv import load_dotenv
|
||||||
from relationaldb.models import memory
|
from relationaldb.database import (
|
||||||
from relationaldb.models import metadatas
|
Base,DatabaseConfig)
|
||||||
from relationaldb.models import operation
|
from cognitive_architecture.config import Config
|
||||||
from relationaldb.models import sessions
|
config = Config()
|
||||||
from relationaldb.models import user
|
config.load()
|
||||||
from relationaldb.models import docs
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from relationaldb.database import (
|
|
||||||
Base,
|
|
||||||
) # Assuming all models are imported within this module
|
|
||||||
from relationaldb.database import (
|
|
||||||
DatabaseConfig,
|
|
||||||
) # Assuming DatabaseConfig is defined as before
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
|
"""Manages database creation, deletion, and table initialization."""
|
||||||
def __init__(self, config: DatabaseConfig):
|
def __init__(self, config: DatabaseConfig):
|
||||||
|
"""Initialize the DatabaseManager with a given configuration."""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.engine = create_engine(config.get_sqlalchemy_database_url())
|
self.engine = create_async_engine(config.get_sqlalchemy_database_url(), echo=True)
|
||||||
self.db_type = config.db_type
|
self.db_type = config.db_type
|
||||||
|
|
||||||
@contextmanager
|
@asynccontextmanager
|
||||||
def get_connection(self):
|
async def get_connection(self):
|
||||||
|
"""Initialize the DatabaseManager with a given configuration."""
|
||||||
if self.db_type in ["sqlite", "duckdb"]:
|
if self.db_type in ["sqlite", "duckdb"]:
|
||||||
# For SQLite and DuckDB, the engine itself manages connections
|
# For SQLite and DuckDB, the engine itself manages connections
|
||||||
yield self.engine
|
yield self.engine
|
||||||
else:
|
else:
|
||||||
connection = self.engine.connect()
|
async with self.engine.connect() as connection:
|
||||||
try:
|
|
||||||
yield connection
|
yield connection
|
||||||
finally:
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
def database_exists(self, db_name):
|
async def database_exists(self, db_name):
|
||||||
|
"""Check if a database exists."""
|
||||||
if self.db_type in ["sqlite", "duckdb"]:
|
if self.db_type in ["sqlite", "duckdb"]:
|
||||||
# For SQLite and DuckDB, check if the database file exists
|
# For SQLite and DuckDB, check if the database file exists
|
||||||
return os.path.exists(db_name)
|
return os.path.exists(db_name)
|
||||||
else:
|
else:
|
||||||
query = text(f"SELECT 1 FROM pg_database WHERE datname='{db_name}'")
|
query = text(f"SELECT 1 FROM pg_database WHERE datname='{db_name}'")
|
||||||
with self.get_connection() as connection:
|
async with self.get_connection() as connection:
|
||||||
result = connection.execute(query).fetchone()
|
result = await connection.execute(query)
|
||||||
return result is not None
|
return await result.fetchone() is not None
|
||||||
|
|
||||||
def create_database(self, db_name):
|
async def create_database(self, db_name):
|
||||||
|
"""Create a new database."""
|
||||||
if self.db_type not in ["sqlite", "duckdb"]:
|
if self.db_type not in ["sqlite", "duckdb"]:
|
||||||
# For databases like PostgreSQL, create the database explicitly
|
# For databases like PostgreSQL, create the database explicitly
|
||||||
with self.get_connection() as connection:
|
async with self.get_connection() as connection:
|
||||||
connection.execution_options(isolation_level="AUTOCOMMIT")
|
await connection.execute(text(f"CREATE DATABASE {db_name}"))
|
||||||
connection.execute(f"CREATE DATABASE {db_name}")
|
|
||||||
|
|
||||||
def drop_database(self, db_name):
|
async def drop_database(self, db_name):
|
||||||
|
"""Drop an existing database."""
|
||||||
if self.db_type in ["sqlite", "duckdb"]:
|
if self.db_type in ["sqlite", "duckdb"]:
|
||||||
# For SQLite and DuckDB, simply remove the database file
|
# For SQLite and DuckDB, simply remove the database file
|
||||||
os.remove(db_name)
|
os.remove(db_name)
|
||||||
else:
|
else:
|
||||||
with self.get_connection() as connection:
|
async with self.get_connection() as connection:
|
||||||
connection.execution_options(isolation_level="AUTOCOMMIT")
|
await connection.execute(text(f"DROP DATABASE IF EXISTS {db_name}"))
|
||||||
connection.execute(f"DROP DATABASE IF EXISTS {db_name}")
|
|
||||||
|
|
||||||
def create_tables(self):
|
|
||||||
Base.metadata.create_all(bind=self.engine)
|
|
||||||
|
|
||||||
|
async def create_tables(self):
|
||||||
|
"""Create tables based on the SQLAlchemy Base metadata."""
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Example usage with SQLite
|
async def main():
|
||||||
config = DatabaseConfig(db_type="sqlite", db_name="mydatabase.db")
|
"""Runs as a part of startup docker scripts to create the database and tables."""
|
||||||
|
|
||||||
# For DuckDB, you would set db_type to 'duckdb' and provide the database file name
|
dbconfig = DatabaseConfig(db_type=config.db_type, db_name=config.db_name)
|
||||||
# config = DatabaseConfig(db_type='duckdb', db_name='mydatabase.duckdb')
|
db_manager = DatabaseManager(config=dbconfig)
|
||||||
|
database_name = dbconfig.db_name
|
||||||
|
|
||||||
db_manager = DatabaseManager(config=config)
|
if not await db_manager.database_exists(database_name):
|
||||||
|
print(f"Database {database_name} does not exist. Creating...")
|
||||||
|
await db_manager.create_database(database_name)
|
||||||
|
print(f"Database {database_name} created successfully.")
|
||||||
|
|
||||||
database_name = config.db_name
|
await db_manager.create_tables()
|
||||||
|
|
||||||
if not db_manager.database_exists(database_name):
|
|
||||||
logger.info(f"Database {database_name} does not exist. Creating...")
|
|
||||||
db_manager.create_database(database_name)
|
|
||||||
logger.info(f"Database {database_name} created successfully.")
|
|
||||||
|
|
||||||
db_manager.create_tables()
|
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
#
|
#
|
||||||
# def create_admin_engine(username, password, host, database_name):
|
# def create_admin_engine(username, password, host, database_name):
|
||||||
# admin_url = f"postgresql://{username}:{password}@{host}:5432/{database_name}"
|
# admin_url = f"postgresql://{username}:{password}@{host}:5432/{database_name}"
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,11 @@
|
||||||
import json
|
"""Database configuration and connection."""
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from sqlalchemy.exc import OperationalError
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from cognitive_architecture.config import Config
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
@ -19,6 +16,7 @@ MAX_RETRIES = 3
|
||||||
RETRY_DELAY = 5
|
RETRY_DELAY = 5
|
||||||
|
|
||||||
class DatabaseConfig:
|
class DatabaseConfig:
|
||||||
|
"""Configuration for the database connection."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_type=None,
|
db_type=None,
|
||||||
|
|
@ -27,37 +25,26 @@ class DatabaseConfig:
|
||||||
user=None,
|
user=None,
|
||||||
password=None,
|
password=None,
|
||||||
port=None,
|
port=None,
|
||||||
config_file=None,
|
|
||||||
):
|
):
|
||||||
if config_file:
|
self.config = Config()
|
||||||
self.load_from_file(config_file)
|
self.config.load()
|
||||||
else:
|
self.base_path = Path(self.config.db_path)
|
||||||
# Load default values from environment variables or use provided values
|
# Load default values from environment variables or use provided values
|
||||||
self.db_type = db_type or os.getenv("DB_TYPE", "sqlite")
|
self.db_type = db_type or self.config.db_type
|
||||||
self.db_name = db_name or os.getenv("DB_NAME", "database.db")
|
self.db_name = db_name or self.config.db_name
|
||||||
self.host = host or os.getenv("DB_HOST", "localhost")
|
self.host = host or self.config.db_host
|
||||||
self.user = user or os.getenv("DB_USER", "user")
|
self.user = user or self.config.db_user
|
||||||
self.password = password or os.getenv("DB_PASSWORD", "password")
|
self.password = password or self.config.db_password
|
||||||
self.port = port or os.getenv("DB_PORT", "5432")
|
self.port = port or self.config.db_port
|
||||||
|
|
||||||
|
|
||||||
def load_from_file(self, file_path):
|
|
||||||
with open(file_path, "r") as file:
|
|
||||||
config = json.load(file)
|
|
||||||
self.db_type = config.get("db_type", "sqlite")
|
|
||||||
self.db_name = config.get("db_name", "database.db")
|
|
||||||
self.host = config.get("host", "localhost")
|
|
||||||
self.user = config.get("user", "user")
|
|
||||||
self.password = config.get("password", "password")
|
|
||||||
self.port = config.get("port", "5432")
|
|
||||||
|
|
||||||
def get_sqlalchemy_database_url(self):
|
def get_sqlalchemy_database_url(self):
|
||||||
|
"""Get the SQLAlchemy database URL based on the configuration."""
|
||||||
|
db_path = (self.base_path / self.db_name).absolute()
|
||||||
if self.db_type == "sqlite":
|
if self.db_type == "sqlite":
|
||||||
db_path = Path(self.db_name).absolute() # Ensure the path is absolute
|
|
||||||
return f"sqlite+aiosqlite:///{db_path}" # SQLite uses file path
|
return f"sqlite+aiosqlite:///{db_path}" # SQLite uses file path
|
||||||
elif self.db_type == "duckdb":
|
elif self.db_type == "duckdb":
|
||||||
db_path = Path(
|
|
||||||
self.db_name
|
|
||||||
).absolute() # Ensure the path is absolute for DuckDB as well
|
|
||||||
return f"duckdb+aiosqlite:///{db_path}"
|
return f"duckdb+aiosqlite:///{db_path}"
|
||||||
elif self.db_type == "postgresql":
|
elif self.db_type == "postgresql":
|
||||||
# Ensure optional parameters are handled gracefully
|
# Ensure optional parameters are handled gracefully
|
||||||
|
|
@ -95,6 +82,7 @@ Base = declarative_base()
|
||||||
# Use asynccontextmanager to define an async context manager
|
# Use asynccontextmanager to define an async context manager
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_db():
|
async def get_db():
|
||||||
|
"""Provide a database session to the context."""
|
||||||
db = AsyncSessionLocal()
|
db = AsyncSessionLocal()
|
||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
|
|
@ -102,31 +90,3 @@ async def get_db():
|
||||||
await db.close()
|
await db.close()
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# if os.environ.get('AWS_ENV') == 'prd' or os.environ.get('AWS_ENV') == 'dev':
|
|
||||||
# host = os.environ.get('POSTGRES_HOST')
|
|
||||||
# username = os.environ.get('POSTGRES_USER')
|
|
||||||
# password = os.environ.get('POSTGRES_PASSWORD')
|
|
||||||
# database_name = os.environ.get('POSTGRES_DB')
|
|
||||||
# elif os.environ.get('AWS_ENV') == 'local':
|
|
||||||
# host = os.environ.get('POSTGRES_HOST')
|
|
||||||
# username = os.environ.get('POSTGRES_USER')
|
|
||||||
# password = os.environ.get('POSTGRES_PASSWORD')
|
|
||||||
# database_name = os.environ.get('POSTGRES_DB')
|
|
||||||
# else:
|
|
||||||
# host = os.environ.get('POSTGRES_HOST')
|
|
||||||
# username = os.environ.get('POSTGRES_USER')
|
|
||||||
# password = os.environ.get('POSTGRES_PASSWORD')
|
|
||||||
# database_name = os.environ.get('POSTGRES_DB')
|
|
||||||
#
|
|
||||||
# # host = config.postgres_host
|
|
||||||
# # username = config.postgres_user
|
|
||||||
# # password = config.postgres_password
|
|
||||||
# # database_name = config.postgres_db
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# # Use the asyncpg driver for async operation
|
|
||||||
# SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}"
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue