fix: improve delete tests
This commit is contained in:
parent
a89dad328e
commit
43459eeeac
8 changed files with 169 additions and 131 deletions
|
|
@ -4,9 +4,11 @@ from cognee.api.v1.exceptions.exceptions import DocumentSubgraphNotFoundError
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.data.models import Data
|
from cognee.modules.data.models import Data
|
||||||
|
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||||
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
@ -74,6 +76,21 @@ async def delete_document_subgraph(document_id: UUID, mode: str = "soft"):
|
||||||
if nodes:
|
if nodes:
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_id = node["id"]
|
node_id = node["id"]
|
||||||
|
|
||||||
|
if key == "chunks":
|
||||||
|
chunk_connections = await graph_db.get_connections(node_id)
|
||||||
|
deleted_node_ids.extend(
|
||||||
|
[
|
||||||
|
str(
|
||||||
|
generate_edge_id(
|
||||||
|
get_contains_edge_text(node["name"], node["description"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for (__, edge, node) in chunk_connections
|
||||||
|
if "relationship_name: contains;" in edge["relationship_name"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
await graph_db.delete_node(node_id)
|
await graph_db.delete_node(node_id)
|
||||||
deleted_node_ids.append(node_id)
|
deleted_node_ids.append(node_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,45 +27,28 @@ async def upsert_edges(
|
||||||
edges_to_add = []
|
edges_to_add = []
|
||||||
|
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
|
edge_text = (
|
||||||
|
edge[3]["edge_text"] if edge[2] == "contains" and "edge_text" in edge[3] else edge[2]
|
||||||
|
)
|
||||||
|
|
||||||
edges_to_add.append(
|
edges_to_add.append(
|
||||||
{
|
{
|
||||||
"id": uuid5(
|
"id": uuid5(
|
||||||
NAMESPACE_OID,
|
NAMESPACE_OID,
|
||||||
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge[2]) + str(edge[1]),
|
str(user_id) + str(dataset_id) + str(edge[0]) + str(edge_text) + str(edge[1]),
|
||||||
),
|
),
|
||||||
"slug": generate_edge_id(edge[2]),
|
"slug": generate_edge_id(edge_text),
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"data_id": data_id,
|
"data_id": data_id,
|
||||||
"dataset_id": dataset_id,
|
"dataset_id": dataset_id,
|
||||||
"source_node_id": edge[0],
|
"source_node_id": edge[0],
|
||||||
"destination_node_id": edge[1],
|
"destination_node_id": edge[1],
|
||||||
"relationship_name": edge[2],
|
"relationship_name": edge_text,
|
||||||
"label": edge[2],
|
"label": edge[2],
|
||||||
"attributes": jsonable_encoder(edge[3]),
|
"attributes": jsonable_encoder(edge[3]),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(edge) == 4 and "edge_text" in edge[3]:
|
|
||||||
edge_text = edge[3]["edge_text"]
|
|
||||||
|
|
||||||
edges_to_add.append(
|
|
||||||
{
|
|
||||||
"id": uuid5(
|
|
||||||
NAMESPACE_OID,
|
|
||||||
str(user_id) + str(dataset_id) + str(edge_text),
|
|
||||||
),
|
|
||||||
"slug": generate_edge_id(edge_text),
|
|
||||||
"user_id": user_id,
|
|
||||||
"data_id": data_id,
|
|
||||||
"dataset_id": dataset_id,
|
|
||||||
"source_node_id": edge[0],
|
|
||||||
"destination_node_id": edge[1],
|
|
||||||
"relationship_name": edge_text,
|
|
||||||
"label": edge_text,
|
|
||||||
"attributes": jsonable_encoder(edge[3]),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
upsert_statement = (
|
upsert_statement = (
|
||||||
insert(Edge).values(edges_to_add).on_conflict_do_nothing(index_elements=["id"])
|
insert(Edge).values(edges_to_add).on_conflict_do_nothing(index_elements=["id"])
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,6 @@ from cognee.tests.utils.extract_summary import extract_summary
|
||||||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -225,29 +224,17 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
(johns_summary.id, johns_chunk.id, "made_from"),
|
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||||
*johns_relationships,
|
*johns_relationships,
|
||||||
]
|
]
|
||||||
johns_edge_text_relationships = [
|
|
||||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
|
||||||
for entity in johns_entities
|
|
||||||
if isinstance(entity, Entity)
|
|
||||||
]
|
|
||||||
maries_relationships = [
|
maries_relationships = [
|
||||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||||
*maries_relationships,
|
*maries_relationships,
|
||||||
]
|
]
|
||||||
maries_edge_text_relationships = [
|
|
||||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
|
||||||
for entity in maries_entities
|
|
||||||
if isinstance(entity, Entity)
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_relationships = johns_relationships + maries_relationships + overlapping_relationships
|
expected_relationships = johns_relationships + maries_relationships + overlapping_relationships
|
||||||
|
|
||||||
await assert_graph_edges_present(expected_relationships)
|
await assert_graph_edges_present(expected_relationships)
|
||||||
|
|
||||||
await assert_edges_vector_index_present(
|
await assert_edges_vector_index_present(expected_relationships)
|
||||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete John's data from cognee
|
# Delete John's data from cognee
|
||||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||||
|
|
@ -261,15 +248,24 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
|
||||||
# Assert relationships presence in the graph, vector collections and nodes table
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
||||||
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
|
await assert_edges_vector_index_present(maries_relationships)
|
||||||
|
|
||||||
await assert_graph_edges_not_present(johns_relationships)
|
await assert_graph_edges_not_present(johns_relationships)
|
||||||
|
|
||||||
strictly_johns_relationships = isolate_relationships(johns_relationships, maries_relationships)
|
johns_contains_relationships = [
|
||||||
|
(
|
||||||
|
johns_chunk.id,
|
||||||
|
entity.id,
|
||||||
|
get_contains_edge_text(entity.name, entity.description),
|
||||||
|
{
|
||||||
|
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for entity in johns_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||||
await assert_edges_vector_index_not_present(
|
await assert_edges_vector_index_not_present(johns_contains_relationships)
|
||||||
strictly_johns_relationships + johns_edge_text_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete Marie's data from cognee
|
# Delete Marie's data from cognee
|
||||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ from cognee.tests.utils.extract_summary import extract_summary
|
||||||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
|
||||||
|
|
||||||
|
|
||||||
def create_nodes_and_edges():
|
def create_nodes_and_edges():
|
||||||
|
|
@ -356,21 +355,11 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
(johns_summary.id, johns_chunk.id, "made_from"),
|
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||||
*johns_relationships,
|
*johns_relationships,
|
||||||
]
|
]
|
||||||
johns_edge_text_relationships = [
|
|
||||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
|
||||||
for entity in johns_entities
|
|
||||||
if isinstance(entity, Entity)
|
|
||||||
]
|
|
||||||
maries_relationships = [
|
maries_relationships = [
|
||||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||||
*maries_relationships,
|
*maries_relationships,
|
||||||
]
|
]
|
||||||
maries_edge_text_relationships = [
|
|
||||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
|
||||||
for entity in maries_entities
|
|
||||||
if isinstance(entity, Entity)
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_relationships = (
|
expected_relationships = (
|
||||||
johns_relationships
|
johns_relationships
|
||||||
|
|
@ -381,9 +370,7 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
|
||||||
await assert_graph_edges_present(expected_relationships)
|
await assert_graph_edges_present(expected_relationships)
|
||||||
|
|
||||||
await assert_edges_vector_index_present(
|
await assert_edges_vector_index_present(expected_relationships)
|
||||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete John's data
|
# Delete John's data
|
||||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||||
|
|
@ -399,19 +386,24 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
await assert_graph_edges_present(
|
await assert_graph_edges_present(
|
||||||
maries_relationships + overlapping_relationships + legacy_relationships
|
maries_relationships + overlapping_relationships + legacy_relationships
|
||||||
)
|
)
|
||||||
await assert_edges_vector_index_present(
|
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
|
||||||
maries_relationships + maries_edge_text_relationships + legacy_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
await assert_graph_edges_not_present(johns_relationships)
|
await assert_graph_edges_not_present(johns_relationships)
|
||||||
|
|
||||||
strictly_johns_relationships = isolate_relationships(
|
johns_contains_relationships = [
|
||||||
johns_relationships, maries_relationships, legacy_relationships
|
(
|
||||||
)
|
johns_chunk.id,
|
||||||
|
entity.id,
|
||||||
|
get_contains_edge_text(entity.name, entity.description),
|
||||||
|
{
|
||||||
|
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for entity in johns_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||||
await assert_edges_vector_index_not_present(
|
await assert_edges_vector_index_not_present(johns_contains_relationships)
|
||||||
strictly_johns_relationships + johns_edge_text_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete Marie's data
|
# Delete Marie's data
|
||||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||||
|
|
@ -431,12 +423,20 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
johns_relationships + maries_relationships + overlapping_relationships
|
johns_relationships + maries_relationships + overlapping_relationships
|
||||||
)
|
)
|
||||||
|
|
||||||
strictly_maries_relationships = isolate_relationships(
|
maries_contains_relationships = [
|
||||||
maries_relationships, legacy_relationships
|
(
|
||||||
)
|
maries_chunk.id,
|
||||||
|
entity.id,
|
||||||
|
get_contains_edge_text(entity.name, entity.description),
|
||||||
|
{
|
||||||
|
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for entity in maries_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
||||||
if strictly_maries_relationships:
|
await assert_edges_vector_index_not_present(maries_contains_relationships)
|
||||||
await assert_edges_vector_index_not_present(strictly_maries_relationships)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_mocked_legacy_data(user):
|
async def create_mocked_legacy_data(user):
|
||||||
|
|
@ -447,8 +447,29 @@ async def create_mocked_legacy_data(user):
|
||||||
await graph_engine.add_nodes(legacy_nodes)
|
await graph_engine.add_nodes(legacy_nodes)
|
||||||
await graph_engine.add_edges(legacy_edges)
|
await graph_engine.add_edges(legacy_edges)
|
||||||
|
|
||||||
|
nodes_by_id = {node.id: node for node in legacy_nodes}
|
||||||
|
|
||||||
|
def format_relationship_name(relationship):
|
||||||
|
if relationship[2] == "contains":
|
||||||
|
node = nodes_by_id[relationship[1]]
|
||||||
|
return get_contains_edge_text(node.name, node.description)
|
||||||
|
return relationship[2]
|
||||||
|
|
||||||
await index_data_points(legacy_nodes)
|
await index_data_points(legacy_nodes)
|
||||||
await index_graph_edges(legacy_edges)
|
await index_graph_edges(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
edge[0],
|
||||||
|
edge[1],
|
||||||
|
format_relationship_name(edge),
|
||||||
|
{
|
||||||
|
**(edge[3] or {}),
|
||||||
|
"relationship_name": format_relationship_name(edge),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for edge in legacy_edges
|
||||||
|
] # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.datasets import datasets
|
from cognee.api.v1.datasets import datasets
|
||||||
|
from cognee.api.v1.visualize.visualize import visualize_graph
|
||||||
from cognee.context_global_variables import set_database_global_context_variables
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
|
@ -41,7 +42,6 @@ from cognee.tests.utils.extract_summary import extract_summary
|
||||||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
|
||||||
|
|
||||||
|
|
||||||
def create_nodes_and_edges():
|
def create_nodes_and_edges():
|
||||||
|
|
@ -358,21 +358,11 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
(johns_summary.id, johns_chunk.id, "made_from"),
|
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||||
*johns_relationships,
|
*johns_relationships,
|
||||||
]
|
]
|
||||||
johns_edge_text_relationships = [
|
|
||||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
|
||||||
for entity in johns_entities
|
|
||||||
if isinstance(entity, Entity)
|
|
||||||
]
|
|
||||||
maries_relationships = [
|
maries_relationships = [
|
||||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||||
*maries_relationships,
|
*maries_relationships,
|
||||||
]
|
]
|
||||||
maries_edge_text_relationships = [
|
|
||||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
|
||||||
for entity in maries_entities
|
|
||||||
if isinstance(entity, Entity)
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_relationships = (
|
expected_relationships = (
|
||||||
johns_relationships
|
johns_relationships
|
||||||
|
|
@ -383,9 +373,7 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
|
||||||
await assert_graph_edges_present(expected_relationships)
|
await assert_graph_edges_present(expected_relationships)
|
||||||
|
|
||||||
await assert_edges_vector_index_present(
|
await assert_edges_vector_index_present(expected_relationships)
|
||||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete John's data
|
# Delete John's data
|
||||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||||
|
|
@ -401,23 +389,35 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
await assert_graph_edges_present(
|
await assert_graph_edges_present(
|
||||||
maries_relationships + overlapping_relationships + legacy_relationships
|
maries_relationships + overlapping_relationships + legacy_relationships
|
||||||
)
|
)
|
||||||
await assert_edges_vector_index_present(
|
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
|
||||||
maries_relationships + maries_edge_text_relationships + legacy_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
await assert_graph_edges_not_present(johns_relationships)
|
await assert_graph_edges_not_present(johns_relationships)
|
||||||
|
|
||||||
strictly_johns_relationships = isolate_relationships(
|
johns_contains_relationships = [
|
||||||
johns_relationships, maries_relationships, legacy_relationships
|
(
|
||||||
)
|
johns_chunk.id,
|
||||||
|
entity.id,
|
||||||
|
get_contains_edge_text(entity.name, entity.description),
|
||||||
|
{
|
||||||
|
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for entity in johns_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
|
|
||||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||||
await assert_edges_vector_index_not_present(
|
await assert_edges_vector_index_not_present(johns_contains_relationships)
|
||||||
strictly_johns_relationships + johns_edge_text_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete legacy data
|
# Delete legacy data
|
||||||
await datasets.delete_data(dataset_id, legacy_document.id, user) # type: ignore
|
await datasets.delete_data(dataset_id, legacy_document.id, user) # type: ignore
|
||||||
|
|
||||||
|
graph_file_path = os.path.join(
|
||||||
|
pathlib.Path(__file__).parent,
|
||||||
|
".artifacts/graph_visualization.html",
|
||||||
|
)
|
||||||
|
await visualize_graph(graph_file_path)
|
||||||
|
|
||||||
# Assert data points presence in the graph, vector collections and nodes table
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
await assert_graph_nodes_present(maries_data + overlapping_entities)
|
await assert_graph_nodes_present(maries_data + overlapping_entities)
|
||||||
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
|
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
|
||||||
|
|
@ -427,16 +427,11 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
|
||||||
# Assert relationships presence in the graph, vector collections and nodes table
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
||||||
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
|
await assert_edges_vector_index_present(maries_relationships)
|
||||||
|
|
||||||
await assert_graph_edges_not_present(johns_relationships + legacy_relationships)
|
await assert_graph_edges_not_present(johns_relationships + legacy_relationships)
|
||||||
|
|
||||||
strictly_legacy_relationships = isolate_relationships(
|
# Vector index didn't change after deleting legacy data
|
||||||
legacy_relationships, maries_relationships
|
|
||||||
)
|
|
||||||
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
|
||||||
if strictly_legacy_relationships:
|
|
||||||
await assert_edges_vector_index_not_present(strictly_legacy_relationships)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_mocked_legacy_data(user):
|
async def create_mocked_legacy_data(user):
|
||||||
|
|
@ -447,8 +442,29 @@ async def create_mocked_legacy_data(user):
|
||||||
await graph_engine.add_nodes(legacy_nodes)
|
await graph_engine.add_nodes(legacy_nodes)
|
||||||
await graph_engine.add_edges(legacy_edges)
|
await graph_engine.add_edges(legacy_edges)
|
||||||
|
|
||||||
|
nodes_by_id = {node.id: node for node in legacy_nodes}
|
||||||
|
|
||||||
|
def format_relationship_name(relationship):
|
||||||
|
if relationship[2] == "contains":
|
||||||
|
node = nodes_by_id[relationship[1]]
|
||||||
|
return get_contains_edge_text(node.name, node.description)
|
||||||
|
return relationship[2]
|
||||||
|
|
||||||
await index_data_points(legacy_nodes)
|
await index_data_points(legacy_nodes)
|
||||||
await index_graph_edges(legacy_edges)
|
await index_graph_edges(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
edge[0],
|
||||||
|
edge[1],
|
||||||
|
format_relationship_name(edge),
|
||||||
|
{
|
||||||
|
**(edge[3] or {}),
|
||||||
|
"relationship_name": format_relationship_name(edge),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for edge in legacy_edges
|
||||||
|
] # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.engine.utils import generate_edge_id
|
from cognee.modules.engine.utils import generate_edge_id
|
||||||
|
|
||||||
|
|
||||||
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str]]):
|
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
query_edge_ids = {
|
query_edge_ids = {
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,40 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.modules.engine.utils import generate_edge_id
|
from cognee.modules.engine.utils import generate_edge_id, generate_node_name
|
||||||
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
|
|
||||||
|
|
||||||
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str]]):
|
def format_relationship(relationship: Tuple[UUID, UUID, str, Dict], node: Dict):
|
||||||
|
if relationship[2] == "contains":
|
||||||
|
relationship_name = get_contains_edge_text(
|
||||||
|
generate_node_name(node["name"]),
|
||||||
|
node["description"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
str(generate_edge_id(relationship_name)): relationship_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {str(generate_edge_id(relationship[2])): relationship[2]}
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
query_edge_ids = {
|
graph_engine = await get_graph_engine()
|
||||||
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
|
nodes, _ = await graph_engine.get_graph_data()
|
||||||
}
|
|
||||||
|
nodes_by_id = {str(node[0]): node[1] for node in nodes}
|
||||||
|
|
||||||
|
query_edge_ids = {}
|
||||||
|
for relationship in relationships:
|
||||||
|
query_edge_ids = {
|
||||||
|
**query_edge_ids,
|
||||||
|
**format_relationship(relationship, nodes_by_id[str(relationship[1])]),
|
||||||
|
}
|
||||||
|
|
||||||
vector_items = await vector_engine.retrieve(
|
vector_items = await vector_engine.retrieve(
|
||||||
"EdgeType_relationship_name", list(query_edge_ids.keys())
|
"EdgeType_relationship_name", list(query_edge_ids.keys())
|
||||||
|
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
def isolate_relationships(source_relationships, *other_relationships):
|
|
||||||
final_relationships = []
|
|
||||||
cache = {relationship[2]: 1 for relationship in source_relationships}
|
|
||||||
duplicated_relationships = {}
|
|
||||||
|
|
||||||
for relationships in other_relationships:
|
|
||||||
for relationship in relationships:
|
|
||||||
if relationship[2] not in cache:
|
|
||||||
cache[relationship[2]] = 0
|
|
||||||
|
|
||||||
cache[relationship[2]] += 1
|
|
||||||
|
|
||||||
if cache[relationship[2]] == 2:
|
|
||||||
duplicated_relationships[relationship[2]] = True
|
|
||||||
|
|
||||||
for relationship in source_relationships:
|
|
||||||
if relationship[2] not in duplicated_relationships:
|
|
||||||
final_relationships.append(relationship)
|
|
||||||
|
|
||||||
return final_relationships
|
|
||||||
Loading…
Add table
Reference in a new issue