fix: update failing tests and refactor delete_preview implementation
This commit is contained in:
parent
d5dd6c2fc2
commit
a92f4bdf3f
4 changed files with 60 additions and 43 deletions
|
|
@ -46,13 +46,17 @@ Be careful with deletion operations as they are irreversible.
|
||||||
if not args.force:
|
if not args.force:
|
||||||
# --- START PREVIEW LOGIC ---
|
# --- START PREVIEW LOGIC ---
|
||||||
fmt.echo("Gathering data for preview...")
|
fmt.echo("Gathering data for preview...")
|
||||||
preview_data = asyncio.run(
|
try:
|
||||||
get_deletion_counts(
|
preview_data = asyncio.run(
|
||||||
dataset_name=args.dataset_name,
|
get_deletion_counts(
|
||||||
user_id=args.user_id,
|
dataset_name=args.dataset_name,
|
||||||
all_data=args.all,
|
user_id=args.user_id,
|
||||||
|
all_data=args.all,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
except CliCommandException as e:
|
||||||
|
fmt.error(f"Error occured when fetching preview data: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
if not preview_data:
|
if not preview_data:
|
||||||
fmt.success("No data found to delete.")
|
fmt.success("No data found to delete.")
|
||||||
|
|
@ -63,22 +67,27 @@ Be careful with deletion operations as they are irreversible.
|
||||||
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
|
f"Datasets: {preview_data.datasets}\nEntries: {preview_data.entries}\nUsers: {preview_data.users}"
|
||||||
)
|
)
|
||||||
fmt.echo("-" * 20)
|
fmt.echo("-" * 20)
|
||||||
fmt.warning("This operation is irreversible!")
|
|
||||||
if not fmt.confirm("Proceed?"):
|
|
||||||
fmt.echo("Deletion cancelled.")
|
|
||||||
return
|
|
||||||
# --- END PREVIEW LOGIC ---
|
# --- END PREVIEW LOGIC ---
|
||||||
|
|
||||||
# Build operation message for success/failure logging
|
# Build operation message for success/failure logging
|
||||||
if args.all:
|
if args.all:
|
||||||
|
confirm_msg = "Delete ALL data from cognee?"
|
||||||
operation = "all data"
|
operation = "all data"
|
||||||
elif args.dataset_name:
|
elif args.dataset_name:
|
||||||
|
confirm_msg = f"Delete dataset '{args.dataset_name}'?"
|
||||||
operation = f"dataset '{args.dataset_name}'"
|
operation = f"dataset '{args.dataset_name}'"
|
||||||
elif args.user_id:
|
elif args.user_id:
|
||||||
|
confirm_msg = f"Delete all data for user '{args.user_id}'?"
|
||||||
operation = f"data for user '{args.user_id}'"
|
operation = f"data for user '{args.user_id}'"
|
||||||
else:
|
else:
|
||||||
operation = "data"
|
operation = "data"
|
||||||
|
|
||||||
|
if not args.force:
|
||||||
|
fmt.warning("This operation is irreversible!")
|
||||||
|
if not fmt.confirm(confirm_msg):
|
||||||
|
fmt.echo("Deletion cancelled.")
|
||||||
|
return
|
||||||
|
|
||||||
fmt.echo(f"Deleting {operation}...")
|
fmt.echo(f"Deleting {operation}...")
|
||||||
|
|
||||||
# Run the async delete function
|
# Run the async delete function
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,9 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.modules.data.models import Dataset, Data, DatasetData
|
from cognee.modules.data.models import Dataset, Data, DatasetData
|
||||||
from cognee.modules.users.models import User, DatasetDatabase
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_user, get_default_user
|
from cognee.modules.users.methods import get_user
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import cognee.cli.echo as fmt
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -35,7 +34,6 @@ async def get_deletion_counts(
|
||||||
dataset = dataset_result.scalar_one_or_none()
|
dataset = dataset_result.scalar_one_or_none()
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
fmt.error(f"No dataset with this name: {dataset_name}")
|
|
||||||
raise CliCommandException(
|
raise CliCommandException(
|
||||||
f"No Dataset exists with the name {dataset_name}", error_code=1
|
f"No Dataset exists with the name {dataset_name}", error_code=1
|
||||||
)
|
)
|
||||||
|
|
@ -72,27 +70,23 @@ async def get_deletion_counts(
|
||||||
user_uuid = UUID(user_id)
|
user_uuid = UUID(user_id)
|
||||||
user = await get_user(user_uuid)
|
user = await get_user(user_uuid)
|
||||||
except (ValueError, EntityNotFoundError):
|
except (ValueError, EntityNotFoundError):
|
||||||
# Handles cases where user_id is not a valid UUID or user is not found
|
|
||||||
fmt.error(f"No user exists with ID {user_id}")
|
|
||||||
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
|
raise CliCommandException(f"No User exists with ID {user_id}", error_code=1)
|
||||||
user = await get_user(user_uuid)
|
counts.users = 1
|
||||||
if user:
|
# Find all datasets owned by this user
|
||||||
counts.users = 1
|
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
|
||||||
# Find all datasets owned by this user
|
user_datasets = (await session.execute(datasets_query)).scalars().all()
|
||||||
datasets_query = select(Dataset).where(Dataset.owner_id == user.id)
|
dataset_count = len(user_datasets)
|
||||||
user_datasets = (await session.execute(datasets_query)).scalars().all()
|
counts.datasets = dataset_count
|
||||||
dataset_count = len(user_datasets)
|
if dataset_count > 0:
|
||||||
counts.datasets = dataset_count
|
dataset_ids = [d.id for d in user_datasets]
|
||||||
if dataset_count > 0:
|
# Count all data entries across all of the user's datasets
|
||||||
dataset_ids = [d.id for d in user_datasets]
|
data_count_query = (
|
||||||
# Count all data entries across all of the user's datasets
|
select(func.count())
|
||||||
data_count_query = (
|
.select_from(DatasetData)
|
||||||
select(func.count())
|
.where(DatasetData.dataset_id.in_(dataset_ids))
|
||||||
.select_from(DatasetData)
|
)
|
||||||
.where(DatasetData.dataset_id.in_(dataset_ids))
|
data_entry_count = (await session.execute(data_count_query)).scalar_one()
|
||||||
)
|
counts.entries = data_entry_count
|
||||||
data_entry_count = (await session.execute(data_count_query)).scalar_one()
|
else:
|
||||||
counts.entries = data_entry_count
|
counts.entries = 0
|
||||||
else:
|
return counts
|
||||||
counts.entries = 0
|
|
||||||
return counts
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,8 @@ from cognee.cli.commands.search_command import SearchCommand
|
||||||
from cognee.cli.commands.cognify_command import CognifyCommand
|
from cognee.cli.commands.cognify_command import CognifyCommand
|
||||||
from cognee.cli.commands.delete_command import DeleteCommand
|
from cognee.cli.commands.delete_command import DeleteCommand
|
||||||
from cognee.cli.commands.config_command import ConfigCommand
|
from cognee.cli.commands.config_command import ConfigCommand
|
||||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
from cognee.cli.exceptions import CliCommandException
|
||||||
|
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
|
||||||
|
|
||||||
|
|
||||||
# Mock asyncio.run to properly handle coroutines
|
# Mock asyncio.run to properly handle coroutines
|
||||||
|
|
@ -282,13 +283,18 @@ class TestDeleteCommand:
|
||||||
assert "all" in actions
|
assert "all" in actions
|
||||||
assert "force" in actions
|
assert "force" in actions
|
||||||
|
|
||||||
|
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||||
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
|
@patch("cognee.cli.commands.delete_command.asyncio.run", side_effect=_mock_run)
|
||||||
def test_execute_delete_dataset_with_confirmation(self, mock_asyncio_run, mock_confirm):
|
def test_execute_delete_dataset_with_confirmation(
|
||||||
|
self, mock_asyncio_run, mock_confirm, mock_get_deletion_counts
|
||||||
|
):
|
||||||
"""Test execute delete dataset with user confirmation"""
|
"""Test execute delete dataset with user confirmation"""
|
||||||
# Mock the cognee module
|
# Mock the cognee module
|
||||||
mock_cognee = MagicMock()
|
mock_cognee = MagicMock()
|
||||||
mock_cognee.delete = AsyncMock()
|
mock_cognee.delete = AsyncMock()
|
||||||
|
mock_get_deletion_counts = AsyncMock()
|
||||||
|
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||||
|
|
||||||
with patch.dict(sys.modules, {"cognee": mock_cognee}):
|
with patch.dict(sys.modules, {"cognee": mock_cognee}):
|
||||||
command = DeleteCommand()
|
command = DeleteCommand()
|
||||||
|
|
@ -301,13 +307,16 @@ class TestDeleteCommand:
|
||||||
command.execute(args)
|
command.execute(args)
|
||||||
|
|
||||||
mock_confirm.assert_called_once_with(f"Delete dataset '{args.dataset_name}'?")
|
mock_confirm.assert_called_once_with(f"Delete dataset '{args.dataset_name}'?")
|
||||||
mock_asyncio_run.assert_called_once()
|
assert mock_asyncio_run.call_count == 2
|
||||||
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
||||||
mock_cognee.delete.assert_awaited_once_with(dataset_name="test_dataset", user_id=None)
|
mock_cognee.delete.assert_awaited_once_with(dataset_name="test_dataset", user_id=None)
|
||||||
|
|
||||||
|
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||||
def test_execute_delete_cancelled(self, mock_confirm):
|
def test_execute_delete_cancelled(self, mock_confirm, mock_get_deletion_counts):
|
||||||
"""Test execute when user cancels deletion"""
|
"""Test execute when user cancels deletion"""
|
||||||
|
mock_get_deletion_counts = AsyncMock()
|
||||||
|
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||||
command = DeleteCommand()
|
command = DeleteCommand()
|
||||||
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from cognee.cli.commands.cognify_command import CognifyCommand
|
||||||
from cognee.cli.commands.delete_command import DeleteCommand
|
from cognee.cli.commands.delete_command import DeleteCommand
|
||||||
from cognee.cli.commands.config_command import ConfigCommand
|
from cognee.cli.commands.config_command import ConfigCommand
|
||||||
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
from cognee.cli.exceptions import CliCommandException, CliCommandInnerException
|
||||||
|
from cognee.modules.data.methods.get_deletion_counts import DeletionCountsPreview
|
||||||
|
|
||||||
|
|
||||||
# Mock asyncio.run to properly handle coroutines
|
# Mock asyncio.run to properly handle coroutines
|
||||||
|
|
@ -396,13 +397,17 @@ class TestDeleteCommandEdgeCases:
|
||||||
command.execute(args)
|
command.execute(args)
|
||||||
|
|
||||||
mock_confirm.assert_called_once_with("Delete ALL data from cognee?")
|
mock_confirm.assert_called_once_with("Delete ALL data from cognee?")
|
||||||
mock_asyncio_run.assert_called_once()
|
assert mock_asyncio_run.call_count == 2
|
||||||
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
assert asyncio.iscoroutine(mock_asyncio_run.call_args[0][0])
|
||||||
mock_cognee.delete.assert_awaited_once_with(dataset_name=None, user_id="test_user")
|
mock_cognee.delete.assert_awaited_once_with(dataset_name=None, user_id="test_user")
|
||||||
|
|
||||||
|
@patch("cognee.cli.commands.delete_command.get_deletion_counts")
|
||||||
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
@patch("cognee.cli.commands.delete_command.fmt.confirm")
|
||||||
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm):
|
def test_delete_confirmation_keyboard_interrupt(self, mock_confirm, mock_get_deletion_counts):
|
||||||
"""Test delete command when user interrupts confirmation"""
|
"""Test delete command when user interrupts confirmation"""
|
||||||
|
mock_get_deletion_counts = AsyncMock()
|
||||||
|
mock_get_deletion_counts.return_value = DeletionCountsPreview()
|
||||||
|
|
||||||
command = DeleteCommand()
|
command = DeleteCommand()
|
||||||
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
args = argparse.Namespace(dataset_name="test_dataset", user_id=None, all=False, force=False)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue