fix: improve delete tests
This commit is contained in:
parent
a89dad328e
commit
43459eeeac
8 changed files with 169 additions and 131 deletions
|
|
@ -4,9 +4,11 @@ from cognee.api.v1.exceptions.exceptions import DocumentSubgraphNotFoundError
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -74,6 +76,21 @@ async def delete_document_subgraph(document_id: UUID, mode: str = "soft"):
|
|||
if nodes:
|
||||
for node in nodes:
|
||||
node_id = node["id"]
|
||||
|
||||
if key == "chunks":
|
||||
chunk_connections = await graph_db.get_connections(node_id)
|
||||
deleted_node_ids.extend(
|
||||
[
|
||||
str(
|
||||
generate_edge_id(
|
||||
get_contains_edge_text(node["name"], node["description"])
|
||||
)
|
||||
)
|
||||
for (__, edge, node) in chunk_connections
|
||||
if "relationship_name: contains;" in edge["relationship_name"]
|
||||
]
|
||||
)
|
||||
|
||||
await graph_db.delete_node(node_id)
|
||||
deleted_node_ids.append(node_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,45 +27,28 @@ async def upsert_edges(
|
|||
edges_to_add = []
|
||||
|
||||
for edge in edges:
|
||||
edge_text = (
|
||||
edge[3]["edge_text"] if edge[2] == "contains" and "edge_text" in edge[3] else edge[2]
|
||||
)
|
||||
|
||||
edges_to_add.append(
|
||||
{
|
||||
"id": uuid5(
|
||||
NAMESPACE_OID,
|
||||
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge[2]) + str(edge[1]),
|
||||
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge_text) + str(edge[1]),
|
||||
),
|
||||
"slug": generate_edge_id(edge[2]),
|
||||
"slug": generate_edge_id(edge_text),
|
||||
"user_id": user_id,
|
||||
"data_id": data_id,
|
||||
"dataset_id": dataset_id,
|
||||
"source_node_id": edge[0],
|
||||
"destination_node_id": edge[1],
|
||||
"relationship_name": edge[2],
|
||||
"relationship_name": edge_text,
|
||||
"label": edge[2],
|
||||
"attributes": jsonable_encoder(edge[3]),
|
||||
}
|
||||
)
|
||||
|
||||
if len(edge) == 4 and "edge_text" in edge[3]:
|
||||
edge_text = edge[3]["edge_text"]
|
||||
|
||||
edges_to_add.append(
|
||||
{
|
||||
"id": uuid5(
|
||||
NAMESPACE_OID,
|
||||
str(user_id) + str(dataset_id) + str(edge_text),
|
||||
),
|
||||
"slug": generate_edge_id(edge_text),
|
||||
"user_id": user_id,
|
||||
"data_id": data_id,
|
||||
"dataset_id": dataset_id,
|
||||
"source_node_id": edge[0],
|
||||
"destination_node_id": edge[1],
|
||||
"relationship_name": edge_text,
|
||||
"label": edge_text,
|
||||
"attributes": jsonable_encoder(edge[3]),
|
||||
}
|
||||
)
|
||||
|
||||
upsert_statement = (
|
||||
insert(Edge).values(edges_to_add).on_conflict_do_nothing(index_elements=["id"])
|
||||
)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ from cognee.tests.utils.extract_summary import extract_summary
|
|||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -225,29 +224,17 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||
*johns_relationships,
|
||||
]
|
||||
johns_edge_text_relationships = [
|
||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
maries_relationships = [
|
||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||
*maries_relationships,
|
||||
]
|
||||
maries_edge_text_relationships = [
|
||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in maries_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
|
||||
expected_relationships = johns_relationships + maries_relationships + overlapping_relationships
|
||||
|
||||
await assert_graph_edges_present(expected_relationships)
|
||||
|
||||
await assert_edges_vector_index_present(
|
||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(expected_relationships)
|
||||
|
||||
# Delete John's data from cognee
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
|
@ -261,15 +248,24 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
||||
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
|
||||
await assert_edges_vector_index_present(maries_relationships)
|
||||
|
||||
await assert_graph_edges_not_present(johns_relationships)
|
||||
|
||||
strictly_johns_relationships = isolate_relationships(johns_relationships, maries_relationships)
|
||||
johns_contains_relationships = [
|
||||
(
|
||||
johns_chunk.id,
|
||||
entity.id,
|
||||
get_contains_edge_text(entity.name, entity.description),
|
||||
{
|
||||
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||
},
|
||||
)
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||
await assert_edges_vector_index_not_present(
|
||||
strictly_johns_relationships + johns_edge_text_relationships
|
||||
)
|
||||
await assert_edges_vector_index_not_present(johns_contains_relationships)
|
||||
|
||||
# Delete Marie's data from cognee
|
||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ from cognee.tests.utils.extract_summary import extract_summary
|
|||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||
|
||||
|
||||
def create_nodes_and_edges():
|
||||
|
|
@ -356,21 +355,11 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||
*johns_relationships,
|
||||
]
|
||||
johns_edge_text_relationships = [
|
||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
maries_relationships = [
|
||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||
*maries_relationships,
|
||||
]
|
||||
maries_edge_text_relationships = [
|
||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in maries_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
|
||||
expected_relationships = (
|
||||
johns_relationships
|
||||
|
|
@ -381,9 +370,7 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
|
||||
await assert_graph_edges_present(expected_relationships)
|
||||
|
||||
await assert_edges_vector_index_present(
|
||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(expected_relationships)
|
||||
|
||||
# Delete John's data
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
|
@ -399,19 +386,24 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
await assert_graph_edges_present(
|
||||
maries_relationships + overlapping_relationships + legacy_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(
|
||||
maries_relationships + maries_edge_text_relationships + legacy_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
|
||||
|
||||
await assert_graph_edges_not_present(johns_relationships)
|
||||
|
||||
strictly_johns_relationships = isolate_relationships(
|
||||
johns_relationships, maries_relationships, legacy_relationships
|
||||
)
|
||||
johns_contains_relationships = [
|
||||
(
|
||||
johns_chunk.id,
|
||||
entity.id,
|
||||
get_contains_edge_text(entity.name, entity.description),
|
||||
{
|
||||
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||
},
|
||||
)
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||
await assert_edges_vector_index_not_present(
|
||||
strictly_johns_relationships + johns_edge_text_relationships
|
||||
)
|
||||
await assert_edges_vector_index_not_present(johns_contains_relationships)
|
||||
|
||||
# Delete Marie's data
|
||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||
|
|
@ -431,12 +423,20 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
johns_relationships + maries_relationships + overlapping_relationships
|
||||
)
|
||||
|
||||
strictly_maries_relationships = isolate_relationships(
|
||||
maries_relationships, legacy_relationships
|
||||
)
|
||||
maries_contains_relationships = [
|
||||
(
|
||||
maries_chunk.id,
|
||||
entity.id,
|
||||
get_contains_edge_text(entity.name, entity.description),
|
||||
{
|
||||
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||
},
|
||||
)
|
||||
for entity in maries_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
||||
if strictly_maries_relationships:
|
||||
await assert_edges_vector_index_not_present(strictly_maries_relationships)
|
||||
await assert_edges_vector_index_not_present(maries_contains_relationships)
|
||||
|
||||
|
||||
async def create_mocked_legacy_data(user):
|
||||
|
|
@ -447,8 +447,29 @@ async def create_mocked_legacy_data(user):
|
|||
await graph_engine.add_nodes(legacy_nodes)
|
||||
await graph_engine.add_edges(legacy_edges)
|
||||
|
||||
nodes_by_id = {node.id: node for node in legacy_nodes}
|
||||
|
||||
def format_relationship_name(relationship):
|
||||
if relationship[2] == "contains":
|
||||
node = nodes_by_id[relationship[1]]
|
||||
return get_contains_edge_text(node.name, node.description)
|
||||
return relationship[2]
|
||||
|
||||
await index_data_points(legacy_nodes)
|
||||
await index_graph_edges(legacy_edges)
|
||||
await index_graph_edges(
|
||||
[
|
||||
(
|
||||
edge[0],
|
||||
edge[1],
|
||||
format_relationship_name(edge),
|
||||
{
|
||||
**(edge[3] or {}),
|
||||
"relationship_name": format_relationship_name(edge),
|
||||
},
|
||||
)
|
||||
for edge in legacy_edges
|
||||
] # type: ignore
|
||||
)
|
||||
|
||||
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
|
@ -41,7 +42,6 @@ from cognee.tests.utils.extract_summary import extract_summary
|
|||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||
|
||||
|
||||
def create_nodes_and_edges():
|
||||
|
|
@ -358,21 +358,11 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||
*johns_relationships,
|
||||
]
|
||||
johns_edge_text_relationships = [
|
||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
maries_relationships = [
|
||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||
*maries_relationships,
|
||||
]
|
||||
maries_edge_text_relationships = [
|
||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in maries_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
|
||||
expected_relationships = (
|
||||
johns_relationships
|
||||
|
|
@ -383,9 +373,7 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
|
||||
await assert_graph_edges_present(expected_relationships)
|
||||
|
||||
await assert_edges_vector_index_present(
|
||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(expected_relationships)
|
||||
|
||||
# Delete John's data
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
|
@ -401,23 +389,35 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
await assert_graph_edges_present(
|
||||
maries_relationships + overlapping_relationships + legacy_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(
|
||||
maries_relationships + maries_edge_text_relationships + legacy_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
|
||||
|
||||
await assert_graph_edges_not_present(johns_relationships)
|
||||
|
||||
strictly_johns_relationships = isolate_relationships(
|
||||
johns_relationships, maries_relationships, legacy_relationships
|
||||
)
|
||||
johns_contains_relationships = [
|
||||
(
|
||||
johns_chunk.id,
|
||||
entity.id,
|
||||
get_contains_edge_text(entity.name, entity.description),
|
||||
{
|
||||
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||
},
|
||||
)
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
|
||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||
await assert_edges_vector_index_not_present(
|
||||
strictly_johns_relationships + johns_edge_text_relationships
|
||||
)
|
||||
await assert_edges_vector_index_not_present(johns_contains_relationships)
|
||||
|
||||
# Delete legacy data
|
||||
await datasets.delete_data(dataset_id, legacy_document.id, user) # type: ignore
|
||||
|
||||
graph_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent,
|
||||
".artifacts/graph_visualization.html",
|
||||
)
|
||||
await visualize_graph(graph_file_path)
|
||||
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(maries_data + overlapping_entities)
|
||||
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
|
||||
|
|
@ -427,16 +427,11 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
||||
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
|
||||
await assert_edges_vector_index_present(maries_relationships)
|
||||
|
||||
await assert_graph_edges_not_present(johns_relationships + legacy_relationships)
|
||||
|
||||
strictly_legacy_relationships = isolate_relationships(
|
||||
legacy_relationships, maries_relationships
|
||||
)
|
||||
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
||||
if strictly_legacy_relationships:
|
||||
await assert_edges_vector_index_not_present(strictly_legacy_relationships)
|
||||
# Vector index didn't change after deleting legacy data
|
||||
|
||||
|
||||
async def create_mocked_legacy_data(user):
|
||||
|
|
@ -447,8 +442,29 @@ async def create_mocked_legacy_data(user):
|
|||
await graph_engine.add_nodes(legacy_nodes)
|
||||
await graph_engine.add_edges(legacy_edges)
|
||||
|
||||
nodes_by_id = {node.id: node for node in legacy_nodes}
|
||||
|
||||
def format_relationship_name(relationship):
|
||||
if relationship[2] == "contains":
|
||||
node = nodes_by_id[relationship[1]]
|
||||
return get_contains_edge_text(node.name, node.description)
|
||||
return relationship[2]
|
||||
|
||||
await index_data_points(legacy_nodes)
|
||||
await index_graph_edges(legacy_edges)
|
||||
await index_graph_edges(
|
||||
[
|
||||
(
|
||||
edge[0],
|
||||
edge[1],
|
||||
format_relationship_name(edge),
|
||||
{
|
||||
**(edge[3] or {}),
|
||||
"relationship_name": format_relationship_name(edge),
|
||||
},
|
||||
)
|
||||
for edge in legacy_edges
|
||||
] # type: ignore
|
||||
)
|
||||
|
||||
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.engine.utils import generate_edge_id
|
||||
|
||||
|
||||
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
query_edge_ids = {
|
||||
|
|
|
|||
|
|
@ -1,15 +1,40 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.engine.utils import generate_edge_id
|
||||
from cognee.modules.engine.utils import generate_edge_id, generate_node_name
|
||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
|
||||
|
||||
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||
def format_relationship(relationship: Tuple[UUID, UUID, str, Dict], node: Dict):
|
||||
if relationship[2] == "contains":
|
||||
relationship_name = get_contains_edge_text(
|
||||
generate_node_name(node["name"]),
|
||||
node["description"],
|
||||
)
|
||||
|
||||
return {
|
||||
str(generate_edge_id(relationship_name)): relationship_name,
|
||||
}
|
||||
|
||||
return {str(generate_edge_id(relationship[2])): relationship[2]}
|
||||
|
||||
|
||||
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
query_edge_ids = {
|
||||
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
|
||||
}
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, _ = await graph_engine.get_graph_data()
|
||||
|
||||
nodes_by_id = {str(node[0]): node[1] for node in nodes}
|
||||
|
||||
query_edge_ids = {}
|
||||
for relationship in relationships:
|
||||
query_edge_ids = {
|
||||
**query_edge_ids,
|
||||
**format_relationship(relationship, nodes_by_id[str(relationship[1])]),
|
||||
}
|
||||
|
||||
vector_items = await vector_engine.retrieve(
|
||||
"EdgeType_relationship_name", list(query_edge_ids.keys())
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
def isolate_relationships(source_relationships, *other_relationships):
|
||||
final_relationships = []
|
||||
cache = {relationship[2]: 1 for relationship in source_relationships}
|
||||
duplicated_relationships = {}
|
||||
|
||||
for relationships in other_relationships:
|
||||
for relationship in relationships:
|
||||
if relationship[2] not in cache:
|
||||
cache[relationship[2]] = 0
|
||||
|
||||
cache[relationship[2]] += 1
|
||||
|
||||
if cache[relationship[2]] == 2:
|
||||
duplicated_relationships[relationship[2]] = True
|
||||
|
||||
for relationship in source_relationships:
|
||||
if relationship[2] not in duplicated_relationships:
|
||||
final_relationships.append(relationship)
|
||||
|
||||
return final_relationships
|
||||
Loading…
Add table
Reference in a new issue