From 79983c25eeefeeb387ba5c19a59251bdd0949cbd Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Tue, 4 Nov 2025 14:39:02 +0100 Subject: [PATCH] fix: handle data deletion with backwards compatibility --- .github/workflows/e2e_tests.yml | 54 +++ cognee/api/v1/datasets/datasets.py | 18 +- .../v1/delete/routers/get_delete_router.py | 3 +- .../vector/lancedb/LanceDBAdapter.py | 6 +- .../modules/engine/utils/generate_node_id.py | 4 +- .../graph/legacy/GraphRelationshipLedger.py | 40 ++ .../legacy/has_edges_in_legacy_ledger.py | 49 +++ .../legacy/has_nodes_in_legacy_ledger.py | 36 ++ .../legacy/record_data_in_legacy_ledger.py | 38 ++ cognee/modules/graph/methods/__init__.py | 4 + .../methods/delete_data_nodes_and_edges.py | 36 +- .../graph/methods/has_data_related_nodes.py | 16 + cognee/modules/graph/methods/legacy_delete.py | 94 +++++ cognee/tests/test_delete_default_graph.py | 31 +- .../test_delete_default_graph_non_mocked.py | 32 +- ...delete_default_graph_with_legacy_data_1.py | 372 ++++++++++++++++++ ...delete_default_graph_with_legacy_data_2.py | 372 ++++++++++++++++++ 17 files changed, 1148 insertions(+), 57 deletions(-) create mode 100644 cognee/modules/graph/legacy/GraphRelationshipLedger.py create mode 100644 cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py create mode 100644 cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py create mode 100644 cognee/modules/graph/legacy/record_data_in_legacy_ledger.py create mode 100644 cognee/modules/graph/methods/has_data_related_nodes.py create mode 100644 cognee/modules/graph/methods/legacy_delete.py create mode 100644 cognee/tests/test_delete_default_graph_with_legacy_data_1.py create mode 100644 cognee/tests/test_delete_default_graph_with_legacy_data_2.py diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index bfa596855..47af8de4a 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -346,6 +346,60 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_delete_custom_graph.py + test-deletion-on-default-graph_with_legacy_data_1: + name: Delete default graph with legacy data test 1 + runs-on: ubuntu-22.04 + steps: + - name: Check out + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run deletion on custom graph + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_1.py + + test-deletion-on-default-graph_with_legacy_data_2: + name: Delete default graph with legacy data test 2 + runs-on: ubuntu-22.04 + steps: + - name: Check out + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run deletion on custom graph + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_2.py + test-graph-edges: name: Test graph edge ingestion runs-on: ubuntu-22.04 diff --git a/cognee/api/v1/datasets/datasets.py b/cognee/api/v1/datasets/datasets.py index a29953cbd..5a6293240 100644 --- a/cognee/api/v1/datasets/datasets.py +++ b/cognee/api/v1/datasets/datasets.py @@ -7,7 +7,12 @@ from cognee.modules.users.exceptions import PermissionDeniedError from cognee.modules.data.methods import has_dataset_data from cognee.modules.data.methods import get_authorized_dataset, get_authorized_existing_datasets from cognee.modules.data.exceptions.exceptions import UnauthorizedDataAccessError -from cognee.modules.graph.methods import delete_data_nodes_and_edges, delete_dataset_nodes_and_edges +from cognee.modules.graph.methods import ( + delete_data_nodes_and_edges, + delete_dataset_nodes_and_edges, + has_data_related_nodes, + legacy_delete, +) from cognee.modules.ingestion import discover_directory_datasets from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status @@ -66,7 +71,9 @@ class datasets: return await delete_dataset(dataset) @staticmethod - async def delete_data(dataset_id: UUID, data_id: UUID, user: Optional[User] = None): + async def delete_data( + dataset_id: UUID, data_id: UUID, user: Optional[User] = None, mode: str = "soft" + ): from cognee.modules.data.methods import delete_data, get_data if not user: @@ -81,7 +88,7 @@ class datasets: if not data: # If data is not found in the system, user is using a custom graph model. - await delete_data_nodes_and_edges(dataset_id, data_id) + await delete_data_nodes_and_edges(dataset_id, data_id, user.id) return data_datasets = data.datasets @@ -89,7 +96,10 @@ class datasets: if not data or not any([dataset.id == dataset_id for dataset in data_datasets]): raise UnauthorizedDataAccessError(f"Data {data_id} not accessible.") - await delete_data_nodes_and_edges(dataset_id, data.id) + if not await has_data_related_nodes(dataset_id, data_id): + await legacy_delete(data, mode) + else: + await delete_data_nodes_and_edges(dataset_id, data_id, user.id) await delete_data(data) diff --git a/cognee/api/v1/delete/routers/get_delete_router.py b/cognee/api/v1/delete/routers/get_delete_router.py index 8a9a0729e..23f447127 100644 --- a/cognee/api/v1/delete/routers/get_delete_router.py +++ b/cognee/api/v1/delete/routers/get_delete_router.py @@ -54,9 +54,10 @@ def get_delete_router() -> APIRouter: try: result = await datasets.delete_data( - data_id=data_id, dataset_id=dataset_id, + data_id=data_id, user=user, + mode=mode, ) return result diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 6abc5abfc..543f1bed0 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -206,12 +206,12 @@ class LanceDBAdapter(VectorDBInterface): collection = await self.get_collection(collection_name) if len(data_point_ids) == 1: - results = await collection.query().where(f"id = '{data_point_ids[0]}'") + query = collection.query().where(f"id = '{data_point_ids[0]}'") else: - results = await collection.query().where(f"id IN {tuple(data_point_ids)}") + query = collection.query().where(f"id IN {tuple(data_point_ids)}") # Convert query results to list format - results_list = results.to_list() if hasattr(results, "to_list") else list(results) + results_list = await query.to_list() return [ ScoredResult( diff --git a/cognee/modules/engine/utils/generate_node_id.py b/cognee/modules/engine/utils/generate_node_id.py index 489a88875..4beee0416 100644 --- a/cognee/modules/engine/utils/generate_node_id.py +++ b/cognee/modules/engine/utils/generate_node_id.py @@ -1,5 +1,5 @@ -from uuid import NAMESPACE_OID, uuid5 +from uuid import NAMESPACE_OID, UUID, uuid5 -def generate_node_id(node_id: str) -> str: +def generate_node_id(node_id: str) -> UUID: return uuid5(NAMESPACE_OID, node_id.lower().replace(" ", "_").replace("'", "")) diff --git a/cognee/modules/graph/legacy/GraphRelationshipLedger.py b/cognee/modules/graph/legacy/GraphRelationshipLedger.py new file mode 100644 index 000000000..0b03483e6 --- /dev/null +++ b/cognee/modules/graph/legacy/GraphRelationshipLedger.py @@ -0,0 +1,40 @@ +from uuid import uuid5, NAMESPACE_OID +from datetime import datetime, timezone +from sqlalchemy import UUID, Column, DateTime, String, Index + +from cognee.infrastructure.databases.relational import Base + + +class GraphRelationshipLedger(Base): + __tablename__ = "graph_relationship_ledger" + + id = Column( + UUID, + primary_key=True, + default=lambda: uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), + ) + source_node_id = Column(UUID, nullable=False) + destination_node_id = Column(UUID, nullable=False) + creator_function = Column(String, nullable=False) + node_label = Column(String, nullable=True) + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + deleted_at = Column(DateTime(timezone=True), nullable=True) + user_id = Column(UUID, nullable=True) + + # Create indexes + __table_args__ = ( + Index("idx_graph_relationship_id", "id"), + Index("idx_graph_relationship_ledger_source_node_id", "source_node_id"), + Index("idx_graph_relationship_ledger_destination_node_id", "destination_node_id"), + ) + + def to_json(self) -> dict: + return { + "id": str(self.id), + "source_node_id": str(self.parent_id), + "destination_node_id": str(self.child_id), + "creator_function": self.creator_function, + "created_at": self.created_at.isoformat(), + "deleted_at": self.deleted_at.isoformat() if self.deleted_at else None, + "user_id": str(self.user_id), + } diff --git a/cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py b/cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py new file mode 100644 index 000000000..7857a8e25 --- /dev/null +++ b/cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py @@ -0,0 +1,49 @@ +from uuid import UUID +from typing import List +from sqlalchemy import and_, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models import Edge, Node +from .GraphRelationshipLedger import GraphRelationshipLedger + + +@with_async_session +async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session: AsyncSession): + if len(edges) == 0: + return [] + + query = select(GraphRelationshipLedger).where( + and_( + GraphRelationshipLedger.user_id == user_id, + or_( + *[ + GraphRelationshipLedger.creator_function.ilike(f"%{edge.relationship_name}") + for edge in edges + ] + ), + ) + ) + + legacy_edges = (await session.scalars(query)).all() + + legacy_edge_names = set([edge.creator_function.split(".")[1] for edge in legacy_edges]) + + return [edge.relationship_name in legacy_edge_names for edge in edges] + + +@with_async_session +async def get_node_ids(edges: List[Edge], session: AsyncSession): + node_slugs = [] + + for edge in edges: + node_slugs.append(edge.source_node_id) + node_slugs.append(edge.destination_node_id) + + query = select(Node).where(Node.slug.in_(node_slugs)) + + nodes = (await session.scalars(query)).all() + + node_ids = {node.slug: node.id for node in nodes} + + return node_ids diff --git a/cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py b/cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py new file mode 100644 index 000000000..a84b57d75 --- /dev/null +++ b/cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py @@ -0,0 +1,36 @@ +from typing import List +from uuid import UUID +from sqlalchemy import and_, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models import Node +from .GraphRelationshipLedger import GraphRelationshipLedger + + +@with_async_session +async def has_nodes_in_legacy_ledger(nodes: List[Node], user_id: UUID, session: AsyncSession): + node_ids = [node.slug for node in nodes] + + query = select( + GraphRelationshipLedger.source_node_id, + GraphRelationshipLedger.destination_node_id, + ).where( + and_( + GraphRelationshipLedger.user_id == user_id, + or_( + GraphRelationshipLedger.source_node_id.in_(node_ids), + GraphRelationshipLedger.destination_node_id.in_(node_ids), + ), + ) + ) + + legacy_nodes = await session.execute(query) + entries = legacy_nodes.all() + + found_ids = set() + for entry in entries: + found_ids.add(entry.source_node_id) + found_ids.add(entry.destination_node_id) + + return [node_id in found_ids for node_id in node_ids] diff --git a/cognee/modules/graph/legacy/record_data_in_legacy_ledger.py b/cognee/modules/graph/legacy/record_data_in_legacy_ledger.py new file mode 100644 index 000000000..d3efbb99b --- /dev/null +++ b/cognee/modules/graph/legacy/record_data_in_legacy_ledger.py @@ -0,0 +1,38 @@ +from uuid import UUID +from typing import Dict, List, Tuple +from sqlalchemy.ext.asyncio import AsyncSession + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.modules.users.models.User import User +from .GraphRelationshipLedger import GraphRelationshipLedger + + +@with_async_session +async def record_data_in_legacy_ledger( + nodes: List[DataPoint], + edges: List[Tuple[UUID, UUID, str, Dict]], + user: User, + session: AsyncSession, +) -> None: + relationships = [ + GraphRelationshipLedger( + source_node_id=node.id, + destination_node_id=node.id, + creator_function="add_nodes", + user_id=user.id, + ) + for node in nodes + ] + [ + GraphRelationshipLedger( + source_node_id=edge[0], + destination_node_id=edge[1], + creator_function=f"add_edges.{edge[2]}", + user_id=user.id, + ) + for edge in edges + ] + + session.add_all(relationships) + + await session.commit() diff --git a/cognee/modules/graph/methods/__init__.py b/cognee/modules/graph/methods/__init__.py index 33d50deb7..5cde7f79c 100644 --- a/cognee/modules/graph/methods/__init__.py +++ b/cognee/modules/graph/methods/__init__.py @@ -3,6 +3,8 @@ from .get_formatted_graph_data import get_formatted_graph_data from .upsert_edges import upsert_edges from .upsert_nodes import upsert_nodes +from .has_data_related_nodes import has_data_related_nodes + from .get_data_related_nodes import get_data_related_nodes from .get_data_related_edges import get_data_related_edges from .delete_data_related_nodes import delete_data_related_nodes @@ -14,3 +16,5 @@ from .get_dataset_related_edges import get_dataset_related_edges from .delete_dataset_related_nodes import delete_dataset_related_nodes from .delete_dataset_related_edges import delete_dataset_related_edges from .delete_dataset_nodes_and_edges import delete_dataset_nodes_and_edges + +from .legacy_delete import legacy_delete diff --git a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py index 799bd6b3f..d96d8f265 100644 --- a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py +++ b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py @@ -3,6 +3,8 @@ from typing import Dict, List from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine +from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger +from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger from cognee.modules.graph.methods import ( delete_data_related_edges, delete_data_related_nodes, @@ -11,17 +13,26 @@ from cognee.modules.graph.methods import ( ) -async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID) -> None: +async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id: UUID) -> None: affected_nodes = await get_data_related_nodes(dataset_id, data_id) if len(affected_nodes) == 0: return + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + + affected_relationships = await get_data_related_edges(dataset_id, data_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + + non_legacy_nodes = [ + node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] + ] + graph_engine = await get_graph_engine() - await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes]) + await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes]) affected_vector_collections: Dict[str, List] = {} - for node in affected_nodes: + for node in non_legacy_nodes: for indexed_field in node.indexed_fields: collection_name = f"{node.type}_{indexed_field}" if collection_name not in affected_vector_collections: @@ -29,17 +40,22 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID) -> None: affected_vector_collections[collection_name].append(node) vector_engine = get_vector_engine() - for affected_collection, affected_nodes in affected_vector_collections.items(): + for affected_collection, non_legacy_nodes in affected_vector_collections.items(): await vector_engine.delete_data_points( - affected_collection, [node.slug for node in affected_nodes] + affected_collection, [str(node.slug) for node in non_legacy_nodes] ) - affected_relationships = await get_data_related_edges(dataset_id, data_id) + if len(affected_relationships) > 0: + non_legacy_relationships = [ + edge + for index, edge in enumerate(affected_relationships) + if not is_legacy_relationship[index] + ] - await vector_engine.delete_data_points( - "EdgeType_relationship_name", - [edge.slug for edge in affected_relationships], - ) + await vector_engine.delete_data_points( + "EdgeType_relationship_name", + [str(relationship.slug) for relationship in non_legacy_relationships], + ) await delete_data_related_nodes(data_id) await delete_data_related_edges(data_id) diff --git a/cognee/modules/graph/methods/has_data_related_nodes.py b/cognee/modules/graph/methods/has_data_related_nodes.py new file mode 100644 index 000000000..f3c1e29b4 --- /dev/null +++ b/cognee/modules/graph/methods/has_data_related_nodes.py @@ -0,0 +1,16 @@ +from uuid import UUID +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from cognee.infrastructure.databases.relational import with_async_session +from cognee.modules.graph.models import Node + + +@with_async_session +async def has_data_related_nodes(dataset_id: UUID, data_id: UUID, session: AsyncSession): + query_statement = ( + select(Node).where(and_(Node.data_id == data_id, Node.dataset_id == dataset_id)).limit(1) + ) + + data_related_node = await session.scalar(query_statement) + return data_related_node != None diff --git a/cognee/modules/graph/methods/legacy_delete.py b/cognee/modules/graph/methods/legacy_delete.py new file mode 100644 index 000000000..7c499662e --- /dev/null +++ b/cognee/modules/graph/methods/legacy_delete.py @@ -0,0 +1,94 @@ +from uuid import UUID + +from cognee.api.v1.exceptions.exceptions import DocumentSubgraphNotFoundError +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.models import Data +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses + + +logger = get_logger() + + +async def legacy_delete(data: Data, mode: str = "soft"): + """Delete a single document by its content hash.""" + + # Delete from graph database + deleted_node_ids = await delete_document_subgraph(data.id, mode) + + # Delete from vector database + vector_engine = get_vector_engine() + + # Determine vector collections dynamically + subclasses = get_all_subclasses(DataPoint) + vector_collections = [] + + for subclass in subclasses: + index_fields = subclass.model_fields["metadata"].default.get("index_fields", []) + for field_name in index_fields: + vector_collections.append(f"{subclass.__name__}_{field_name}") + + # If no collections found, use default collections + if not vector_collections: + vector_collections = [ + "DocumentChunk_text", + "EdgeType_relationship_name", + "EntityType_name", + "Entity_name", + "TextDocument_name", + "TextSummary_text", + ] + + # Delete records from each vector collection that exists + for collection in vector_collections: + if await vector_engine.has_collection(collection): + await vector_engine.delete_data_points( + collection, [str(node_id) for node_id in deleted_node_ids] + ) + + +async def delete_document_subgraph(document_id: UUID, mode: str = "soft"): + """Delete a document and all its related nodes in the correct order.""" + graph_db = await get_graph_engine() + subgraph = await graph_db.get_document_subgraph(document_id) + if not subgraph: + raise DocumentSubgraphNotFoundError(f"Document not found with id: {document_id}") + + # Delete in the correct order to maintain graph integrity + deletion_order = [ + ("orphan_entities", "orphaned entities"), + ("orphan_types", "orphaned entity types"), + ( + "made_from_nodes", + "made_from nodes", + ), # Move before chunks since summaries are connected to chunks + ("chunks", "document chunks"), + ("document", "document"), + ] + + deleted_node_ids = [] + for key, description in deletion_order: + nodes = subgraph[key] + if nodes: + for node in nodes: + node_id = node["id"] + await graph_db.delete_node(node_id) + deleted_node_ids.append(node_id) + + # If hard mode, also delete degree-one nodes + if mode == "hard": + # Get and delete degree one entity nodes + degree_one_entity_nodes = await graph_db.get_degree_one_nodes("Entity") + for node in degree_one_entity_nodes: + await graph_db.delete_node(node["id"]) + deleted_node_ids.append(node["id"]) + + # Get and delete degree one entity types + degree_one_entity_types = await graph_db.get_degree_one_nodes("EntityType") + for node in degree_one_entity_types: + await graph_db.delete_node(node["id"]) + deleted_node_ids.append(node["id"]) + + return deleted_node_ids diff --git a/cognee/tests/test_delete_default_graph.py b/cognee/tests/test_delete_default_graph.py index 061292b48..ea75f2496 100644 --- a/cognee/tests/test_delete_default_graph.py +++ b/cognee/tests/test_delete_default_graph.py @@ -8,7 +8,6 @@ from cognee.api.v1.datasets import datasets from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.llm import LLMGateway -from cognee.modules.data.methods import get_dataset_data from cognee.modules.engine.operations.setup import setup from cognee.modules.users.methods import get_default_user from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent @@ -100,24 +99,19 @@ async def main(mock_create_structured_output: AsyncMock): ), ] - await cognee.add( + add_john_result = await cognee.add( "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" ) + johns_data_id = add_john_result.data_ingestion_info[0]["data_id"] - await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.") + add_marie_result = await cognee.add( + "Marie works for Apple as well. She is a software engineer on MacOS project." + ) + maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"] cognify_result: dict = await cognee.cognify() dataset_id = list(cognify_result.keys())[0] - dataset_data = await get_dataset_data(dataset_id) - added_data_1 = dataset_data[0] - added_data_2 = dataset_data[1] - - # file_path = os.path.join( - # pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html" - # ) - # await visualize_graph(file_path) - graph_engine = await get_graph_engine() initial_nodes, initial_edges = await graph_engine.get_graph_data() assert len(initial_nodes) == 15 and len(initial_edges) == 19, ( @@ -136,16 +130,13 @@ async def main(mock_create_structured_output: AsyncMock): initial_node_ids = set([node[0] for node in initial_nodes]) user = await get_default_user() - await datasets.delete_data(dataset_id, added_data_1.id, user) # type: ignore - - # file_path = os.path.join( - # pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html" - # ) - # await visualize_graph(file_path) + await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore nodes, edges = await graph_engine.get_graph_data() assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted." - assert not any(node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes), "Nodes are not deleted." + assert not any( + node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes + ), "Nodes are not deleted." after_first_delete_node_ids = set([node[0] for node in nodes]) @@ -166,7 +157,7 @@ async def main(mock_create_structured_output: AsyncMock): vector_items = await vector_engine.retrieve(collection_name, query_node_ids) assert len(vector_items) == 0, "Vector items are not deleted." - await datasets.delete_data(dataset_id, added_data_2.id, user) # type: ignore + await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore final_nodes, final_edges = await graph_engine.get_graph_data() assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted." diff --git a/cognee/tests/test_delete_default_graph_non_mocked.py b/cognee/tests/test_delete_default_graph_non_mocked.py index b267f7c7f..b2d167a70 100644 --- a/cognee/tests/test_delete_default_graph_non_mocked.py +++ b/cognee/tests/test_delete_default_graph_non_mocked.py @@ -1,18 +1,12 @@ import os import pathlib -import time -import pytest -from unittest.mock import AsyncMock, patch import cognee from cognee.api.v1.datasets import datasets from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.llm import LLMGateway -from cognee.modules.data.methods import get_dataset_data from cognee.modules.engine.operations.setup import setup from cognee.modules.users.methods import get_default_user -from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent from cognee.shared.logging_utils import get_logger logger = get_logger() @@ -41,22 +35,22 @@ async def main(): assert not await vector_engine.has_collection("TextSummary_text") assert not await vector_engine.has_collection("TextDocument_text") - await cognee.add( + add_result = await cognee.add( "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" ) + johns_data_id = add_result.data_ingestion_info[0]["data_id"] - await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.") + add_result = await cognee.add( + "Marie works for Apple as well. She is a software engineer on MacOS project." + ) + maries_data_id = add_result.data_ingestion_info[0]["data_id"] cognify_result: dict = await cognee.cognify() dataset_id = list(cognify_result.keys())[0] - dataset_data = await get_dataset_data(dataset_id) - added_data_1 = dataset_data[0] - added_data_2 = dataset_data[1] - graph_engine = await get_graph_engine() initial_nodes, initial_edges = await graph_engine.get_graph_data() - assert len(initial_nodes) >= 15 and len(initial_edges) >= 19, ( + assert len(initial_nodes) >= 14 and len(initial_edges) >= 18, ( "Number of nodes and edges is not correct." ) @@ -72,11 +66,15 @@ async def main(): initial_node_ids = set([node[0] for node in initial_nodes]) user = await get_default_user() - await datasets.delete_data(dataset_id, added_data_1.id, user) # type: ignore + await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore nodes, edges = await graph_engine.get_graph_data() - assert len(nodes) >= 9 and len(nodes) <= 11 and len(edges) >= 10 and len(edges) <= 12, "Nodes and edges are not deleted." - assert not any(node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes), "Nodes are not deleted." + assert len(nodes) >= 9 and len(nodes) <= 11 and len(edges) >= 10 and len(edges) <= 12, ( + "Nodes and edges are not deleted." + ) + assert not any( + node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes + ), "Nodes are not deleted." after_first_delete_node_ids = set([node[0] for node in nodes]) @@ -97,7 +95,7 @@ async def main(): vector_items = await vector_engine.retrieve(collection_name, query_node_ids) assert len(vector_items) == 0, "Vector items are not deleted." - await datasets.delete_data(dataset_id, added_data_2.id, user) # type: ignore + await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore final_nodes, final_edges = await graph_engine.get_graph_data() assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted." diff --git a/cognee/tests/test_delete_default_graph_with_legacy_data_1.py b/cognee/tests/test_delete_default_graph_with_legacy_data_1.py new file mode 100644 index 000000000..dc331199b --- /dev/null +++ b/cognee/tests/test_delete_default_graph_with_legacy_data_1.py @@ -0,0 +1,372 @@ +import os +import pathlib +from uuid import NAMESPACE_OID, uuid5 +import pytest +from unittest.mock import AsyncMock, patch + +import cognee +from cognee.api.v1.datasets import datasets +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.llm import LLMGateway +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.methods import create_authorized_dataset +from cognee.modules.data.models import Data +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.engine.operations.setup import setup +from cognee.modules.engine.utils import generate_edge_id, generate_node_id +from cognee.modules.pipelines.models import DataItemStatus +from cognee.modules.users.methods import get_default_user +from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent +from cognee.tasks.storage import index_data_points, index_graph_edges + +from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger + + +def get_nodes_and_edges(): + document = TextDocument( + id=uuid5(NAMESPACE_OID, "text_test.txt"), + name="text_test.txt", + raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt", + external_metadata="{}", + mime_type="text/plain", + ) + document_chunk = DocumentChunk( + id=uuid5( + NAMESPACE_OID, + "Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", + ), + text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", + chunk_size=187, + chunk_index=0, + cut_type="paragraph_end", + is_part_of=document, + ) + + graph_database = EntityType( + id=uuid5(NAMESPACE_OID, "graph_database"), + name="graph database", + description="graph database", + ) + neptune_analytics_entity = Entity( + id=generate_node_id("neptune analytics"), + name="neptune analytics", + description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.", + ) + neptune_database_entity = Entity( + id=generate_node_id("amazon neptune database"), + name="amazon neptune database", + description="A popular managed graph database that complements Neptune Analytics.", + ) + + storage = EntityType( + id=generate_node_id("storage"), + name="storage", + description="storage", + ) + storage_entity = Entity( + id=generate_node_id("amazon s3"), + name="amazon s3", + description="A storage service provided by Amazon Web Services that allows storing graph data.", + ) + + nodes_data = [ + document, + document_chunk, + graph_database, + neptune_analytics_entity, + neptune_database_entity, + storage, + storage_entity, + ] + + edges_data = [ + ( + document_chunk.id, + storage_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + storage_entity.id, + storage.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ( + document_chunk.id, + neptune_database_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + neptune_database_entity.id, + graph_database.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ( + document_chunk.id, + document.id, + "is_part_of", + { + "relationship_name": "is_part_of", + }, + ), + ( + document_chunk.id, + neptune_analytics_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + neptune_analytics_entity.id, + graph_database.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ] + + return nodes_data, edges_data + + +@pytest.mark.asyncio +@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock) +async def main(mock_create_structured_output: AsyncMock): + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_1" + ) + cognee.config.data_root_directory(data_directory_path) + + cognee_directory_path = os.path.join( + pathlib.Path(__file__).parent, + ".cognee_system/test_delete_default_graph_with_legacy_graph_1", + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + vector_engine = get_vector_engine() + + assert not await vector_engine.has_collection("EdgeType_relationship_name") + assert not await vector_engine.has_collection("Entity_name") + assert not await vector_engine.has_collection("DocumentChunk_text") + assert not await vector_engine.has_collection("TextSummary_text") + assert not await vector_engine.has_collection("TextDocument_text") + + user = await get_default_user() + + graph_engine = await get_graph_engine() + old_nodes, old_edges = get_nodes_and_edges() + old_document = old_nodes[0] + + await graph_engine.add_nodes(old_nodes) + await graph_engine.add_edges(old_edges) + + await index_data_points(old_nodes) + await index_graph_edges(old_edges) + + await record_data_in_legacy_ledger(old_nodes, old_edges, user) + + db_engine = get_relational_engine() + + dataset = await create_authorized_dataset("main_dataset", user) + + async with db_engine.get_async_session() as session: + old_data = Data( + id=old_document.id, + name=old_document.name, + extension="txt", + raw_data_location=old_document.raw_data_location, + external_metadata=old_document.external_metadata, + mime_type=old_document.mime_type, + owner_id=user.id, + pipeline_status={ + "cognify_pipeline": { + str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED, + } + }, + ) + session.add(old_data) + + dataset.data.append(old_data) + session.add(dataset) + + await session.commit() + + def mock_llm_output(text_input: str, system_prompt: str, response_model): + if text_input == "test": # LLM connection test + return "test" + + if "John" in text_input and response_model == SummarizedContent: + return SummarizedContent( + summary="Summary of John's work.", description="Summary of John's work." + ) + + if "Marie" in text_input and response_model == SummarizedContent: + return SummarizedContent( + summary="Summary of Marie's work.", description="Summary of Marie's work." + ) + + if "Marie" in text_input and response_model == KnowledgeGraph: + return KnowledgeGraph( + nodes=[ + Node(id="Marie", name="Marie", type="Person", description="Marie is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="MacOS", + name="MacOS", + type="Product", + description="MacOS is Apple's operating system", + ), + ], + edges=[ + Edge( + source_node_id="Marie", + target_node_id="Apple", + relationship_name="works_for", + ), + Edge( + source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on" + ), + ], + ) + + if "John" in text_input and response_model == KnowledgeGraph: + return KnowledgeGraph( + nodes=[ + Node(id="John", name="John", type="Person", description="John is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="Food for Hungry", + name="Food for Hungry", + type="Non-profit organization", + description="Food for Hungry is a non-profit organization", + ), + ], + edges=[ + Edge( + source_node_id="John", target_node_id="Apple", relationship_name="works_for" + ), + Edge( + source_node_id="John", + target_node_id="Food for Hungry", + relationship_name="works_for", + ), + ], + ) + + mock_create_structured_output.side_effect = mock_llm_output + + add_john_result = await cognee.add( + "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" + ) + johns_data_id = add_john_result.data_ingestion_info[0]["data_id"] + + add_marie_result = await cognee.add( + "Marie works for Apple as well. She is a software engineer on MacOS project." + ) + maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"] + + cognify_result: dict = await cognee.cognify() + dataset_id = list(cognify_result.keys())[0] + + graph_engine = await get_graph_engine() + initial_nodes, initial_edges = await graph_engine.get_graph_data() + assert len(initial_nodes) == 22 and len(initial_edges) == 26, ( + "Number of nodes and edges is not correct." + ) + + initial_nodes_by_vector_collection = {} + + for node in initial_nodes: + node_data = node[1] + collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] + if collection_name not in initial_nodes_by_vector_collection: + initial_nodes_by_vector_collection[collection_name] = [] + initial_nodes_by_vector_collection[collection_name].append(node) + + initial_node_ids = set([node[0] for node in initial_nodes]) + + await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 16 and len(edges) == 17, "Nodes and edges are not deleted." + assert not any( + node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes + ), "Nodes are not deleted." + + after_first_delete_node_ids = set([node[0] for node in nodes]) + + after_delete_nodes_by_vector_collection = {} + for node in initial_nodes: + node_data = node[1] + collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] + if collection_name not in after_delete_nodes_by_vector_collection: + after_delete_nodes_by_vector_collection[collection_name] = [] + after_delete_nodes_by_vector_collection[collection_name].append(node) + + removed_node_ids = initial_node_ids - after_first_delete_node_ids + + for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): + query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids] + + if query_node_ids: + vector_items = await vector_engine.retrieve(collection_name, query_node_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore + + final_nodes, final_edges = await graph_engine.get_graph_data() + assert len(final_nodes) == 7 and len(final_edges) == 7, "Nodes and edges are not deleted." + + old_nodes_by_vector_collection = {} + for node in old_nodes: + collection_name = node.type + "_" + node.metadata["index_fields"][0] + if collection_name not in old_nodes_by_vector_collection: + old_nodes_by_vector_collection[collection_name] = [] + old_nodes_by_vector_collection[collection_name].append(node) + + for collection_name, old_nodes in old_nodes_by_vector_collection.items(): + query_node_ids = [str(node.id) for node in old_nodes] + + if query_node_ids: + vector_items = await vector_engine.retrieve(collection_name, query_node_ids) + assert len(vector_items) == len(old_nodes), "Vector items are not deleted." + + query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges])) + + vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids) + assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted." + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/cognee/tests/test_delete_default_graph_with_legacy_data_2.py b/cognee/tests/test_delete_default_graph_with_legacy_data_2.py new file mode 100644 index 000000000..f9ec9434f --- /dev/null +++ b/cognee/tests/test_delete_default_graph_with_legacy_data_2.py @@ -0,0 +1,372 @@ +import os +import pathlib +from uuid import NAMESPACE_OID, uuid5 +import pytest +from unittest.mock import AsyncMock, patch + +import cognee +from cognee.api.v1.datasets import datasets +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.llm import LLMGateway +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.methods import create_authorized_dataset +from cognee.modules.data.models import Data +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.engine.operations.setup import setup +from cognee.modules.engine.utils import generate_edge_id, generate_node_id +from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger +from cognee.modules.pipelines.models import DataItemStatus +from cognee.modules.users.methods import get_default_user +from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent +from cognee.tasks.storage import index_data_points, index_graph_edges + + +def get_nodes_and_edges(): + document = TextDocument( + id=uuid5(NAMESPACE_OID, "text_test.txt"), + name="text_test.txt", + raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt", + external_metadata="{}", + mime_type="text/plain", + ) + document_chunk = DocumentChunk( + id=uuid5( + NAMESPACE_OID, + "Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", + ), + text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", + chunk_size=187, + chunk_index=0, + cut_type="paragraph_end", + is_part_of=document, + ) + + graph_database = EntityType( + id=uuid5(NAMESPACE_OID, "graph_database"), + name="graph database", + description="graph database", + ) + neptune_analytics_entity = Entity( + id=generate_node_id("neptune analytics"), + name="neptune analytics", + description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.", + ) + neptune_database_entity = Entity( + id=generate_node_id("amazon neptune database"), + name="amazon neptune database", + description="A popular managed graph database that complements Neptune Analytics.", + ) + + storage = EntityType( + id=generate_node_id("storage"), + name="storage", + description="storage", + ) + storage_entity = Entity( + id=generate_node_id("amazon s3"), + name="amazon s3", + description="A storage service provided by Amazon Web Services that allows storing graph data.", + ) + + nodes_data = [ + document, + document_chunk, + graph_database, + neptune_analytics_entity, + neptune_database_entity, + storage, + storage_entity, + ] + + edges_data = [ + ( + document_chunk.id, + storage_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + storage_entity.id, + storage.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ( + document_chunk.id, + neptune_database_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + neptune_database_entity.id, + graph_database.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ( + document_chunk.id, + document.id, + "is_part_of", + { + "relationship_name": "is_part_of", + }, + ), + ( + document_chunk.id, + neptune_analytics_entity.id, + "contains", + { + "relationship_name": "contains", + }, + ), + ( + neptune_analytics_entity.id, + graph_database.id, + "is_a", + { + "relationship_name": "is_a", + }, + ), + ] + + return nodes_data, edges_data + + +@pytest.mark.asyncio +@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock) +async def main(mock_create_structured_output: AsyncMock): + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_2" + ) + cognee.config.data_root_directory(data_directory_path) + + cognee_directory_path = os.path.join( + pathlib.Path(__file__).parent, + ".cognee_system/test_delete_default_graph_with_legacy_graph_2", + ) + cognee.config.system_root_directory(cognee_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + vector_engine = get_vector_engine() + + assert not await vector_engine.has_collection("EdgeType_relationship_name") + assert not await vector_engine.has_collection("Entity_name") + assert not await vector_engine.has_collection("DocumentChunk_text") + assert not await vector_engine.has_collection("TextSummary_text") + assert not await vector_engine.has_collection("TextDocument_text") + + user = await get_default_user() + + graph_engine = await get_graph_engine() + old_nodes, old_edges = get_nodes_and_edges() + old_document = old_nodes[0] + + await graph_engine.add_nodes(old_nodes) + await graph_engine.add_edges(old_edges) + + await index_data_points(old_nodes) + await index_graph_edges(old_edges) + + await record_data_in_legacy_ledger(old_nodes, old_edges, user) + + db_engine = get_relational_engine() + + dataset = await create_authorized_dataset("main_dataset", user) + + async with db_engine.get_async_session() as session: + old_data = Data( + id=old_document.id, + name=old_document.name, + extension="txt", + raw_data_location=old_document.raw_data_location, + external_metadata=old_document.external_metadata, + mime_type=old_document.mime_type, + owner_id=user.id, + pipeline_status={ + "cognify_pipeline": { + str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED, + } + }, + ) + session.add(old_data) + + dataset.data.append(old_data) + session.add(dataset) + + await session.commit() + + def mock_llm_output(text_input: str, system_prompt: str, response_model): + if text_input == "test": # LLM connection test + return "test" + + if "John" in text_input and response_model == SummarizedContent: + return SummarizedContent( + summary="Summary of John's work.", description="Summary of John's work." + ) + + if "Marie" in text_input and response_model == SummarizedContent: + return SummarizedContent( + summary="Summary of Marie's work.", description="Summary of Marie's work." + ) + + if "Marie" in text_input and response_model == KnowledgeGraph: + return KnowledgeGraph( + nodes=[ + Node(id="Marie", name="Marie", type="Person", description="Marie is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="MacOS", + name="MacOS", + type="Product", + description="MacOS is Apple's operating system", + ), + ], + edges=[ + Edge( + source_node_id="Marie", + target_node_id="Apple", + relationship_name="works_for", + ), + Edge( + source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on" + ), + ], + ) + + if "John" in text_input and response_model == KnowledgeGraph: + return KnowledgeGraph( + nodes=[ + Node(id="John", name="John", type="Person", description="John is a person"), + Node( + id="Apple", + name="Apple", + type="Company", + description="Apple is a company", + ), + Node( + id="Food for Hungry", + name="Food for Hungry", + type="Non-profit organization", + description="Food for Hungry is a non-profit organization", + ), + ], + edges=[ + Edge( + source_node_id="John", target_node_id="Apple", relationship_name="works_for" + ), + Edge( + source_node_id="John", + target_node_id="Food for Hungry", + relationship_name="works_for", + ), + ], + ) + + mock_create_structured_output.side_effect = mock_llm_output + + add_john_result = await cognee.add( + "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'" + ) + johns_data_id = add_john_result.data_ingestion_info[0]["data_id"] + + add_marie_result = await cognee.add( + "Marie works for Apple as well. She is a software engineer on MacOS project." + ) + # maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"] + + cognify_result: dict = await cognee.cognify() + dataset_id = list(cognify_result.keys())[0] + + graph_engine = await get_graph_engine() + initial_nodes, initial_edges = await graph_engine.get_graph_data() + assert len(initial_nodes) == 22 and len(initial_edges) == 26, ( + "Number of nodes and edges is not correct." + ) + + initial_nodes_by_vector_collection = {} + + for node in initial_nodes: + node_data = node[1] + collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] + if collection_name not in initial_nodes_by_vector_collection: + initial_nodes_by_vector_collection[collection_name] = [] + initial_nodes_by_vector_collection[collection_name].append(node) + + initial_node_ids = set([node[0] for node in initial_nodes]) + + await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore + + nodes, edges = await graph_engine.get_graph_data() + assert len(nodes) == 16 and len(edges) == 17, "Nodes and edges are not deleted." + assert not any( + node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes + ), "Nodes are not deleted." + + after_first_delete_node_ids = set([node[0] for node in nodes]) + + after_delete_nodes_by_vector_collection = {} + for node in initial_nodes: + node_data = node[1] + collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0] + if collection_name not in after_delete_nodes_by_vector_collection: + after_delete_nodes_by_vector_collection[collection_name] = [] + after_delete_nodes_by_vector_collection[collection_name].append(node) + + removed_node_ids = initial_node_ids - after_first_delete_node_ids + + for collection_name, initial_nodes in initial_nodes_by_vector_collection.items(): + query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids] + + if query_node_ids: + vector_items = await vector_engine.retrieve(collection_name, query_node_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + # Delete old document + await datasets.delete_data(dataset_id, old_document.id, user) # type: ignore + + final_nodes, final_edges = await graph_engine.get_graph_data() + assert len(final_nodes) == 9 and len(final_edges) == 10, "Nodes and edges are not deleted." + + old_nodes_by_vector_collection = {} + for node in old_nodes: + collection_name = node.type + "_" + node.metadata["index_fields"][0] + if collection_name not in old_nodes_by_vector_collection: + old_nodes_by_vector_collection[collection_name] = [] + old_nodes_by_vector_collection[collection_name].append(node) + + for collection_name, old_nodes in old_nodes_by_vector_collection.items(): + query_node_ids = [str(node.id) for node in old_nodes] + + if query_node_ids: + vector_items = await vector_engine.retrieve(collection_name, query_node_ids) + assert len(vector_items) == 0, "Vector items are not deleted." + + query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges])) + + vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids) + assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted." + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main())