From a89dad328ef19b07354255680040a9ffa8909e75 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 17 Nov 2025 22:04:30 +0100 Subject: [PATCH] fix: add detailed tests for delete --- .../utils/get_or_create_dataset_database.py | 4 +- cognee/tests/test_delete_custom_graph.py | 53 ++- cognee/tests/test_delete_default_graph.py | 206 +++++++--- ...delete_default_graph_with_legacy_data_1.py | 369 ++++++++++++------ ...delete_default_graph_with_legacy_data_2.py | 266 +++++++++---- .../assert_edges_vector_index_not_present.py | 23 ++ .../assert_edges_vector_index_present.py | 28 ++ .../utils/assert_graph_edges_not_present.py | 21 + .../tests/utils/assert_graph_edges_present.py | 21 + .../utils/assert_graph_nodes_not_present.py | 16 + .../tests/utils/assert_graph_nodes_present.py | 14 + .../assert_nodes_vector_index_not_present.py | 28 ++ .../assert_nodes_vector_index_present.py | 28 ++ cognee/tests/utils/extract_entities.py | 45 +++ cognee/tests/utils/extract_relationships.py | 55 +++ cognee/tests/utils/extract_summary.py | 12 + .../utils/filter_overlapping_entities.py | 26 ++ .../utils/filter_overlapping_relationships.py | 33 ++ cognee/tests/utils/get_contains_edge_text.py | 9 + cognee/tests/utils/isolate_relationships.py | 20 + 20 files changed, 1011 insertions(+), 266 deletions(-) create mode 100644 cognee/tests/utils/assert_edges_vector_index_not_present.py create mode 100644 cognee/tests/utils/assert_edges_vector_index_present.py create mode 100644 cognee/tests/utils/assert_graph_edges_not_present.py create mode 100644 cognee/tests/utils/assert_graph_edges_present.py create mode 100644 cognee/tests/utils/assert_graph_nodes_not_present.py create mode 100644 cognee/tests/utils/assert_graph_nodes_present.py create mode 100644 cognee/tests/utils/assert_nodes_vector_index_not_present.py create mode 100644 cognee/tests/utils/assert_nodes_vector_index_present.py create mode 100644 cognee/tests/utils/extract_entities.py create mode 100644 cognee/tests/utils/extract_relationships.py create mode 100644 cognee/tests/utils/extract_summary.py create mode 100644 cognee/tests/utils/filter_overlapping_entities.py create mode 100644 cognee/tests/utils/filter_overlapping_relationships.py create mode 100644 cognee/tests/utils/get_contains_edge_text.py create mode 100644 cognee/tests/utils/isolate_relationships.py diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 3684bb100..5e0c62aac 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.exc import IntegrityError from cognee.base_config import get_base_config -from cognee.modules.data.methods import create_dataset +from cognee.modules.data.methods import create_authorized_dataset from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.vector import get_vectordb_config from cognee.infrastructure.databases.graph.config import get_graph_config @@ -66,7 +66,7 @@ async def get_or_create_dataset_database( async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist if isinstance(dataset, str): - dataset = await create_dataset(dataset, user, session) + dataset = await create_authorized_dataset(dataset, user) # Try to fetch an existing row first stmt = select(DatasetDatabase).where( diff --git a/cognee/tests/test_delete_custom_graph.py b/cognee/tests/test_delete_custom_graph.py index a9faed078..f8b1edd04 100644 --- a/cognee/tests/test_delete_custom_graph.py +++ b/cognee/tests/test_delete_custom_graph.py @@ -10,6 +10,7 @@ from cognee.context_global_variables import set_database_global_context_variable from cognee.infrastructure.engine import DataPoint from cognee.modules.data.methods import create_authorized_dataset from cognee.modules.engine.operations.setup import setup +from cognee.modules.engine.utils import generate_node_id from cognee.modules.users.models import User from cognee.modules.users.methods import get_default_user from cognee.shared.logging_utils import get_logger @@ -53,11 +54,11 @@ async def main(): works_for: List[Organization] metadata: dict = {"index_fields": ["name"]} - companyA = ForProfit(name="Company A") - companyB = NonProfit(name="Company B") + companyA = ForProfit(id=generate_node_id("Company A"), name="Company A") + companyB = NonProfit(id=generate_node_id("Company B"), name="Company B") - person1 = Person(name="John", works_for=[companyA, companyB]) - person2 = Person(name="Jane", works_for=[companyB]) + person1 = Person(id=generate_node_id("John"), name="John", works_for=[companyA, companyB]) + person2 = Person(id=generate_node_id("Jane"), name="Jane", works_for=[companyB]) user: User = await get_default_user() # type: ignore @@ -93,15 +94,59 @@ async def main(): graph_engine = await get_graph_engine() nodes, edges = await graph_engine.get_graph_data() + + # Initial check assert len(nodes) == 4 and len(edges) == 3, ( "Nodes and edges are not correctly added to the graph." ) + nodes_by_id = {node[0]: node[1] for node in nodes} + + assert str(generate_node_id("John")) in nodes_by_id, "John node not present in the graph." + assert str(generate_node_id("Jane")) in nodes_by_id, "Jane node not present in the graph." + assert str(generate_node_id("Company A")) in nodes_by_id, ( + "Company A node not present in the graph." + ) + assert str(generate_node_id("Company B")) in nodes_by_id, ( + "Company B node not present in the graph." + ) + + edges_by_ids = {f"{edge[0]}_{edge[2]}_{edge[1]}": edge[3] for edge in edges} + + assert ( + f"{str(generate_node_id('John'))}_works_for_{str(generate_node_id('Company A'))}" + in edges_by_ids + ), "Edge between John and Company A not present in the graph." + assert ( + f"{str(generate_node_id('John'))}_works_for_{str(generate_node_id('Company B'))}" + in edges_by_ids + ), "Edge between John and Company A not present in the graph." + assert ( + f"{str(generate_node_id('Jane'))}_works_for_{str(generate_node_id('Company B'))}" + in edges_by_ids + ), "Edge between John and Company A not present in the graph." + + # Second data deletion await datasets.delete_data(dataset.id, data1.id, user) nodes, edges = await graph_engine.get_graph_data() assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly." + nodes_by_id = {node[0]: node[1] for node in nodes} + + assert str(generate_node_id("Jane")) in nodes_by_id, "Jane node not present in the graph." + assert str(generate_node_id("Company B")) in nodes_by_id, ( + "Company B node not present in the graph." + ) + + edges_by_ids = {f"{edge[0]}_{edge[2]}_{edge[1]}": edge[3] for edge in edges} + + assert ( + f"{str(generate_node_id('Jane'))}_works_for_{str(generate_node_id('Company B'))}" + in edges_by_ids + ), "Edge between John and Company A not present in the graph." + + # Second data deletion await datasets.delete_data(dataset.id, data2.id, user) nodes, edges = await graph_engine.get_graph_data() diff --git a/cognee/tests/test_delete_default_graph.py b/cognee/tests/test_delete_default_graph.py index d6236500c..1ffcf1fda 100644 --- a/cognee/tests/test_delete_default_graph.py +++ b/cognee/tests/test_delete_default_graph.py @@ -1,17 +1,40 @@ import os -import pathlib import pytest +import pathlib +from uuid import NAMESPACE_OID, uuid5 from unittest.mock import AsyncMock, patch import cognee from cognee.api.v1.datasets import datasets +from cognee.context_global_variables import set_database_global_context_variables from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.llm import LLMGateway +from cognee.modules.chunking.models.DocumentChunk import DocumentChunk +from cognee.modules.data.processing.document_types.TextDocument import TextDocument +from cognee.modules.engine.models import Entity from cognee.modules.engine.operations.setup import setup from cognee.modules.users.methods import get_default_user from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent from cognee.shared.logging_utils import get_logger +from cognee.tests.utils.assert_edges_vector_index_not_present import ( + assert_edges_vector_index_not_present, +) +from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present +from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present +from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present +from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present +from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present +from cognee.tests.utils.assert_nodes_vector_index_not_present import ( + assert_nodes_vector_index_not_present, +) +from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present +from cognee.tests.utils.extract_entities import extract_entities +from cognee.tests.utils.extract_relationships import extract_relationships +from cognee.tests.utils.extract_summary import extract_summary +from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities +from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships +from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text +from cognee.tests.utils.isolate_relationships import isolate_relationships logger = get_logger() @@ -107,92 +130,159 @@ async def main(mock_create_structured_output: AsyncMock): mock_create_structured_output.side_effect = mock_llm_output + user = await get_default_user() + + await set_database_global_context_variables("main_dataset", user.id) + vector_engine = get_vector_engine() - assert not await vector_engine.has_collection("EdgeType_relationship_name") assert not await vector_engine.has_collection("Entity_name") assert not await vector_engine.has_collection("DocumentChunk_text") assert not await vector_engine.has_collection("TextSummary_text") assert not await vector_engine.has_collection("TextDocument_text") + assert not await vector_engine.has_collection("EdgeType_relationship_name") - add_john_result = await cognee.add( - "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" - ) + johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" + add_john_result = await cognee.add(johns_text) johns_data_id = add_john_result.data_ingestion_info[0]["data_id"] - add_marie_result = await cognee.add( - "Marie works for Apple as well. She is a software engineer on MacOS project." - ) + maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project." + add_marie_result = await cognee.add(maries_text) maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"] cognify_result: dict = await cognee.cognify() dataset_id = list(cognify_result.keys())[0] - graph_engine = await get_graph_engine() - initial_nodes, initial_edges = await graph_engine.get_graph_data() - assert len(initial_nodes) == 15 and len(initial_edges) == 19, ( - "Number of nodes and edges is not correct." + johns_document = TextDocument( + id=johns_data_id, + name="John's Work", + raw_data_location="johns_data_location", + external_metadata="", + ) + johns_chunk = DocumentChunk( + id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"), + text=johns_text, + chunk_size=14, + chunk_index=0, + cut_type="sentence_end", + is_part_of=johns_document, + ) + johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore + + maries_document = TextDocument( + id=maries_data_id, + name="Maries's Work", + raw_data_location="maries_data_location", + external_metadata="", + ) + maries_chunk = DocumentChunk( + id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"), + text=maries_text, + chunk_size=14, + chunk_index=0, + cut_type="sentence_end", + is_part_of=maries_document, + ) + maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore + + johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore + maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore + (overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities( + johns_entities, maries_entities ) - initial_nodes_by_vector_collection = {} + johns_data = [ + johns_document, + johns_chunk, + johns_summary, + *johns_entities, + ] + maries_data = [ + maries_document, + maries_chunk, + maries_summary, + *maries_entities, + ] - for node in initial_nodes: - node_data = node[1] - collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] - if collection_name not in initial_nodes_by_vector_collection: - initial_nodes_by_vector_collection[collection_name] = [] - initial_nodes_by_vector_collection[collection_name].append(node) + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(johns_data + maries_data + overlapping_entities) + await assert_nodes_vector_index_present(johns_data + maries_data + overlapping_entities) - initial_node_ids = set([node[0] for node in initial_nodes]) + johns_relationships = extract_relationships( + johns_chunk, + mock_llm_output("John", "", KnowledgeGraph), # type: ignore + ) + maries_relationships = extract_relationships( + maries_chunk, + mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore + ) + (overlapping_relationships, johns_relationships, maries_relationships) = ( + filter_overlapping_relationships(johns_relationships, maries_relationships) + ) - user = await get_default_user() + johns_relationships = [ + (johns_chunk.id, johns_document.id, "is_part_of"), + (johns_summary.id, johns_chunk.id, "made_from"), + *johns_relationships, + ] + johns_edge_text_relationships = [ + (johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description)) + for entity in johns_entities + if isinstance(entity, Entity) + ] + maries_relationships = [ + (maries_chunk.id, maries_document.id, "is_part_of"), + (maries_summary.id, maries_chunk.id, "made_from"), + *maries_relationships, + ] + maries_edge_text_relationships = [ + (maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description)) + for entity in maries_entities + if isinstance(entity, Entity) + ] + + expected_relationships = johns_relationships + maries_relationships + overlapping_relationships + + await assert_graph_edges_present(expected_relationships) + + await assert_edges_vector_index_present( + expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships + ) + + # Delete John's data from cognee await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore - nodes, edges = await graph_engine.get_graph_data() - assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted." - assert not any( - node[1]["name"] == "john" or node[1]["name"] == "food for hungry" - for node in nodes - if "name" in node[1] - ), "Nodes are not deleted." + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(maries_data + overlapping_entities) + await assert_nodes_vector_index_present(maries_data + overlapping_entities) - after_first_delete_node_ids = set([node[0] for node in nodes]) + await assert_graph_nodes_not_present(johns_data) + await assert_nodes_vector_index_not_present(johns_data) - after_delete_nodes_by_vector_collection = {} - for node in initial_nodes: - node_data = node[1] - collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] - if collection_name not in after_delete_nodes_by_vector_collection: - after_delete_nodes_by_vector_collection[collection_name] = [] - after_delete_nodes_by_vector_collection[collection_name].append(node) + # Assert relationships presence in the graph, vector collections and nodes table + await assert_graph_edges_present(maries_relationships + overlapping_relationships) + await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships) - vector_engine = get_vector_engine() + await assert_graph_edges_not_present(johns_relationships) - removed_node_ids = initial_node_ids - after_first_delete_node_ids - - for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): - query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids] - - if query_node_ids: - vector_items = await vector_engine.retrieve(collection_name, query_node_ids) - assert len(vector_items) == 0, "Vector items are not deleted." + strictly_johns_relationships = isolate_relationships(johns_relationships, maries_relationships) + # We check only by relationship name and we need edges that are created by John's data and no other. + await assert_edges_vector_index_not_present( + strictly_johns_relationships + johns_edge_text_relationships + ) + # Delete Marie's data from cognee await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore - final_nodes, final_edges = await graph_engine.get_graph_data() - assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted." + await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities) + await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities) - for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): - query_node_ids = [node[0] for node in initial_nodes] + # Assert relationships presence in the graph, vector collections and nodes table + await assert_graph_edges_not_present( + johns_relationships + maries_relationships + overlapping_relationships + ) - if query_node_ids: - vector_items = await vector_engine.retrieve(collection_name, query_node_ids) - assert len(vector_items) == 0, "Vector items are not deleted." - - query_edge_ids = [edge[0] for edge in initial_edges] - - vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids) - assert len(vector_items) == 0, "Vector items are not deleted." + await assert_edges_vector_index_not_present(maries_relationships) if __name__ == "__main__": diff --git a/cognee/tests/test_delete_default_graph_with_legacy_data_1.py b/cognee/tests/test_delete_default_graph_with_legacy_data_1.py index 042618bf6..e0e0ab4c0 100644 --- a/cognee/tests/test_delete_default_graph_with_legacy_data_1.py +++ b/cognee/tests/test_delete_default_graph_with_legacy_data_1.py @@ -1,12 +1,12 @@ import os -import json +import pytest import pathlib from uuid import NAMESPACE_OID, uuid5 -import pytest from unittest.mock import AsyncMock, patch import cognee from cognee.api.v1.datasets import datasets +from cognee.context_global_variables import set_database_global_context_variables from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine @@ -17,17 +17,34 @@ from cognee.modules.data.models import Data from cognee.modules.engine.models import Entity, EntityType from cognee.modules.data.processing.document_types import TextDocument from cognee.modules.engine.operations.setup import setup -from cognee.modules.engine.utils import generate_edge_id, generate_node_id -from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model +from cognee.modules.engine.utils import generate_node_id +from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger from cognee.modules.pipelines.models import DataItemStatus from cognee.modules.users.methods import get_default_user from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent from cognee.tasks.storage import index_data_points, index_graph_edges - -from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger +from cognee.tests.utils.assert_edges_vector_index_not_present import ( + assert_edges_vector_index_not_present, +) +from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present +from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present +from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present +from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present +from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present +from cognee.tests.utils.assert_nodes_vector_index_not_present import ( + assert_nodes_vector_index_not_present, +) +from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present +from cognee.tests.utils.extract_entities import extract_entities +from cognee.tests.utils.extract_relationships import extract_relationships +from cognee.tests.utils.extract_summary import extract_summary +from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities +from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships +from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text +from cognee.tests.utils.isolate_relationships import isolate_relationships -async def get_nodes_and_edges(): +def create_nodes_and_edges(): document = TextDocument( id=uuid5(NAMESPACE_OID, "text_test.txt"), name="text_test.txt", @@ -73,15 +90,8 @@ async def get_nodes_and_edges(): name="amazon s3", description="A storage service provided by Amazon Web Services that allows storing graph data.", ) - document_chunk.contains = [ - graph_database, - neptune_analytics_entity, - neptune_database_entity, - storage, - storage_entity, - ] - data_points = [ + nodes_data = [ document, document_chunk, graph_database, @@ -91,39 +101,71 @@ async def get_nodes_and_edges(): storage_entity, ] - nodes = [] - edges = [] + edges_data = [ + ( + document_chunk.id, + storage_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + storage_entity.id, + storage.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ( + document_chunk.id, + neptune_database_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + neptune_database_entity.id, + graph_database.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ( + document_chunk.id, + document.id, + "is_part_of", + { + "relationship_name": "is_part_of", + }, + ), + ( + document_chunk.id, + neptune_analytics_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + neptune_analytics_entity.id, + graph_database.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ] - added_nodes = {} - added_edges = {} - visited_properties = {} - - results = await asyncio.gather( - *[ - get_graph_from_model( - data_point, - added_nodes=added_nodes, - added_edges=added_edges, - visited_properties=visited_properties, - ) - for data_point in data_points - ] - ) - - for result_nodes, result_edges in results: - nodes.extend(result_nodes) - edges.extend(result_edges) - - nodes, edges = deduplicate_nodes_and_edges(nodes, edges) - - return nodes, edges + return nodes_data, edges_data @pytest.mark.asyncio @patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock) async def main(mock_create_structured_output: AsyncMock): - os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "False" - data_directory_path = os.path.join( pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_1" ) @@ -139,6 +181,9 @@ async def main(mock_create_structured_output: AsyncMock): await cognee.prune.prune_system(metadata=True) await setup() + user = await get_default_user() + await set_database_global_context_variables("main_dataset", user.id) + vector_engine = get_vector_engine() assert not await vector_engine.has_collection("EdgeType_relationship_name") @@ -147,9 +192,8 @@ async def main(mock_create_structured_output: AsyncMock): assert not await vector_engine.has_collection("TextSummary_text") assert not await vector_engine.has_collection("TextDocument_text") - user = await get_default_user() - - old_nodes, old_edges = await add_mocked_legacy_data(user) + # Add legacy data to the system + __, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(user) def mock_llm_output(text_input: str, system_prompt: str, response_model): if text_input == "test": # LLM connection test @@ -225,109 +269,188 @@ async def main(mock_create_structured_output: AsyncMock): mock_create_structured_output.side_effect = mock_llm_output - add_john_result = await cognee.add( - "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" - ) + johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" + add_john_result = await cognee.add(johns_text) johns_data_id = add_john_result.data_ingestion_info[0]["data_id"] - add_marie_result = await cognee.add( - "Marie works for Apple as well. She is a software engineer on MacOS project." - ) + maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project." + add_marie_result = await cognee.add(maries_text) maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"] cognify_result: dict = await cognee.cognify() dataset_id = list(cognify_result.keys())[0] - graph_engine = await get_graph_engine() - initial_nodes, initial_edges = await graph_engine.get_graph_data() - assert len(initial_nodes) == 22 and len(initial_edges) == 25, ( - "Number of nodes and edges is not correct." + johns_document = TextDocument( + id=johns_data_id, + name="John's Work", + raw_data_location="johns_data_location", + external_metadata="", + ) + johns_chunk = DocumentChunk( + id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"), + text=johns_text, + chunk_size=14, + chunk_index=0, + cut_type="sentence_end", + is_part_of=johns_document, + ) + johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore + + maries_document = TextDocument( + id=maries_data_id, + name="Maries's Work", + raw_data_location="maries_data_location", + external_metadata="", + ) + maries_chunk = DocumentChunk( + id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"), + text=maries_text, + chunk_size=14, + chunk_index=0, + cut_type="sentence_end", + is_part_of=maries_document, + ) + maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore + + johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore + maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore + (overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities( + johns_entities, maries_entities ) - initial_nodes_by_vector_collection = {} + johns_data = [ + johns_document, + johns_chunk, + johns_summary, + *johns_entities, + ] + maries_data = [ + maries_document, + maries_chunk, + maries_summary, + *maries_entities, + ] - for node in initial_nodes: - node_data = node[1] - node_metadata = node_data["metadata"] - node_metadata = json.loads(node_metadata) if type(node_metadata) is str else node_metadata - collection_name = node_data["type"] + "_" + node_metadata["index_fields"][0] - if collection_name not in initial_nodes_by_vector_collection: - initial_nodes_by_vector_collection[collection_name] = [] - initial_nodes_by_vector_collection[collection_name].append(node) + expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points - initial_node_ids = set([node[0] for node in initial_nodes]) + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(expected_data_points) + await assert_nodes_vector_index_present(expected_data_points) + johns_relationships = extract_relationships( + johns_chunk, + mock_llm_output("John", "", KnowledgeGraph), # type: ignore + ) + maries_relationships = extract_relationships( + maries_chunk, + mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore + ) + (overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = ( + filter_overlapping_relationships( + johns_relationships, maries_relationships, legacy_relationships + ) + ) + + johns_relationships = [ + (johns_chunk.id, johns_document.id, "is_part_of"), + (johns_summary.id, johns_chunk.id, "made_from"), + *johns_relationships, + ] + johns_edge_text_relationships = [ + (johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description)) + for entity in johns_entities + if isinstance(entity, Entity) + ] + maries_relationships = [ + (maries_chunk.id, maries_document.id, "is_part_of"), + (maries_summary.id, maries_chunk.id, "made_from"), + *maries_relationships, + ] + maries_edge_text_relationships = [ + (maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description)) + for entity in maries_entities + if isinstance(entity, Entity) + ] + + expected_relationships = ( + johns_relationships + + maries_relationships + + overlapping_relationships + + legacy_relationships + ) + + await assert_graph_edges_present(expected_relationships) + + await assert_edges_vector_index_present( + expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships + ) + + # Delete John's data await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore - nodes, edges = await graph_engine.get_graph_data() - assert len(nodes) == 16 and len(edges) == 16, "Nodes and edges are not deleted." - assert not any( - node[1]["name"] == "john" or node[1]["name"] == "food for hungry" - for node in nodes - if "name" in node[1] - ), "Nodes are not deleted." + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points) + await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points) - after_first_delete_node_ids = set([node[0] for node in nodes]) + await assert_graph_nodes_not_present(johns_data) + await assert_nodes_vector_index_not_present(johns_data) - after_delete_nodes_by_vector_collection = {} - for node in initial_nodes: - node_data = node[1] - node_metadata = node_data["metadata"] - node_metadata = json.loads(node_metadata) if type(node_metadata) is str else node_metadata - collection_name = node_data["type"] + "_" + node_metadata["index_fields"][0] - if collection_name not in after_delete_nodes_by_vector_collection: - after_delete_nodes_by_vector_collection[collection_name] = [] - after_delete_nodes_by_vector_collection[collection_name].append(node) + # Assert relationships presence in the graph, vector collections and nodes table + await assert_graph_edges_present( + maries_relationships + overlapping_relationships + legacy_relationships + ) + await assert_edges_vector_index_present( + maries_relationships + maries_edge_text_relationships + legacy_relationships + ) - vector_engine = get_vector_engine() + await assert_graph_edges_not_present(johns_relationships) - removed_node_ids = initial_node_ids - after_first_delete_node_ids - - for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): - query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids] - - if query_node_ids: - vector_items = await vector_engine.retrieve(collection_name, query_node_ids) - assert len(vector_items) == 0, "Vector items are not deleted." + strictly_johns_relationships = isolate_relationships( + johns_relationships, maries_relationships, legacy_relationships + ) + # We check only by relationship name and we need edges that are created by John's data and no other. + await assert_edges_vector_index_not_present( + strictly_johns_relationships + johns_edge_text_relationships + ) + # Delete Marie's data await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore - final_nodes, final_edges = await graph_engine.get_graph_data() - assert len(final_nodes) == 7 and len(final_edges) == 6, "Nodes and edges are not deleted." + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(legacy_data_points) + await assert_nodes_vector_index_present(legacy_data_points) - old_nodes_by_vector_collection = {} - for node in old_nodes: - node_metadata = node.metadata - collection_name = node.type + "_" + node_metadata["index_fields"][0] - if collection_name not in old_nodes_by_vector_collection: - old_nodes_by_vector_collection[collection_name] = [] - old_nodes_by_vector_collection[collection_name].append(node) + await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities) + await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities) - for collection_name, old_nodes in old_nodes_by_vector_collection.items(): - query_node_ids = [str(node.id) for node in old_nodes] + # Assert relationships presence in the graph, vector collections and nodes table + await assert_graph_edges_present(legacy_relationships) + await assert_edges_vector_index_present(legacy_relationships) - if query_node_ids: - vector_items = await vector_engine.retrieve(collection_name, query_node_ids) - assert len(vector_items) == len(old_nodes), "Vector items are not deleted." + await assert_graph_edges_not_present( + johns_relationships + maries_relationships + overlapping_relationships + ) - query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges])) - - vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids) - assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted." + strictly_maries_relationships = isolate_relationships( + maries_relationships, legacy_relationships + ) + # We check only by relationship name and we need edges that are created by legacy data and no other. + if strictly_maries_relationships: + await assert_edges_vector_index_not_present(strictly_maries_relationships) -async def add_mocked_legacy_data(user): +async def create_mocked_legacy_data(user): graph_engine = await get_graph_engine() - old_nodes, old_edges = await get_nodes_and_edges() - old_document = old_nodes[0] + legacy_nodes, legacy_edges = create_nodes_and_edges() + legacy_document = legacy_nodes[0] - await graph_engine.add_nodes(old_nodes) - await graph_engine.add_edges(old_edges) + await graph_engine.add_nodes(legacy_nodes) + await graph_engine.add_edges(legacy_edges) - await index_data_points(old_nodes) - await index_graph_edges(old_edges) + await index_data_points(legacy_nodes) + await index_graph_edges(legacy_edges) - await record_data_in_legacy_ledger(old_nodes, old_edges, user) + await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user) db_engine = get_relational_engine() @@ -335,12 +458,12 @@ async def add_mocked_legacy_data(user): async with db_engine.get_async_session() as session: old_data = Data( - id=old_document.id, - name=old_document.name, + id=legacy_document.id, + name=legacy_document.name, extension="txt", - raw_data_location=old_document.raw_data_location, - external_metadata=old_document.external_metadata, - mime_type=old_document.mime_type, + raw_data_location=legacy_document.raw_data_location, + external_metadata=legacy_document.external_metadata, + mime_type=legacy_document.mime_type, owner_id=user.id, pipeline_status={ "cognify_pipeline": { @@ -355,7 +478,7 @@ async def add_mocked_legacy_data(user): await session.commit() - return old_nodes, old_edges + return legacy_document, legacy_nodes, legacy_edges if __name__ == "__main__": diff --git a/cognee/tests/test_delete_default_graph_with_legacy_data_2.py b/cognee/tests/test_delete_default_graph_with_legacy_data_2.py index 51c962768..7c21cde72 100644 --- a/cognee/tests/test_delete_default_graph_with_legacy_data_2.py +++ b/cognee/tests/test_delete_default_graph_with_legacy_data_2.py @@ -1,11 +1,12 @@ import os +import pytest import pathlib from uuid import NAMESPACE_OID, uuid5 -import pytest from unittest.mock import AsyncMock, patch import cognee from cognee.api.v1.datasets import datasets +from cognee.context_global_variables import set_database_global_context_variables from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine @@ -16,15 +17,34 @@ from cognee.modules.data.models import Data from cognee.modules.engine.models import Entity, EntityType from cognee.modules.data.processing.document_types import TextDocument from cognee.modules.engine.operations.setup import setup -from cognee.modules.engine.utils import generate_edge_id, generate_node_id +from cognee.modules.engine.utils import generate_node_id from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger from cognee.modules.pipelines.models import DataItemStatus from cognee.modules.users.methods import get_default_user from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent from cognee.tasks.storage import index_data_points, index_graph_edges +from cognee.tests.utils.assert_edges_vector_index_not_present import ( + assert_edges_vector_index_not_present, +) +from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present +from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present +from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present +from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present +from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present +from cognee.tests.utils.assert_nodes_vector_index_not_present import ( + assert_nodes_vector_index_not_present, +) +from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present +from cognee.tests.utils.extract_entities import extract_entities +from cognee.tests.utils.extract_relationships import extract_relationships +from cognee.tests.utils.extract_summary import extract_summary +from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities +from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships +from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text +from cognee.tests.utils.isolate_relationships import isolate_relationships -def get_nodes_and_edges(): +def create_nodes_and_edges(): document = TextDocument( id=uuid5(NAMESPACE_OID, "text_test.txt"), name="text_test.txt", @@ -146,8 +166,6 @@ def get_nodes_and_edges(): @pytest.mark.asyncio @patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock) async def main(mock_create_structured_output: AsyncMock): - os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "False" - data_directory_path = os.path.join( pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_2" ) @@ -163,6 +181,9 @@ async def main(mock_create_structured_output: AsyncMock): await cognee.prune.prune_system(metadata=True) await setup() + user = await get_default_user() + await set_database_global_context_variables("main_dataset", user.id) + vector_engine = get_vector_engine() assert not await vector_engine.has_collection("EdgeType_relationship_name") @@ -171,9 +192,10 @@ async def main(mock_create_structured_output: AsyncMock): assert not await vector_engine.has_collection("TextSummary_text") assert not await vector_engine.has_collection("TextDocument_text") - user = await get_default_user() - - old_document, old_nodes, old_edges = await add_mocked_legacy_data(user) + # Add legacy data to the system + legacy_document, legacy_data_points, legacy_relationships = await create_mocked_legacy_data( + user + ) def mock_llm_output(text_input: str, system_prompt: str, response_model): if text_input == "test": # LLM connection test @@ -249,100 +271,186 @@ async def main(mock_create_structured_output: AsyncMock): mock_create_structured_output.side_effect = mock_llm_output - add_john_result = await cognee.add( - "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" - ) + johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" + add_john_result = await cognee.add(johns_text) johns_data_id = add_john_result.data_ingestion_info[0]["data_id"] - await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.") + maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project." + add_marie_result = await cognee.add(maries_text) + maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"] cognify_result: dict = await cognee.cognify() dataset_id = list(cognify_result.keys())[0] - graph_engine = await get_graph_engine() - initial_nodes, initial_edges = await graph_engine.get_graph_data() - assert len(initial_nodes) == 22 and len(initial_edges) == 26, ( - "Number of nodes and edges is not correct." + johns_document = TextDocument( + id=johns_data_id, + name="John's Work", + raw_data_location="johns_data_location", + external_metadata="", + ) + johns_chunk = DocumentChunk( + id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"), + text=johns_text, + chunk_size=14, + chunk_index=0, + cut_type="sentence_end", + is_part_of=johns_document, + ) + johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore + + maries_document = TextDocument( + id=maries_data_id, + name="Maries's Work", + raw_data_location="maries_data_location", + external_metadata="", + ) + maries_chunk = DocumentChunk( + id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"), + text=maries_text, + chunk_size=14, + chunk_index=0, + cut_type="sentence_end", + is_part_of=maries_document, + ) + maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore + + johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore + maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore + (overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities( + johns_entities, maries_entities ) - initial_nodes_by_vector_collection = {} + johns_data = [ + johns_document, + johns_chunk, + johns_summary, + *johns_entities, + ] + maries_data = [ + maries_document, + maries_chunk, + maries_summary, + *maries_entities, + ] - for node in initial_nodes: - node_data = node[1] - collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] - if collection_name not in initial_nodes_by_vector_collection: - initial_nodes_by_vector_collection[collection_name] = [] - initial_nodes_by_vector_collection[collection_name].append(node) + expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points - initial_node_ids = set([node[0] for node in initial_nodes]) + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(expected_data_points) + await assert_nodes_vector_index_present(expected_data_points) + johns_relationships = extract_relationships( + johns_chunk, + mock_llm_output("John", "", KnowledgeGraph), # type: ignore + ) + maries_relationships = extract_relationships( + maries_chunk, + mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore + ) + (overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = ( + filter_overlapping_relationships( + johns_relationships, maries_relationships, legacy_relationships + ) + ) + + johns_relationships = [ + (johns_chunk.id, johns_document.id, "is_part_of"), + (johns_summary.id, johns_chunk.id, "made_from"), + *johns_relationships, + ] + johns_edge_text_relationships = [ + (johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description)) + for entity in johns_entities + if isinstance(entity, Entity) + ] + maries_relationships = [ + (maries_chunk.id, maries_document.id, "is_part_of"), + (maries_summary.id, maries_chunk.id, "made_from"), + *maries_relationships, + ] + maries_edge_text_relationships = [ + (maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description)) + for entity in maries_entities + if isinstance(entity, Entity) + ] + + expected_relationships = ( + johns_relationships + + maries_relationships + + overlapping_relationships + + legacy_relationships + ) + + await assert_graph_edges_present(expected_relationships) + + await assert_edges_vector_index_present( + expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships + ) + + # Delete John's data await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore - nodes, edges = await graph_engine.get_graph_data() - assert len(nodes) == 16 and len(edges) == 17, "Nodes and edges are not deleted." - assert not any( - node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes - ), "Nodes are not deleted." + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points) + await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points) - after_first_delete_node_ids = set([node[0] for node in nodes]) + await assert_graph_nodes_not_present(johns_data) + await assert_nodes_vector_index_not_present(johns_data) - after_delete_nodes_by_vector_collection = {} - for node in initial_nodes: - node_data = node[1] - collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] - if collection_name not in after_delete_nodes_by_vector_collection: - after_delete_nodes_by_vector_collection[collection_name] = [] - after_delete_nodes_by_vector_collection[collection_name].append(node) + # Assert relationships presence in the graph, vector collections and nodes table + await assert_graph_edges_present( + maries_relationships + overlapping_relationships + legacy_relationships + ) + await assert_edges_vector_index_present( + maries_relationships + maries_edge_text_relationships + legacy_relationships + ) - vector_engine = get_vector_engine() + await assert_graph_edges_not_present(johns_relationships) - removed_node_ids = initial_node_ids - after_first_delete_node_ids + strictly_johns_relationships = isolate_relationships( + johns_relationships, maries_relationships, legacy_relationships + ) + # We check only by relationship name and we need edges that are created by John's data and no other. + await assert_edges_vector_index_not_present( + strictly_johns_relationships + johns_edge_text_relationships + ) - for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): - query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids] + # Delete legacy data + await datasets.delete_data(dataset_id, legacy_document.id, user) # type: ignore - if query_node_ids: - vector_items = await vector_engine.retrieve(collection_name, query_node_ids) - assert len(vector_items) == 0, "Vector items are not deleted." + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(maries_data + overlapping_entities) + await assert_nodes_vector_index_present(maries_data + overlapping_entities) - # Delete old document - await datasets.delete_data(dataset_id, old_document.id, user) # type: ignore + await assert_graph_nodes_not_present(johns_data + legacy_data_points) + await assert_nodes_vector_index_not_present(johns_data + legacy_data_points) - final_nodes, final_edges = await graph_engine.get_graph_data() - assert len(final_nodes) == 9 and len(final_edges) == 10, "Nodes and edges are not deleted." + # Assert relationships presence in the graph, vector collections and nodes table + await assert_graph_edges_present(maries_relationships + overlapping_relationships) + await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships) - old_nodes_by_vector_collection = {} - for node in old_nodes: - collection_name = node.type + "_" + node.metadata["index_fields"][0] - if collection_name not in old_nodes_by_vector_collection: - old_nodes_by_vector_collection[collection_name] = [] - old_nodes_by_vector_collection[collection_name].append(node) + await assert_graph_edges_not_present(johns_relationships + legacy_relationships) - for collection_name, old_nodes in old_nodes_by_vector_collection.items(): - query_node_ids = [str(node.id) for node in old_nodes] - - if query_node_ids: - vector_items = await vector_engine.retrieve(collection_name, query_node_ids) - assert len(vector_items) == 0, "Vector items are not deleted." - - query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges])) - - vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids) - assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted." + strictly_legacy_relationships = isolate_relationships( + legacy_relationships, maries_relationships + ) + # We check only by relationship name and we need edges that are created by legacy data and no other. + if strictly_legacy_relationships: + await assert_edges_vector_index_not_present(strictly_legacy_relationships) -async def add_mocked_legacy_data(user): +async def create_mocked_legacy_data(user): graph_engine = await get_graph_engine() - old_nodes, old_edges = get_nodes_and_edges() - old_document = old_nodes[0] + legacy_nodes, legacy_edges = create_nodes_and_edges() + legacy_document = legacy_nodes[0] - await graph_engine.add_nodes(old_nodes) - await graph_engine.add_edges(old_edges) + await graph_engine.add_nodes(legacy_nodes) + await graph_engine.add_edges(legacy_edges) - await index_data_points(old_nodes) - await index_graph_edges(old_edges) + await index_data_points(legacy_nodes) + await index_graph_edges(legacy_edges) - await record_data_in_legacy_ledger(old_nodes, old_edges, user) + await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user) db_engine = get_relational_engine() @@ -350,12 +458,12 @@ async def add_mocked_legacy_data(user): async with db_engine.get_async_session() as session: old_data = Data( - id=old_document.id, - name=old_document.name, + id=legacy_document.id, + name=legacy_document.name, extension="txt", - raw_data_location=old_document.raw_data_location, - external_metadata=old_document.external_metadata, - mime_type=old_document.mime_type, + raw_data_location=legacy_document.raw_data_location, + external_metadata=legacy_document.external_metadata, + mime_type=legacy_document.mime_type, owner_id=user.id, pipeline_status={ "cognify_pipeline": { @@ -370,7 +478,7 @@ async def add_mocked_legacy_data(user): await session.commit() - return old_document, old_nodes, old_edges + return legacy_document, legacy_nodes, legacy_edges if __name__ == "__main__": diff --git a/cognee/tests/utils/assert_edges_vector_index_not_present.py b/cognee/tests/utils/assert_edges_vector_index_not_present.py new file mode 100644 index 000000000..84261f83b --- /dev/null +++ b/cognee/tests/utils/assert_edges_vector_index_not_present.py @@ -0,0 +1,23 @@ +from uuid import UUID +from typing import List, Tuple +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.engine.utils import generate_edge_id + + +async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str]]): + vector_engine = get_vector_engine() + + query_edge_ids = { + str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships + } + + vector_items = await vector_engine.retrieve( + "EdgeType_relationship_name", list(query_edge_ids.keys()) + ) + + vector_items_by_id = {str(vector_item.id): vector_item for vector_item in vector_items} + + for relationship_id, relationship_name in query_edge_ids.items(): + assert relationship_id not in vector_items_by_id, ( + f"Relationship '{relationship_name}' still present in the vector store." + ) diff --git a/cognee/tests/utils/assert_edges_vector_index_present.py b/cognee/tests/utils/assert_edges_vector_index_present.py new file mode 100644 index 000000000..c4f2b6dbe --- /dev/null +++ b/cognee/tests/utils/assert_edges_vector_index_present.py @@ -0,0 +1,28 @@ +from uuid import UUID +from typing import List, Tuple +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.engine.utils import generate_edge_id + + +async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str]]): + vector_engine = get_vector_engine() + + query_edge_ids = { + str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships + } + + vector_items = await vector_engine.retrieve( + "EdgeType_relationship_name", list(query_edge_ids.keys()) + ) + + vector_items_by_id = {str(vector_item.id): vector_item for vector_item in vector_items} + + for relationship_id, relationship_name in query_edge_ids.items(): + assert relationship_id in vector_items_by_id, ( + f"Relationship '{relationship_name}' not found in vector store." + ) + + vector_relationship = vector_items_by_id[relationship_id] + assert vector_relationship.payload["text"] == relationship_name, ( + f"Vectorized edge '{vector_relationship.payload['text']}' does not match the relationship text '{relationship_name}'." + ) diff --git a/cognee/tests/utils/assert_graph_edges_not_present.py b/cognee/tests/utils/assert_graph_edges_not_present.py new file mode 100644 index 000000000..03e22ae62 --- /dev/null +++ b/cognee/tests/utils/assert_graph_edges_not_present.py @@ -0,0 +1,21 @@ +from uuid import UUID +from typing import List, Tuple + +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.engine.models.DataPoint import DataPoint + + +async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str]]): + graph_engine = await get_graph_engine() + nodes, edges = await graph_engine.get_graph_data() + + nodes_by_id = {str(node[0]): node[1] for node in nodes} + + edge_ids = set([f"{str(edge[0])}_{edge[2]}_{str(edge[1])}" for edge in edges]) + + for relationship in relationships: + relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}" + relationship_name = relationship[2] + assert relationship_id not in edge_ids, ( + f"Edge '{relationship_name}' still present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database." + ) diff --git a/cognee/tests/utils/assert_graph_edges_present.py b/cognee/tests/utils/assert_graph_edges_present.py new file mode 100644 index 000000000..ae14c75cf --- /dev/null +++ b/cognee/tests/utils/assert_graph_edges_present.py @@ -0,0 +1,21 @@ +from uuid import UUID +from typing import List, Tuple + +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.engine.models.DataPoint import DataPoint + + +async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]): + graph_engine = await get_graph_engine() + nodes, edges = await graph_engine.get_graph_data() + + nodes_by_id = {str(node[0]): node[1] for node in nodes} + + edge_ids = set([f"{str(edge[0])}_{edge[2]}_{str(edge[1])}" for edge in edges]) + + for relationship in relationships: + relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}" + relationship_name = relationship[2] + assert relationship_id in edge_ids, ( + f"Edge '{relationship_name}' not present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database." + ) diff --git a/cognee/tests/utils/assert_graph_nodes_not_present.py b/cognee/tests/utils/assert_graph_nodes_not_present.py new file mode 100644 index 000000000..112b8be03 --- /dev/null +++ b/cognee/tests/utils/assert_graph_nodes_not_present.py @@ -0,0 +1,16 @@ +from typing import List +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.engine.models.DataPoint import DataPoint + + +async def assert_graph_nodes_not_present(data_points: List[DataPoint]): + graph_engine = await get_graph_engine() + nodes, __ = await graph_engine.get_graph_data() + + node_ids = set(node[0] for node in nodes) + + for data_point in data_points: + node_name = getattr(data_point, "label", getattr(data_point, "name", data_point.id)) + assert str(data_point.id) not in node_ids, ( + f"Node '{node_name}' is present in graph database." + ) diff --git a/cognee/tests/utils/assert_graph_nodes_present.py b/cognee/tests/utils/assert_graph_nodes_present.py new file mode 100644 index 000000000..0508db28b --- /dev/null +++ b/cognee/tests/utils/assert_graph_nodes_present.py @@ -0,0 +1,14 @@ +from typing import List +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.engine.models.DataPoint import DataPoint + + +async def assert_graph_nodes_present(data_points: List[DataPoint]): + graph_engine = await get_graph_engine() + nodes, __ = await graph_engine.get_graph_data() + + node_ids = set(node[0] for node in nodes) + + for data_point in data_points: + node_name = getattr(data_point, "label", getattr(data_point, "name", data_point.id)) + assert str(data_point.id) in node_ids, f"Node '{node_name}' not found in graph database." diff --git a/cognee/tests/utils/assert_nodes_vector_index_not_present.py b/cognee/tests/utils/assert_nodes_vector_index_not_present.py new file mode 100644 index 000000000..5fa510aba --- /dev/null +++ b/cognee/tests/utils/assert_nodes_vector_index_not_present.py @@ -0,0 +1,28 @@ +from typing import List +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.engine.models.DataPoint import DataPoint + + +async def assert_nodes_vector_index_not_present(data_points: List[DataPoint]): + vector_engine = get_vector_engine() + + data_points_by_vector_collection = {} + + for data_point in data_points: + node_metadata = data_point.metadata or {} + collection_name = data_point.type + "_" + node_metadata["index_fields"][0] + + if collection_name not in data_points_by_vector_collection: + data_points_by_vector_collection[collection_name] = [] + + data_points_by_vector_collection[collection_name].append(data_point) + + for collection_name, collection_data_points in data_points_by_vector_collection.items(): + query_data_point_ids = set([str(data_point.id) for data_point in collection_data_points]) + + vector_items = await vector_engine.retrieve(collection_name, list(query_data_point_ids)) + + for vector_item in vector_items: + assert str(vector_item.id) not in query_data_point_ids, ( + f"{vector_item.payload['text']} is still present in the vector store." + ) diff --git a/cognee/tests/utils/assert_nodes_vector_index_present.py b/cognee/tests/utils/assert_nodes_vector_index_present.py new file mode 100644 index 000000000..ffc6c65f4 --- /dev/null +++ b/cognee/tests/utils/assert_nodes_vector_index_present.py @@ -0,0 +1,28 @@ +from typing import List +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.engine.models.DataPoint import DataPoint + + +async def assert_nodes_vector_index_present(data_points: List[DataPoint]): + vector_engine = get_vector_engine() + + data_points_by_vector_collection = {} + + for data_point in data_points: + node_metadata = data_point.metadata or {} + collection_name = data_point.type + "_" + node_metadata["index_fields"][0] + + if collection_name not in data_points_by_vector_collection: + data_points_by_vector_collection[collection_name] = [] + + data_points_by_vector_collection[collection_name].append(data_point) + + for collection_name, collection_data_points in data_points_by_vector_collection.items(): + query_data_point_ids = set([str(data_point.id) for data_point in collection_data_points]) + + vector_items = await vector_engine.retrieve(collection_name, list(query_data_point_ids)) + + for vector_item in vector_items: + assert str(vector_item.id) in query_data_point_ids, ( + f"{vector_item.payload['text']} is not present in the vector store." + ) diff --git a/cognee/tests/utils/extract_entities.py b/cognee/tests/utils/extract_entities.py new file mode 100644 index 000000000..541d90faf --- /dev/null +++ b/cognee/tests/utils/extract_entities.py @@ -0,0 +1,45 @@ +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.engine.utils import generate_node_id, generate_node_name +from cognee.shared.data_models import KnowledgeGraph + + +def extract_entities(graph: KnowledgeGraph, cache: dict = {}): + entities = [] + entity_types = [] + + for node in graph.nodes: + node_id = generate_node_id(node.id) + + if node_id not in cache: + entity = Entity( + id=node_id, + name=generate_node_name(node.id), + type=node.type, + description=node.description, + ontology_valid=False, + ) + cache[node_id] = entity + else: + entity = cache[node_id] + + entities.append(entity) + + node_type = node.type + type_node_id = generate_node_id(node_type) + if type_node_id not in cache: + type_node_name = generate_node_name(node_type) + + type_node = EntityType( + id=type_node_id, + name=type_node_name, + type=type_node_name, + description=type_node_name, + ontology_valid=False, + ) + cache[type_node_id] = type_node + else: + type_node = cache[type_node_id] + + entity_types.append(type_node) + + return entities + entity_types diff --git a/cognee/tests/utils/extract_relationships.py b/cognee/tests/utils/extract_relationships.py new file mode 100644 index 000000000..e3f7849df --- /dev/null +++ b/cognee/tests/utils/extract_relationships.py @@ -0,0 +1,55 @@ +from cognee.shared.data_models import KnowledgeGraph +from cognee.modules.chunking.models.DocumentChunk import DocumentChunk +from cognee.modules.engine.utils import generate_edge_id, generate_node_id + + +def extract_relationships(document_chunk: DocumentChunk, graph: KnowledgeGraph, cache: dict = {}): + relationships = [] + + for edge in graph.edges: + edge_id = f"{edge.source_node_id}_{edge.relationship_name}_{edge.target_node_id}" + + if edge_id not in cache: + relationship = ( + generate_edge_id(edge.source_node_id), + generate_edge_id(edge.target_node_id), + edge.relationship_name, + ) + cache[edge_id] = relationship + else: + relationship = cache[edge_id] + + relationships.append(relationship) + + for node in graph.nodes: + node_id = generate_node_id(node.id) + type_node_id = generate_node_id(node.type) + type_edge_id = f"{str(node_id)}_is_a_{str(type_node_id)}" + + if type_edge_id not in cache: + relationship = ( + node_id, + type_node_id, + "is_a", + ) + cache[type_edge_id] = relationship + else: + relationship = cache[type_edge_id] + + relationships.append(relationship) + + chunk_edge_id = f"{str(document_chunk.id)}_contains_{str(node_id)}" + + if chunk_edge_id not in cache: + relationship = ( + document_chunk.id, + node_id, + "contains", + ) + cache[chunk_edge_id] = relationship + else: + relationship = cache[chunk_edge_id] + + relationships.append(relationship) + + return relationships diff --git a/cognee/tests/utils/extract_summary.py b/cognee/tests/utils/extract_summary.py new file mode 100644 index 000000000..be9125fff --- /dev/null +++ b/cognee/tests/utils/extract_summary.py @@ -0,0 +1,12 @@ +from uuid import uuid5 +from cognee.modules.chunking.models import DocumentChunk +from cognee.shared.data_models import SummarizedContent +from cognee.tasks.summarization.models import TextSummary + + +def extract_summary(document_chunk: DocumentChunk, summary=SummarizedContent) -> TextSummary: + return TextSummary( + id=uuid5(document_chunk.id, "TextSummary"), + text=summary.summary, + made_from=document_chunk, + ) diff --git a/cognee/tests/utils/filter_overlapping_entities.py b/cognee/tests/utils/filter_overlapping_entities.py new file mode 100644 index 000000000..40c6b0936 --- /dev/null +++ b/cognee/tests/utils/filter_overlapping_entities.py @@ -0,0 +1,26 @@ +def filter_overlapping_entities(*entity_groups): + entity_count = {} + overlapping_entities = [] + + for group in entity_groups: + for entity in group: + if not entity.id in entity_count: + entity_count[entity.id] = 1 + else: + entity_count[entity.id] += 1 + + index = 0 + grouped_entities = [] + for group in entity_groups: + grouped_entities.append([]) + + for entity in group: + if entity_count[entity.id] == 1: + grouped_entities[index].append(entity) + else: + if entity not in overlapping_entities: + overlapping_entities.append(entity) + + index += 1 + + return overlapping_entities, *grouped_entities diff --git a/cognee/tests/utils/filter_overlapping_relationships.py b/cognee/tests/utils/filter_overlapping_relationships.py new file mode 100644 index 000000000..064e0503c --- /dev/null +++ b/cognee/tests/utils/filter_overlapping_relationships.py @@ -0,0 +1,33 @@ +from cognee.modules.engine.utils import generate_node_id + + +def filter_overlapping_relationships(*relationship_groups): + relationship_count = {} + overlapping_relationships = [] + + for group in relationship_groups: + for relationship in group: + relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}" + + if not relationship_id in relationship_count: + relationship_count[relationship_id] = 1 + else: + relationship_count[relationship_id] += 1 + + index = 0 + grouped_relationships = [] + for group in relationship_groups: + grouped_relationships.append([]) + + for relationship in group: + relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}" + + if relationship_count[relationship_id] == 1: + grouped_relationships[index].append(relationship) + else: + if relationship not in overlapping_relationships: + overlapping_relationships.append(relationship) + + index += 1 + + return overlapping_relationships, *grouped_relationships diff --git a/cognee/tests/utils/get_contains_edge_text.py b/cognee/tests/utils/get_contains_edge_text.py new file mode 100644 index 000000000..478ea30fe --- /dev/null +++ b/cognee/tests/utils/get_contains_edge_text.py @@ -0,0 +1,9 @@ +def get_contains_edge_text(entity_name: str, entity_description: str) -> str: + edge_text = "; ".join( + [ + "relationship_name: contains", + f"entity_name: {entity_name}", + f"entity_description: {entity_description}", + ] + ) + return edge_text diff --git a/cognee/tests/utils/isolate_relationships.py b/cognee/tests/utils/isolate_relationships.py new file mode 100644 index 000000000..ca94f99ec --- /dev/null +++ b/cognee/tests/utils/isolate_relationships.py @@ -0,0 +1,20 @@ +def isolate_relationships(source_relationships, *other_relationships): + final_relationships = [] + cache = {relationship[2]: 1 for relationship in source_relationships} + duplicated_relationships = {} + + for relationships in other_relationships: + for relationship in relationships: + if relationship[2] not in cache: + cache[relationship[2]] = 0 + + cache[relationship[2]] += 1 + + if cache[relationship[2]] == 2: + duplicated_relationships[relationship[2]] = True + + for relationship in source_relationships: + if relationship[2] not in duplicated_relationships: + final_relationships.append(relationship) + + return final_relationships