diff --git a/cognee/modules/graph/methods/legacy_delete.py b/cognee/modules/graph/methods/legacy_delete.py index d7cc7d4e9..dffacd0a0 100644 --- a/cognee/modules/graph/methods/legacy_delete.py +++ b/cognee/modules/graph/methods/legacy_delete.py @@ -54,7 +54,7 @@ async def legacy_delete(data: Data, mode: str = "soft"): async def delete_document_subgraph(document_id: UUID, mode: str = "soft"): """Delete a document and all its related nodes in the correct order.""" graph_db = await get_graph_engine() - subgraph = await graph_db.get_document_subgraph(document_id) + subgraph = await graph_db.get_document_subgraph(str(document_id)) if not subgraph: raise DocumentSubgraphNotFoundError(f"Document not found with id: {document_id}") diff --git a/cognee/tests/test_delete_default_graph_with_legacy_data_1.py b/cognee/tests/test_delete_default_graph_with_legacy_data_1.py index 02edd4e48..9fae87874 100644 --- a/cognee/tests/test_delete_default_graph_with_legacy_data_1.py +++ b/cognee/tests/test_delete_default_graph_with_legacy_data_1.py @@ -19,6 +19,8 @@ from cognee.modules.data.processing.document_types import TextDocument from cognee.modules.engine.operations.setup import setup from cognee.modules.engine.utils import generate_node_id from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger +from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges +from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model from cognee.modules.pipelines.models import DataItemStatus from cognee.modules.users.methods import get_default_user from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent @@ -43,7 +45,7 @@ from cognee.tests.utils.filter_overlapping_relationships import filter_overlappi from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text -def create_nodes_and_edges(): +def create_legacy_data_points(): document = TextDocument( id=uuid5(NAMESPACE_OID, "text_test.txt"), name="text_test.txt", @@ -72,11 +74,13 @@ def create_nodes_and_edges(): id=generate_node_id("neptune analytics"), name="neptune analytics", description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.", + is_a=graph_database, ) neptune_database_entity = Entity( id=generate_node_id("amazon neptune database"), name="amazon neptune database", description="A popular managed graph database that complements Neptune Analytics.", + is_a=graph_database, ) storage = EntityType( @@ -88,11 +92,10 @@ def create_nodes_and_edges(): id=generate_node_id("amazon s3"), name="amazon s3", description="A storage service provided by Amazon Web Services that allows storing graph data.", + is_a=storage, ) - nodes_data = [ - document, - document_chunk, + entities = [ graph_database, neptune_analytics_entity, neptune_database_entity, @@ -100,66 +103,14 @@ def create_nodes_and_edges(): storage_entity, ] - edges_data = [ - ( - document_chunk.id, - storage_entity.id, - "contains", - { - "relationship_name": "contains", - }, - ), - ( - storage_entity.id, - storage.id, - "is_a", - { - "relationship_name": "is_a", - }, - ), - ( - document_chunk.id, - neptune_database_entity.id, - "contains", - { - "relationship_name": "contains", - }, - ), - ( - neptune_database_entity.id, - graph_database.id, - "is_a", - { - "relationship_name": "is_a", - }, - ), - ( - document_chunk.id, - document.id, - "is_part_of", - { - "relationship_name": "is_part_of", - }, - ), - ( - document_chunk.id, - neptune_analytics_entity.id, - "contains", - { - "relationship_name": "contains", - }, - ), - ( - neptune_analytics_entity.id, - graph_database.id, - "is_a", - { - "relationship_name": "is_a", - }, - ), + document_chunk.contains = entities + + data_points = [ + document, + document_chunk, ] - return nodes_data, edges_data + return data_points @pytest.mark.asyncio @@ -441,13 +392,38 @@ async def main(mock_create_structured_output: AsyncMock): async def create_mocked_legacy_data(user): graph_engine = await get_graph_engine() - legacy_nodes, legacy_edges = create_nodes_and_edges() - legacy_document = legacy_nodes[0] + legacy_data_points = create_legacy_data_points() + legacy_document = legacy_data_points[0] - await graph_engine.add_nodes(legacy_nodes) - await graph_engine.add_edges(legacy_edges) + nodes = [] + edges = [] - nodes_by_id = {node.id: node for node in legacy_nodes} + added_nodes = {} + added_edges = {} + visited_properties = {} + + results = await asyncio.gather( + *[ + get_graph_from_model( + data_point, + added_nodes=added_nodes, + added_edges=added_edges, + visited_properties=visited_properties, + ) + for data_point in legacy_data_points + ] + ) + + for result_nodes, result_edges in results: + nodes.extend(result_nodes) + edges.extend(result_edges) + + graph_nodes, graph_edges = deduplicate_nodes_and_edges(nodes, edges) + + await graph_engine.add_nodes(graph_nodes) + await graph_engine.add_edges(graph_edges) + + nodes_by_id = {node.id: node for node in graph_nodes} def format_relationship_name(relationship): if relationship[2] == "contains": @@ -455,7 +431,7 @@ async def create_mocked_legacy_data(user): return get_contains_edge_text(node.name, node.description) return relationship[2] - await index_data_points(legacy_nodes) + await index_data_points(graph_nodes) await index_graph_edges( [ ( @@ -467,11 +443,11 @@ async def create_mocked_legacy_data(user): "relationship_name": format_relationship_name(edge), }, ) - for edge in legacy_edges + for edge in graph_edges ] # type: ignore ) - await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user) + await record_data_in_legacy_ledger(graph_nodes, graph_edges, user) db_engine = get_relational_engine() @@ -499,7 +475,7 @@ async def create_mocked_legacy_data(user): await session.commit() - return legacy_document, legacy_nodes, legacy_edges + return legacy_document, graph_nodes, graph_edges if __name__ == "__main__": diff --git a/cognee/tests/test_delete_default_graph_with_legacy_data_2.py b/cognee/tests/test_delete_default_graph_with_legacy_data_2.py index db59d2a60..13824bff6 100644 --- a/cognee/tests/test_delete_default_graph_with_legacy_data_2.py +++ b/cognee/tests/test_delete_default_graph_with_legacy_data_2.py @@ -20,6 +20,8 @@ from cognee.modules.data.processing.document_types import TextDocument from cognee.modules.engine.operations.setup import setup from cognee.modules.engine.utils import generate_node_id from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger +from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges +from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model from cognee.modules.pipelines.models import DataItemStatus from cognee.modules.users.methods import get_default_user from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent @@ -44,7 +46,7 @@ from cognee.tests.utils.filter_overlapping_relationships import filter_overlappi from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text -def create_nodes_and_edges(): +def create_legacy_data_points(): document = TextDocument( id=uuid5(NAMESPACE_OID, "text_test.txt"), name="text_test.txt", @@ -73,11 +75,13 @@ def create_nodes_and_edges(): id=generate_node_id("neptune analytics"), name="neptune analytics", description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.", + is_a=graph_database, ) neptune_database_entity = Entity( id=generate_node_id("amazon neptune database"), name="amazon neptune database", description="A popular managed graph database that complements Neptune Analytics.", + is_a=graph_database, ) storage = EntityType( @@ -89,11 +93,10 @@ def create_nodes_and_edges(): id=generate_node_id("amazon s3"), name="amazon s3", description="A storage service provided by Amazon Web Services that allows storing graph data.", + is_a=storage, ) - nodes_data = [ - document, - document_chunk, + entities = [ graph_database, neptune_analytics_entity, neptune_database_entity, @@ -101,66 +104,14 @@ def create_nodes_and_edges(): storage_entity, ] - edges_data = [ - ( - document_chunk.id, - storage_entity.id, - "contains", - { - "relationship_name": "contains", - }, - ), - ( - storage_entity.id, - storage.id, - "is_a", - { - "relationship_name": "is_a", - }, - ), - ( - document_chunk.id, - neptune_database_entity.id, - "contains", - { - "relationship_name": "contains", - }, - ), - ( - neptune_database_entity.id, - graph_database.id, - "is_a", - { - "relationship_name": "is_a", - }, - ), - ( - document_chunk.id, - document.id, - "is_part_of", - { - "relationship_name": "is_part_of", - }, - ), - ( - document_chunk.id, - neptune_analytics_entity.id, - "contains", - { - "relationship_name": "contains", - }, - ), - ( - neptune_analytics_entity.id, - graph_database.id, - "is_a", - { - "relationship_name": "is_a", - }, - ), + document_chunk.contains = entities + + data_points = [ + document, + document_chunk, ] - return nodes_data, edges_data + return data_points @pytest.mark.asyncio @@ -436,13 +387,38 @@ async def main(mock_create_structured_output: AsyncMock): async def create_mocked_legacy_data(user): graph_engine = await get_graph_engine() - legacy_nodes, legacy_edges = create_nodes_and_edges() - legacy_document = legacy_nodes[0] + legacy_data_points = create_legacy_data_points() + legacy_document = legacy_data_points[0] - await graph_engine.add_nodes(legacy_nodes) - await graph_engine.add_edges(legacy_edges) + nodes = [] + edges = [] - nodes_by_id = {node.id: node for node in legacy_nodes} + added_nodes = {} + added_edges = {} + visited_properties = {} + + results = await asyncio.gather( + *[ + get_graph_from_model( + data_point, + added_nodes=added_nodes, + added_edges=added_edges, + visited_properties=visited_properties, + ) + for data_point in legacy_data_points + ] + ) + + for result_nodes, result_edges in results: + nodes.extend(result_nodes) + edges.extend(result_edges) + + graph_nodes, graph_edges = deduplicate_nodes_and_edges(nodes, edges) + + await graph_engine.add_nodes(graph_nodes) + await graph_engine.add_edges(graph_edges) + + nodes_by_id = {node.id: node for node in graph_nodes} def format_relationship_name(relationship): if relationship[2] == "contains": @@ -450,7 +426,7 @@ async def create_mocked_legacy_data(user): return get_contains_edge_text(node.name, node.description) return relationship[2] - await index_data_points(legacy_nodes) + await index_data_points(graph_nodes) await index_graph_edges( [ ( @@ -462,11 +438,11 @@ async def create_mocked_legacy_data(user): "relationship_name": format_relationship_name(edge), }, ) - for edge in legacy_edges + for edge in graph_edges ] # type: ignore ) - await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user) + await record_data_in_legacy_ledger(graph_nodes, graph_edges, user) db_engine = get_relational_engine() @@ -494,7 +470,7 @@ async def create_mocked_legacy_data(user): await session.commit() - return legacy_document, legacy_nodes, legacy_edges + return legacy_document, graph_nodes, graph_edges if __name__ == "__main__": diff --git a/cognee/tests/utils/filter_overlapping_entities.py b/cognee/tests/utils/filter_overlapping_entities.py index 40c6b0936..dc0afb9c6 100644 --- a/cognee/tests/utils/filter_overlapping_entities.py +++ b/cognee/tests/utils/filter_overlapping_entities.py @@ -4,7 +4,7 @@ def filter_overlapping_entities(*entity_groups): for group in entity_groups: for entity in group: - if not entity.id in entity_count: + if entity.id not in entity_count: entity_count[entity.id] = 1 else: entity_count[entity.id] += 1 diff --git a/cognee/tests/utils/filter_overlapping_relationships.py b/cognee/tests/utils/filter_overlapping_relationships.py index 064e0503c..6bdd6e1a8 100644 --- a/cognee/tests/utils/filter_overlapping_relationships.py +++ b/cognee/tests/utils/filter_overlapping_relationships.py @@ -9,7 +9,7 @@ def filter_overlapping_relationships(*relationship_groups): for relationship in group: relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}" - if not relationship_id in relationship_count: + if relationship_id not in relationship_count: relationship_count[relationship_id] = 1 else: relationship_count[relationship_id] += 1