diff --git a/cognee/tests/test_delete_default_graph.py b/cognee/tests/test_delete_default_graph.py index 0a7d54e40..ec0a1aa51 100644 --- a/cognee/tests/test_delete_default_graph.py +++ b/cognee/tests/test_delete_default_graph.py @@ -1,20 +1,25 @@ import os import pathlib +import pytest +from unittest.mock import AsyncMock, patch import cognee from cognee.api.v1.datasets import datasets -from cognee.api.v1.visualize.visualize import visualize_graph from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.llm import LLMGateway from cognee.modules.data.methods import get_dataset_data from cognee.modules.engine.operations.setup import setup from cognee.modules.users.methods import get_default_user +from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent from cognee.shared.logging_utils import get_logger logger = get_logger() -async def main(): +@pytest.mark.asyncio +@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock) +async def main(mock_create_structured_output: AsyncMock): data_directory_path = str( pathlib.Path( os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph") @@ -37,9 +42,68 @@ async def main(): assert not await vector_engine.has_collection("EdgeType_relationship_name") assert not await vector_engine.has_collection("Entity_name") + mock_create_structured_output.side_effect = [ + "", # For LLM connection test + KnowledgeGraph( + nodes=[ + Node(id="John", name="John", type="Person", description="John is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="Food for Hungry", + name="Food for Hungry", + type="Non-profit organization", + description="Food for Hungry is a non-profit organization", + ), + ], + edges=[ + Edge(source_node_id="John", target_node_id="Apple", relationship_name="works_for"), + Edge( + source_node_id="John", + target_node_id="Food for Hungry", + relationship_name="works_for", + ), + ], + ), + KnowledgeGraph( + nodes=[ + Node(id="Marie", name="Marie", type="Person", description="Marie is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="MacOS", + name="MacOS", + type="Product", + description="MacOS is Apple's operating system", + ), + ], + edges=[ + Edge( + source_node_id="Marie", + target_node_id="Apple", + relationship_name="works_for", + ), + Edge(source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on"), + ], + ), + SummarizedContent(summary="Summary of John's work.", description="Summary of John's work."), + SummarizedContent( + summary="Summary of Marie's work.", description="Summary of Marie's work." + ), + ] + await cognee.add( "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" ) + await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.") cognify_result: dict = await cognee.cognify() @@ -48,27 +112,25 @@ async def main(): dataset_data = await get_dataset_data(dataset_id) added_data = dataset_data[0] - file_path = os.path.join( - pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html" - ) - await visualize_graph(file_path) + # file_path = os.path.join( + # pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html" + # ) + # await visualize_graph(file_path) graph_engine = await get_graph_engine() nodes, edges = await graph_engine.get_graph_data() - assert len(nodes) >= 12 and len(edges) >= 18, "Nodes and edges are not deleted." + assert len(nodes) == 15 and len(edges) == 19, "Number of nodes and edges is not correct." user = await get_default_user() await datasets.delete_data(dataset_id, added_data.id, user) # type: ignore - file_path = os.path.join( - pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html" - ) - await visualize_graph(file_path) + # file_path = os.path.join( + # pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html" + # ) + # await visualize_graph(file_path) nodes, edges = await graph_engine.get_graph_data() - assert len(nodes) >= 8 and len(nodes) < 12 and len(edges) >= 10 and len(edges) < 18, ( - "Nodes and edges are not deleted." - ) + assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted." if __name__ == "__main__": diff --git a/cognee/tests/test_delete_permission.py b/cognee/tests/test_delete_permission.py new file mode 100644 index 000000000..0735b6f75 --- /dev/null +++ b/cognee/tests/test_delete_permission.py @@ -0,0 +1,127 @@ +import os +import pathlib +from typing import List +from uuid import UUID, uuid4 +from pydantic import BaseModel + +import cognee +from cognee.api.v1.datasets import datasets +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.exceptions.exceptions import UnauthorizedDataAccessError +from cognee.modules.data.methods import create_authorized_dataset +from cognee.modules.engine.operations.setup import setup +from cognee.modules.users.models import User +from cognee.modules.users.methods import create_user +from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets +from cognee.shared.logging_utils import get_logger +from cognee.tasks.storage import add_data_points + +logger = get_logger() + + +async def main(): + os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "True" + + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_delete_permission" + ) + cognee.config.data_root_directory(data_directory_path) + + cognee_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_delete_permission" + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Organization(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class ForProfit(Organization): + name: str = "For-Profit" + metadata: dict = {"index_fields": ["name"]} + + class NonProfit(Organization): + name: str = "Non-Profit" + metadata: dict = {"index_fields": ["name"]} + + class Person(DataPoint): + name: str + works_for: List[Organization] + metadata: dict = {"index_fields": ["name"]} + + companyA = ForProfit(name="Company A") + companyB = NonProfit(name="Company B") + + person1 = Person(name="John", works_for=[companyA, companyB]) + person2 = Person(name="Jane", works_for=[companyB]) + + user1: User = await create_user(email="user1@example.com", password="password123") + user2: User = await create_user(email="user2@example.com", password="password123") + + class CustomData(BaseModel): + id: UUID + + dataset = await create_authorized_dataset(dataset_name="test_dataset", user=user1) + + data1 = CustomData(id=uuid4()) + data2 = CustomData(id=uuid4()) + + await add_data_points( + [person1], + context={ + "user": user1, + "dataset": dataset, + "data": data1, + }, + ) + + await add_data_points( + [person2], + context={ + "user": user1, + "dataset": dataset, + "data": data2, + }, + ) + + from cognee.infrastructure.databases.graph import get_graph_engine + + graph_engine = await get_graph_engine() + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 4 and len(edges) == 3, ( + "Nodes and edges are not correctly added to the graph." + ) + + is_permission_error_raised = False + try: + await datasets.delete_data(dataset.id, data1.id, user2) + except UnauthorizedDataAccessError: + is_permission_error_raised = True + + assert is_permission_error_raised, "PermissionDeniedError was not raised as expected." + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 4 and len(edges) == 3, "Graph is changed without permissions." + + await authorized_give_permission_on_datasets(user2.id, [dataset.id], "delete", user1.id) + + await datasets.delete_data(dataset.id, data1.id, user2) + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly." + + await datasets.delete_data(dataset.id, data2.id, user2) + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 0 and len(edges) == 0, "Nodes and edges are not deleted." + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main())