feat: Add delete preview for --dataset-name and --all flags

This commit introduces the preview functionality for the  command. The preview displays a summary of what will be deleted before asking for user confirmation.

The feature is fully functional for the following flags:
-  / : Correctly counts the number of data entries within the specified dataset.
- : Correctly counts the total number of datasets, data entries, and users in the system.

The logic for the  flag is a work in progress. The current implementation uses a placeholder and needs a method to query a user directly by their ID to be completed.
This commit is contained in:
shehab-badawy 2025-09-26 22:27:32 -04:00
parent de162cb491
commit 9c87a10848
2 changed files with 64 additions and 29 deletions

View file

@ -54,21 +54,15 @@ Be careful with deletion operations as they are irreversible.
) )
) )
if not preview_data or all(value == 0 for value in preview_data.values()): if not preview_data:
fmt.success("No data found to delete.") fmt.success("No data found to delete.")
return return
fmt.echo("You are about to delete:") fmt.echo("You are about to delete:")
if "datasets" in preview_data and preview_data["datasets"] > 0: fmt.echo(
fmt.echo(f"- {preview_data['datasets']} datasets") f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
if "data_entries" in preview_data and preview_data["data_entries"] > 0: )
fmt.echo(f"- {preview_data['data_entries']} data entries")
if "users" in preview_data and preview_data["users"] > 0:
fmt.echo(
f"- {preview_data['users']} {'users' if preview_data['users'] > 1 else 'user'}"
)
fmt.echo("-" * 20) fmt.echo("-" * 20)
fmt.warning("This operation is irreversible!") fmt.warning("This operation is irreversible!")
if not fmt.confirm("Proceed?"): if not fmt.confirm("Proceed?"):
fmt.echo("Deletion cancelled.") fmt.echo("Deletion cancelled.")

View file

@ -1,16 +1,30 @@
from uuid import UUID
from cognee.cli.exceptions import CliCommandException
from cognee.infrastructure.databases.exceptions.exceptions import EntityNotFoundError
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.sql import func from sqlalchemy.sql import func
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset, Data, DatasetData from cognee.modules.data.models import Dataset, Data, DatasetData
from cognee.modules.users.models import User from cognee.modules.users.models import User, DatasetDatabase
from cognee.modules.users.methods import get_user, get_default_user
from dataclasses import dataclass
import cognee.cli.echo as fmt
@dataclass
class DeletionCountsPreview:
datasets: int = 0
data_entries: int = 0
users: int = 0
async def get_deletion_counts( async def get_deletion_counts(
dataset_name: str = None, user_id: str = None, all_data: bool = False dataset_name: str = None, user_id: str = None, all_data: bool = False
) -> dict: ) -> DeletionCountsPreview:
""" """
Calculates the number of items that will be deleted based on the provided arguments. Calculates the number of items that will be deleted based on the provided arguments.
""" """
counts = DeletionCountsPreview()
relational_engine = get_relational_engine() relational_engine = get_relational_engine()
async with relational_engine.get_async_session() as session: async with relational_engine.get_async_session() as session:
if dataset_name: if dataset_name:
@ -21,7 +35,10 @@ async def get_deletion_counts(
dataset = dataset_result.scalar_one_or_none() dataset = dataset_result.scalar_one_or_none()
if dataset is None: if dataset is None:
return {"datasets": 0, "data_entries": 0} fmt.error(f"No dataset with this name: {dataset_name}")
raise CliCommandException(
f"No Dataset exists with the name {dataset_name}", error_code=1
)
# Count data entries linked to this dataset # Count data entries linked to this dataset
count_query = ( count_query = (
@ -30,28 +47,52 @@ async def get_deletion_counts(
.where(DatasetData.dataset_id == dataset.id) .where(DatasetData.dataset_id == dataset.id)
) )
data_entry_count = (await session.execute(count_query)).scalar_one() data_entry_count = (await session.execute(count_query)).scalar_one()
counts.users = 1
counts.datasets = 1
counts.entries = data_entry_count
return counts
return {"datasets": 1, "data_entries": data_entry_count} elif all_data:
# Simplified logic: Get total counts directly from the tables.
if all_data: counts.datasets = (
dataset_count = (
await session.execute(select(func.count()).select_from(Dataset)) await session.execute(select(func.count()).select_from(Dataset))
).scalar_one() ).scalar_one()
data_entry_count = ( counts.entries = (
await session.execute(select(func.count()).select_from(Data)) await session.execute(select(func.count()).select_from(Data))
).scalar_one() ).scalar_one()
user_count = ( counts.users = (
await session.execute(select(func.count()).select_from(User)) await session.execute(select(func.count()).select_from(User))
).scalar_one() ).scalar_one()
return { return counts
"datasets": dataset_count,
"data_entries": data_entry_count,
"users": user_count,
}
# Placeholder for user_id logic # Placeholder for user_id logic
if user_id: elif user_id:
# TODO: Implement counting logic for a specific user user = None
return {"datasets": 0, "data_entries": 0, "users": 1} try:
user_uuid = UUID(user_id)
return {} user = await get_user(user_uuid)
except (ValueError, EntityNotFoundError):
# Handles cases where user_id is not a valid UUID or user is not found
fmt.error(f"No user exists with ID {user_id}")
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
user = await get_user(user_uuid)
if user:
counts.users = 1
# Find all datasets owned by this user
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
user_datasets = (await session.execute(datasets_query)).scalars().all()
dataset_count = len(user_datasets)
counts.datasets = dataset_count
if dataset_count > 0:
dataset_ids = [d.id for d in user_datasets]
# Count all data entries across all of the user's datasets
data_count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id.in_(dataset_ids))
)
data_entry_count = (await session.execute(data_count_query)).scalar_one()
counts.entries = data_entry_count
else:
counts.entries = 0
return counts