cognee/cognee/modules/data/methods/get_deletion_counts.py

92 lines
3.7 KiB
Python

from uuid import UUID
from cognee.cli.exceptions import CliCommandException
from cognee.infrastructure.databases.exceptions.exceptions import EntityNotFoundError
from sqlalchemy import select
from sqlalchemy.sql import func
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset, Data, DatasetData
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_user
from dataclasses import dataclass
@dataclass
class DeletionCountsPreview:
datasets: int = 0
data_entries: int = 0
users: int = 0
async def get_deletion_counts(
dataset_name: str = None, user_id: str = None, all_data: bool = False
) -> DeletionCountsPreview:
"""
Calculates the number of items that will be deleted based on the provided arguments.
"""
counts = DeletionCountsPreview()
relational_engine = get_relational_engine()
async with relational_engine.get_async_session() as session:
if dataset_name:
# Find the dataset by name
dataset_result = await session.execute(
select(Dataset).where(Dataset.name == dataset_name)
)
dataset = dataset_result.scalar_one_or_none()
if dataset is None:
raise CliCommandException(
f"No Dataset exists with the name {dataset_name}", error_code=1
)
# Count data entries linked to this dataset
count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id == dataset.id)
)
data_entry_count = (await session.execute(count_query)).scalar_one()
counts.users = 1
counts.datasets = 1
counts.entries = data_entry_count
return counts
elif all_data:
# Simplified logic: Get total counts directly from the tables.
counts.datasets = (
await session.execute(select(func.count()).select_from(Dataset))
).scalar_one()
counts.entries = (
await session.execute(select(func.count()).select_from(Data))
).scalar_one()
counts.users = (
await session.execute(select(func.count()).select_from(User))
).scalar_one()
return counts
# Placeholder for user_id logic
elif user_id:
user = None
try:
user_uuid = UUID(user_id)
user = await get_user(user_uuid)
except (ValueError, EntityNotFoundError):
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
counts.users = 1
# Find all datasets owned by this user
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
user_datasets = (await session.execute(datasets_query)).scalars().all()
dataset_count = len(user_datasets)
counts.datasets = dataset_count
if dataset_count > 0:
dataset_ids = [d.id for d in user_datasets]
# Count all data entries across all of the user's datasets
data_count_query = (
select(func.count())
.select_from(DatasetData)
.where(DatasetData.dataset_id.in_(dataset_ids))
)
data_entry_count = (await session.execute(data_count_query)).scalar_one()
counts.entries = data_entry_count
else:
counts.entries = 0
return counts