fix: improve delete tests

This commit is contained in:
Boris Arzentar 2025-11-19 14:01:45 +01:00
parent a89dad328e
commit 43459eeeac
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
8 changed files with 169 additions and 131 deletions

View file

@ -4,9 +4,11 @@ from cognee.api.v1.exceptions.exceptions import DocumentSubgraphNotFoundError
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.models import Data
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
logger = get_logger()
@ -74,6 +76,21 @@ async def delete_document_subgraph(document_id: UUID, mode: str = "soft"):
if nodes:
for node in nodes:
node_id = node["id"]
if key == "chunks":
chunk_connections = await graph_db.get_connections(node_id)
deleted_node_ids.extend(
[
str(
generate_edge_id(
get_contains_edge_text(node["name"], node["description"])
)
)
for (__, edge, node) in chunk_connections
if "relationship_name: contains;" in edge["relationship_name"]
]
)
await graph_db.delete_node(node_id)
deleted_node_ids.append(node_id)

View file

@ -27,45 +27,28 @@ async def upsert_edges(
edges_to_add = []
for edge in edges:
edge_text = (
edge[3]["edge_text"] if edge[2] == "contains" and "edge_text" in edge[3] else edge[2]
)
edges_to_add.append(
{
"id": uuid5(
NAMESPACE_OID,
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge[2]) + str(edge[1]),
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge_text) + str(edge[1]),
),
"slug": generate_edge_id(edge[2]),
"slug": generate_edge_id(edge_text),
"user_id": user_id,
"data_id": data_id,
"dataset_id": dataset_id,
"source_node_id": edge[0],
"destination_node_id": edge[1],
"relationship_name": edge[2],
"relationship_name": edge_text,
"label": edge[2],
"attributes": jsonable_encoder(edge[3]),
}
)
if len(edge) == 4 and "edge_text" in edge[3]:
edge_text = edge[3]["edge_text"]
edges_to_add.append(
{
"id": uuid5(
NAMESPACE_OID,
str(user_id) + str(dataset_id) + str(edge_text),
),
"slug": generate_edge_id(edge_text),
"user_id": user_id,
"data_id": data_id,
"dataset_id": dataset_id,
"source_node_id": edge[0],
"destination_node_id": edge[1],
"relationship_name": edge_text,
"label": edge_text,
"attributes": jsonable_encoder(edge[3]),
}
)
upsert_statement = (
insert(Edge).values(edges_to_add).on_conflict_do_nothing(index_elements=["id"])
)

View file

@ -34,7 +34,6 @@ from cognee.tests.utils.extract_summary import extract_summary
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
from cognee.tests.utils.isolate_relationships import isolate_relationships
logger = get_logger()
@ -225,29 +224,17 @@ async def main(mock_create_structured_output: AsyncMock):
(johns_summary.id, johns_chunk.id, "made_from"),
*johns_relationships,
]
johns_edge_text_relationships = [
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in johns_entities
if isinstance(entity, Entity)
]
maries_relationships = [
(maries_chunk.id, maries_document.id, "is_part_of"),
(maries_summary.id, maries_chunk.id, "made_from"),
*maries_relationships,
]
maries_edge_text_relationships = [
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in maries_entities
if isinstance(entity, Entity)
]
expected_relationships = johns_relationships + maries_relationships + overlapping_relationships
await assert_graph_edges_present(expected_relationships)
await assert_edges_vector_index_present(
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
)
await assert_edges_vector_index_present(expected_relationships)
# Delete John's data from cognee
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
@ -261,15 +248,24 @@ async def main(mock_create_structured_output: AsyncMock):
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
await assert_edges_vector_index_present(maries_relationships)
await assert_graph_edges_not_present(johns_relationships)
strictly_johns_relationships = isolate_relationships(johns_relationships, maries_relationships)
johns_contains_relationships = [
(
johns_chunk.id,
entity.id,
get_contains_edge_text(entity.name, entity.description),
{
"relationship_name": get_contains_edge_text(entity.name, entity.description),
},
)
for entity in johns_entities
if isinstance(entity, Entity)
]
# We check only by relationship name and we need edges that are created by John's data and no other.
await assert_edges_vector_index_not_present(
strictly_johns_relationships + johns_edge_text_relationships
)
await assert_edges_vector_index_not_present(johns_contains_relationships)
# Delete Marie's data from cognee
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore

View file

@ -41,7 +41,6 @@ from cognee.tests.utils.extract_summary import extract_summary
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
from cognee.tests.utils.isolate_relationships import isolate_relationships
def create_nodes_and_edges():
@ -356,21 +355,11 @@ async def main(mock_create_structured_output: AsyncMock):
(johns_summary.id, johns_chunk.id, "made_from"),
*johns_relationships,
]
johns_edge_text_relationships = [
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in johns_entities
if isinstance(entity, Entity)
]
maries_relationships = [
(maries_chunk.id, maries_document.id, "is_part_of"),
(maries_summary.id, maries_chunk.id, "made_from"),
*maries_relationships,
]
maries_edge_text_relationships = [
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in maries_entities
if isinstance(entity, Entity)
]
expected_relationships = (
johns_relationships
@ -381,9 +370,7 @@ async def main(mock_create_structured_output: AsyncMock):
await assert_graph_edges_present(expected_relationships)
await assert_edges_vector_index_present(
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
)
await assert_edges_vector_index_present(expected_relationships)
# Delete John's data
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
@ -399,19 +386,24 @@ async def main(mock_create_structured_output: AsyncMock):
await assert_graph_edges_present(
maries_relationships + overlapping_relationships + legacy_relationships
)
await assert_edges_vector_index_present(
maries_relationships + maries_edge_text_relationships + legacy_relationships
)
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
await assert_graph_edges_not_present(johns_relationships)
strictly_johns_relationships = isolate_relationships(
johns_relationships, maries_relationships, legacy_relationships
)
johns_contains_relationships = [
(
johns_chunk.id,
entity.id,
get_contains_edge_text(entity.name, entity.description),
{
"relationship_name": get_contains_edge_text(entity.name, entity.description),
},
)
for entity in johns_entities
if isinstance(entity, Entity)
]
# We check only by relationship name and we need edges that are created by John's data and no other.
await assert_edges_vector_index_not_present(
strictly_johns_relationships + johns_edge_text_relationships
)
await assert_edges_vector_index_not_present(johns_contains_relationships)
# Delete Marie's data
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
@ -431,12 +423,20 @@ async def main(mock_create_structured_output: AsyncMock):
johns_relationships + maries_relationships + overlapping_relationships
)
strictly_maries_relationships = isolate_relationships(
maries_relationships, legacy_relationships
)
maries_contains_relationships = [
(
maries_chunk.id,
entity.id,
get_contains_edge_text(entity.name, entity.description),
{
"relationship_name": get_contains_edge_text(entity.name, entity.description),
},
)
for entity in maries_entities
if isinstance(entity, Entity)
]
# We check only by relationship name and we need edges that are created by legacy data and no other.
if strictly_maries_relationships:
await assert_edges_vector_index_not_present(strictly_maries_relationships)
await assert_edges_vector_index_not_present(maries_contains_relationships)
async def create_mocked_legacy_data(user):
@ -447,8 +447,29 @@ async def create_mocked_legacy_data(user):
await graph_engine.add_nodes(legacy_nodes)
await graph_engine.add_edges(legacy_edges)
nodes_by_id = {node.id: node for node in legacy_nodes}
def format_relationship_name(relationship):
if relationship[2] == "contains":
node = nodes_by_id[relationship[1]]
return get_contains_edge_text(node.name, node.description)
return relationship[2]
await index_data_points(legacy_nodes)
await index_graph_edges(legacy_edges)
await index_graph_edges(
[
(
edge[0],
edge[1],
format_relationship_name(edge),
{
**(edge[3] or {}),
"relationship_name": format_relationship_name(edge),
},
)
for edge in legacy_edges
] # type: ignore
)
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)

View file

@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, patch
import cognee
from cognee.api.v1.datasets import datasets
from cognee.api.v1.visualize.visualize import visualize_graph
from cognee.context_global_variables import set_database_global_context_variables
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vector_engine
@ -41,7 +42,6 @@ from cognee.tests.utils.extract_summary import extract_summary
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
from cognee.tests.utils.isolate_relationships import isolate_relationships
def create_nodes_and_edges():
@ -358,21 +358,11 @@ async def main(mock_create_structured_output: AsyncMock):
(johns_summary.id, johns_chunk.id, "made_from"),
*johns_relationships,
]
johns_edge_text_relationships = [
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in johns_entities
if isinstance(entity, Entity)
]
maries_relationships = [
(maries_chunk.id, maries_document.id, "is_part_of"),
(maries_summary.id, maries_chunk.id, "made_from"),
*maries_relationships,
]
maries_edge_text_relationships = [
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in maries_entities
if isinstance(entity, Entity)
]
expected_relationships = (
johns_relationships
@ -383,9 +373,7 @@ async def main(mock_create_structured_output: AsyncMock):
await assert_graph_edges_present(expected_relationships)
await assert_edges_vector_index_present(
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
)
await assert_edges_vector_index_present(expected_relationships)
# Delete John's data
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
@ -401,23 +389,35 @@ async def main(mock_create_structured_output: AsyncMock):
await assert_graph_edges_present(
maries_relationships + overlapping_relationships + legacy_relationships
)
await assert_edges_vector_index_present(
maries_relationships + maries_edge_text_relationships + legacy_relationships
)
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
await assert_graph_edges_not_present(johns_relationships)
strictly_johns_relationships = isolate_relationships(
johns_relationships, maries_relationships, legacy_relationships
)
johns_contains_relationships = [
(
johns_chunk.id,
entity.id,
get_contains_edge_text(entity.name, entity.description),
{
"relationship_name": get_contains_edge_text(entity.name, entity.description),
},
)
for entity in johns_entities
if isinstance(entity, Entity)
]
# We check only by relationship name and we need edges that are created by John's data and no other.
await assert_edges_vector_index_not_present(
strictly_johns_relationships + johns_edge_text_relationships
)
await assert_edges_vector_index_not_present(johns_contains_relationships)
# Delete legacy data
await datasets.delete_data(dataset_id, legacy_document.id, user) # type: ignore
graph_file_path = os.path.join(
pathlib.Path(__file__).parent,
".artifacts/graph_visualization.html",
)
await visualize_graph(graph_file_path)
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(maries_data + overlapping_entities)
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
@ -427,16 +427,11 @@ async def main(mock_create_structured_output: AsyncMock):
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
await assert_edges_vector_index_present(maries_relationships)
await assert_graph_edges_not_present(johns_relationships + legacy_relationships)
strictly_legacy_relationships = isolate_relationships(
legacy_relationships, maries_relationships
)
# We check only by relationship name and we need edges that are created by legacy data and no other.
if strictly_legacy_relationships:
await assert_edges_vector_index_not_present(strictly_legacy_relationships)
# Vector index didn't change after deleting legacy data
async def create_mocked_legacy_data(user):
@ -447,8 +442,29 @@ async def create_mocked_legacy_data(user):
await graph_engine.add_nodes(legacy_nodes)
await graph_engine.add_edges(legacy_edges)
nodes_by_id = {node.id: node for node in legacy_nodes}
def format_relationship_name(relationship):
if relationship[2] == "contains":
node = nodes_by_id[relationship[1]]
return get_contains_edge_text(node.name, node.description)
return relationship[2]
await index_data_points(legacy_nodes)
await index_graph_edges(legacy_edges)
await index_graph_edges(
[
(
edge[0],
edge[1],
format_relationship_name(edge),
{
**(edge[3] or {}),
"relationship_name": format_relationship_name(edge),
},
)
for edge in legacy_edges
] # type: ignore
)
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)

View file

@ -1,10 +1,10 @@
from uuid import UUID
from typing import List, Tuple
from typing import Dict, List, Tuple
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.engine.utils import generate_edge_id
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str]]):
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
vector_engine = get_vector_engine()
query_edge_ids = {

View file

@ -1,15 +1,40 @@
from uuid import UUID
from typing import List, Tuple
from typing import Dict, List, Tuple
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.engine.utils import generate_edge_id
from cognee.modules.engine.utils import generate_edge_id, generate_node_name
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str]]):
def format_relationship(relationship: Tuple[UUID, UUID, str, Dict], node: Dict):
if relationship[2] == "contains":
relationship_name = get_contains_edge_text(
generate_node_name(node["name"]),
node["description"],
)
return {
str(generate_edge_id(relationship_name)): relationship_name,
}
return {str(generate_edge_id(relationship[2])): relationship[2]}
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
vector_engine = get_vector_engine()
query_edge_ids = {
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
}
graph_engine = await get_graph_engine()
nodes, _ = await graph_engine.get_graph_data()
nodes_by_id = {str(node[0]): node[1] for node in nodes}
query_edge_ids = {}
for relationship in relationships:
query_edge_ids = {
**query_edge_ids,
**format_relationship(relationship, nodes_by_id[str(relationship[1])]),
}
vector_items = await vector_engine.retrieve(
"EdgeType_relationship_name", list(query_edge_ids.keys())

View file

@ -1,20 +0,0 @@
def isolate_relationships(source_relationships, *other_relationships):
final_relationships = []
cache = {relationship[2]: 1 for relationship in source_relationships}
duplicated_relationships = {}
for relationships in other_relationships:
for relationship in relationships:
if relationship[2] not in cache:
cache[relationship[2]] = 0
cache[relationship[2]] += 1
if cache[relationship[2]] == 2:
duplicated_relationships[relationship[2]] = True
for relationship in source_relationships:
if relationship[2] not in duplicated_relationships:
final_relationships.append(relationship)
return final_relationships