From 17632a5becb4dbde857914417a15c2d2e56a4bf6 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 13 Nov 2025 23:40:45 +0100 Subject: [PATCH] fix: add delete support for single db graphs --- .github/workflows/e2e_tests.yml | 95 +++++++- cognee/api/v1/datasets/datasets.py | 20 +- .../is_backend_access_control_enabled.py | 6 +- cognee/modules/graph/methods/__init__.py | 8 +- .../methods/delete_data_nodes_and_edges.py | 128 +++++++---- .../methods/delete_dataset_nodes_and_edges.py | 128 +++++++---- .../graph/methods/get_data_related_edges.py | 22 ++ .../graph/methods/get_data_related_nodes.py | 17 ++ .../methods/get_dataset_related_edges.py | 25 +- .../methods/get_dataset_related_nodes.py | 25 +- cognee/tests/test_delete_dataset.py | 212 +++++++++++++++++ cognee/tests/test_delete_two_users_graph.py | 214 ++++++++++++++++++ 12 files changed, 805 insertions(+), 95 deletions(-) create mode 100644 cognee/tests/test_delete_dataset.py create mode 100644 cognee/tests/test_delete_two_users_graph.py diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 16da929e1..43df1dfd2 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -387,8 +387,8 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_delete_custom_graph.py - test-deletion-on-default-graph_with_legacy_data_1: - name: Delete default graph with legacy data test 1 + test-deletion-on-default-graph_with_legacy_data_1_default: + name: Delete default graph with legacy data test 1 in Kuzu case runs-on: ubuntu-22.04 steps: - name: Check out @@ -414,6 +414,39 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_1.py + test-deletion-on-default-graph_with_legacy_data_1_neo4j: + name: Delete default graph with legacy data test 1 in Neo4j case + runs-on: ubuntu-22.04 + steps: + - name: Check out + uses: actions/checkout@master + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Setup Neo4j with GDS + uses: ./.github/actions/setup_neo4j + id: neo4j + + - name: Run delete dataset test in Neo4j case + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + GRAPH_DATABASE_PROVIDER: "neo4j" + GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} + GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} + GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} + run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_1.py + test-deletion-on-default-graph_with_legacy_data_2: name: Delete default graph with legacy data test 2 runs-on: ubuntu-22.04 @@ -441,6 +474,64 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_2.py + test-delete-dataset-default: + name: Delete dataset in Kuzu graph case + runs-on: ubuntu-22.04 + steps: + - name: Check out + uses: actions/checkout@master + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run delete dataset test in Kuzu case + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_delete_dataset.py + + test-delete-dataset-neo4j: + name: Delete dataset in Neo4j graph case + runs-on: ubuntu-22.04 + steps: + - name: Check out + uses: actions/checkout@master + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Setup Neo4j with GDS + uses: ./.github/actions/setup_neo4j + id: neo4j + + - name: Run delete dataset test in Neo4j case + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + GRAPH_DATABASE_PROVIDER: "neo4j" + GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} + GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} + GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} + run: uv run python ./cognee/tests/test_delete_dataset.py + test-graph-edges: name: Test graph edge ingestion runs-on: ubuntu-22.04 diff --git a/cognee/api/v1/datasets/datasets.py b/cognee/api/v1/datasets/datasets.py index 0f2b5dcd5..b5cdb2a6c 100644 --- a/cognee/api/v1/datasets/datasets.py +++ b/cognee/api/v1/datasets/datasets.py @@ -4,7 +4,7 @@ from typing import Optional from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user from cognee.modules.users.exceptions import PermissionDeniedError -from cognee.modules.data.methods import has_dataset_data +from cognee.modules.data.methods import get_dataset_data, has_dataset_data from cognee.modules.data.methods import get_authorized_dataset, get_authorized_existing_datasets from cognee.modules.data.exceptions.exceptions import UnauthorizedDataAccessError from cognee.modules.graph.methods import ( @@ -41,12 +41,11 @@ class datasets: return await get_dataset_data(dataset.id) @staticmethod - async def has_data(dataset_id: str) -> bool: - from cognee.modules.data.methods import get_dataset + async def has_data(dataset_id: str, user: Optional[User] = None) -> bool: + if not user: + user = await get_default_user() - user = await get_default_user() - - dataset = await get_dataset(user.id, dataset_id) + dataset = await get_authorized_dataset(user.id, dataset_id) return await has_dataset_data(dataset.id) @@ -56,7 +55,7 @@ class datasets: @staticmethod async def delete_dataset(dataset_id: UUID, user: Optional[User] = None): - from cognee.modules.data.methods import delete_dataset + from cognee.modules.data.methods import delete_data, delete_dataset if not user: user = await get_default_user() @@ -68,6 +67,11 @@ class datasets: await delete_dataset_nodes_and_edges(dataset_id, user.id) + dataset_data = await get_dataset_data(dataset.id) + + for data in dataset_data: + await delete_data(data) + return await delete_dataset(dataset) @staticmethod @@ -108,7 +112,7 @@ class datasets: if not user: user = await get_default_user() - user_datasets = await get_authorized_existing_datasets([], "read", user) + user_datasets = await get_authorized_existing_datasets([], "delete", user) for dataset in user_datasets: await datasets.delete_dataset(dataset.id, user) diff --git a/cognee/infrastructure/environment/config/is_backend_access_control_enabled.py b/cognee/infrastructure/environment/config/is_backend_access_control_enabled.py index 5181e4bb5..4a21f2635 100644 --- a/cognee/infrastructure/environment/config/is_backend_access_control_enabled.py +++ b/cognee/infrastructure/environment/config/is_backend_access_control_enabled.py @@ -8,7 +8,7 @@ VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"] GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"] -def multi_user_support_possible(): +def is_multi_user_support_possible(): graph_db_config = get_graph_context_config() vector_db_config = get_vectordb_context_config() return ( @@ -22,10 +22,10 @@ def is_backend_access_control_enabled(): if backend_access_control is None: # If backend access control is not defined in environment variables, # enable it by default if graph and vector DBs can support it, otherwise disable it - return multi_user_support_possible() + return is_multi_user_support_possible() elif backend_access_control.lower() == "true": # If enabled, ensure that the current graph and vector DBs can support it - multi_user_support = multi_user_support_possible() + multi_user_support = is_multi_user_support_possible() if not multi_user_support: raise EnvironmentError( "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control." diff --git a/cognee/modules/graph/methods/__init__.py b/cognee/modules/graph/methods/__init__.py index 5cde7f79c..4e82e86ec 100644 --- a/cognee/modules/graph/methods/__init__.py +++ b/cognee/modules/graph/methods/__init__.py @@ -5,14 +5,14 @@ from .upsert_nodes import upsert_nodes from .has_data_related_nodes import has_data_related_nodes -from .get_data_related_nodes import get_data_related_nodes -from .get_data_related_edges import get_data_related_edges +from .get_data_related_nodes import get_data_related_nodes, get_global_data_related_nodes +from .get_data_related_edges import get_data_related_edges, get_global_data_related_edges from .delete_data_related_nodes import delete_data_related_nodes from .delete_data_related_edges import delete_data_related_edges from .delete_data_nodes_and_edges import delete_data_nodes_and_edges -from .get_dataset_related_nodes import get_dataset_related_nodes -from .get_dataset_related_edges import get_dataset_related_edges +from .get_dataset_related_nodes import get_dataset_related_nodes, get_global_dataset_related_nodes +from .get_dataset_related_edges import get_dataset_related_edges, get_global_dataset_related_edges from .delete_dataset_related_nodes import delete_dataset_related_nodes from .delete_dataset_related_edges import delete_dataset_related_edges from .delete_dataset_nodes_and_edges import delete_dataset_nodes_and_edges diff --git a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py index 1ce56a313..b2501e43b 100644 --- a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py +++ b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py @@ -3,6 +3,9 @@ from typing import Dict, List from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine +from cognee.infrastructure.environment.config.is_backend_access_control_enabled import ( + is_multi_user_support_possible, +) from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger from cognee.modules.graph.methods import ( @@ -10,52 +13,101 @@ from cognee.modules.graph.methods import ( delete_data_related_nodes, get_data_related_nodes, get_data_related_edges, + get_global_data_related_nodes, + get_global_data_related_edges, ) async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id: UUID) -> None: - affected_nodes = await get_data_related_nodes(dataset_id, data_id) + if is_multi_user_support_possible(): + affected_nodes = await get_data_related_nodes(dataset_id, data_id) - if len(affected_nodes) == 0: - return + if len(affected_nodes) == 0: + return - is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) - affected_relationships = await get_data_related_edges(dataset_id, data_id) - is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + affected_relationships = await get_data_related_edges(dataset_id, data_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) - non_legacy_nodes = [ - node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] - ] - - graph_engine = await get_graph_engine() - await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes]) - - affected_vector_collections: Dict[str, List] = {} - for node in non_legacy_nodes: - for indexed_field in node.indexed_fields: - collection_name = f"{node.type}_{indexed_field}" - if collection_name not in affected_vector_collections: - affected_vector_collections[collection_name] = [] - affected_vector_collections[collection_name].append(node) - - vector_engine = get_vector_engine() - for affected_collection, non_legacy_nodes in affected_vector_collections.items(): - await vector_engine.delete_data_points( - affected_collection, [str(node.slug) for node in non_legacy_nodes] - ) - - if len(affected_relationships) > 0: - non_legacy_relationships = [ - edge - for index, edge in enumerate(affected_relationships) - if not is_legacy_relationship[index] + non_legacy_nodes = [ + node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] ] - await vector_engine.delete_data_points( - "EdgeType_relationship_name", - [str(relationship.slug) for relationship in non_legacy_relationships], - ) + graph_engine = await get_graph_engine() + await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes]) - await delete_data_related_nodes(data_id) - await delete_data_related_edges(data_id) + affected_vector_collections: Dict[str, List] = {} + for node in non_legacy_nodes: + for indexed_field in node.indexed_fields: + collection_name = f"{node.type}_{indexed_field}" + if collection_name not in affected_vector_collections: + affected_vector_collections[collection_name] = [] + affected_vector_collections[collection_name].append(node) + + vector_engine = get_vector_engine() + for affected_collection, non_legacy_nodes in affected_vector_collections.items(): + await vector_engine.delete_data_points( + affected_collection, [str(node.slug) for node in non_legacy_nodes] + ) + + if len(affected_relationships) > 0: + non_legacy_relationships = [ + edge + for index, edge in enumerate(affected_relationships) + if not is_legacy_relationship[index] + ] + + await vector_engine.delete_data_points( + "EdgeType_relationship_name", + [str(relationship.slug) for relationship in non_legacy_relationships], + ) + + await delete_data_related_nodes(data_id) + await delete_data_related_edges(data_id) + else: + affected_nodes = await get_global_data_related_nodes(data_id) + + if len(affected_nodes) == 0: + return + + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + + affected_relationships = await get_global_data_related_edges(data_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + + non_legacy_nodes = [ + node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] + ] + + graph_engine = await get_graph_engine() + await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes]) + + affected_vector_collections: Dict[str, List] = {} + for node in non_legacy_nodes: + for indexed_field in node.indexed_fields: + collection_name = f"{node.type}_{indexed_field}" + if collection_name not in affected_vector_collections: + affected_vector_collections[collection_name] = [] + affected_vector_collections[collection_name].append(node) + + vector_engine = get_vector_engine() + for affected_collection, non_legacy_nodes in affected_vector_collections.items(): + await vector_engine.delete_data_points( + affected_collection, [str(node.slug) for node in non_legacy_nodes] + ) + + if len(affected_relationships) > 0: + non_legacy_relationships = [ + edge + for index, edge in enumerate(affected_relationships) + if not is_legacy_relationship[index] + ] + + await vector_engine.delete_data_points( + "EdgeType_relationship_name", + [str(relationship.slug) for relationship in non_legacy_relationships], + ) + + await delete_data_related_nodes(data_id) + await delete_data_related_edges(data_id) diff --git a/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py b/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py index 391816a13..c58afd9e7 100644 --- a/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py +++ b/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py @@ -3,6 +3,9 @@ from typing import Dict, List from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine +from cognee.infrastructure.environment.config.is_backend_access_control_enabled import ( + is_multi_user_support_possible, +) from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger from cognee.modules.graph.methods import ( @@ -10,52 +13,101 @@ from cognee.modules.graph.methods import ( delete_dataset_related_nodes, get_dataset_related_nodes, get_dataset_related_edges, + get_global_dataset_related_nodes, + get_global_dataset_related_edges, ) async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> None: - affected_nodes = await get_dataset_related_nodes(dataset_id) + if is_multi_user_support_possible(): + affected_nodes = await get_dataset_related_nodes(dataset_id) - if len(affected_nodes) == 0: - return + if len(affected_nodes) == 0: + return - is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) - affected_relationships = await get_dataset_related_edges(dataset_id) - is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + affected_relationships = await get_dataset_related_edges(dataset_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) - non_legacy_nodes = [ - node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] - ] - - graph_engine = await get_graph_engine() - await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes]) - - affected_vector_collections: Dict[str, List] = {} - for node in non_legacy_nodes: - for indexed_field in node.indexed_fields: - collection_name = f"{node.type}_{indexed_field}" - if collection_name not in affected_vector_collections: - affected_vector_collections[collection_name] = [] - affected_vector_collections[collection_name].append(node) - - vector_engine = get_vector_engine() - for affected_collection, non_legacy_nodes in affected_vector_collections.items(): - await vector_engine.delete_data_points( - affected_collection, [node.id for node in non_legacy_nodes] - ) - - if len(affected_relationships) > 0: - non_legacy_relationships = [ - edge - for index, edge in enumerate(affected_relationships) - if not is_legacy_relationship[index] + non_legacy_nodes = [ + node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] ] - await vector_engine.delete_data_points( - "EdgeType_relationship_name", - [str(relationship.slug) for relationship in non_legacy_relationships], - ) + graph_engine = await get_graph_engine() + await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes]) - await delete_dataset_related_nodes(dataset_id) - await delete_dataset_related_edges(dataset_id) + affected_vector_collections: Dict[str, List] = {} + for node in non_legacy_nodes: + for indexed_field in node.indexed_fields: + collection_name = f"{node.type}_{indexed_field}" + if collection_name not in affected_vector_collections: + affected_vector_collections[collection_name] = [] + affected_vector_collections[collection_name].append(node) + + vector_engine = get_vector_engine() + for affected_collection, non_legacy_nodes in affected_vector_collections.items(): + await vector_engine.delete_data_points( + affected_collection, [node.id for node in non_legacy_nodes] + ) + + if len(affected_relationships) > 0: + non_legacy_relationships = [ + edge + for index, edge in enumerate(affected_relationships) + if not is_legacy_relationship[index] + ] + + await vector_engine.delete_data_points( + "EdgeType_relationship_name", + [str(relationship.slug) for relationship in non_legacy_relationships], + ) + + await delete_dataset_related_nodes(dataset_id) + await delete_dataset_related_edges(dataset_id) + else: + affected_nodes = await get_global_dataset_related_nodes(dataset_id) + + if len(affected_nodes) == 0: + return + + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + + affected_relationships = await get_global_dataset_related_edges(dataset_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + + non_legacy_nodes = [ + node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] + ] + + graph_engine = await get_graph_engine() + await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes]) + + affected_vector_collections: Dict[str, List] = {} + for node in non_legacy_nodes: + for indexed_field in node.indexed_fields: + collection_name = f"{node.type}_{indexed_field}" + if collection_name not in affected_vector_collections: + affected_vector_collections[collection_name] = [] + affected_vector_collections[collection_name].append(node) + + vector_engine = get_vector_engine() + for affected_collection, non_legacy_nodes in affected_vector_collections.items(): + await vector_engine.delete_data_points( + affected_collection, [str(node.slug) for node in non_legacy_nodes] + ) + + if len(affected_relationships) > 0: + non_legacy_relationships = [ + edge + for index, edge in enumerate(affected_relationships) + if not is_legacy_relationship[index] + ] + + await vector_engine.delete_data_points( + "EdgeType_relationship_name", + [str(relationship.slug) for relationship in non_legacy_relationships], + ) + + await delete_dataset_related_nodes(dataset_id) + await delete_dataset_related_edges(dataset_id) diff --git a/cognee/modules/graph/methods/get_data_related_edges.py b/cognee/modules/graph/methods/get_data_related_edges.py index 326543b84..89cc50e1e 100644 --- a/cognee/modules/graph/methods/get_data_related_edges.py +++ b/cognee/modules/graph/methods/get_data_related_edges.py @@ -29,3 +29,25 @@ async def get_data_related_edges(dataset_id: UUID, data_id: UUID, session: Async data_related_edges = await session.scalars(query_statement) return data_related_edges.all() + + +@with_async_session +async def get_global_data_related_edges(data_id: UUID, session: AsyncSession): + EdgeAlias = aliased(Edge) + + subq = select(EdgeAlias.id).where( + and_( + EdgeAlias.slug == Edge.slug, + EdgeAlias.data_id != data_id, + ) + ) + + query_statement = select(Edge).where( + and_( + Edge.data_id == data_id, + ~exists(subq), + ) + ) + + data_related_edges = await session.scalars(query_statement) + return data_related_edges.all() diff --git a/cognee/modules/graph/methods/get_data_related_nodes.py b/cognee/modules/graph/methods/get_data_related_nodes.py index df4f3f09b..895f2e055 100644 --- a/cognee/modules/graph/methods/get_data_related_nodes.py +++ b/cognee/modules/graph/methods/get_data_related_nodes.py @@ -25,3 +25,20 @@ async def get_data_related_nodes(dataset_id: UUID, data_id: UUID, session: Async data_related_nodes = await session.scalars(query_statement) return data_related_nodes.all() + + +@with_async_session +async def get_global_data_related_nodes(data_id: UUID, session: AsyncSession): + NodeAlias = aliased(Node) + + subq = select(NodeAlias.id).where( + and_( + NodeAlias.slug == Node.slug, + NodeAlias.data_id != data_id, + ) + ) + + query_statement = select(Node).where(and_(Node.data_id == data_id, ~exists(subq))) + + data_related_nodes = await session.scalars(query_statement) + return data_related_nodes.all() diff --git a/cognee/modules/graph/methods/get_dataset_related_edges.py b/cognee/modules/graph/methods/get_dataset_related_edges.py index a67c022e4..b9764f8ee 100644 --- a/cognee/modules/graph/methods/get_dataset_related_edges.py +++ b/cognee/modules/graph/methods/get_dataset_related_edges.py @@ -1,5 +1,6 @@ from uuid import UUID -from sqlalchemy import select +from sqlalchemy.orm import aliased +from sqlalchemy import select, and_, exists from sqlalchemy.ext.asyncio import AsyncSession from cognee.infrastructure.databases.relational import with_async_session @@ -13,3 +14,25 @@ async def get_dataset_related_edges(dataset_id: UUID, session: AsyncSession): select(Edge).where(Edge.dataset_id == dataset_id).distinct(Edge.relationship_name) ) ).all() + + +@with_async_session +async def get_global_dataset_related_edges(dataset_id: UUID, session: AsyncSession): + EdgeAlias = aliased(Edge) + + subq = select(EdgeAlias.id).where( + and_( + EdgeAlias.slug == Edge.slug, + EdgeAlias.dataset_id != dataset_id, + ) + ) + + query_statement = select(Edge).where( + and_( + Edge.dataset_id == dataset_id, + ~exists(subq), + ) + ) + + related_edges = await session.scalars(query_statement) + return related_edges.all() diff --git a/cognee/modules/graph/methods/get_dataset_related_nodes.py b/cognee/modules/graph/methods/get_dataset_related_nodes.py index 818209b19..60245ad81 100644 --- a/cognee/modules/graph/methods/get_dataset_related_nodes.py +++ b/cognee/modules/graph/methods/get_dataset_related_nodes.py @@ -1,5 +1,6 @@ from uuid import UUID -from sqlalchemy import select +from sqlalchemy.orm import aliased +from sqlalchemy import exists, and_, select from sqlalchemy.ext.asyncio import AsyncSession from cognee.infrastructure.databases.relational import with_async_session @@ -12,3 +13,25 @@ async def get_dataset_related_nodes(dataset_id: UUID, session: AsyncSession): data_related_nodes = await session.scalars(query_statement) return data_related_nodes.all() + + +@with_async_session +async def get_global_dataset_related_nodes(dataset_id: UUID, session: AsyncSession): + NodeAlias = aliased(Node) + + subq = select(NodeAlias.id).where( + and_( + NodeAlias.slug == Node.slug, + NodeAlias.dataset_id != dataset_id, + ) + ) + + query_statement = select(Node).where( + and_( + Node.dataset_id == dataset_id, + ~exists(subq), + ) + ) + + related_nodes = await session.scalars(query_statement) + return related_nodes.all() diff --git a/cognee/tests/test_delete_dataset.py b/cognee/tests/test_delete_dataset.py new file mode 100644 index 000000000..9ac6c8a97 --- /dev/null +++ b/cognee/tests/test_delete_dataset.py @@ -0,0 +1,212 @@ +import os +import pathlib +import pytest +from unittest.mock import AsyncMock, patch + +import cognee +from cognee.api.v1.datasets import datasets +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.llm import LLMGateway +from cognee.modules.engine.operations.setup import setup +from cognee.modules.users.methods import create_user, get_default_user +from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +@pytest.mark.asyncio +@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock) +async def main(mock_create_structured_output: AsyncMock): + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_delete_dataset_two_users_graph" + ) + cognee.config.data_root_directory(data_directory_path) + + cognee_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_delete_dataset_two_users_graph" + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + def mock_llm_output(text_input: str, system_prompt: str, response_model): + if text_input == "test": # LLM connection test + return "test" + + if "John" in text_input and response_model == SummarizedContent: + return SummarizedContent( + summary="Summary of John's work.", description="Summary of John's work." + ) + + if "Marie" in text_input and response_model == SummarizedContent: + return SummarizedContent( + summary="Summary of Marie's work.", description="Summary of Marie's work." + ) + + if "Marie" in text_input and response_model == KnowledgeGraph: + return KnowledgeGraph( + nodes=[ + Node(id="Marie", name="Marie", type="Person", description="Marie is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="MacOS", + name="MacOS", + type="Product", + description="MacOS is Apple's operating system", + ), + ], + edges=[ + Edge( + source_node_id="Marie", + target_node_id="Apple", + relationship_name="works_for", + ), + Edge( + source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on" + ), + ], + ) + + if "John" in text_input and response_model == KnowledgeGraph: + return KnowledgeGraph( + nodes=[ + Node(id="John", name="John", type="Person", description="John is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="Food for Hungry", + name="Food for Hungry", + type="Non-profit organization", + description="Food for Hungry is a non-profit organization", + ), + ], + edges=[ + Edge( + source_node_id="John", target_node_id="Apple", relationship_name="works_for" + ), + Edge( + source_node_id="John", + target_node_id="Food for Hungry", + relationship_name="works_for", + ), + ], + ) + + mock_create_structured_output.side_effect = mock_llm_output + + vector_engine = get_vector_engine() + + assert not await vector_engine.has_collection("EdgeType_relationship_name") + assert not await vector_engine.has_collection("Entity_name") + assert not await vector_engine.has_collection("DocumentChunk_text") + assert not await vector_engine.has_collection("TextSummary_text") + assert not await vector_engine.has_collection("TextDocument_text") + + new_user = await create_user( + email="example@user.com", + password="mypassword", + is_superuser=True, + is_active=True, + is_verified=True, + auto_login=True, + ) + + await cognee.add( + "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" + ) + + await cognee.add( + "Marie works for Apple as well. She is a software engineer on MacOS project.", + user=new_user, + ) + + cognify_result: dict = await cognee.cognify() + johns_dataset_id = list(cognify_result.keys())[0] + + cognify_result: dict = await cognee.cognify(user=new_user) + maries_dataset_id = list(cognify_result.keys())[0] + + graph_engine = await get_graph_engine() + initial_nodes, initial_edges = await graph_engine.get_graph_data() + assert len(initial_nodes) == 15 and len(initial_edges) == 19, ( + "Number of nodes and edges is not correct." + ) + + initial_nodes_by_vector_collection = {} + + for node in initial_nodes: + node_data = node[1] + collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] + if collection_name not in initial_nodes_by_vector_collection: + initial_nodes_by_vector_collection[collection_name] = [] + initial_nodes_by_vector_collection[collection_name].append(node) + + initial_node_ids = set([node[0] for node in initial_nodes]) + + default_user = await get_default_user() + await datasets.delete_dataset(johns_dataset_id, default_user) # type: ignore + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted." + assert not any( + node[1]["name"] == "john" or node[1]["name"] == "food for hungry" + for node in nodes + if "name" in node[1] + ), "Nodes are not deleted." + + after_first_delete_node_ids = set([node[0] for node in nodes]) + + after_delete_nodes_by_vector_collection = {} + for node in initial_nodes: + node_data = node[1] + collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] + if collection_name not in after_delete_nodes_by_vector_collection: + after_delete_nodes_by_vector_collection[collection_name] = [] + after_delete_nodes_by_vector_collection[collection_name].append(node) + + vector_engine = get_vector_engine() + + removed_node_ids = initial_node_ids - after_first_delete_node_ids + + for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): + query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids] + + if query_node_ids: + vector_items = await vector_engine.retrieve(collection_name, query_node_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + await datasets.delete_dataset(maries_dataset_id, new_user) # type: ignore + + final_nodes, final_edges = await graph_engine.get_graph_data() + assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted." + + for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): + query_node_ids = [node[0] for node in initial_nodes] + + if query_node_ids: + vector_items = await vector_engine.retrieve(collection_name, query_node_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + query_edge_ids = [edge[0] for edge in initial_edges] + + vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/cognee/tests/test_delete_two_users_graph.py b/cognee/tests/test_delete_two_users_graph.py new file mode 100644 index 000000000..cde6f3525 --- /dev/null +++ b/cognee/tests/test_delete_two_users_graph.py @@ -0,0 +1,214 @@ +import os +import pathlib +import pytest +from unittest.mock import AsyncMock, patch + +import cognee +from cognee.api.v1.datasets import datasets +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.llm import LLMGateway +from cognee.modules.engine.operations.setup import setup +from cognee.modules.users.methods import create_user, get_default_user +from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent +from cognee.shared.logging_utils import get_logger + +logger = get_logger() + + +@pytest.mark.asyncio +@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock) +async def main(mock_create_structured_output: AsyncMock): + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_delete_two_users_graph" + ) + cognee.config.data_root_directory(data_directory_path) + + cognee_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_delete_two_users_graph" + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + def mock_llm_output(text_input: str, system_prompt: str, response_model): + if text_input == "test": # LLM connection test + return "test" + + if "John" in text_input and response_model == SummarizedContent: + return SummarizedContent( + summary="Summary of John's work.", description="Summary of John's work." + ) + + if "Marie" in text_input and response_model == SummarizedContent: + return SummarizedContent( + summary="Summary of Marie's work.", description="Summary of Marie's work." + ) + + if "Marie" in text_input and response_model == KnowledgeGraph: + return KnowledgeGraph( + nodes=[ + Node(id="Marie", name="Marie", type="Person", description="Marie is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="MacOS", + name="MacOS", + type="Product", + description="MacOS is Apple's operating system", + ), + ], + edges=[ + Edge( + source_node_id="Marie", + target_node_id="Apple", + relationship_name="works_for", + ), + Edge( + source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on" + ), + ], + ) + + if "John" in text_input and response_model == KnowledgeGraph: + return KnowledgeGraph( + nodes=[ + Node(id="John", name="John", type="Person", description="John is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="Food for Hungry", + name="Food for Hungry", + type="Non-profit organization", + description="Food for Hungry is a non-profit organization", + ), + ], + edges=[ + Edge( + source_node_id="John", target_node_id="Apple", relationship_name="works_for" + ), + Edge( + source_node_id="John", + target_node_id="Food for Hungry", + relationship_name="works_for", + ), + ], + ) + + mock_create_structured_output.side_effect = mock_llm_output + + vector_engine = get_vector_engine() + + assert not await vector_engine.has_collection("EdgeType_relationship_name") + assert not await vector_engine.has_collection("Entity_name") + assert not await vector_engine.has_collection("DocumentChunk_text") + assert not await vector_engine.has_collection("TextSummary_text") + assert not await vector_engine.has_collection("TextDocument_text") + + new_user = await create_user( + email="example@user.com", + password="mypassword", + is_superuser=True, + is_active=True, + is_verified=True, + auto_login=True, + ) + + add_john_result = await cognee.add( + "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" + ) + johns_data_id = add_john_result.data_ingestion_info[0]["data_id"] + + add_marie_result = await cognee.add( + "Marie works for Apple as well. She is a software engineer on MacOS project.", + user=new_user, + ) + maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"] + + cognify_result: dict = await cognee.cognify() + johns_dataset_id = list(cognify_result.keys())[0] + + cognify_result: dict = await cognee.cognify(user=new_user) + maries_dataset_id = list(cognify_result.keys())[0] + + graph_engine = await get_graph_engine() + initial_nodes, initial_edges = await graph_engine.get_graph_data() + assert len(initial_nodes) == 15 and len(initial_edges) == 19, ( + "Number of nodes and edges is not correct." + ) + + initial_nodes_by_vector_collection = {} + + for node in initial_nodes: + node_data = node[1] + collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] + if collection_name not in initial_nodes_by_vector_collection: + initial_nodes_by_vector_collection[collection_name] = [] + initial_nodes_by_vector_collection[collection_name].append(node) + + initial_node_ids = set([node[0] for node in initial_nodes]) + + default_user = await get_default_user() + await datasets.delete_data(johns_dataset_id, johns_data_id, default_user) # type: ignore + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted." + assert not any( + node[1]["name"] == "john" or node[1]["name"] == "food for hungry" + for node in nodes + if "name" in node[1] + ), "Nodes are not deleted." + + after_first_delete_node_ids = set([node[0] for node in nodes]) + + after_delete_nodes_by_vector_collection = {} + for node in initial_nodes: + node_data = node[1] + collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] + if collection_name not in after_delete_nodes_by_vector_collection: + after_delete_nodes_by_vector_collection[collection_name] = [] + after_delete_nodes_by_vector_collection[collection_name].append(node) + + vector_engine = get_vector_engine() + + removed_node_ids = initial_node_ids - after_first_delete_node_ids + + for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): + query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids] + + if query_node_ids: + vector_items = await vector_engine.retrieve(collection_name, query_node_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + await datasets.delete_data(maries_dataset_id, maries_data_id, new_user) # type: ignore + + final_nodes, final_edges = await graph_engine.get_graph_data() + assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted." + + for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): + query_node_ids = [node[0] for node in initial_nodes] + + if query_node_ids: + vector_items = await vector_engine.retrieve(collection_name, query_node_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + query_edge_ids = [edge[0] for edge in initial_edges] + + vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main())