fix: handle data deletion with backwards compatibility
This commit is contained in:
parent
5efd0a4fb6
commit
79983c25ee
17 changed files with 1148 additions and 57 deletions
54
.github/workflows/e2e_tests.yml
vendored
54
.github/workflows/e2e_tests.yml
vendored
|
|
@ -346,6 +346,60 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_delete_custom_graph.py
|
||||
|
||||
test-deletion-on-default-graph_with_legacy_data_1:
|
||||
name: Delete default graph with legacy data test 1
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run deletion on custom graph
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_1.py
|
||||
|
||||
test-deletion-on-default-graph_with_legacy_data_2:
|
||||
name: Delete default graph with legacy data test 2
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run deletion on custom graph
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_delete_default_graph_with_legacy_data_2.py
|
||||
|
||||
test-graph-edges:
|
||||
name: Test graph edge ingestion
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
|
|||
|
|
@ -7,7 +7,12 @@ from cognee.modules.users.exceptions import PermissionDeniedError
|
|||
from cognee.modules.data.methods import has_dataset_data
|
||||
from cognee.modules.data.methods import get_authorized_dataset, get_authorized_existing_datasets
|
||||
from cognee.modules.data.exceptions.exceptions import UnauthorizedDataAccessError
|
||||
from cognee.modules.graph.methods import delete_data_nodes_and_edges, delete_dataset_nodes_and_edges
|
||||
from cognee.modules.graph.methods import (
|
||||
delete_data_nodes_and_edges,
|
||||
delete_dataset_nodes_and_edges,
|
||||
has_data_related_nodes,
|
||||
legacy_delete,
|
||||
)
|
||||
from cognee.modules.ingestion import discover_directory_datasets
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
|
||||
|
|
@ -66,7 +71,9 @@ class datasets:
|
|||
return await delete_dataset(dataset)
|
||||
|
||||
@staticmethod
|
||||
async def delete_data(dataset_id: UUID, data_id: UUID, user: Optional[User] = None):
|
||||
async def delete_data(
|
||||
dataset_id: UUID, data_id: UUID, user: Optional[User] = None, mode: str = "soft"
|
||||
):
|
||||
from cognee.modules.data.methods import delete_data, get_data
|
||||
|
||||
if not user:
|
||||
|
|
@ -81,7 +88,7 @@ class datasets:
|
|||
|
||||
if not data:
|
||||
# If data is not found in the system, user is using a custom graph model.
|
||||
await delete_data_nodes_and_edges(dataset_id, data_id)
|
||||
await delete_data_nodes_and_edges(dataset_id, data_id, user.id)
|
||||
return
|
||||
|
||||
data_datasets = data.datasets
|
||||
|
|
@ -89,7 +96,10 @@ class datasets:
|
|||
if not data or not any([dataset.id == dataset_id for dataset in data_datasets]):
|
||||
raise UnauthorizedDataAccessError(f"Data {data_id} not accessible.")
|
||||
|
||||
await delete_data_nodes_and_edges(dataset_id, data.id)
|
||||
if not await has_data_related_nodes(dataset_id, data_id):
|
||||
await legacy_delete(data, mode)
|
||||
else:
|
||||
await delete_data_nodes_and_edges(dataset_id, data_id, user.id)
|
||||
|
||||
await delete_data(data)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,9 +54,10 @@ def get_delete_router() -> APIRouter:
|
|||
|
||||
try:
|
||||
result = await datasets.delete_data(
|
||||
data_id=data_id,
|
||||
dataset_id=dataset_id,
|
||||
data_id=data_id,
|
||||
user=user,
|
||||
mode=mode,
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -206,12 +206,12 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if len(data_point_ids) == 1:
|
||||
results = await collection.query().where(f"id = '{data_point_ids[0]}'")
|
||||
query = collection.query().where(f"id = '{data_point_ids[0]}'")
|
||||
else:
|
||||
results = await collection.query().where(f"id IN {tuple(data_point_ids)}")
|
||||
query = collection.query().where(f"id IN {tuple(data_point_ids)}")
|
||||
|
||||
# Convert query results to list format
|
||||
results_list = results.to_list() if hasattr(results, "to_list") else list(results)
|
||||
results_list = await query.to_list()
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from uuid import NAMESPACE_OID, uuid5
|
||||
from uuid import NAMESPACE_OID, UUID, uuid5
|
||||
|
||||
|
||||
def generate_node_id(node_id: str) -> str:
|
||||
def generate_node_id(node_id: str) -> UUID:
|
||||
return uuid5(NAMESPACE_OID, node_id.lower().replace(" ", "_").replace("'", ""))
|
||||
|
|
|
|||
40
cognee/modules/graph/legacy/GraphRelationshipLedger.py
Normal file
40
cognee/modules/graph/legacy/GraphRelationshipLedger.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import UUID, Column, DateTime, String, Index
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class GraphRelationshipLedger(Base):
|
||||
__tablename__ = "graph_relationship_ledger"
|
||||
|
||||
id = Column(
|
||||
UUID,
|
||||
primary_key=True,
|
||||
default=lambda: uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
)
|
||||
source_node_id = Column(UUID, nullable=False)
|
||||
destination_node_id = Column(UUID, nullable=False)
|
||||
creator_function = Column(String, nullable=False)
|
||||
node_label = Column(String, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True)
|
||||
user_id = Column(UUID, nullable=True)
|
||||
|
||||
# Create indexes
|
||||
__table_args__ = (
|
||||
Index("idx_graph_relationship_id", "id"),
|
||||
Index("idx_graph_relationship_ledger_source_node_id", "source_node_id"),
|
||||
Index("idx_graph_relationship_ledger_destination_node_id", "destination_node_id"),
|
||||
)
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"source_node_id": str(self.parent_id),
|
||||
"destination_node_id": str(self.child_id),
|
||||
"creator_function": self.creator_function,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"deleted_at": self.deleted_at.isoformat() if self.deleted_at else None,
|
||||
"user_id": str(self.user_id),
|
||||
}
|
||||
49
cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py
Normal file
49
cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from uuid import UUID
|
||||
from typing import List
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Edge, Node
|
||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session: AsyncSession):
|
||||
if len(edges) == 0:
|
||||
return []
|
||||
|
||||
query = select(GraphRelationshipLedger).where(
|
||||
and_(
|
||||
GraphRelationshipLedger.user_id == user_id,
|
||||
or_(
|
||||
*[
|
||||
GraphRelationshipLedger.creator_function.ilike(f"%{edge.relationship_name}")
|
||||
for edge in edges
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
legacy_edges = (await session.scalars(query)).all()
|
||||
|
||||
legacy_edge_names = set([edge.creator_function.split(".")[1] for edge in legacy_edges])
|
||||
|
||||
return [edge.relationship_name in legacy_edge_names for edge in edges]
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def get_node_ids(edges: List[Edge], session: AsyncSession):
|
||||
node_slugs = []
|
||||
|
||||
for edge in edges:
|
||||
node_slugs.append(edge.source_node_id)
|
||||
node_slugs.append(edge.destination_node_id)
|
||||
|
||||
query = select(Node).where(Node.slug.in_(node_slugs))
|
||||
|
||||
nodes = (await session.scalars(query)).all()
|
||||
|
||||
node_ids = {node.slug: node.id for node in nodes}
|
||||
|
||||
return node_ids
|
||||
36
cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py
Normal file
36
cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from typing import List
|
||||
from uuid import UUID
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Node
|
||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def has_nodes_in_legacy_ledger(nodes: List[Node], user_id: UUID, session: AsyncSession):
|
||||
node_ids = [node.slug for node in nodes]
|
||||
|
||||
query = select(
|
||||
GraphRelationshipLedger.source_node_id,
|
||||
GraphRelationshipLedger.destination_node_id,
|
||||
).where(
|
||||
and_(
|
||||
GraphRelationshipLedger.user_id == user_id,
|
||||
or_(
|
||||
GraphRelationshipLedger.source_node_id.in_(node_ids),
|
||||
GraphRelationshipLedger.destination_node_id.in_(node_ids),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
legacy_nodes = await session.execute(query)
|
||||
entries = legacy_nodes.all()
|
||||
|
||||
found_ids = set()
|
||||
for entry in entries:
|
||||
found_ids.add(entry.source_node_id)
|
||||
found_ids.add(entry.destination_node_id)
|
||||
|
||||
return [node_id in found_ids for node_id in node_ids]
|
||||
38
cognee/modules/graph/legacy/record_data_in_legacy_ledger.py
Normal file
38
cognee/modules/graph/legacy/record_data_in_legacy_ledger.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from uuid import UUID
|
||||
from typing import Dict, List, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.users.models.User import User
|
||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def record_data_in_legacy_ledger(
|
||||
nodes: List[DataPoint],
|
||||
edges: List[Tuple[UUID, UUID, str, Dict]],
|
||||
user: User,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
relationships = [
|
||||
GraphRelationshipLedger(
|
||||
source_node_id=node.id,
|
||||
destination_node_id=node.id,
|
||||
creator_function="add_nodes",
|
||||
user_id=user.id,
|
||||
)
|
||||
for node in nodes
|
||||
] + [
|
||||
GraphRelationshipLedger(
|
||||
source_node_id=edge[0],
|
||||
destination_node_id=edge[1],
|
||||
creator_function=f"add_edges.{edge[2]}",
|
||||
user_id=user.id,
|
||||
)
|
||||
for edge in edges
|
||||
]
|
||||
|
||||
session.add_all(relationships)
|
||||
|
||||
await session.commit()
|
||||
|
|
@ -3,6 +3,8 @@ from .get_formatted_graph_data import get_formatted_graph_data
|
|||
from .upsert_edges import upsert_edges
|
||||
from .upsert_nodes import upsert_nodes
|
||||
|
||||
from .has_data_related_nodes import has_data_related_nodes
|
||||
|
||||
from .get_data_related_nodes import get_data_related_nodes
|
||||
from .get_data_related_edges import get_data_related_edges
|
||||
from .delete_data_related_nodes import delete_data_related_nodes
|
||||
|
|
@ -14,3 +16,5 @@ from .get_dataset_related_edges import get_dataset_related_edges
|
|||
from .delete_dataset_related_nodes import delete_dataset_related_nodes
|
||||
from .delete_dataset_related_edges import delete_dataset_related_edges
|
||||
from .delete_dataset_nodes_and_edges import delete_dataset_nodes_and_edges
|
||||
|
||||
from .legacy_delete import legacy_delete
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Dict, List
|
|||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
|
||||
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
|
||||
from cognee.modules.graph.methods import (
|
||||
delete_data_related_edges,
|
||||
delete_data_related_nodes,
|
||||
|
|
@ -11,17 +13,26 @@ from cognee.modules.graph.methods import (
|
|||
)
|
||||
|
||||
|
||||
async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID) -> None:
|
||||
async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id: UUID) -> None:
|
||||
affected_nodes = await get_data_related_nodes(dataset_id, data_id)
|
||||
|
||||
if len(affected_nodes) == 0:
|
||||
return
|
||||
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||
|
||||
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||
|
||||
non_legacy_nodes = [
|
||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||
]
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes])
|
||||
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
|
||||
|
||||
affected_vector_collections: Dict[str, List] = {}
|
||||
for node in affected_nodes:
|
||||
for node in non_legacy_nodes:
|
||||
for indexed_field in node.indexed_fields:
|
||||
collection_name = f"{node.type}_{indexed_field}"
|
||||
if collection_name not in affected_vector_collections:
|
||||
|
|
@ -29,17 +40,22 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID) -> None:
|
|||
affected_vector_collections[collection_name].append(node)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
for affected_collection, affected_nodes in affected_vector_collections.items():
|
||||
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
|
||||
await vector_engine.delete_data_points(
|
||||
affected_collection, [node.slug for node in affected_nodes]
|
||||
affected_collection, [str(node.slug) for node in non_legacy_nodes]
|
||||
)
|
||||
|
||||
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
||||
if len(affected_relationships) > 0:
|
||||
non_legacy_relationships = [
|
||||
edge
|
||||
for index, edge in enumerate(affected_relationships)
|
||||
if not is_legacy_relationship[index]
|
||||
]
|
||||
|
||||
await vector_engine.delete_data_points(
|
||||
"EdgeType_relationship_name",
|
||||
[edge.slug for edge in affected_relationships],
|
||||
)
|
||||
await vector_engine.delete_data_points(
|
||||
"EdgeType_relationship_name",
|
||||
[str(relationship.slug) for relationship in non_legacy_relationships],
|
||||
)
|
||||
|
||||
await delete_data_related_nodes(data_id)
|
||||
await delete_data_related_edges(data_id)
|
||||
|
|
|
|||
16
cognee/modules/graph/methods/has_data_related_nodes.py
Normal file
16
cognee/modules/graph/methods/has_data_related_nodes.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Node
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def has_data_related_nodes(dataset_id: UUID, data_id: UUID, session: AsyncSession):
|
||||
query_statement = (
|
||||
select(Node).where(and_(Node.data_id == data_id, Node.dataset_id == dataset_id)).limit(1)
|
||||
)
|
||||
|
||||
data_related_node = await session.scalar(query_statement)
|
||||
return data_related_node != None
|
||||
94
cognee/modules/graph/methods/legacy_delete.py
Normal file
94
cognee/modules/graph/methods/legacy_delete.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.api.v1.exceptions.exceptions import DocumentSubgraphNotFoundError
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def legacy_delete(data: Data, mode: str = "soft"):
|
||||
"""Delete a single document by its content hash."""
|
||||
|
||||
# Delete from graph database
|
||||
deleted_node_ids = await delete_document_subgraph(data.id, mode)
|
||||
|
||||
# Delete from vector database
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
# Determine vector collections dynamically
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
vector_collections = []
|
||||
|
||||
for subclass in subclasses:
|
||||
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
|
||||
for field_name in index_fields:
|
||||
vector_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
# If no collections found, use default collections
|
||||
if not vector_collections:
|
||||
vector_collections = [
|
||||
"DocumentChunk_text",
|
||||
"EdgeType_relationship_name",
|
||||
"EntityType_name",
|
||||
"Entity_name",
|
||||
"TextDocument_name",
|
||||
"TextSummary_text",
|
||||
]
|
||||
|
||||
# Delete records from each vector collection that exists
|
||||
for collection in vector_collections:
|
||||
if await vector_engine.has_collection(collection):
|
||||
await vector_engine.delete_data_points(
|
||||
collection, [str(node_id) for node_id in deleted_node_ids]
|
||||
)
|
||||
|
||||
|
||||
async def delete_document_subgraph(document_id: UUID, mode: str = "soft"):
|
||||
"""Delete a document and all its related nodes in the correct order."""
|
||||
graph_db = await get_graph_engine()
|
||||
subgraph = await graph_db.get_document_subgraph(document_id)
|
||||
if not subgraph:
|
||||
raise DocumentSubgraphNotFoundError(f"Document not found with id: {document_id}")
|
||||
|
||||
# Delete in the correct order to maintain graph integrity
|
||||
deletion_order = [
|
||||
("orphan_entities", "orphaned entities"),
|
||||
("orphan_types", "orphaned entity types"),
|
||||
(
|
||||
"made_from_nodes",
|
||||
"made_from nodes",
|
||||
), # Move before chunks since summaries are connected to chunks
|
||||
("chunks", "document chunks"),
|
||||
("document", "document"),
|
||||
]
|
||||
|
||||
deleted_node_ids = []
|
||||
for key, description in deletion_order:
|
||||
nodes = subgraph[key]
|
||||
if nodes:
|
||||
for node in nodes:
|
||||
node_id = node["id"]
|
||||
await graph_db.delete_node(node_id)
|
||||
deleted_node_ids.append(node_id)
|
||||
|
||||
# If hard mode, also delete degree-one nodes
|
||||
if mode == "hard":
|
||||
# Get and delete degree one entity nodes
|
||||
degree_one_entity_nodes = await graph_db.get_degree_one_nodes("Entity")
|
||||
for node in degree_one_entity_nodes:
|
||||
await graph_db.delete_node(node["id"])
|
||||
deleted_node_ids.append(node["id"])
|
||||
|
||||
# Get and delete degree one entity types
|
||||
degree_one_entity_types = await graph_db.get_degree_one_nodes("EntityType")
|
||||
for node in degree_one_entity_types:
|
||||
await graph_db.delete_node(node["id"])
|
||||
deleted_node_ids.append(node["id"])
|
||||
|
||||
return deleted_node_ids
|
||||
|
|
@ -8,7 +8,6 @@ from cognee.api.v1.datasets import datasets
|
|||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.modules.data.methods import get_dataset_data
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||
|
|
@ -100,24 +99,19 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
),
|
||||
]
|
||||
|
||||
await cognee.add(
|
||||
add_john_result = await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
|
||||
add_marie_result = await cognee.add(
|
||||
"Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
)
|
||||
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
dataset_id = list(cognify_result.keys())[0]
|
||||
|
||||
dataset_data = await get_dataset_data(dataset_id)
|
||||
added_data_1 = dataset_data[0]
|
||||
added_data_2 = dataset_data[1]
|
||||
|
||||
# file_path = os.path.join(
|
||||
# pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_full.html"
|
||||
# )
|
||||
# await visualize_graph(file_path)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||
assert len(initial_nodes) == 15 and len(initial_edges) == 19, (
|
||||
|
|
@ -136,16 +130,13 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||
|
||||
user = await get_default_user()
|
||||
await datasets.delete_data(dataset_id, added_data_1.id, user) # type: ignore
|
||||
|
||||
# file_path = os.path.join(
|
||||
# pathlib.Path(__file__).parent, ".artifacts", "graph_visualization_after_delete.html"
|
||||
# )
|
||||
# await visualize_graph(file_path)
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted."
|
||||
assert not any(node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes), "Nodes are not deleted."
|
||||
assert not any(
|
||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes
|
||||
), "Nodes are not deleted."
|
||||
|
||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||
|
||||
|
|
@ -166,7 +157,7 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
|
||||
await datasets.delete_data(dataset_id, added_data_2.id, user) # type: ignore
|
||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||
|
||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
|
||||
|
|
|
|||
|
|
@ -1,18 +1,12 @@
|
|||
import os
|
||||
import pathlib
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.modules.data.methods import get_dataset_data
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -41,22 +35,22 @@ async def main():
|
|||
assert not await vector_engine.has_collection("TextSummary_text")
|
||||
assert not await vector_engine.has_collection("TextDocument_text")
|
||||
|
||||
await cognee.add(
|
||||
add_result = await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
johns_data_id = add_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
|
||||
add_result = await cognee.add(
|
||||
"Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
)
|
||||
maries_data_id = add_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
dataset_id = list(cognify_result.keys())[0]
|
||||
|
||||
dataset_data = await get_dataset_data(dataset_id)
|
||||
added_data_1 = dataset_data[0]
|
||||
added_data_2 = dataset_data[1]
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||
assert len(initial_nodes) >= 15 and len(initial_edges) >= 19, (
|
||||
assert len(initial_nodes) >= 14 and len(initial_edges) >= 18, (
|
||||
"Number of nodes and edges is not correct."
|
||||
)
|
||||
|
||||
|
|
@ -72,11 +66,15 @@ async def main():
|
|||
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||
|
||||
user = await get_default_user()
|
||||
await datasets.delete_data(dataset_id, added_data_1.id, user) # type: ignore
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) >= 9 and len(nodes) <= 11 and len(edges) >= 10 and len(edges) <= 12, "Nodes and edges are not deleted."
|
||||
assert not any(node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes), "Nodes are not deleted."
|
||||
assert len(nodes) >= 9 and len(nodes) <= 11 and len(edges) >= 10 and len(edges) <= 12, (
|
||||
"Nodes and edges are not deleted."
|
||||
)
|
||||
assert not any(
|
||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes
|
||||
), "Nodes are not deleted."
|
||||
|
||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||
|
||||
|
|
@ -97,7 +95,7 @@ async def main():
|
|||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
|
||||
await datasets.delete_data(dataset_id, added_data_2.id, user) # type: ignore
|
||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||
|
||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
|
||||
|
|
|
|||
372
cognee/tests/test_delete_default_graph_with_legacy_data_1.py
Normal file
372
cognee/tests/test_delete_default_graph_with_legacy_data_1.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
import os
|
||||
import pathlib
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.methods import create_authorized_dataset
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
|
||||
from cognee.modules.pipelines.models import DataItemStatus
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||
from cognee.tasks.storage import index_data_points, index_graph_edges
|
||||
|
||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||
|
||||
|
||||
def get_nodes_and_edges():
|
||||
document = TextDocument(
|
||||
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
||||
name="text_test.txt",
|
||||
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt",
|
||||
external_metadata="{}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
document_chunk = DocumentChunk(
|
||||
id=uuid5(
|
||||
NAMESPACE_OID,
|
||||
"Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
),
|
||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
chunk_size=187,
|
||||
chunk_index=0,
|
||||
cut_type="paragraph_end",
|
||||
is_part_of=document,
|
||||
)
|
||||
|
||||
graph_database = EntityType(
|
||||
id=uuid5(NAMESPACE_OID, "graph_database"),
|
||||
name="graph database",
|
||||
description="graph database",
|
||||
)
|
||||
neptune_analytics_entity = Entity(
|
||||
id=generate_node_id("neptune analytics"),
|
||||
name="neptune analytics",
|
||||
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
||||
)
|
||||
neptune_database_entity = Entity(
|
||||
id=generate_node_id("amazon neptune database"),
|
||||
name="amazon neptune database",
|
||||
description="A popular managed graph database that complements Neptune Analytics.",
|
||||
)
|
||||
|
||||
storage = EntityType(
|
||||
id=generate_node_id("storage"),
|
||||
name="storage",
|
||||
description="storage",
|
||||
)
|
||||
storage_entity = Entity(
|
||||
id=generate_node_id("amazon s3"),
|
||||
name="amazon s3",
|
||||
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||
)
|
||||
|
||||
nodes_data = [
|
||||
document,
|
||||
document_chunk,
|
||||
graph_database,
|
||||
neptune_analytics_entity,
|
||||
neptune_database_entity,
|
||||
storage,
|
||||
storage_entity,
|
||||
]
|
||||
|
||||
edges_data = [
|
||||
(
|
||||
document_chunk.id,
|
||||
storage_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
storage_entity.id,
|
||||
storage.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
neptune_database_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
neptune_database_entity.id,
|
||||
graph_database.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
document.id,
|
||||
"is_part_of",
|
||||
{
|
||||
"relationship_name": "is_part_of",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
neptune_analytics_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
neptune_analytics_entity.id,
|
||||
graph_database.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
return nodes_data, edges_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||
async def main(mock_create_structured_output: AsyncMock):
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_1"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
cognee_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_delete_default_graph_with_legacy_graph_1",
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||
assert not await vector_engine.has_collection("Entity_name")
|
||||
assert not await vector_engine.has_collection("DocumentChunk_text")
|
||||
assert not await vector_engine.has_collection("TextSummary_text")
|
||||
assert not await vector_engine.has_collection("TextDocument_text")
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
old_nodes, old_edges = get_nodes_and_edges()
|
||||
old_document = old_nodes[0]
|
||||
|
||||
await graph_engine.add_nodes(old_nodes)
|
||||
await graph_engine.add_edges(old_edges)
|
||||
|
||||
await index_data_points(old_nodes)
|
||||
await index_graph_edges(old_edges)
|
||||
|
||||
await record_data_in_legacy_ledger(old_nodes, old_edges, user)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
dataset = await create_authorized_dataset("main_dataset", user)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
old_data = Data(
|
||||
id=old_document.id,
|
||||
name=old_document.name,
|
||||
extension="txt",
|
||||
raw_data_location=old_document.raw_data_location,
|
||||
external_metadata=old_document.external_metadata,
|
||||
mime_type=old_document.mime_type,
|
||||
owner_id=user.id,
|
||||
pipeline_status={
|
||||
"cognify_pipeline": {
|
||||
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED,
|
||||
}
|
||||
},
|
||||
)
|
||||
session.add(old_data)
|
||||
|
||||
dataset.data.append(old_data)
|
||||
session.add(dataset)
|
||||
|
||||
await session.commit()
|
||||
|
||||
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||
if text_input == "test": # LLM connection test
|
||||
return "test"
|
||||
|
||||
if "John" in text_input and response_model == SummarizedContent:
|
||||
return SummarizedContent(
|
||||
summary="Summary of John's work.", description="Summary of John's work."
|
||||
)
|
||||
|
||||
if "Marie" in text_input and response_model == SummarizedContent:
|
||||
return SummarizedContent(
|
||||
summary="Summary of Marie's work.", description="Summary of Marie's work."
|
||||
)
|
||||
|
||||
if "Marie" in text_input and response_model == KnowledgeGraph:
|
||||
return KnowledgeGraph(
|
||||
nodes=[
|
||||
Node(id="Marie", name="Marie", type="Person", description="Marie is a person"),
|
||||
Node(
|
||||
id="Apple",
|
||||
name="Apple",
|
||||
type="Company",
|
||||
description="Apple is a company",
|
||||
),
|
||||
Node(
|
||||
id="MacOS",
|
||||
name="MacOS",
|
||||
type="Product",
|
||||
description="MacOS is Apple's operating system",
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
Edge(
|
||||
source_node_id="Marie",
|
||||
target_node_id="Apple",
|
||||
relationship_name="works_for",
|
||||
),
|
||||
Edge(
|
||||
source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on"
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if "John" in text_input and response_model == KnowledgeGraph:
|
||||
return KnowledgeGraph(
|
||||
nodes=[
|
||||
Node(id="John", name="John", type="Person", description="John is a person"),
|
||||
Node(
|
||||
id="Apple",
|
||||
name="Apple",
|
||||
type="Company",
|
||||
description="Apple is a company",
|
||||
),
|
||||
Node(
|
||||
id="Food for Hungry",
|
||||
name="Food for Hungry",
|
||||
type="Non-profit organization",
|
||||
description="Food for Hungry is a non-profit organization",
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
Edge(
|
||||
source_node_id="John", target_node_id="Apple", relationship_name="works_for"
|
||||
),
|
||||
Edge(
|
||||
source_node_id="John",
|
||||
target_node_id="Food for Hungry",
|
||||
relationship_name="works_for",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
mock_create_structured_output.side_effect = mock_llm_output
|
||||
|
||||
add_john_result = await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
add_marie_result = await cognee.add(
|
||||
"Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
)
|
||||
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
dataset_id = list(cognify_result.keys())[0]
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||
assert len(initial_nodes) == 22 and len(initial_edges) == 26, (
|
||||
"Number of nodes and edges is not correct."
|
||||
)
|
||||
|
||||
initial_nodes_by_vector_collection = {}
|
||||
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||
if collection_name not in initial_nodes_by_vector_collection:
|
||||
initial_nodes_by_vector_collection[collection_name] = []
|
||||
initial_nodes_by_vector_collection[collection_name].append(node)
|
||||
|
||||
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 16 and len(edges) == 17, "Nodes and edges are not deleted."
|
||||
assert not any(
|
||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes
|
||||
), "Nodes are not deleted."
|
||||
|
||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||
|
||||
after_delete_nodes_by_vector_collection = {}
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||
if collection_name not in after_delete_nodes_by_vector_collection:
|
||||
after_delete_nodes_by_vector_collection[collection_name] = []
|
||||
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
||||
|
||||
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
||||
|
||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
|
||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||
|
||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||
assert len(final_nodes) == 7 and len(final_edges) == 7, "Nodes and edges are not deleted."
|
||||
|
||||
old_nodes_by_vector_collection = {}
|
||||
for node in old_nodes:
|
||||
collection_name = node.type + "_" + node.metadata["index_fields"][0]
|
||||
if collection_name not in old_nodes_by_vector_collection:
|
||||
old_nodes_by_vector_collection[collection_name] = []
|
||||
old_nodes_by_vector_collection[collection_name].append(node)
|
||||
|
||||
for collection_name, old_nodes in old_nodes_by_vector_collection.items():
|
||||
query_node_ids = [str(node.id) for node in old_nodes]
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == len(old_nodes), "Vector items are not deleted."
|
||||
|
||||
query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges]))
|
||||
|
||||
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
||||
assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
372
cognee/tests/test_delete_default_graph_with_legacy_data_2.py
Normal file
372
cognee/tests/test_delete_default_graph_with_legacy_data_2.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
import os
|
||||
import pathlib
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.data.methods import create_authorized_dataset
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
|
||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||
from cognee.modules.pipelines.models import DataItemStatus
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||
from cognee.tasks.storage import index_data_points, index_graph_edges
|
||||
|
||||
|
||||
def get_nodes_and_edges():
|
||||
document = TextDocument(
|
||||
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
||||
name="text_test.txt",
|
||||
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text_test.txt",
|
||||
external_metadata="{}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
document_chunk = DocumentChunk(
|
||||
id=uuid5(
|
||||
NAMESPACE_OID,
|
||||
"Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
),
|
||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
chunk_size=187,
|
||||
chunk_index=0,
|
||||
cut_type="paragraph_end",
|
||||
is_part_of=document,
|
||||
)
|
||||
|
||||
graph_database = EntityType(
|
||||
id=uuid5(NAMESPACE_OID, "graph_database"),
|
||||
name="graph database",
|
||||
description="graph database",
|
||||
)
|
||||
neptune_analytics_entity = Entity(
|
||||
id=generate_node_id("neptune analytics"),
|
||||
name="neptune analytics",
|
||||
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
||||
)
|
||||
neptune_database_entity = Entity(
|
||||
id=generate_node_id("amazon neptune database"),
|
||||
name="amazon neptune database",
|
||||
description="A popular managed graph database that complements Neptune Analytics.",
|
||||
)
|
||||
|
||||
storage = EntityType(
|
||||
id=generate_node_id("storage"),
|
||||
name="storage",
|
||||
description="storage",
|
||||
)
|
||||
storage_entity = Entity(
|
||||
id=generate_node_id("amazon s3"),
|
||||
name="amazon s3",
|
||||
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||
)
|
||||
|
||||
nodes_data = [
|
||||
document,
|
||||
document_chunk,
|
||||
graph_database,
|
||||
neptune_analytics_entity,
|
||||
neptune_database_entity,
|
||||
storage,
|
||||
storage_entity,
|
||||
]
|
||||
|
||||
edges_data = [
|
||||
(
|
||||
document_chunk.id,
|
||||
storage_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
storage_entity.id,
|
||||
storage.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
neptune_database_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
neptune_database_entity.id,
|
||||
graph_database.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
document.id,
|
||||
"is_part_of",
|
||||
{
|
||||
"relationship_name": "is_part_of",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
neptune_analytics_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
neptune_analytics_entity.id,
|
||||
graph_database.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
return nodes_data, edges_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||
async def main(mock_create_structured_output: AsyncMock):
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_2"
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
|
||||
cognee_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".cognee_system/test_delete_default_graph_with_legacy_graph_2",
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||
assert not await vector_engine.has_collection("Entity_name")
|
||||
assert not await vector_engine.has_collection("DocumentChunk_text")
|
||||
assert not await vector_engine.has_collection("TextSummary_text")
|
||||
assert not await vector_engine.has_collection("TextDocument_text")
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
old_nodes, old_edges = get_nodes_and_edges()
|
||||
old_document = old_nodes[0]
|
||||
|
||||
await graph_engine.add_nodes(old_nodes)
|
||||
await graph_engine.add_edges(old_edges)
|
||||
|
||||
await index_data_points(old_nodes)
|
||||
await index_graph_edges(old_edges)
|
||||
|
||||
await record_data_in_legacy_ledger(old_nodes, old_edges, user)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
dataset = await create_authorized_dataset("main_dataset", user)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
old_data = Data(
|
||||
id=old_document.id,
|
||||
name=old_document.name,
|
||||
extension="txt",
|
||||
raw_data_location=old_document.raw_data_location,
|
||||
external_metadata=old_document.external_metadata,
|
||||
mime_type=old_document.mime_type,
|
||||
owner_id=user.id,
|
||||
pipeline_status={
|
||||
"cognify_pipeline": {
|
||||
str(dataset.id): DataItemStatus.DATA_ITEM_PROCESSING_COMPLETED,
|
||||
}
|
||||
},
|
||||
)
|
||||
session.add(old_data)
|
||||
|
||||
dataset.data.append(old_data)
|
||||
session.add(dataset)
|
||||
|
||||
await session.commit()
|
||||
|
||||
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||
if text_input == "test": # LLM connection test
|
||||
return "test"
|
||||
|
||||
if "John" in text_input and response_model == SummarizedContent:
|
||||
return SummarizedContent(
|
||||
summary="Summary of John's work.", description="Summary of John's work."
|
||||
)
|
||||
|
||||
if "Marie" in text_input and response_model == SummarizedContent:
|
||||
return SummarizedContent(
|
||||
summary="Summary of Marie's work.", description="Summary of Marie's work."
|
||||
)
|
||||
|
||||
if "Marie" in text_input and response_model == KnowledgeGraph:
|
||||
return KnowledgeGraph(
|
||||
nodes=[
|
||||
Node(id="Marie", name="Marie", type="Person", description="Marie is a person"),
|
||||
Node(
|
||||
id="Apple",
|
||||
name="Apple",
|
||||
type="Company",
|
||||
description="Apple is a company",
|
||||
),
|
||||
Node(
|
||||
id="MacOS",
|
||||
name="MacOS",
|
||||
type="Product",
|
||||
description="MacOS is Apple's operating system",
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
Edge(
|
||||
source_node_id="Marie",
|
||||
target_node_id="Apple",
|
||||
relationship_name="works_for",
|
||||
),
|
||||
Edge(
|
||||
source_node_id="Marie", target_node_id="MacOS", relationship_name="works_on"
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if "John" in text_input and response_model == KnowledgeGraph:
|
||||
return KnowledgeGraph(
|
||||
nodes=[
|
||||
Node(id="John", name="John", type="Person", description="John is a person"),
|
||||
Node(
|
||||
id="Apple",
|
||||
name="Apple",
|
||||
type="Company",
|
||||
description="Apple is a company",
|
||||
),
|
||||
Node(
|
||||
id="Food for Hungry",
|
||||
name="Food for Hungry",
|
||||
type="Non-profit organization",
|
||||
description="Food for Hungry is a non-profit organization",
|
||||
),
|
||||
],
|
||||
edges=[
|
||||
Edge(
|
||||
source_node_id="John", target_node_id="Apple", relationship_name="works_for"
|
||||
),
|
||||
Edge(
|
||||
source_node_id="John",
|
||||
target_node_id="Food for Hungry",
|
||||
relationship_name="works_for",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
mock_create_structured_output.side_effect = mock_llm_output
|
||||
|
||||
add_john_result = await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
add_marie_result = await cognee.add(
|
||||
"Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
)
|
||||
# maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
dataset_id = list(cognify_result.keys())[0]
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||
assert len(initial_nodes) == 22 and len(initial_edges) == 26, (
|
||||
"Number of nodes and edges is not correct."
|
||||
)
|
||||
|
||||
initial_nodes_by_vector_collection = {}
|
||||
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||
if collection_name not in initial_nodes_by_vector_collection:
|
||||
initial_nodes_by_vector_collection[collection_name] = []
|
||||
initial_nodes_by_vector_collection[collection_name].append(node)
|
||||
|
||||
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 16 and len(edges) == 17, "Nodes and edges are not deleted."
|
||||
assert not any(
|
||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes
|
||||
), "Nodes are not deleted."
|
||||
|
||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||
|
||||
after_delete_nodes_by_vector_collection = {}
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||
if collection_name not in after_delete_nodes_by_vector_collection:
|
||||
after_delete_nodes_by_vector_collection[collection_name] = []
|
||||
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
||||
|
||||
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
||||
|
||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
|
||||
# Delete old document
|
||||
await datasets.delete_data(dataset_id, old_document.id, user) # type: ignore
|
||||
|
||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||
assert len(final_nodes) == 9 and len(final_edges) == 10, "Nodes and edges are not deleted."
|
||||
|
||||
old_nodes_by_vector_collection = {}
|
||||
for node in old_nodes:
|
||||
collection_name = node.type + "_" + node.metadata["index_fields"][0]
|
||||
if collection_name not in old_nodes_by_vector_collection:
|
||||
old_nodes_by_vector_collection[collection_name] = []
|
||||
old_nodes_by_vector_collection[collection_name].append(node)
|
||||
|
||||
for collection_name, old_nodes in old_nodes_by_vector_collection.items():
|
||||
query_node_ids = [str(node.id) for node in old_nodes]
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
|
||||
query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges]))
|
||||
|
||||
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
||||
assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Reference in a new issue