add sqlalchemy engine
This commit is contained in:
parent
866270bb5a
commit
77e8c1b1d5
5 changed files with 165 additions and 9 deletions
|
|
@ -11,9 +11,11 @@ class RelationalConfig(BaseSettings):
|
||||||
db_port: str = "5432"
|
db_port: str = "5432"
|
||||||
db_user: str = "cognee"
|
db_user: str = "cognee"
|
||||||
db_password: str = "cognee"
|
db_password: str = "cognee"
|
||||||
database_engine: object = create_relational_engine(db_path, db_name)
|
db_provider: str = "duckdb"
|
||||||
|
database_engine: object = create_relational_engine(db_path, db_name, db_provider)
|
||||||
db_file_path: str = os.path.join(db_path, db_name)
|
db_file_path: str = os.path.join(db_path, db_name)
|
||||||
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||||
|
|
||||||
def create_engine(self):
|
def create_engine(self):
|
||||||
|
|
@ -29,6 +31,7 @@ class RelationalConfig(BaseSettings):
|
||||||
"db_user": self.db_user,
|
"db_user": self.db_user,
|
||||||
"db_password": self.db_password,
|
"db_password": self.db_password,
|
||||||
"db_engine": self.database_engine,
|
"db_engine": self.database_engine,
|
||||||
|
"db_provider": self.db_provider,
|
||||||
}
|
}
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,33 @@
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from enum import Enum
|
||||||
from cognee.infrastructure.databases.relational import DuckDBAdapter
|
|
||||||
|
|
||||||
def create_relational_engine(db_path: str, db_name: str):
|
from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
from cognee.infrastructure.databases.relational import DuckDBAdapter, get_relationaldb_config
|
||||||
|
|
||||||
|
|
||||||
|
class DBProvider(Enum):
|
||||||
|
DUCKDB = "duckdb"
|
||||||
|
POSTGRES = "postgres"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_relational_engine(db_path: str, db_name: str, db_provider:str):
|
||||||
LocalStorage.ensure_directory_exists(db_path)
|
LocalStorage.ensure_directory_exists(db_path)
|
||||||
|
|
||||||
return DuckDBAdapter(
|
llm_config = get_relationaldb_config()
|
||||||
db_name = db_name,
|
|
||||||
db_path = db_path,
|
provider = DBProvider(llm_config.llm_provider)
|
||||||
)
|
|
||||||
|
|
||||||
|
if provider == DBProvider.DUCKDB:
|
||||||
|
|
||||||
|
return DuckDBAdapter(
|
||||||
|
db_name = db_name,
|
||||||
|
db_path = db_path,
|
||||||
|
)
|
||||||
|
elif provider == DBProvider.POSTGRES:
|
||||||
|
return SQLAlchemyAdapter(
|
||||||
|
db_name = db_name,
|
||||||
|
db_path = db_path,
|
||||||
|
db_type = db_provider,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,125 @@
|
||||||
|
import os
|
||||||
|
from sqlalchemy import create_engine, MetaData, Table, Column, String, Boolean, TIMESTAMP, text
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
class SQLAlchemyAdapter():
|
||||||
|
def __init__(self, db_type: str, db_path: str, db_name: str):
|
||||||
|
self.db_location = os.path.abspath(os.path.join(db_path, db_name))
|
||||||
|
self.engine = create_engine(f"{db_type}:///{self.db_location}")
|
||||||
|
self.Session = sessionmaker(bind=self.engine)
|
||||||
|
|
||||||
|
def get_datasets(self):
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
result = connection.execute(text("SELECT DISTINCT schema_name FROM information_schema.tables;"))
|
||||||
|
tables = [row['schema_name'] for row in result]
|
||||||
|
return list(
|
||||||
|
filter(
|
||||||
|
lambda schema_name: not schema_name.endswith("staging") and schema_name != "cognee",
|
||||||
|
tables
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_files_metadata(self, dataset_name: str):
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
result = connection.execute(text(f"SELECT id, name, file_path, extension, mime_type FROM {dataset_name}.file_metadata;"))
|
||||||
|
return [dict(row) for row in result]
|
||||||
|
|
||||||
|
def create_table(self, schema_name: str, table_name: str, table_config: list[dict]):
|
||||||
|
fields_query_parts = [f"{item['name']} {item['type']}" for item in table_config]
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema_name};"))
|
||||||
|
connection.execute(text(f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ({', '.join(fields_query_parts)});"))
|
||||||
|
|
||||||
|
def delete_table(self, table_name: str):
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
|
||||||
|
|
||||||
|
def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
|
||||||
|
columns = ", ".join(data[0].keys())
|
||||||
|
values = ", ".join([f"({', '.join([f':{key}' for key in row.keys()])})" for row in data])
|
||||||
|
insert_query = text(f"INSERT INTO {schema_name}.{table_name} ({columns}) VALUES {values};")
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
connection.execute(insert_query, data)
|
||||||
|
|
||||||
|
def get_data(self, table_name: str, filters: dict = None):
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
query = f"SELECT * FROM {table_name}"
|
||||||
|
if filters:
|
||||||
|
filter_conditions = " AND ".join([
|
||||||
|
f"{key} IN ({', '.join([f':{key}{i}' for i in range(len(value))])})" if isinstance(value, list)
|
||||||
|
else f"{key} = :{key}" for key, value in filters.items()
|
||||||
|
])
|
||||||
|
query += f" WHERE {filter_conditions};"
|
||||||
|
query = text(query)
|
||||||
|
results = connection.execute(query, filters)
|
||||||
|
else:
|
||||||
|
query += ";"
|
||||||
|
query = text(query)
|
||||||
|
results = connection.execute(query)
|
||||||
|
return {result["data_id"]: result["status"] for result in results}
|
||||||
|
|
||||||
|
def execute_query(self, query):
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
result = connection.execute(text(query))
|
||||||
|
return [dict(row) for row in result]
|
||||||
|
|
||||||
|
def load_cognify_data(self, data):
|
||||||
|
metadata = MetaData()
|
||||||
|
cognify_table = Table(
|
||||||
|
'cognify', metadata,
|
||||||
|
Column('document_id', String),
|
||||||
|
Column('layer_id', String),
|
||||||
|
Column('created_at', TIMESTAMP, server_default=text('CURRENT_TIMESTAMP')),
|
||||||
|
Column('updated_at', TIMESTAMP, nullable=True, default=None),
|
||||||
|
Column('processed', Boolean, default=False),
|
||||||
|
Column('document_id_target', String, nullable=True)
|
||||||
|
)
|
||||||
|
metadata.create_all(self.engine)
|
||||||
|
insert_query = cognify_table.insert().values(document_id=text(':document_id'), layer_id=text(':layer_id'))
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
connection.execute(insert_query, data)
|
||||||
|
|
||||||
|
def fetch_cognify_data(self, excluded_document_id: str):
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
connection.execute(text("""
|
||||||
|
CREATE TABLE IF NOT EXISTS cognify (
|
||||||
|
document_id STRING,
|
||||||
|
layer_id STRING,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT NULL,
|
||||||
|
processed BOOLEAN DEFAULT FALSE,
|
||||||
|
document_id_target STRING NULL
|
||||||
|
);
|
||||||
|
"""))
|
||||||
|
query = text(f"""
|
||||||
|
SELECT document_id, layer_id, created_at, updated_at, processed
|
||||||
|
FROM cognify
|
||||||
|
WHERE document_id != :excluded_document_id AND processed = FALSE;
|
||||||
|
""")
|
||||||
|
records = connection.execute(query, {'excluded_document_id': excluded_document_id}).fetchall()
|
||||||
|
if records:
|
||||||
|
document_ids = tuple(record['document_id'] for record in records)
|
||||||
|
update_query = text(f"UPDATE cognify SET processed = TRUE WHERE document_id IN :document_ids;")
|
||||||
|
connection.execute(update_query, {'document_ids': document_ids})
|
||||||
|
return [dict(record) for record in records]
|
||||||
|
|
||||||
|
def delete_cognify_data(self):
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
connection.execute(text("""
|
||||||
|
CREATE TABLE IF NOT EXISTS cognify (
|
||||||
|
document_id STRING,
|
||||||
|
layer_id STRING,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT NULL,
|
||||||
|
processed BOOLEAN DEFAULT FALSE,
|
||||||
|
document_id_target STRING NULL
|
||||||
|
);
|
||||||
|
"""))
|
||||||
|
connection.execute(text("DELETE FROM cognify;"))
|
||||||
|
connection.execute(text("DROP TABLE cognify;"))
|
||||||
|
|
||||||
|
def delete_database(self):
|
||||||
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
LocalStorage.remove(self.db_location)
|
||||||
|
if LocalStorage.file_exists(self.db_location + ".wal"):
|
||||||
|
LocalStorage.remove(self.db_location + ".wal")
|
||||||
|
|
@ -5,6 +5,9 @@ from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.relational import get_relationaldb_config
|
||||||
|
from cognee.infrastructure.databases.relational.create_relational_engine import create_relational_engine
|
||||||
|
|
||||||
DATABASE_URL = "sqlite+aiosqlite:///./test.db"
|
DATABASE_URL = "sqlite+aiosqlite:///./test.db"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,8 +20,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
llm_config = get_relationaldb_config()
|
||||||
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
|
||||||
|
engine = create_relational_engine(llm_config.db_path, llm_config.db_name, llm_config.db_provider)
|
||||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue