fix: add delete support for single db graphs
This commit is contained in:
parent
468205fed1
commit
17632a5bec
12 changed files with 805 additions and 95 deletions
95
.github/workflows/e2e_tests.yml
vendored
95
.github/workflows/e2e_tests.yml
vendored
|
|
@ -387,8 +387,8 @@ jobs:
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_delete_custom_graph.py
|
run: uv run python ./cognee/tests/test_delete_custom_graph.py
|
||||||
|
|
||||||
test-deletion-on-default-graph_with_legacy_data_1:
|
test-deletion-on-default-graph_with_legacy_data_1_default:
|
||||||
name: Delete default graph with legacy data test 1
|
name: Delete default graph with legacy data test 1 in Kuzu case
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Check out
|
- name: Check out
|
||||||
|
|
@ -414,6 +414,39 @@ jobs:
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_1.py
|
run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_1.py
|
||||||
|
|
||||||
|
test-deletion-on-default-graph_with_legacy_data_1_neo4j:
|
||||||
|
name: Delete default graph with legacy data test 1 in Neo4j case
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@master
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
|
- name: Run delete dataset test in Neo4j case
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||||
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
|
run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_1.py
|
||||||
|
|
||||||
test-deletion-on-default-graph_with_legacy_data_2:
|
test-deletion-on-default-graph_with_legacy_data_2:
|
||||||
name: Delete default graph with legacy data test 2
|
name: Delete default graph with legacy data test 2
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
|
@ -441,6 +474,64 @@ jobs:
|
||||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_2.py
|
run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_2.py
|
||||||
|
|
||||||
|
test-delete-dataset-default:
|
||||||
|
name: Delete dataset in Kuzu graph case
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@master
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Run delete dataset test in Kuzu case
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
run: uv run python ./cognee/tests/test_delete_dataset.py
|
||||||
|
|
||||||
|
test-delete-dataset-neo4j:
|
||||||
|
name: Delete dataset in Neo4j graph case
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Check out
|
||||||
|
uses: actions/checkout@master
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Setup Neo4j with GDS
|
||||||
|
uses: ./.github/actions/setup_neo4j
|
||||||
|
id: neo4j
|
||||||
|
|
||||||
|
- name: Run delete dataset test in Neo4j case
|
||||||
|
env:
|
||||||
|
ENV: 'dev'
|
||||||
|
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||||
|
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||||
|
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||||
|
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||||
|
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||||
|
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||||
|
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||||
|
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||||
|
GRAPH_DATABASE_PROVIDER: "neo4j"
|
||||||
|
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||||
|
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||||
|
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||||
|
run: uv run python ./cognee/tests/test_delete_dataset.py
|
||||||
|
|
||||||
test-graph-edges:
|
test-graph-edges:
|
||||||
name: Test graph edge ingestion
|
name: Test graph edge ingestion
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.users.exceptions import PermissionDeniedError
|
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||||
from cognee.modules.data.methods import has_dataset_data
|
from cognee.modules.data.methods import get_dataset_data, has_dataset_data
|
||||||
from cognee.modules.data.methods import get_authorized_dataset, get_authorized_existing_datasets
|
from cognee.modules.data.methods import get_authorized_dataset, get_authorized_existing_datasets
|
||||||
from cognee.modules.data.exceptions.exceptions import UnauthorizedDataAccessError
|
from cognee.modules.data.exceptions.exceptions import UnauthorizedDataAccessError
|
||||||
from cognee.modules.graph.methods import (
|
from cognee.modules.graph.methods import (
|
||||||
|
|
@ -41,12 +41,11 @@ class datasets:
|
||||||
return await get_dataset_data(dataset.id)
|
return await get_dataset_data(dataset.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def has_data(dataset_id: str) -> bool:
|
async def has_data(dataset_id: str, user: Optional[User] = None) -> bool:
|
||||||
from cognee.modules.data.methods import get_dataset
|
if not user:
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
user = await get_default_user()
|
dataset = await get_authorized_dataset(user.id, dataset_id)
|
||||||
|
|
||||||
dataset = await get_dataset(user.id, dataset_id)
|
|
||||||
|
|
||||||
return await has_dataset_data(dataset.id)
|
return await has_dataset_data(dataset.id)
|
||||||
|
|
||||||
|
|
@ -56,7 +55,7 @@ class datasets:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def delete_dataset(dataset_id: UUID, user: Optional[User] = None):
|
async def delete_dataset(dataset_id: UUID, user: Optional[User] = None):
|
||||||
from cognee.modules.data.methods import delete_dataset
|
from cognee.modules.data.methods import delete_data, delete_dataset
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
@ -68,6 +67,11 @@ class datasets:
|
||||||
|
|
||||||
await delete_dataset_nodes_and_edges(dataset_id, user.id)
|
await delete_dataset_nodes_and_edges(dataset_id, user.id)
|
||||||
|
|
||||||
|
dataset_data = await get_dataset_data(dataset.id)
|
||||||
|
|
||||||
|
for data in dataset_data:
|
||||||
|
await delete_data(data)
|
||||||
|
|
||||||
return await delete_dataset(dataset)
|
return await delete_dataset(dataset)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -108,7 +112,7 @@ class datasets:
|
||||||
if not user:
|
if not user:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
user_datasets = await get_authorized_existing_datasets([], "read", user)
|
user_datasets = await get_authorized_existing_datasets([], "delete", user)
|
||||||
|
|
||||||
for dataset in user_datasets:
|
for dataset in user_datasets:
|
||||||
await datasets.delete_dataset(dataset.id, user)
|
await datasets.delete_dataset(dataset.id, user)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ VECTOR_DBS_WITH_MULTI_USER_SUPPORT = ["lancedb", "falkor"]
|
||||||
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
GRAPH_DBS_WITH_MULTI_USER_SUPPORT = ["kuzu", "falkor"]
|
||||||
|
|
||||||
|
|
||||||
def multi_user_support_possible():
|
def is_multi_user_support_possible():
|
||||||
graph_db_config = get_graph_context_config()
|
graph_db_config = get_graph_context_config()
|
||||||
vector_db_config = get_vectordb_context_config()
|
vector_db_config = get_vectordb_context_config()
|
||||||
return (
|
return (
|
||||||
|
|
@ -22,10 +22,10 @@ def is_backend_access_control_enabled():
|
||||||
if backend_access_control is None:
|
if backend_access_control is None:
|
||||||
# If backend access control is not defined in environment variables,
|
# If backend access control is not defined in environment variables,
|
||||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||||
return multi_user_support_possible()
|
return is_multi_user_support_possible()
|
||||||
elif backend_access_control.lower() == "true":
|
elif backend_access_control.lower() == "true":
|
||||||
# If enabled, ensure that the current graph and vector DBs can support it
|
# If enabled, ensure that the current graph and vector DBs can support it
|
||||||
multi_user_support = multi_user_support_possible()
|
multi_user_support = is_multi_user_support_possible()
|
||||||
if not multi_user_support:
|
if not multi_user_support:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,14 @@ from .upsert_nodes import upsert_nodes
|
||||||
|
|
||||||
from .has_data_related_nodes import has_data_related_nodes
|
from .has_data_related_nodes import has_data_related_nodes
|
||||||
|
|
||||||
from .get_data_related_nodes import get_data_related_nodes
|
from .get_data_related_nodes import get_data_related_nodes, get_global_data_related_nodes
|
||||||
from .get_data_related_edges import get_data_related_edges
|
from .get_data_related_edges import get_data_related_edges, get_global_data_related_edges
|
||||||
from .delete_data_related_nodes import delete_data_related_nodes
|
from .delete_data_related_nodes import delete_data_related_nodes
|
||||||
from .delete_data_related_edges import delete_data_related_edges
|
from .delete_data_related_edges import delete_data_related_edges
|
||||||
from .delete_data_nodes_and_edges import delete_data_nodes_and_edges
|
from .delete_data_nodes_and_edges import delete_data_nodes_and_edges
|
||||||
|
|
||||||
from .get_dataset_related_nodes import get_dataset_related_nodes
|
from .get_dataset_related_nodes import get_dataset_related_nodes, get_global_dataset_related_nodes
|
||||||
from .get_dataset_related_edges import get_dataset_related_edges
|
from .get_dataset_related_edges import get_dataset_related_edges, get_global_dataset_related_edges
|
||||||
from .delete_dataset_related_nodes import delete_dataset_related_nodes
|
from .delete_dataset_related_nodes import delete_dataset_related_nodes
|
||||||
from .delete_dataset_related_edges import delete_dataset_related_edges
|
from .delete_dataset_related_edges import delete_dataset_related_edges
|
||||||
from .delete_dataset_nodes_and_edges import delete_dataset_nodes_and_edges
|
from .delete_dataset_nodes_and_edges import delete_dataset_nodes_and_edges
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@ from typing import Dict, List
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||||
|
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
|
||||||
|
is_multi_user_support_possible,
|
||||||
|
)
|
||||||
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
|
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
|
||||||
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
|
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
|
||||||
from cognee.modules.graph.methods import (
|
from cognee.modules.graph.methods import (
|
||||||
|
|
@ -10,52 +13,101 @@ from cognee.modules.graph.methods import (
|
||||||
delete_data_related_nodes,
|
delete_data_related_nodes,
|
||||||
get_data_related_nodes,
|
get_data_related_nodes,
|
||||||
get_data_related_edges,
|
get_data_related_edges,
|
||||||
|
get_global_data_related_nodes,
|
||||||
|
get_global_data_related_edges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id: UUID) -> None:
|
async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id: UUID) -> None:
|
||||||
affected_nodes = await get_data_related_nodes(dataset_id, data_id)
|
if is_multi_user_support_possible():
|
||||||
|
affected_nodes = await get_data_related_nodes(dataset_id, data_id)
|
||||||
|
|
||||||
if len(affected_nodes) == 0:
|
if len(affected_nodes) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||||
|
|
||||||
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
||||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||||
|
|
||||||
non_legacy_nodes = [
|
non_legacy_nodes = [
|
||||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||||
]
|
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
|
|
||||||
|
|
||||||
affected_vector_collections: Dict[str, List] = {}
|
|
||||||
for node in non_legacy_nodes:
|
|
||||||
for indexed_field in node.indexed_fields:
|
|
||||||
collection_name = f"{node.type}_{indexed_field}"
|
|
||||||
if collection_name not in affected_vector_collections:
|
|
||||||
affected_vector_collections[collection_name] = []
|
|
||||||
affected_vector_collections[collection_name].append(node)
|
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
|
|
||||||
await vector_engine.delete_data_points(
|
|
||||||
affected_collection, [str(node.slug) for node in non_legacy_nodes]
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(affected_relationships) > 0:
|
|
||||||
non_legacy_relationships = [
|
|
||||||
edge
|
|
||||||
for index, edge in enumerate(affected_relationships)
|
|
||||||
if not is_legacy_relationship[index]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
await vector_engine.delete_data_points(
|
graph_engine = await get_graph_engine()
|
||||||
"EdgeType_relationship_name",
|
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
|
||||||
[str(relationship.slug) for relationship in non_legacy_relationships],
|
|
||||||
)
|
|
||||||
|
|
||||||
await delete_data_related_nodes(data_id)
|
affected_vector_collections: Dict[str, List] = {}
|
||||||
await delete_data_related_edges(data_id)
|
for node in non_legacy_nodes:
|
||||||
|
for indexed_field in node.indexed_fields:
|
||||||
|
collection_name = f"{node.type}_{indexed_field}"
|
||||||
|
if collection_name not in affected_vector_collections:
|
||||||
|
affected_vector_collections[collection_name] = []
|
||||||
|
affected_vector_collections[collection_name].append(node)
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
affected_collection, [str(node.slug) for node in non_legacy_nodes]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(affected_relationships) > 0:
|
||||||
|
non_legacy_relationships = [
|
||||||
|
edge
|
||||||
|
for index, edge in enumerate(affected_relationships)
|
||||||
|
if not is_legacy_relationship[index]
|
||||||
|
]
|
||||||
|
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
"EdgeType_relationship_name",
|
||||||
|
[str(relationship.slug) for relationship in non_legacy_relationships],
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete_data_related_nodes(data_id)
|
||||||
|
await delete_data_related_edges(data_id)
|
||||||
|
else:
|
||||||
|
affected_nodes = await get_global_data_related_nodes(data_id)
|
||||||
|
|
||||||
|
if len(affected_nodes) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||||
|
|
||||||
|
affected_relationships = await get_global_data_related_edges(data_id)
|
||||||
|
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||||
|
|
||||||
|
non_legacy_nodes = [
|
||||||
|
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||||
|
]
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
|
||||||
|
|
||||||
|
affected_vector_collections: Dict[str, List] = {}
|
||||||
|
for node in non_legacy_nodes:
|
||||||
|
for indexed_field in node.indexed_fields:
|
||||||
|
collection_name = f"{node.type}_{indexed_field}"
|
||||||
|
if collection_name not in affected_vector_collections:
|
||||||
|
affected_vector_collections[collection_name] = []
|
||||||
|
affected_vector_collections[collection_name].append(node)
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
affected_collection, [str(node.slug) for node in non_legacy_nodes]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(affected_relationships) > 0:
|
||||||
|
non_legacy_relationships = [
|
||||||
|
edge
|
||||||
|
for index, edge in enumerate(affected_relationships)
|
||||||
|
if not is_legacy_relationship[index]
|
||||||
|
]
|
||||||
|
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
"EdgeType_relationship_name",
|
||||||
|
[str(relationship.slug) for relationship in non_legacy_relationships],
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete_data_related_nodes(data_id)
|
||||||
|
await delete_data_related_edges(data_id)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@ from typing import Dict, List
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||||
|
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
|
||||||
|
is_multi_user_support_possible,
|
||||||
|
)
|
||||||
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
|
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
|
||||||
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
|
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
|
||||||
from cognee.modules.graph.methods import (
|
from cognee.modules.graph.methods import (
|
||||||
|
|
@ -10,52 +13,101 @@ from cognee.modules.graph.methods import (
|
||||||
delete_dataset_related_nodes,
|
delete_dataset_related_nodes,
|
||||||
get_dataset_related_nodes,
|
get_dataset_related_nodes,
|
||||||
get_dataset_related_edges,
|
get_dataset_related_edges,
|
||||||
|
get_global_dataset_related_nodes,
|
||||||
|
get_global_dataset_related_edges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> None:
|
async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> None:
|
||||||
affected_nodes = await get_dataset_related_nodes(dataset_id)
|
if is_multi_user_support_possible():
|
||||||
|
affected_nodes = await get_dataset_related_nodes(dataset_id)
|
||||||
|
|
||||||
if len(affected_nodes) == 0:
|
if len(affected_nodes) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||||
|
|
||||||
affected_relationships = await get_dataset_related_edges(dataset_id)
|
affected_relationships = await get_dataset_related_edges(dataset_id)
|
||||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||||
|
|
||||||
non_legacy_nodes = [
|
non_legacy_nodes = [
|
||||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||||
]
|
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
|
|
||||||
|
|
||||||
affected_vector_collections: Dict[str, List] = {}
|
|
||||||
for node in non_legacy_nodes:
|
|
||||||
for indexed_field in node.indexed_fields:
|
|
||||||
collection_name = f"{node.type}_{indexed_field}"
|
|
||||||
if collection_name not in affected_vector_collections:
|
|
||||||
affected_vector_collections[collection_name] = []
|
|
||||||
affected_vector_collections[collection_name].append(node)
|
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
|
|
||||||
await vector_engine.delete_data_points(
|
|
||||||
affected_collection, [node.id for node in non_legacy_nodes]
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(affected_relationships) > 0:
|
|
||||||
non_legacy_relationships = [
|
|
||||||
edge
|
|
||||||
for index, edge in enumerate(affected_relationships)
|
|
||||||
if not is_legacy_relationship[index]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
await vector_engine.delete_data_points(
|
graph_engine = await get_graph_engine()
|
||||||
"EdgeType_relationship_name",
|
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
|
||||||
[str(relationship.slug) for relationship in non_legacy_relationships],
|
|
||||||
)
|
|
||||||
|
|
||||||
await delete_dataset_related_nodes(dataset_id)
|
affected_vector_collections: Dict[str, List] = {}
|
||||||
await delete_dataset_related_edges(dataset_id)
|
for node in non_legacy_nodes:
|
||||||
|
for indexed_field in node.indexed_fields:
|
||||||
|
collection_name = f"{node.type}_{indexed_field}"
|
||||||
|
if collection_name not in affected_vector_collections:
|
||||||
|
affected_vector_collections[collection_name] = []
|
||||||
|
affected_vector_collections[collection_name].append(node)
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
affected_collection, [node.id for node in non_legacy_nodes]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(affected_relationships) > 0:
|
||||||
|
non_legacy_relationships = [
|
||||||
|
edge
|
||||||
|
for index, edge in enumerate(affected_relationships)
|
||||||
|
if not is_legacy_relationship[index]
|
||||||
|
]
|
||||||
|
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
"EdgeType_relationship_name",
|
||||||
|
[str(relationship.slug) for relationship in non_legacy_relationships],
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete_dataset_related_nodes(dataset_id)
|
||||||
|
await delete_dataset_related_edges(dataset_id)
|
||||||
|
else:
|
||||||
|
affected_nodes = await get_global_dataset_related_nodes(dataset_id)
|
||||||
|
|
||||||
|
if len(affected_nodes) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||||
|
|
||||||
|
affected_relationships = await get_global_dataset_related_edges(dataset_id)
|
||||||
|
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||||
|
|
||||||
|
non_legacy_nodes = [
|
||||||
|
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||||
|
]
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
|
||||||
|
|
||||||
|
affected_vector_collections: Dict[str, List] = {}
|
||||||
|
for node in non_legacy_nodes:
|
||||||
|
for indexed_field in node.indexed_fields:
|
||||||
|
collection_name = f"{node.type}_{indexed_field}"
|
||||||
|
if collection_name not in affected_vector_collections:
|
||||||
|
affected_vector_collections[collection_name] = []
|
||||||
|
affected_vector_collections[collection_name].append(node)
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
affected_collection, [str(node.slug) for node in non_legacy_nodes]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(affected_relationships) > 0:
|
||||||
|
non_legacy_relationships = [
|
||||||
|
edge
|
||||||
|
for index, edge in enumerate(affected_relationships)
|
||||||
|
if not is_legacy_relationship[index]
|
||||||
|
]
|
||||||
|
|
||||||
|
await vector_engine.delete_data_points(
|
||||||
|
"EdgeType_relationship_name",
|
||||||
|
[str(relationship.slug) for relationship in non_legacy_relationships],
|
||||||
|
)
|
||||||
|
|
||||||
|
await delete_dataset_related_nodes(dataset_id)
|
||||||
|
await delete_dataset_related_edges(dataset_id)
|
||||||
|
|
|
||||||
|
|
@ -29,3 +29,25 @@ async def get_data_related_edges(dataset_id: UUID, data_id: UUID, session: Async
|
||||||
|
|
||||||
data_related_edges = await session.scalars(query_statement)
|
data_related_edges = await session.scalars(query_statement)
|
||||||
return data_related_edges.all()
|
return data_related_edges.all()
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def get_global_data_related_edges(data_id: UUID, session: AsyncSession):
|
||||||
|
EdgeAlias = aliased(Edge)
|
||||||
|
|
||||||
|
subq = select(EdgeAlias.id).where(
|
||||||
|
and_(
|
||||||
|
EdgeAlias.slug == Edge.slug,
|
||||||
|
EdgeAlias.data_id != data_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_statement = select(Edge).where(
|
||||||
|
and_(
|
||||||
|
Edge.data_id == data_id,
|
||||||
|
~exists(subq),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
data_related_edges = await session.scalars(query_statement)
|
||||||
|
return data_related_edges.all()
|
||||||
|
|
|
||||||
|
|
@ -25,3 +25,20 @@ async def get_data_related_nodes(dataset_id: UUID, data_id: UUID, session: Async
|
||||||
|
|
||||||
data_related_nodes = await session.scalars(query_statement)
|
data_related_nodes = await session.scalars(query_statement)
|
||||||
return data_related_nodes.all()
|
return data_related_nodes.all()
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def get_global_data_related_nodes(data_id: UUID, session: AsyncSession):
|
||||||
|
NodeAlias = aliased(Node)
|
||||||
|
|
||||||
|
subq = select(NodeAlias.id).where(
|
||||||
|
and_(
|
||||||
|
NodeAlias.slug == Node.slug,
|
||||||
|
NodeAlias.data_id != data_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_statement = select(Node).where(and_(Node.data_id == data_id, ~exists(subq)))
|
||||||
|
|
||||||
|
data_related_nodes = await session.scalars(query_statement)
|
||||||
|
return data_related_nodes.all()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy import select
|
from sqlalchemy.orm import aliased
|
||||||
|
from sqlalchemy import select, and_, exists
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import with_async_session
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
|
@ -13,3 +14,25 @@ async def get_dataset_related_edges(dataset_id: UUID, session: AsyncSession):
|
||||||
select(Edge).where(Edge.dataset_id == dataset_id).distinct(Edge.relationship_name)
|
select(Edge).where(Edge.dataset_id == dataset_id).distinct(Edge.relationship_name)
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def get_global_dataset_related_edges(dataset_id: UUID, session: AsyncSession):
|
||||||
|
EdgeAlias = aliased(Edge)
|
||||||
|
|
||||||
|
subq = select(EdgeAlias.id).where(
|
||||||
|
and_(
|
||||||
|
EdgeAlias.slug == Edge.slug,
|
||||||
|
EdgeAlias.dataset_id != dataset_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_statement = select(Edge).where(
|
||||||
|
and_(
|
||||||
|
Edge.dataset_id == dataset_id,
|
||||||
|
~exists(subq),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
related_edges = await session.scalars(query_statement)
|
||||||
|
return related_edges.all()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy import select
|
from sqlalchemy.orm import aliased
|
||||||
|
from sqlalchemy import exists, and_, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import with_async_session
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
|
@ -12,3 +13,25 @@ async def get_dataset_related_nodes(dataset_id: UUID, session: AsyncSession):
|
||||||
|
|
||||||
data_related_nodes = await session.scalars(query_statement)
|
data_related_nodes = await session.scalars(query_statement)
|
||||||
return data_related_nodes.all()
|
return data_related_nodes.all()
|
||||||
|
|
||||||
|
|
||||||
|
@with_async_session
|
||||||
|
async def get_global_dataset_related_nodes(dataset_id: UUID, session: AsyncSession):
|
||||||
|
NodeAlias = aliased(Node)
|
||||||
|
|
||||||
|
subq = select(NodeAlias.id).where(
|
||||||
|
and_(
|
||||||
|
NodeAlias.slug == Node.slug,
|
||||||
|
NodeAlias.dataset_id != dataset_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_statement = select(Node).where(
|
||||||
|
and_(
|
||||||
|
Node.dataset_id == dataset_id,
|
||||||
|
~exists(subq),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
related_nodes = await session.scalars(query_statement)
|
||||||
|
return related_nodes.all()
|
||||||
|
|
|
||||||
212
cognee/tests/test_delete_dataset.py
Normal file
212
cognee/tests/test_delete_dataset.py
Normal file
|
|
@ -0,0 +1,212 @@
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.api.v1.datasets import datasets
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.modules.engine.operations.setup import setup
|
||||||
|
from cognee.modules.users.methods import create_user, get_default_user
|
||||||
|
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||||
|
async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_delete_dataset_two_users_graph"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
cognee_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_delete_dataset_two_users_graph"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||||
|
if text_input == "test": # LLM connection test
|
||||||
|
return "test"
|
||||||
|
|
||||||
|
if "John" in text_input and response_model == SummarizedContent:
|
||||||
|
return SummarizedContent(
|
||||||
|
summary="Summary of John's work.", description="Summary of John's work."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "Marie" in text_input and response_model == SummarizedContent:
|
||||||
|
return SummarizedContent(
|
||||||
|
summary="Summary of Marie's work.", description="Summary of Marie's work."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "Marie" in text_input and response_model == KnowledgeGraph:
|
||||||
|
return KnowledgeGraph(
|
||||||
|
nodes=[
|
||||||
|
Node(id="Marie", name="Marie", type="Person", description="Marie is a person"),
|
||||||
|
Node(
|
||||||
|
id="Apple",
|
||||||
|
name="Apple",
|
||||||
|
type="Company",
|
||||||
|
description="Apple is a company",
|
||||||
|
),
|
||||||
|
Node(
|
||||||
|
id="MacOS",
|
||||||
|
name="MacOS",
|
||||||
|
type="Product",
|
||||||
|
description="MacOS is Apple's operating system",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
edges=[
|
||||||
|
Edge(
|
||||||
|
source_node_id="Marie",
|
||||||
|
target_node_id="Apple",
|
||||||
|
relationship_name="works_for",
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if "John" in text_input and response_model == KnowledgeGraph:
|
||||||
|
return KnowledgeGraph(
|
||||||
|
nodes=[
|
||||||
|
Node(id="John", name="John", type="Person", description="John is a person"),
|
||||||
|
Node(
|
||||||
|
id="Apple",
|
||||||
|
name="Apple",
|
||||||
|
type="Company",
|
||||||
|
description="Apple is a company",
|
||||||
|
),
|
||||||
|
Node(
|
||||||
|
id="Food for Hungry",
|
||||||
|
name="Food for Hungry",
|
||||||
|
type="Non-profit organization",
|
||||||
|
description="Food for Hungry is a non-profit organization",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
edges=[
|
||||||
|
Edge(
|
||||||
|
source_node_id="John", target_node_id="Apple", relationship_name="works_for"
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source_node_id="John",
|
||||||
|
target_node_id="Food for Hungry",
|
||||||
|
relationship_name="works_for",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_create_structured_output.side_effect = mock_llm_output
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||||
|
assert not await vector_engine.has_collection("Entity_name")
|
||||||
|
assert not await vector_engine.has_collection("DocumentChunk_text")
|
||||||
|
assert not await vector_engine.has_collection("TextSummary_text")
|
||||||
|
assert not await vector_engine.has_collection("TextDocument_text")
|
||||||
|
|
||||||
|
new_user = await create_user(
|
||||||
|
email="example@user.com",
|
||||||
|
password="mypassword",
|
||||||
|
is_superuser=True,
|
||||||
|
is_active=True,
|
||||||
|
is_verified=True,
|
||||||
|
auto_login=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cognee.add(
|
||||||
|
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||||
|
)
|
||||||
|
|
||||||
|
await cognee.add(
|
||||||
|
"Marie works for Apple as well. She is a software engineer on MacOS project.",
|
||||||
|
user=new_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
cognify_result: dict = await cognee.cognify()
|
||||||
|
johns_dataset_id = list(cognify_result.keys())[0]
|
||||||
|
|
||||||
|
cognify_result: dict = await cognee.cognify(user=new_user)
|
||||||
|
maries_dataset_id = list(cognify_result.keys())[0]
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(initial_nodes) == 15 and len(initial_edges) == 19, (
|
||||||
|
"Number of nodes and edges is not correct."
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_nodes_by_vector_collection = {}
|
||||||
|
|
||||||
|
for node in initial_nodes:
|
||||||
|
node_data = node[1]
|
||||||
|
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||||
|
if collection_name not in initial_nodes_by_vector_collection:
|
||||||
|
initial_nodes_by_vector_collection[collection_name] = []
|
||||||
|
initial_nodes_by_vector_collection[collection_name].append(node)
|
||||||
|
|
||||||
|
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||||
|
|
||||||
|
default_user = await get_default_user()
|
||||||
|
await datasets.delete_dataset(johns_dataset_id, default_user) # type: ignore
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted."
|
||||||
|
assert not any(
|
||||||
|
node[1]["name"] == "john" or node[1]["name"] == "food for hungry"
|
||||||
|
for node in nodes
|
||||||
|
if "name" in node[1]
|
||||||
|
), "Nodes are not deleted."
|
||||||
|
|
||||||
|
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||||
|
|
||||||
|
after_delete_nodes_by_vector_collection = {}
|
||||||
|
for node in initial_nodes:
|
||||||
|
node_data = node[1]
|
||||||
|
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||||
|
if collection_name not in after_delete_nodes_by_vector_collection:
|
||||||
|
after_delete_nodes_by_vector_collection[collection_name] = []
|
||||||
|
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
||||||
|
|
||||||
|
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||||
|
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
||||||
|
|
||||||
|
if query_node_ids:
|
||||||
|
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||||
|
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||||
|
|
||||||
|
await datasets.delete_dataset(maries_dataset_id, new_user) # type: ignore
|
||||||
|
|
||||||
|
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
|
||||||
|
|
||||||
|
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||||
|
query_node_ids = [node[0] for node in initial_nodes]
|
||||||
|
|
||||||
|
if query_node_ids:
|
||||||
|
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||||
|
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||||
|
|
||||||
|
query_edge_ids = [edge[0] for edge in initial_edges]
|
||||||
|
|
||||||
|
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
||||||
|
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
214
cognee/tests/test_delete_two_users_graph.py
Normal file
214
cognee/tests/test_delete_two_users_graph.py
Normal file
|
|
@ -0,0 +1,214 @@
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import cognee
|
||||||
|
from cognee.api.v1.datasets import datasets
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.modules.engine.operations.setup import setup
|
||||||
|
from cognee.modules.users.methods import create_user, get_default_user
|
||||||
|
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||||
|
async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
data_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".data_storage/test_delete_two_users_graph"
|
||||||
|
)
|
||||||
|
cognee.config.data_root_directory(data_directory_path)
|
||||||
|
|
||||||
|
cognee_directory_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent, ".cognee_system/test_delete_two_users_graph"
|
||||||
|
)
|
||||||
|
cognee.config.system_root_directory(cognee_directory_path)
|
||||||
|
|
||||||
|
await cognee.prune.prune_data()
|
||||||
|
await cognee.prune.prune_system(metadata=True)
|
||||||
|
await setup()
|
||||||
|
|
||||||
|
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||||
|
if text_input == "test": # LLM connection test
|
||||||
|
return "test"
|
||||||
|
|
||||||
|
if "John" in text_input and response_model == SummarizedContent:
|
||||||
|
return SummarizedContent(
|
||||||
|
summary="Summary of John's work.", description="Summary of John's work."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "Marie" in text_input and response_model == SummarizedContent:
|
||||||
|
return SummarizedContent(
|
||||||
|
summary="Summary of Marie's work.", description="Summary of Marie's work."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "Marie" in text_input and response_model == KnowledgeGraph:
|
||||||
|
return KnowledgeGraph(
|
||||||
|
nodes=[
|
||||||
|
Node(id="Marie", name="Marie", type="Person", description="Marie is a person"),
|
||||||
|
Node(
|
||||||
|
id="Apple",
|
||||||
|
name="Apple",
|
||||||
|
type="Company",
|
||||||
|
description="Apple is a company",
|
||||||
|
),
|
||||||
|
Node(
|
||||||
|
id="MacOS",
|
||||||
|
name="MacOS",
|
||||||
|
type="Product",
|
||||||
|
description="MacOS is Apple's operating system",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
edges=[
|
||||||
|
Edge(
|
||||||
|
source_node_id="Marie",
|
||||||
|
target_node_id="Apple",
|
||||||
|
relationship_name="works_for",
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if "John" in text_input and response_model == KnowledgeGraph:
|
||||||
|
return KnowledgeGraph(
|
||||||
|
nodes=[
|
||||||
|
Node(id="John", name="John", type="Person", description="John is a person"),
|
||||||
|
Node(
|
||||||
|
id="Apple",
|
||||||
|
name="Apple",
|
||||||
|
type="Company",
|
||||||
|
description="Apple is a company",
|
||||||
|
),
|
||||||
|
Node(
|
||||||
|
id="Food for Hungry",
|
||||||
|
name="Food for Hungry",
|
||||||
|
type="Non-profit organization",
|
||||||
|
description="Food for Hungry is a non-profit organization",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
edges=[
|
||||||
|
Edge(
|
||||||
|
source_node_id="John", target_node_id="Apple", relationship_name="works_for"
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source_node_id="John",
|
||||||
|
target_node_id="Food for Hungry",
|
||||||
|
relationship_name="works_for",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_create_structured_output.side_effect = mock_llm_output
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||||
|
assert not await vector_engine.has_collection("Entity_name")
|
||||||
|
assert not await vector_engine.has_collection("DocumentChunk_text")
|
||||||
|
assert not await vector_engine.has_collection("TextSummary_text")
|
||||||
|
assert not await vector_engine.has_collection("TextDocument_text")
|
||||||
|
|
||||||
|
new_user = await create_user(
|
||||||
|
email="example@user.com",
|
||||||
|
password="mypassword",
|
||||||
|
is_superuser=True,
|
||||||
|
is_active=True,
|
||||||
|
is_verified=True,
|
||||||
|
auto_login=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
add_john_result = await cognee.add(
|
||||||
|
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||||
|
)
|
||||||
|
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||||
|
|
||||||
|
add_marie_result = await cognee.add(
|
||||||
|
"Marie works for Apple as well. She is a software engineer on MacOS project.",
|
||||||
|
user=new_user,
|
||||||
|
)
|
||||||
|
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||||
|
|
||||||
|
cognify_result: dict = await cognee.cognify()
|
||||||
|
johns_dataset_id = list(cognify_result.keys())[0]
|
||||||
|
|
||||||
|
cognify_result: dict = await cognee.cognify(user=new_user)
|
||||||
|
maries_dataset_id = list(cognify_result.keys())[0]
|
||||||
|
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(initial_nodes) == 15 and len(initial_edges) == 19, (
|
||||||
|
"Number of nodes and edges is not correct."
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_nodes_by_vector_collection = {}
|
||||||
|
|
||||||
|
for node in initial_nodes:
|
||||||
|
node_data = node[1]
|
||||||
|
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||||
|
if collection_name not in initial_nodes_by_vector_collection:
|
||||||
|
initial_nodes_by_vector_collection[collection_name] = []
|
||||||
|
initial_nodes_by_vector_collection[collection_name].append(node)
|
||||||
|
|
||||||
|
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||||
|
|
||||||
|
default_user = await get_default_user()
|
||||||
|
await datasets.delete_data(johns_dataset_id, johns_data_id, default_user) # type: ignore
|
||||||
|
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted."
|
||||||
|
assert not any(
|
||||||
|
node[1]["name"] == "john" or node[1]["name"] == "food for hungry"
|
||||||
|
for node in nodes
|
||||||
|
if "name" in node[1]
|
||||||
|
), "Nodes are not deleted."
|
||||||
|
|
||||||
|
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||||
|
|
||||||
|
after_delete_nodes_by_vector_collection = {}
|
||||||
|
for node in initial_nodes:
|
||||||
|
node_data = node[1]
|
||||||
|
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||||
|
if collection_name not in after_delete_nodes_by_vector_collection:
|
||||||
|
after_delete_nodes_by_vector_collection[collection_name] = []
|
||||||
|
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
||||||
|
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
||||||
|
|
||||||
|
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||||
|
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
||||||
|
|
||||||
|
if query_node_ids:
|
||||||
|
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||||
|
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||||
|
|
||||||
|
await datasets.delete_data(maries_dataset_id, maries_data_id, new_user) # type: ignore
|
||||||
|
|
||||||
|
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||||
|
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
|
||||||
|
|
||||||
|
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||||
|
query_node_ids = [node[0] for node in initial_nodes]
|
||||||
|
|
||||||
|
if query_node_ids:
|
||||||
|
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||||
|
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||||
|
|
||||||
|
query_edge_ids = [edge[0] for edge in initial_edges]
|
||||||
|
|
||||||
|
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
||||||
|
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Reference in a new issue