fix: add delete tests
This commit is contained in:
parent
94d2ca01a7
commit
4b067be34b
2 changed files with 203 additions and 14 deletions
|
|
@ -1,20 +1,25 @@
|
|||
import os
|
||||
import pathlib
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.modules.data.methods import get_dataset_data
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||
async def main(mock_create_structured_output: AsyncMock):
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph")
|
||||
|
|
@ -37,9 +42,68 @@ async def main():
|
|||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||
assert not await vector_engine.has_collection("Entity_name")
|
||||
|
||||
mock_create_structured_output.side_effect = [
|
||||
"", # For LLM connection test
|
||||
KnowledgeGraph(
|
||||
nodes=[
|
||||
Node(id="John", name="John", type="Person", description="John is a person"),
|
||||
Node(
|
||||
id="Apple",
|
||||
name="Apple",
|
||||
type="Company",
|
||||
description="Apple is a company",
|
||||
),
|
||||
Node(
|
||||
id="Food for Hungry",
|
||||
name="Food for Hungry",
|
||||
type="Non-profit organization",
|
||||
description="Food for Hungry is a non-profit organization",
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
Edge(source_node_id="John", target_node_id="Apple", relationship_name="works_for"),
|
||||
Edge(
|
||||
source_node_id="John",
|
||||
target_node_id="Food for Hungry",
|
||||
relationship_name="works_for",
|
||||
),
|
||||
],
|
||||
),
|
||||
KnowledgeGraph(
|
||||
nodes=[
|
||||
Node(id="Marie", name="Marie", type="Person", description="Marie is a person"),
|
||||
Node(
|
||||
id="Apple",
|
||||
name="Apple",
|
||||
type="Company",
|
||||
description="Apple is a company",
|
||||
),
|
||||
Node(
|
||||
id="MacOS",
|
||||
name="MacOS",
|
||||
type="Product",
|
||||
description="MacOS is Apple's operating system",
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
Edge(
|
||||
source_node_id="Marie",
|
||||
target_node_id="Apple",
|
||||
relationship_name="works_for",
|
||||
),
|
||||
Edge(source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on"),
|
||||
],
|
||||
),
|
||||
SummarizedContent(summary="Summary of John's work.", description="Summary of John's work."),
|
||||
SummarizedContent(
|
||||
summary="Summary of Marie's work.", description="Summary of Marie's work."
|
||||
),
|
||||
]
|
||||
|
||||
await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
|
||||
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
|
|
@ -48,27 +112,25 @@ async def main():
|
|||
dataset_data = await get_dataset_data(dataset_id)
|
||||
added_data = dataset_data[0]
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html"
|
||||
)
|
||||
await visualize_graph(file_path)
|
||||
# file_path = os.path.join(
|
||||
# pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html"
|
||||
# )
|
||||
# await visualize_graph(file_path)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) >= 12 and len(edges) >= 18, "Nodes and edges are not deleted."
|
||||
assert len(nodes) == 15 and len(edges) == 19, "Number of nodes and edges is not correct."
|
||||
|
||||
user = await get_default_user()
|
||||
await datasets.delete_data(dataset_id, added_data.id, user) # type: ignore
|
||||
|
||||
file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html"
|
||||
)
|
||||
await visualize_graph(file_path)
|
||||
# file_path = os.path.join(
|
||||
# pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html"
|
||||
# )
|
||||
# await visualize_graph(file_path)
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) >= 8 and len(nodes) < 12 and len(edges) >= 10 and len(edges) < 18, (
|
||||
"Nodes and edges are not deleted."
|
||||
)
|
||||
assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
127
cognee/tests/test_delete_permission.py
Normal file
127
cognee/tests/test_delete_permission.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import os
|
||||
import pathlib
|
||||
from typing import List
|
||||
from uuid import UUID, uuid4
|
||||
from pydantic import BaseModel
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.exceptions.exceptions import UnauthorizedDataAccessError
|
||||
from cognee.modules.data.methods import create_authorized_dataset
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import create_user
|
||||
from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.tasks.storage import add_data_points
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "True"
|
||||
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_delete_permission"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
cognee_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".cognee_system/test_delete_permission"
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
class Organization(DataPoint):
|
||||
name: str
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class ForProfit(Organization):
|
||||
name: str = "For-Profit"
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class NonProfit(Organization):
|
||||
name: str = "Non-Profit"
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
class Person(DataPoint):
|
||||
name: str
|
||||
works_for: List[Organization]
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
companyA = ForProfit(name="Company A")
|
||||
companyB = NonProfit(name="Company B")
|
||||
|
||||
person1 = Person(name="John", works_for=[companyA, companyB])
|
||||
person2 = Person(name="Jane", works_for=[companyB])
|
||||
|
||||
user1: User = await create_user(email="user1@example.com", password="password123")
|
||||
user2: User = await create_user(email="user2@example.com", password="password123")
|
||||
|
||||
class CustomData(BaseModel):
|
||||
id: UUID
|
||||
|
||||
dataset = await create_authorized_dataset(dataset_name="test_dataset", user=user1)
|
||||
|
||||
data1 = CustomData(id=uuid4())
|
||||
data2 = CustomData(id=uuid4())
|
||||
|
||||
await add_data_points(
|
||||
[person1],
|
||||
context={
|
||||
"user": user1,
|
||||
"dataset": dataset,
|
||||
"data": data1,
|
||||
},
|
||||
)
|
||||
|
||||
await add_data_points(
|
||||
[person2],
|
||||
context={
|
||||
"user": user1,
|
||||
"dataset": dataset,
|
||||
"data": data2,
|
||||
},
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 4 and len(edges) == 3, (
|
||||
"Nodes and edges are not correctly added to the graph."
|
||||
)
|
||||
|
||||
is_permission_error_raised = False
|
||||
try:
|
||||
await datasets.delete_data(dataset.id, data1.id, user2)
|
||||
except UnauthorizedDataAccessError:
|
||||
is_permission_error_raised = True
|
||||
|
||||
assert is_permission_error_raised, "PermissionDeniedError was not raised as expected."
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 4 and len(edges) == 3, "Graph is changed without permissions."
|
||||
|
||||
await authorized_give_permission_on_datasets(user2.id, [dataset.id], "delete", user1.id)
|
||||
|
||||
await datasets.delete_data(dataset.id, data1.id, user2)
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly."
|
||||
|
||||
await datasets.delete_data(dataset.id, data2.id, user2)
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 0 and len(edges) == 0, "Nodes and edges are not deleted."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Reference in a new issue