From de1ba5cd7c55b8e43aeaadd76aa1589ef0a5af2c Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 1 Nov 2024 12:55:20 +0100 Subject: [PATCH] feat: Add cascade deletion for datasets and data Added cascade deletion so when a dataset or data is deleted the connection in the dataset_data table is also deleted Feature #COG-455 --- .../sqlalchemy/SqlAlchemyAdapter.py | 20 +++++++++++++++---- cognee/modules/data/models/Data.py | 2 ++ cognee/modules/data/models/Dataset.py | 2 ++ cognee/modules/data/models/DatasetData.py | 4 ++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index edde07565..febfe1931 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -89,10 +89,22 @@ class SQLAlchemyAdapter(): """ Delete data in given table based on id. Table must have an id Column. """ - async with self.get_async_session() as session: - TableModel = await self.get_table(table_name, schema_name) - await session.execute(TableModel.delete().where(TableModel.c.id == data_id)) - await session.commit() + if self.engine.dialect.name == "sqlite": + async with self.get_async_session() as session: + TableModel = await self.get_table(table_name, schema_name) + + # Foreign key constraints are disabled by default in SQLite (for backwards compatibility), + # so must be enabled for each database connection/session separately. + await session.execute(text("PRAGMA foreign_keys = ON;")) + + await session.execute(TableModel.delete().where(TableModel.c.id == data_id)) + await session.commit() + else: + async with self.get_async_session() as session: + TableModel = await self.get_table(table_name, schema_name) + await session.execute(TableModel.delete().where(TableModel.c.id == data_id)) + await session.commit() + async def get_table(self, table_name: str, schema_name: Optional[str] = "public") -> Table: """ diff --git a/cognee/modules/data/models/Data.py b/cognee/modules/data/models/Data.py index feb9e3bff..064521539 100644 --- a/cognee/modules/data/models/Data.py +++ b/cognee/modules/data/models/Data.py @@ -20,9 +20,11 @@ class Data(Base): updated_at = Column(DateTime(timezone = True), onupdate = lambda: datetime.now(timezone.utc)) datasets: Mapped[List["Dataset"]] = relationship( + "Dataset", secondary = DatasetData.__tablename__, back_populates = "data", lazy = "noload", + cascade="all, delete" ) def to_json(self) -> dict: diff --git a/cognee/modules/data/models/Dataset.py b/cognee/modules/data/models/Dataset.py index 7e35ce982..5cf5d2351 100644 --- a/cognee/modules/data/models/Dataset.py +++ b/cognee/modules/data/models/Dataset.py @@ -19,9 +19,11 @@ class Dataset(Base): owner_id = Column(UUID, index = True) data: Mapped[List["Data"]] = relationship( + "Data", secondary = DatasetData.__tablename__, back_populates = "datasets", lazy = "noload", + cascade="all, delete" ) def to_json(self) -> dict: diff --git a/cognee/modules/data/models/DatasetData.py b/cognee/modules/data/models/DatasetData.py index b156d8d37..ed9d3c64c 100644 --- a/cognee/modules/data/models/DatasetData.py +++ b/cognee/modules/data/models/DatasetData.py @@ -7,5 +7,5 @@ class DatasetData(Base): created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc)) - dataset_id = Column(UUID, ForeignKey("datasets.id"), primary_key = True) - data_id = Column(UUID, ForeignKey("data.id"), primary_key = True) + dataset_id = Column(UUID, ForeignKey("datasets.id", ondelete="CASCADE"), primary_key = True) + data_id = Column(UUID, ForeignKey("data.id", ondelete="CASCADE"), primary_key = True)