fix: check vector collections in delete tests

This commit is contained in:
Boris Arzentar 2025-11-03 15:16:44 +01:00
parent 42eb898f3a
commit 5efd0a4fb6
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
2 changed files with 179 additions and 4 deletions

View file

@ -38,6 +38,9 @@ async def main(mock_create_structured_output: AsyncMock):
assert not await vector_engine.has_collection("EdgeType_relationship_name")
assert not await vector_engine.has_collection("Entity_name")
assert not await vector_engine.has_collection("DocumentChunk_text")
assert not await vector_engine.has_collection("TextSummary_text")
assert not await vector_engine.has_collection("TextDocument_text")
mock_create_structured_output.side_effect = [
"", # For LLM connection test
@ -107,7 +110,8 @@ async def main(mock_create_structured_output: AsyncMock):
dataset_id = list(cognify_result.keys())[0]
dataset_data = await get_dataset_data(dataset_id)
added_data = dataset_data[0]
added_data_1 = dataset_data[0]
added_data_2 = dataset_data[1]
# file_path = os.path.join(
# pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html"
@ -115,11 +119,24 @@ async def main(mock_create_structured_output: AsyncMock):
# await visualize_graph(file_path)
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 15 and len(edges) == 19, "Number of nodes and edges is not correct."
initial_nodes, initial_edges = await graph_engine.get_graph_data()
assert len(initial_nodes) == 15 and len(initial_edges) == 19, (
"Number of nodes and edges is not correct."
)
initial_nodes_by_vector_collection = {}
for node in initial_nodes:
node_data = node[1]
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
if collection_name not in initial_nodes_by_vector_collection:
initial_nodes_by_vector_collection[collection_name] = []
initial_nodes_by_vector_collection[collection_name].append(node)
initial_node_ids = set([node[0] for node in initial_nodes])
user = await get_default_user()
await datasets.delete_data(dataset_id, added_data.id, user) # type: ignore
await datasets.delete_data(dataset_id, added_data_1.id, user) # type: ignore
# file_path = os.path.join(
# pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html"
@ -128,6 +145,43 @@ async def main(mock_create_structured_output: AsyncMock):
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted."
assert not any(node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes), "Nodes are not deleted."
after_first_delete_node_ids = set([node[0] for node in nodes])
after_delete_nodes_by_vector_collection = {}
for node in initial_nodes:
node_data = node[1]
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
if collection_name not in after_delete_nodes_by_vector_collection:
after_delete_nodes_by_vector_collection[collection_name] = []
after_delete_nodes_by_vector_collection[collection_name].append(node)
removed_node_ids = initial_node_ids - after_first_delete_node_ids
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
await datasets.delete_data(dataset_id, added_data_2.id, user) # type: ignore
final_nodes, final_edges = await graph_engine.get_graph_data()
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
query_node_ids = [node[0] for node in initial_nodes]
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
query_edge_ids = [edge[0] for edge in initial_edges]
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
if __name__ == "__main__":

View file

@ -0,0 +1,121 @@
import os
import pathlib
import time
import pytest
from unittest.mock import AsyncMock, patch
import cognee
from cognee.api.v1.datasets import datasets
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm import LLMGateway
from cognee.modules.data.methods import get_dataset_data
from cognee.modules.engine.operations.setup import setup
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
from cognee.shared.logging_utils import get_logger
logger = get_logger()
async def main():
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_non_mocked"
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".cognee_system/test_delete_default_graph_non_mocked"
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await setup()
vector_engine = get_vector_engine()
assert not await vector_engine.has_collection("EdgeType_relationship_name")
assert not await vector_engine.has_collection("Entity_name")
assert not await vector_engine.has_collection("DocumentChunk_text")
assert not await vector_engine.has_collection("TextSummary_text")
assert not await vector_engine.has_collection("TextDocument_text")
await cognee.add(
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
)
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
cognify_result: dict = await cognee.cognify()
dataset_id = list(cognify_result.keys())[0]
dataset_data = await get_dataset_data(dataset_id)
added_data_1 = dataset_data[0]
added_data_2 = dataset_data[1]
graph_engine = await get_graph_engine()
initial_nodes, initial_edges = await graph_engine.get_graph_data()
assert len(initial_nodes) >= 15 and len(initial_edges) >= 19, (
"Number of nodes and edges is not correct."
)
initial_nodes_by_vector_collection = {}
for node in initial_nodes:
node_data = node[1]
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
if collection_name not in initial_nodes_by_vector_collection:
initial_nodes_by_vector_collection[collection_name] = []
initial_nodes_by_vector_collection[collection_name].append(node)
initial_node_ids = set([node[0] for node in initial_nodes])
user = await get_default_user()
await datasets.delete_data(dataset_id, added_data_1.id, user) # type: ignore
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) >= 9 and len(nodes) <= 11 and len(edges) >= 10 and len(edges) <= 12, "Nodes and edges are not deleted."
assert not any(node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes), "Nodes are not deleted."
after_first_delete_node_ids = set([node[0] for node in nodes])
after_delete_nodes_by_vector_collection = {}
for node in initial_nodes:
node_data = node[1]
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
if collection_name not in after_delete_nodes_by_vector_collection:
after_delete_nodes_by_vector_collection[collection_name] = []
after_delete_nodes_by_vector_collection[collection_name].append(node)
removed_node_ids = initial_node_ids - after_first_delete_node_ids
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
await datasets.delete_data(dataset_id, added_data_2.id, user) # type: ignore
final_nodes, final_edges = await graph_engine.get_graph_data()
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
query_node_ids = [node[0] for node in initial_nodes]
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
query_edge_ids = [edge[0] for edge in initial_edges]
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
if __name__ == "__main__":
import asyncio
asyncio.run(main())