fix: add detailed tests for delete
This commit is contained in:
parent
77b3e731d8
commit
a89dad328e
20 changed files with 1011 additions and 266 deletions
|
|
@ -6,7 +6,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
from cognee.modules.data.methods import create_dataset
|
from cognee.modules.data.methods import create_authorized_dataset
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||||
|
|
@ -66,7 +66,7 @@ async def get_or_create_dataset_database(
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
# Create dataset if it doesn't exist
|
# Create dataset if it doesn't exist
|
||||||
if isinstance(dataset, str):
|
if isinstance(dataset, str):
|
||||||
dataset = await create_dataset(dataset, user, session)
|
dataset = await create_authorized_dataset(dataset, user)
|
||||||
|
|
||||||
# Try to fetch an existing row first
|
# Try to fetch an existing row first
|
||||||
stmt = select(DatasetDatabase).where(
|
stmt = select(DatasetDatabase).where(
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from cognee.context_global_variables import set_database_global_context_variable
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.data.methods import create_authorized_dataset
|
from cognee.modules.data.methods import create_authorized_dataset
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
|
from cognee.modules.engine.utils import generate_node_id
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
@ -53,11 +54,11 @@ async def main():
|
||||||
works_for: List[Organization]
|
works_for: List[Organization]
|
||||||
metadata: dict = {"index_fields": ["name"]}
|
metadata: dict = {"index_fields": ["name"]}
|
||||||
|
|
||||||
companyA = ForProfit(name="Company A")
|
companyA = ForProfit(id=generate_node_id("Company A"), name="Company A")
|
||||||
companyB = NonProfit(name="Company B")
|
companyB = NonProfit(id=generate_node_id("Company B"), name="Company B")
|
||||||
|
|
||||||
person1 = Person(name="John", works_for=[companyA, companyB])
|
person1 = Person(id=generate_node_id("John"), name="John", works_for=[companyA, companyB])
|
||||||
person2 = Person(name="Jane", works_for=[companyB])
|
person2 = Person(id=generate_node_id("Jane"), name="Jane", works_for=[companyB])
|
||||||
|
|
||||||
user: User = await get_default_user() # type: ignore
|
user: User = await get_default_user() # type: ignore
|
||||||
|
|
||||||
|
|
@ -93,15 +94,59 @@ async def main():
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
# Initial check
|
||||||
assert len(nodes) == 4 and len(edges) == 3, (
|
assert len(nodes) == 4 and len(edges) == 3, (
|
||||||
"Nodes and edges are not correctly added to the graph."
|
"Nodes and edges are not correctly added to the graph."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
nodes_by_id = {node[0]: node[1] for node in nodes}
|
||||||
|
|
||||||
|
assert str(generate_node_id("John")) in nodes_by_id, "John node not present in the graph."
|
||||||
|
assert str(generate_node_id("Jane")) in nodes_by_id, "Jane node not present in the graph."
|
||||||
|
assert str(generate_node_id("Company A")) in nodes_by_id, (
|
||||||
|
"Company A node not present in the graph."
|
||||||
|
)
|
||||||
|
assert str(generate_node_id("Company B")) in nodes_by_id, (
|
||||||
|
"Company B node not present in the graph."
|
||||||
|
)
|
||||||
|
|
||||||
|
edges_by_ids = {f"{edge[0]}_{edge[2]}_{edge[1]}": edge[3] for edge in edges}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
f"{str(generate_node_id('John'))}_works_for_{str(generate_node_id('Company A'))}"
|
||||||
|
in edges_by_ids
|
||||||
|
), "Edge between John and Company A not present in the graph."
|
||||||
|
assert (
|
||||||
|
f"{str(generate_node_id('John'))}_works_for_{str(generate_node_id('Company B'))}"
|
||||||
|
in edges_by_ids
|
||||||
|
), "Edge between John and Company A not present in the graph."
|
||||||
|
assert (
|
||||||
|
f"{str(generate_node_id('Jane'))}_works_for_{str(generate_node_id('Company B'))}"
|
||||||
|
in edges_by_ids
|
||||||
|
), "Edge between John and Company A not present in the graph."
|
||||||
|
|
||||||
|
# Second data deletion
|
||||||
await datasets.delete_data(dataset.id, data1.id, user)
|
await datasets.delete_data(dataset.id, data1.id, user)
|
||||||
|
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly."
|
assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly."
|
||||||
|
|
||||||
|
nodes_by_id = {node[0]: node[1] for node in nodes}
|
||||||
|
|
||||||
|
assert str(generate_node_id("Jane")) in nodes_by_id, "Jane node not present in the graph."
|
||||||
|
assert str(generate_node_id("Company B")) in nodes_by_id, (
|
||||||
|
"Company B node not present in the graph."
|
||||||
|
)
|
||||||
|
|
||||||
|
edges_by_ids = {f"{edge[0]}_{edge[2]}_{edge[1]}": edge[3] for edge in edges}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
f"{str(generate_node_id('Jane'))}_works_for_{str(generate_node_id('Company B'))}"
|
||||||
|
in edges_by_ids
|
||||||
|
), "Edge between John and Company A not present in the graph."
|
||||||
|
|
||||||
|
# Second data deletion
|
||||||
await datasets.delete_data(dataset.id, data2.id, user)
|
await datasets.delete_data(dataset.id, data2.id, user)
|
||||||
|
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,40 @@
|
||||||
import os
|
import os
|
||||||
import pathlib
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pathlib
|
||||||
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.datasets import datasets
|
from cognee.api.v1.datasets import datasets
|
||||||
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
||||||
from cognee.infrastructure.llm import LLMGateway
|
from cognee.infrastructure.llm import LLMGateway
|
||||||
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
|
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
|
||||||
|
from cognee.modules.engine.models import Entity
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.tests.utils.assert_edges_vector_index_not_present import (
|
||||||
|
assert_edges_vector_index_not_present,
|
||||||
|
)
|
||||||
|
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
|
||||||
|
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
|
||||||
|
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
|
||||||
|
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
|
||||||
|
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
|
||||||
|
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
|
||||||
|
assert_nodes_vector_index_not_present,
|
||||||
|
)
|
||||||
|
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
|
||||||
|
from cognee.tests.utils.extract_entities import extract_entities
|
||||||
|
from cognee.tests.utils.extract_relationships import extract_relationships
|
||||||
|
from cognee.tests.utils.extract_summary import extract_summary
|
||||||
|
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||||
|
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||||
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
|
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -107,92 +130,159 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
|
||||||
mock_create_structured_output.side_effect = mock_llm_output
|
mock_create_structured_output.side_effect = mock_llm_output
|
||||||
|
|
||||||
|
user = await get_default_user()
|
||||||
|
|
||||||
|
await set_database_global_context_variables("main_dataset", user.id)
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
|
||||||
assert not await vector_engine.has_collection("Entity_name")
|
assert not await vector_engine.has_collection("Entity_name")
|
||||||
assert not await vector_engine.has_collection("DocumentChunk_text")
|
assert not await vector_engine.has_collection("DocumentChunk_text")
|
||||||
assert not await vector_engine.has_collection("TextSummary_text")
|
assert not await vector_engine.has_collection("TextSummary_text")
|
||||||
assert not await vector_engine.has_collection("TextDocument_text")
|
assert not await vector_engine.has_collection("TextDocument_text")
|
||||||
|
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||||
|
|
||||||
add_john_result = await cognee.add(
|
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
add_john_result = await cognee.add(johns_text)
|
||||||
)
|
|
||||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||||
|
|
||||||
add_marie_result = await cognee.add(
|
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||||
"Marie works for Apple as well. She is a software engineer on MacOS project."
|
add_marie_result = await cognee.add(maries_text)
|
||||||
)
|
|
||||||
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||||
|
|
||||||
cognify_result: dict = await cognee.cognify()
|
cognify_result: dict = await cognee.cognify()
|
||||||
dataset_id = list(cognify_result.keys())[0]
|
dataset_id = list(cognify_result.keys())[0]
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
johns_document = TextDocument(
|
||||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
id=johns_data_id,
|
||||||
assert len(initial_nodes) == 15 and len(initial_edges) == 19, (
|
name="John's Work",
|
||||||
"Number of nodes and edges is not correct."
|
raw_data_location="johns_data_location",
|
||||||
|
external_metadata="",
|
||||||
|
)
|
||||||
|
johns_chunk = DocumentChunk(
|
||||||
|
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
|
||||||
|
text=johns_text,
|
||||||
|
chunk_size=14,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=johns_document,
|
||||||
|
)
|
||||||
|
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
|
||||||
|
|
||||||
|
maries_document = TextDocument(
|
||||||
|
id=maries_data_id,
|
||||||
|
name="Maries's Work",
|
||||||
|
raw_data_location="maries_data_location",
|
||||||
|
external_metadata="",
|
||||||
|
)
|
||||||
|
maries_chunk = DocumentChunk(
|
||||||
|
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
|
||||||
|
text=maries_text,
|
||||||
|
chunk_size=14,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=maries_document,
|
||||||
|
)
|
||||||
|
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
||||||
|
|
||||||
|
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||||
|
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||||
|
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
|
||||||
|
johns_entities, maries_entities
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_nodes_by_vector_collection = {}
|
johns_data = [
|
||||||
|
johns_document,
|
||||||
|
johns_chunk,
|
||||||
|
johns_summary,
|
||||||
|
*johns_entities,
|
||||||
|
]
|
||||||
|
maries_data = [
|
||||||
|
maries_document,
|
||||||
|
maries_chunk,
|
||||||
|
maries_summary,
|
||||||
|
*maries_entities,
|
||||||
|
]
|
||||||
|
|
||||||
for node in initial_nodes:
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
node_data = node[1]
|
await assert_graph_nodes_present(johns_data + maries_data + overlapping_entities)
|
||||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
await assert_nodes_vector_index_present(johns_data + maries_data + overlapping_entities)
|
||||||
if collection_name not in initial_nodes_by_vector_collection:
|
|
||||||
initial_nodes_by_vector_collection[collection_name] = []
|
|
||||||
initial_nodes_by_vector_collection[collection_name].append(node)
|
|
||||||
|
|
||||||
initial_node_ids = set([node[0] for node in initial_nodes])
|
johns_relationships = extract_relationships(
|
||||||
|
johns_chunk,
|
||||||
|
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
||||||
|
)
|
||||||
|
maries_relationships = extract_relationships(
|
||||||
|
maries_chunk,
|
||||||
|
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
||||||
|
)
|
||||||
|
(overlapping_relationships, johns_relationships, maries_relationships) = (
|
||||||
|
filter_overlapping_relationships(johns_relationships, maries_relationships)
|
||||||
|
)
|
||||||
|
|
||||||
user = await get_default_user()
|
johns_relationships = [
|
||||||
|
(johns_chunk.id, johns_document.id, "is_part_of"),
|
||||||
|
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||||
|
*johns_relationships,
|
||||||
|
]
|
||||||
|
johns_edge_text_relationships = [
|
||||||
|
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||||
|
for entity in johns_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
|
maries_relationships = [
|
||||||
|
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||||
|
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||||
|
*maries_relationships,
|
||||||
|
]
|
||||||
|
maries_edge_text_relationships = [
|
||||||
|
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||||
|
for entity in maries_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_relationships = johns_relationships + maries_relationships + overlapping_relationships
|
||||||
|
|
||||||
|
await assert_graph_edges_present(expected_relationships)
|
||||||
|
|
||||||
|
await assert_edges_vector_index_present(
|
||||||
|
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete John's data from cognee
|
||||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||||
|
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted."
|
await assert_graph_nodes_present(maries_data + overlapping_entities)
|
||||||
assert not any(
|
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
|
||||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry"
|
|
||||||
for node in nodes
|
|
||||||
if "name" in node[1]
|
|
||||||
), "Nodes are not deleted."
|
|
||||||
|
|
||||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
await assert_graph_nodes_not_present(johns_data)
|
||||||
|
await assert_nodes_vector_index_not_present(johns_data)
|
||||||
|
|
||||||
after_delete_nodes_by_vector_collection = {}
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
for node in initial_nodes:
|
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
||||||
node_data = node[1]
|
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
|
||||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
|
||||||
if collection_name not in after_delete_nodes_by_vector_collection:
|
|
||||||
after_delete_nodes_by_vector_collection[collection_name] = []
|
|
||||||
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
await assert_graph_edges_not_present(johns_relationships)
|
||||||
|
|
||||||
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
strictly_johns_relationships = isolate_relationships(johns_relationships, maries_relationships)
|
||||||
|
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
await assert_edges_vector_index_not_present(
|
||||||
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
strictly_johns_relationships + johns_edge_text_relationships
|
||||||
|
)
|
||||||
if query_node_ids:
|
|
||||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
|
||||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
|
||||||
|
|
||||||
|
# Delete Marie's data from cognee
|
||||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||||
|
|
||||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
|
||||||
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
|
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
|
||||||
|
|
||||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
query_node_ids = [node[0] for node in initial_nodes]
|
await assert_graph_edges_not_present(
|
||||||
|
johns_relationships + maries_relationships + overlapping_relationships
|
||||||
|
)
|
||||||
|
|
||||||
if query_node_ids:
|
await assert_edges_vector_index_not_present(maries_relationships)
|
||||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
|
||||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
|
||||||
|
|
||||||
query_edge_ids = [edge[0] for edge in initial_edges]
|
|
||||||
|
|
||||||
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
|
||||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import pytest
|
||||||
import pathlib
|
import pathlib
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.datasets import datasets
|
from cognee.api.v1.datasets import datasets
|
||||||
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
@ -17,17 +17,34 @@ from cognee.modules.data.models import Data
|
||||||
from cognee.modules.engine.models import Entity, EntityType
|
from cognee.modules.engine.models import Entity, EntityType
|
||||||
from cognee.modules.data.processing.document_types import TextDocument
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
|
from cognee.modules.engine.utils import generate_node_id
|
||||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||||
from cognee.modules.pipelines.models import DataItemStatus
|
from cognee.modules.pipelines.models import DataItemStatus
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||||
from cognee.tasks.storage import index_data_points, index_graph_edges
|
from cognee.tasks.storage import index_data_points, index_graph_edges
|
||||||
|
from cognee.tests.utils.assert_edges_vector_index_not_present import (
|
||||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
assert_edges_vector_index_not_present,
|
||||||
|
)
|
||||||
|
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
|
||||||
|
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
|
||||||
|
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
|
||||||
|
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
|
||||||
|
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
|
||||||
|
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
|
||||||
|
assert_nodes_vector_index_not_present,
|
||||||
|
)
|
||||||
|
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
|
||||||
|
from cognee.tests.utils.extract_entities import extract_entities
|
||||||
|
from cognee.tests.utils.extract_relationships import extract_relationships
|
||||||
|
from cognee.tests.utils.extract_summary import extract_summary
|
||||||
|
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||||
|
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||||
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
|
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||||
|
|
||||||
|
|
||||||
async def get_nodes_and_edges():
|
def create_nodes_and_edges():
|
||||||
document = TextDocument(
|
document = TextDocument(
|
||||||
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
||||||
name="text_test.txt",
|
name="text_test.txt",
|
||||||
|
|
@ -73,15 +90,8 @@ async def get_nodes_and_edges():
|
||||||
name="amazon s3",
|
name="amazon s3",
|
||||||
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||||
)
|
)
|
||||||
document_chunk.contains = [
|
|
||||||
graph_database,
|
|
||||||
neptune_analytics_entity,
|
|
||||||
neptune_database_entity,
|
|
||||||
storage,
|
|
||||||
storage_entity,
|
|
||||||
]
|
|
||||||
|
|
||||||
data_points = [
|
nodes_data = [
|
||||||
document,
|
document,
|
||||||
document_chunk,
|
document_chunk,
|
||||||
graph_database,
|
graph_database,
|
||||||
|
|
@ -91,39 +101,71 @@ async def get_nodes_and_edges():
|
||||||
storage_entity,
|
storage_entity,
|
||||||
]
|
]
|
||||||
|
|
||||||
nodes = []
|
edges_data = [
|
||||||
edges = []
|
(
|
||||||
|
document_chunk.id,
|
||||||
|
storage_entity.id,
|
||||||
|
"contains",
|
||||||
|
{
|
||||||
|
"relationship_name": "contains",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
storage_entity.id,
|
||||||
|
storage.id,
|
||||||
|
"is_a",
|
||||||
|
{
|
||||||
|
"relationship_name": "is_a",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
document_chunk.id,
|
||||||
|
neptune_database_entity.id,
|
||||||
|
"contains",
|
||||||
|
{
|
||||||
|
"relationship_name": "contains",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
neptune_database_entity.id,
|
||||||
|
graph_database.id,
|
||||||
|
"is_a",
|
||||||
|
{
|
||||||
|
"relationship_name": "is_a",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
document_chunk.id,
|
||||||
|
document.id,
|
||||||
|
"is_part_of",
|
||||||
|
{
|
||||||
|
"relationship_name": "is_part_of",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
document_chunk.id,
|
||||||
|
neptune_analytics_entity.id,
|
||||||
|
"contains",
|
||||||
|
{
|
||||||
|
"relationship_name": "contains",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
neptune_analytics_entity.id,
|
||||||
|
graph_database.id,
|
||||||
|
"is_a",
|
||||||
|
{
|
||||||
|
"relationship_name": "is_a",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
added_nodes = {}
|
return nodes_data, edges_data
|
||||||
added_edges = {}
|
|
||||||
visited_properties = {}
|
|
||||||
|
|
||||||
results = await asyncio.gather(
|
|
||||||
*[
|
|
||||||
get_graph_from_model(
|
|
||||||
data_point,
|
|
||||||
added_nodes=added_nodes,
|
|
||||||
added_edges=added_edges,
|
|
||||||
visited_properties=visited_properties,
|
|
||||||
)
|
|
||||||
for data_point in data_points
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
for result_nodes, result_edges in results:
|
|
||||||
nodes.extend(result_nodes)
|
|
||||||
edges.extend(result_edges)
|
|
||||||
|
|
||||||
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
|
|
||||||
|
|
||||||
return nodes, edges
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||||
async def main(mock_create_structured_output: AsyncMock):
|
async def main(mock_create_structured_output: AsyncMock):
|
||||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "False"
|
|
||||||
|
|
||||||
data_directory_path = os.path.join(
|
data_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_1"
|
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_1"
|
||||||
)
|
)
|
||||||
|
|
@ -139,6 +181,9 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
await setup()
|
await setup()
|
||||||
|
|
||||||
|
user = await get_default_user()
|
||||||
|
await set_database_global_context_variables("main_dataset", user.id)
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||||
|
|
@ -147,9 +192,8 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
assert not await vector_engine.has_collection("TextSummary_text")
|
assert not await vector_engine.has_collection("TextSummary_text")
|
||||||
assert not await vector_engine.has_collection("TextDocument_text")
|
assert not await vector_engine.has_collection("TextDocument_text")
|
||||||
|
|
||||||
user = await get_default_user()
|
# Add legacy data to the system
|
||||||
|
__, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(user)
|
||||||
old_nodes, old_edges = await add_mocked_legacy_data(user)
|
|
||||||
|
|
||||||
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||||
if text_input == "test": # LLM connection test
|
if text_input == "test": # LLM connection test
|
||||||
|
|
@ -225,109 +269,188 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
|
||||||
mock_create_structured_output.side_effect = mock_llm_output
|
mock_create_structured_output.side_effect = mock_llm_output
|
||||||
|
|
||||||
add_john_result = await cognee.add(
|
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
add_john_result = await cognee.add(johns_text)
|
||||||
)
|
|
||||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||||
|
|
||||||
add_marie_result = await cognee.add(
|
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||||
"Marie works for Apple as well. She is a software engineer on MacOS project."
|
add_marie_result = await cognee.add(maries_text)
|
||||||
)
|
|
||||||
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||||
|
|
||||||
cognify_result: dict = await cognee.cognify()
|
cognify_result: dict = await cognee.cognify()
|
||||||
dataset_id = list(cognify_result.keys())[0]
|
dataset_id = list(cognify_result.keys())[0]
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
johns_document = TextDocument(
|
||||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
id=johns_data_id,
|
||||||
assert len(initial_nodes) == 22 and len(initial_edges) == 25, (
|
name="John's Work",
|
||||||
"Number of nodes and edges is not correct."
|
raw_data_location="johns_data_location",
|
||||||
|
external_metadata="",
|
||||||
|
)
|
||||||
|
johns_chunk = DocumentChunk(
|
||||||
|
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
|
||||||
|
text=johns_text,
|
||||||
|
chunk_size=14,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=johns_document,
|
||||||
|
)
|
||||||
|
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
|
||||||
|
|
||||||
|
maries_document = TextDocument(
|
||||||
|
id=maries_data_id,
|
||||||
|
name="Maries's Work",
|
||||||
|
raw_data_location="maries_data_location",
|
||||||
|
external_metadata="",
|
||||||
|
)
|
||||||
|
maries_chunk = DocumentChunk(
|
||||||
|
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
|
||||||
|
text=maries_text,
|
||||||
|
chunk_size=14,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=maries_document,
|
||||||
|
)
|
||||||
|
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
||||||
|
|
||||||
|
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||||
|
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||||
|
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
|
||||||
|
johns_entities, maries_entities
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_nodes_by_vector_collection = {}
|
johns_data = [
|
||||||
|
johns_document,
|
||||||
|
johns_chunk,
|
||||||
|
johns_summary,
|
||||||
|
*johns_entities,
|
||||||
|
]
|
||||||
|
maries_data = [
|
||||||
|
maries_document,
|
||||||
|
maries_chunk,
|
||||||
|
maries_summary,
|
||||||
|
*maries_entities,
|
||||||
|
]
|
||||||
|
|
||||||
for node in initial_nodes:
|
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
|
||||||
node_data = node[1]
|
|
||||||
node_metadata = node_data["metadata"]
|
|
||||||
node_metadata = json.loads(node_metadata) if type(node_metadata) is str else node_metadata
|
|
||||||
collection_name = node_data["type"] + "_" + node_metadata["index_fields"][0]
|
|
||||||
if collection_name not in initial_nodes_by_vector_collection:
|
|
||||||
initial_nodes_by_vector_collection[collection_name] = []
|
|
||||||
initial_nodes_by_vector_collection[collection_name].append(node)
|
|
||||||
|
|
||||||
initial_node_ids = set([node[0] for node in initial_nodes])
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
|
await assert_graph_nodes_present(expected_data_points)
|
||||||
|
await assert_nodes_vector_index_present(expected_data_points)
|
||||||
|
|
||||||
|
johns_relationships = extract_relationships(
|
||||||
|
johns_chunk,
|
||||||
|
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
||||||
|
)
|
||||||
|
maries_relationships = extract_relationships(
|
||||||
|
maries_chunk,
|
||||||
|
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
||||||
|
)
|
||||||
|
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
|
||||||
|
filter_overlapping_relationships(
|
||||||
|
johns_relationships, maries_relationships, legacy_relationships
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
johns_relationships = [
|
||||||
|
(johns_chunk.id, johns_document.id, "is_part_of"),
|
||||||
|
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||||
|
*johns_relationships,
|
||||||
|
]
|
||||||
|
johns_edge_text_relationships = [
|
||||||
|
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||||
|
for entity in johns_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
|
maries_relationships = [
|
||||||
|
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||||
|
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||||
|
*maries_relationships,
|
||||||
|
]
|
||||||
|
maries_edge_text_relationships = [
|
||||||
|
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||||
|
for entity in maries_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_relationships = (
|
||||||
|
johns_relationships
|
||||||
|
+ maries_relationships
|
||||||
|
+ overlapping_relationships
|
||||||
|
+ legacy_relationships
|
||||||
|
)
|
||||||
|
|
||||||
|
await assert_graph_edges_present(expected_relationships)
|
||||||
|
|
||||||
|
await assert_edges_vector_index_present(
|
||||||
|
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete John's data
|
||||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||||
|
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
assert len(nodes) == 16 and len(edges) == 16, "Nodes and edges are not deleted."
|
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
|
||||||
assert not any(
|
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
|
||||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry"
|
|
||||||
for node in nodes
|
|
||||||
if "name" in node[1]
|
|
||||||
), "Nodes are not deleted."
|
|
||||||
|
|
||||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
await assert_graph_nodes_not_present(johns_data)
|
||||||
|
await assert_nodes_vector_index_not_present(johns_data)
|
||||||
|
|
||||||
after_delete_nodes_by_vector_collection = {}
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
for node in initial_nodes:
|
await assert_graph_edges_present(
|
||||||
node_data = node[1]
|
maries_relationships + overlapping_relationships + legacy_relationships
|
||||||
node_metadata = node_data["metadata"]
|
)
|
||||||
node_metadata = json.loads(node_metadata) if type(node_metadata) is str else node_metadata
|
await assert_edges_vector_index_present(
|
||||||
collection_name = node_data["type"] + "_" + node_metadata["index_fields"][0]
|
maries_relationships + maries_edge_text_relationships + legacy_relationships
|
||||||
if collection_name not in after_delete_nodes_by_vector_collection:
|
)
|
||||||
after_delete_nodes_by_vector_collection[collection_name] = []
|
|
||||||
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
await assert_graph_edges_not_present(johns_relationships)
|
||||||
|
|
||||||
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
strictly_johns_relationships = isolate_relationships(
|
||||||
|
johns_relationships, maries_relationships, legacy_relationships
|
||||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
)
|
||||||
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||||
|
await assert_edges_vector_index_not_present(
|
||||||
if query_node_ids:
|
strictly_johns_relationships + johns_edge_text_relationships
|
||||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
)
|
||||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
|
||||||
|
|
||||||
|
# Delete Marie's data
|
||||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||||
|
|
||||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
assert len(final_nodes) == 7 and len(final_edges) == 6, "Nodes and edges are not deleted."
|
await assert_graph_nodes_present(legacy_data_points)
|
||||||
|
await assert_nodes_vector_index_present(legacy_data_points)
|
||||||
|
|
||||||
old_nodes_by_vector_collection = {}
|
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
|
||||||
for node in old_nodes:
|
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
|
||||||
node_metadata = node.metadata
|
|
||||||
collection_name = node.type + "_" + node_metadata["index_fields"][0]
|
|
||||||
if collection_name not in old_nodes_by_vector_collection:
|
|
||||||
old_nodes_by_vector_collection[collection_name] = []
|
|
||||||
old_nodes_by_vector_collection[collection_name].append(node)
|
|
||||||
|
|
||||||
for collection_name, old_nodes in old_nodes_by_vector_collection.items():
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
query_node_ids = [str(node.id) for node in old_nodes]
|
await assert_graph_edges_present(legacy_relationships)
|
||||||
|
await assert_edges_vector_index_present(legacy_relationships)
|
||||||
|
|
||||||
if query_node_ids:
|
await assert_graph_edges_not_present(
|
||||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
johns_relationships + maries_relationships + overlapping_relationships
|
||||||
assert len(vector_items) == len(old_nodes), "Vector items are not deleted."
|
)
|
||||||
|
|
||||||
query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges]))
|
strictly_maries_relationships = isolate_relationships(
|
||||||
|
maries_relationships, legacy_relationships
|
||||||
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
)
|
||||||
assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted."
|
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
||||||
|
if strictly_maries_relationships:
|
||||||
|
await assert_edges_vector_index_not_present(strictly_maries_relationships)
|
||||||
|
|
||||||
|
|
||||||
async def add_mocked_legacy_data(user):
|
async def create_mocked_legacy_data(user):
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
old_nodes, old_edges = await get_nodes_and_edges()
|
legacy_nodes, legacy_edges = create_nodes_and_edges()
|
||||||
old_document = old_nodes[0]
|
legacy_document = legacy_nodes[0]
|
||||||
|
|
||||||
await graph_engine.add_nodes(old_nodes)
|
await graph_engine.add_nodes(legacy_nodes)
|
||||||
await graph_engine.add_edges(old_edges)
|
await graph_engine.add_edges(legacy_edges)
|
||||||
|
|
||||||
await index_data_points(old_nodes)
|
await index_data_points(legacy_nodes)
|
||||||
await index_graph_edges(old_edges)
|
await index_graph_edges(legacy_edges)
|
||||||
|
|
||||||
await record_data_in_legacy_ledger(old_nodes, old_edges, user)
|
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
|
@ -335,12 +458,12 @@ async def add_mocked_legacy_data(user):
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
old_data = Data(
|
old_data = Data(
|
||||||
id=old_document.id,
|
id=legacy_document.id,
|
||||||
name=old_document.name,
|
name=legacy_document.name,
|
||||||
extension="txt",
|
extension="txt",
|
||||||
raw_data_location=old_document.raw_data_location,
|
raw_data_location=legacy_document.raw_data_location,
|
||||||
external_metadata=old_document.external_metadata,
|
external_metadata=legacy_document.external_metadata,
|
||||||
mime_type=old_document.mime_type,
|
mime_type=legacy_document.mime_type,
|
||||||
owner_id=user.id,
|
owner_id=user.id,
|
||||||
pipeline_status={
|
pipeline_status={
|
||||||
"cognify_pipeline": {
|
"cognify_pipeline": {
|
||||||
|
|
@ -355,7 +478,7 @@ async def add_mocked_legacy_data(user):
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return old_nodes, old_edges
|
return legacy_document, legacy_nodes, legacy_edges
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
import pytest
|
||||||
import pathlib
|
import pathlib
|
||||||
from uuid import NAMESPACE_OID, uuid5
|
from uuid import NAMESPACE_OID, uuid5
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.api.v1.datasets import datasets
|
from cognee.api.v1.datasets import datasets
|
||||||
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
|
@ -16,15 +17,34 @@ from cognee.modules.data.models import Data
|
||||||
from cognee.modules.engine.models import Entity, EntityType
|
from cognee.modules.engine.models import Entity, EntityType
|
||||||
from cognee.modules.data.processing.document_types import TextDocument
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
|
from cognee.modules.engine.utils import generate_node_id
|
||||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||||
from cognee.modules.pipelines.models import DataItemStatus
|
from cognee.modules.pipelines.models import DataItemStatus
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||||
from cognee.tasks.storage import index_data_points, index_graph_edges
|
from cognee.tasks.storage import index_data_points, index_graph_edges
|
||||||
|
from cognee.tests.utils.assert_edges_vector_index_not_present import (
|
||||||
|
assert_edges_vector_index_not_present,
|
||||||
|
)
|
||||||
|
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
|
||||||
|
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
|
||||||
|
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
|
||||||
|
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
|
||||||
|
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
|
||||||
|
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
|
||||||
|
assert_nodes_vector_index_not_present,
|
||||||
|
)
|
||||||
|
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
|
||||||
|
from cognee.tests.utils.extract_entities import extract_entities
|
||||||
|
from cognee.tests.utils.extract_relationships import extract_relationships
|
||||||
|
from cognee.tests.utils.extract_summary import extract_summary
|
||||||
|
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||||
|
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||||
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
|
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||||
|
|
||||||
|
|
||||||
def get_nodes_and_edges():
|
def create_nodes_and_edges():
|
||||||
document = TextDocument(
|
document = TextDocument(
|
||||||
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
||||||
name="text_test.txt",
|
name="text_test.txt",
|
||||||
|
|
@ -146,8 +166,6 @@ def get_nodes_and_edges():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||||
async def main(mock_create_structured_output: AsyncMock):
|
async def main(mock_create_structured_output: AsyncMock):
|
||||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "False"
|
|
||||||
|
|
||||||
data_directory_path = os.path.join(
|
data_directory_path = os.path.join(
|
||||||
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_2"
|
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_2"
|
||||||
)
|
)
|
||||||
|
|
@ -163,6 +181,9 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
await cognee.prune.prune_system(metadata=True)
|
await cognee.prune.prune_system(metadata=True)
|
||||||
await setup()
|
await setup()
|
||||||
|
|
||||||
|
user = await get_default_user()
|
||||||
|
await set_database_global_context_variables("main_dataset", user.id)
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||||
|
|
@ -171,9 +192,10 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
assert not await vector_engine.has_collection("TextSummary_text")
|
assert not await vector_engine.has_collection("TextSummary_text")
|
||||||
assert not await vector_engine.has_collection("TextDocument_text")
|
assert not await vector_engine.has_collection("TextDocument_text")
|
||||||
|
|
||||||
user = await get_default_user()
|
# Add legacy data to the system
|
||||||
|
legacy_document, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(
|
||||||
old_document, old_nodes, old_edges = await add_mocked_legacy_data(user)
|
user
|
||||||
|
)
|
||||||
|
|
||||||
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||||
if text_input == "test": # LLM connection test
|
if text_input == "test": # LLM connection test
|
||||||
|
|
@ -249,100 +271,186 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
|
|
||||||
mock_create_structured_output.side_effect = mock_llm_output
|
mock_create_structured_output.side_effect = mock_llm_output
|
||||||
|
|
||||||
add_john_result = await cognee.add(
|
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
add_john_result = await cognee.add(johns_text)
|
||||||
)
|
|
||||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||||
|
|
||||||
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
|
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||||
|
add_marie_result = await cognee.add(maries_text)
|
||||||
|
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||||
|
|
||||||
cognify_result: dict = await cognee.cognify()
|
cognify_result: dict = await cognee.cognify()
|
||||||
dataset_id = list(cognify_result.keys())[0]
|
dataset_id = list(cognify_result.keys())[0]
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
johns_document = TextDocument(
|
||||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
id=johns_data_id,
|
||||||
assert len(initial_nodes) == 22 and len(initial_edges) == 26, (
|
name="John's Work",
|
||||||
"Number of nodes and edges is not correct."
|
raw_data_location="johns_data_location",
|
||||||
|
external_metadata="",
|
||||||
|
)
|
||||||
|
johns_chunk = DocumentChunk(
|
||||||
|
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
|
||||||
|
text=johns_text,
|
||||||
|
chunk_size=14,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=johns_document,
|
||||||
|
)
|
||||||
|
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
|
||||||
|
|
||||||
|
maries_document = TextDocument(
|
||||||
|
id=maries_data_id,
|
||||||
|
name="Maries's Work",
|
||||||
|
raw_data_location="maries_data_location",
|
||||||
|
external_metadata="",
|
||||||
|
)
|
||||||
|
maries_chunk = DocumentChunk(
|
||||||
|
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
|
||||||
|
text=maries_text,
|
||||||
|
chunk_size=14,
|
||||||
|
chunk_index=0,
|
||||||
|
cut_type="sentence_end",
|
||||||
|
is_part_of=maries_document,
|
||||||
|
)
|
||||||
|
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
||||||
|
|
||||||
|
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||||
|
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||||
|
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
|
||||||
|
johns_entities, maries_entities
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_nodes_by_vector_collection = {}
|
johns_data = [
|
||||||
|
johns_document,
|
||||||
|
johns_chunk,
|
||||||
|
johns_summary,
|
||||||
|
*johns_entities,
|
||||||
|
]
|
||||||
|
maries_data = [
|
||||||
|
maries_document,
|
||||||
|
maries_chunk,
|
||||||
|
maries_summary,
|
||||||
|
*maries_entities,
|
||||||
|
]
|
||||||
|
|
||||||
for node in initial_nodes:
|
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
|
||||||
node_data = node[1]
|
|
||||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
|
||||||
if collection_name not in initial_nodes_by_vector_collection:
|
|
||||||
initial_nodes_by_vector_collection[collection_name] = []
|
|
||||||
initial_nodes_by_vector_collection[collection_name].append(node)
|
|
||||||
|
|
||||||
initial_node_ids = set([node[0] for node in initial_nodes])
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
|
await assert_graph_nodes_present(expected_data_points)
|
||||||
|
await assert_nodes_vector_index_present(expected_data_points)
|
||||||
|
|
||||||
|
johns_relationships = extract_relationships(
|
||||||
|
johns_chunk,
|
||||||
|
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
||||||
|
)
|
||||||
|
maries_relationships = extract_relationships(
|
||||||
|
maries_chunk,
|
||||||
|
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
||||||
|
)
|
||||||
|
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
|
||||||
|
filter_overlapping_relationships(
|
||||||
|
johns_relationships, maries_relationships, legacy_relationships
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
johns_relationships = [
|
||||||
|
(johns_chunk.id, johns_document.id, "is_part_of"),
|
||||||
|
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||||
|
*johns_relationships,
|
||||||
|
]
|
||||||
|
johns_edge_text_relationships = [
|
||||||
|
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||||
|
for entity in johns_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
|
maries_relationships = [
|
||||||
|
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||||
|
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||||
|
*maries_relationships,
|
||||||
|
]
|
||||||
|
maries_edge_text_relationships = [
|
||||||
|
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||||
|
for entity in maries_entities
|
||||||
|
if isinstance(entity, Entity)
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_relationships = (
|
||||||
|
johns_relationships
|
||||||
|
+ maries_relationships
|
||||||
|
+ overlapping_relationships
|
||||||
|
+ legacy_relationships
|
||||||
|
)
|
||||||
|
|
||||||
|
await assert_graph_edges_present(expected_relationships)
|
||||||
|
|
||||||
|
await assert_edges_vector_index_present(
|
||||||
|
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete John's data
|
||||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||||
|
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
assert len(nodes) == 16 and len(edges) == 17, "Nodes and edges are not deleted."
|
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
|
||||||
assert not any(
|
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
|
||||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes
|
|
||||||
), "Nodes are not deleted."
|
|
||||||
|
|
||||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
await assert_graph_nodes_not_present(johns_data)
|
||||||
|
await assert_nodes_vector_index_not_present(johns_data)
|
||||||
|
|
||||||
after_delete_nodes_by_vector_collection = {}
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
for node in initial_nodes:
|
await assert_graph_edges_present(
|
||||||
node_data = node[1]
|
maries_relationships + overlapping_relationships + legacy_relationships
|
||||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
)
|
||||||
if collection_name not in after_delete_nodes_by_vector_collection:
|
await assert_edges_vector_index_present(
|
||||||
after_delete_nodes_by_vector_collection[collection_name] = []
|
maries_relationships + maries_edge_text_relationships + legacy_relationships
|
||||||
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
)
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
await assert_graph_edges_not_present(johns_relationships)
|
||||||
|
|
||||||
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
strictly_johns_relationships = isolate_relationships(
|
||||||
|
johns_relationships, maries_relationships, legacy_relationships
|
||||||
|
)
|
||||||
|
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||||
|
await assert_edges_vector_index_not_present(
|
||||||
|
strictly_johns_relationships + johns_edge_text_relationships
|
||||||
|
)
|
||||||
|
|
||||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
# Delete legacy data
|
||||||
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
await datasets.delete_data(dataset_id, legacy_document.id, user) # type: ignore
|
||||||
|
|
||||||
if query_node_ids:
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
await assert_graph_nodes_present(maries_data + overlapping_entities)
|
||||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
|
||||||
|
|
||||||
# Delete old document
|
await assert_graph_nodes_not_present(johns_data + legacy_data_points)
|
||||||
await datasets.delete_data(dataset_id, old_document.id, user) # type: ignore
|
await assert_nodes_vector_index_not_present(johns_data + legacy_data_points)
|
||||||
|
|
||||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
assert len(final_nodes) == 9 and len(final_edges) == 10, "Nodes and edges are not deleted."
|
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
||||||
|
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
|
||||||
|
|
||||||
old_nodes_by_vector_collection = {}
|
await assert_graph_edges_not_present(johns_relationships + legacy_relationships)
|
||||||
for node in old_nodes:
|
|
||||||
collection_name = node.type + "_" + node.metadata["index_fields"][0]
|
|
||||||
if collection_name not in old_nodes_by_vector_collection:
|
|
||||||
old_nodes_by_vector_collection[collection_name] = []
|
|
||||||
old_nodes_by_vector_collection[collection_name].append(node)
|
|
||||||
|
|
||||||
for collection_name, old_nodes in old_nodes_by_vector_collection.items():
|
strictly_legacy_relationships = isolate_relationships(
|
||||||
query_node_ids = [str(node.id) for node in old_nodes]
|
legacy_relationships, maries_relationships
|
||||||
|
)
|
||||||
if query_node_ids:
|
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
||||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
if strictly_legacy_relationships:
|
||||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
await assert_edges_vector_index_not_present(strictly_legacy_relationships)
|
||||||
|
|
||||||
query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges]))
|
|
||||||
|
|
||||||
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
|
||||||
assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted."
|
|
||||||
|
|
||||||
|
|
||||||
async def add_mocked_legacy_data(user):
|
async def create_mocked_legacy_data(user):
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
old_nodes, old_edges = get_nodes_and_edges()
|
legacy_nodes, legacy_edges = create_nodes_and_edges()
|
||||||
old_document = old_nodes[0]
|
legacy_document = legacy_nodes[0]
|
||||||
|
|
||||||
await graph_engine.add_nodes(old_nodes)
|
await graph_engine.add_nodes(legacy_nodes)
|
||||||
await graph_engine.add_edges(old_edges)
|
await graph_engine.add_edges(legacy_edges)
|
||||||
|
|
||||||
await index_data_points(old_nodes)
|
await index_data_points(legacy_nodes)
|
||||||
await index_graph_edges(old_edges)
|
await index_graph_edges(legacy_edges)
|
||||||
|
|
||||||
await record_data_in_legacy_ledger(old_nodes, old_edges, user)
|
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
|
@ -350,12 +458,12 @@ async def add_mocked_legacy_data(user):
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
old_data = Data(
|
old_data = Data(
|
||||||
id=old_document.id,
|
id=legacy_document.id,
|
||||||
name=old_document.name,
|
name=legacy_document.name,
|
||||||
extension="txt",
|
extension="txt",
|
||||||
raw_data_location=old_document.raw_data_location,
|
raw_data_location=legacy_document.raw_data_location,
|
||||||
external_metadata=old_document.external_metadata,
|
external_metadata=legacy_document.external_metadata,
|
||||||
mime_type=old_document.mime_type,
|
mime_type=legacy_document.mime_type,
|
||||||
owner_id=user.id,
|
owner_id=user.id,
|
||||||
pipeline_status={
|
pipeline_status={
|
||||||
"cognify_pipeline": {
|
"cognify_pipeline": {
|
||||||
|
|
@ -370,7 +478,7 @@ async def add_mocked_legacy_data(user):
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return old_document, old_nodes, old_edges
|
return legacy_document, legacy_nodes, legacy_edges
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
23
cognee/tests/utils/assert_edges_vector_index_not_present.py
Normal file
23
cognee/tests/utils/assert_edges_vector_index_not_present.py
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import List, Tuple
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.engine.utils import generate_edge_id
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
query_edge_ids = {
|
||||||
|
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
|
||||||
|
}
|
||||||
|
|
||||||
|
vector_items = await vector_engine.retrieve(
|
||||||
|
"EdgeType_relationship_name", list(query_edge_ids.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_items_by_id = {str(vector_item.id): vector_item for vector_item in vector_items}
|
||||||
|
|
||||||
|
for relationship_id, relationship_name in query_edge_ids.items():
|
||||||
|
assert relationship_id not in vector_items_by_id, (
|
||||||
|
f"Relationship '{relationship_name}' still present in the vector store."
|
||||||
|
)
|
||||||
28
cognee/tests/utils/assert_edges_vector_index_present.py
Normal file
28
cognee/tests/utils/assert_edges_vector_index_present.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import List, Tuple
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.modules.engine.utils import generate_edge_id
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
query_edge_ids = {
|
||||||
|
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
|
||||||
|
}
|
||||||
|
|
||||||
|
vector_items = await vector_engine.retrieve(
|
||||||
|
"EdgeType_relationship_name", list(query_edge_ids.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_items_by_id = {str(vector_item.id): vector_item for vector_item in vector_items}
|
||||||
|
|
||||||
|
for relationship_id, relationship_name in query_edge_ids.items():
|
||||||
|
assert relationship_id in vector_items_by_id, (
|
||||||
|
f"Relationship '{relationship_name}' not found in vector store."
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_relationship = vector_items_by_id[relationship_id]
|
||||||
|
assert vector_relationship.payload["text"] == relationship_name, (
|
||||||
|
f"Vectorized edge '{vector_relationship.payload['text']}' does not match the relationship text '{relationship_name}'."
|
||||||
|
)
|
||||||
21
cognee/tests/utils/assert_graph_edges_not_present.py
Normal file
21
cognee/tests/utils/assert_graph_edges_not_present.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
nodes_by_id = {str(node[0]): node[1] for node in nodes}
|
||||||
|
|
||||||
|
edge_ids = set([f"{str(edge[0])}_{edge[2]}_{str(edge[1])}" for edge in edges])
|
||||||
|
|
||||||
|
for relationship in relationships:
|
||||||
|
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
||||||
|
relationship_name = relationship[2]
|
||||||
|
assert relationship_id not in edge_ids, (
|
||||||
|
f"Edge '{relationship_name}' still present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
|
||||||
|
)
|
||||||
21
cognee/tests/utils/assert_graph_edges_present.py
Normal file
21
cognee/tests/utils/assert_graph_edges_present.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
nodes_by_id = {str(node[0]): node[1] for node in nodes}
|
||||||
|
|
||||||
|
edge_ids = set([f"{str(edge[0])}_{edge[2]}_{str(edge[1])}" for edge in edges])
|
||||||
|
|
||||||
|
for relationship in relationships:
|
||||||
|
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
||||||
|
relationship_name = relationship[2]
|
||||||
|
assert relationship_id in edge_ids, (
|
||||||
|
f"Edge '{relationship_name}' not present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
|
||||||
|
)
|
||||||
16
cognee/tests/utils/assert_graph_nodes_not_present.py
Normal file
16
cognee/tests/utils/assert_graph_nodes_not_present.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
from typing import List
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_graph_nodes_not_present(data_points: List[DataPoint]):
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes, __ = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
node_ids = set(node[0] for node in nodes)
|
||||||
|
|
||||||
|
for data_point in data_points:
|
||||||
|
node_name = getattr(data_point, "label", getattr(data_point, "name", data_point.id))
|
||||||
|
assert str(data_point.id) not in node_ids, (
|
||||||
|
f"Node '{node_name}' is present in graph database."
|
||||||
|
)
|
||||||
14
cognee/tests/utils/assert_graph_nodes_present.py
Normal file
14
cognee/tests/utils/assert_graph_nodes_present.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
from typing import List
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_graph_nodes_present(data_points: List[DataPoint]):
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
nodes, __ = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
node_ids = set(node[0] for node in nodes)
|
||||||
|
|
||||||
|
for data_point in data_points:
|
||||||
|
node_name = getattr(data_point, "label", getattr(data_point, "name", data_point.id))
|
||||||
|
assert str(data_point.id) in node_ids, f"Node '{node_name}' not found in graph database."
|
||||||
28
cognee/tests/utils/assert_nodes_vector_index_not_present.py
Normal file
28
cognee/tests/utils/assert_nodes_vector_index_not_present.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
from typing import List
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_nodes_vector_index_not_present(data_points: List[DataPoint]):
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
data_points_by_vector_collection = {}
|
||||||
|
|
||||||
|
for data_point in data_points:
|
||||||
|
node_metadata = data_point.metadata or {}
|
||||||
|
collection_name = data_point.type + "_" + node_metadata["index_fields"][0]
|
||||||
|
|
||||||
|
if collection_name not in data_points_by_vector_collection:
|
||||||
|
data_points_by_vector_collection[collection_name] = []
|
||||||
|
|
||||||
|
data_points_by_vector_collection[collection_name].append(data_point)
|
||||||
|
|
||||||
|
for collection_name, collection_data_points in data_points_by_vector_collection.items():
|
||||||
|
query_data_point_ids = set([str(data_point.id) for data_point in collection_data_points])
|
||||||
|
|
||||||
|
vector_items = await vector_engine.retrieve(collection_name, list(query_data_point_ids))
|
||||||
|
|
||||||
|
for vector_item in vector_items:
|
||||||
|
assert str(vector_item.id) not in query_data_point_ids, (
|
||||||
|
f"{vector_item.payload['text']} is still present in the vector store."
|
||||||
|
)
|
||||||
28
cognee/tests/utils/assert_nodes_vector_index_present.py
Normal file
28
cognee/tests/utils/assert_nodes_vector_index_present.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
from typing import List
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_nodes_vector_index_present(data_points: List[DataPoint]):
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
|
data_points_by_vector_collection = {}
|
||||||
|
|
||||||
|
for data_point in data_points:
|
||||||
|
node_metadata = data_point.metadata or {}
|
||||||
|
collection_name = data_point.type + "_" + node_metadata["index_fields"][0]
|
||||||
|
|
||||||
|
if collection_name not in data_points_by_vector_collection:
|
||||||
|
data_points_by_vector_collection[collection_name] = []
|
||||||
|
|
||||||
|
data_points_by_vector_collection[collection_name].append(data_point)
|
||||||
|
|
||||||
|
for collection_name, collection_data_points in data_points_by_vector_collection.items():
|
||||||
|
query_data_point_ids = set([str(data_point.id) for data_point in collection_data_points])
|
||||||
|
|
||||||
|
vector_items = await vector_engine.retrieve(collection_name, list(query_data_point_ids))
|
||||||
|
|
||||||
|
for vector_item in vector_items:
|
||||||
|
assert str(vector_item.id) in query_data_point_ids, (
|
||||||
|
f"{vector_item.payload['text']} is not present in the vector store."
|
||||||
|
)
|
||||||
45
cognee/tests/utils/extract_entities.py
Normal file
45
cognee/tests/utils/extract_entities.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
from cognee.modules.engine.models import Entity, EntityType
|
||||||
|
from cognee.modules.engine.utils import generate_node_id, generate_node_name
|
||||||
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
|
|
||||||
|
|
||||||
|
def extract_entities(graph: KnowledgeGraph, cache: dict = {}):
|
||||||
|
entities = []
|
||||||
|
entity_types = []
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
node_id = generate_node_id(node.id)
|
||||||
|
|
||||||
|
if node_id not in cache:
|
||||||
|
entity = Entity(
|
||||||
|
id=node_id,
|
||||||
|
name=generate_node_name(node.id),
|
||||||
|
type=node.type,
|
||||||
|
description=node.description,
|
||||||
|
ontology_valid=False,
|
||||||
|
)
|
||||||
|
cache[node_id] = entity
|
||||||
|
else:
|
||||||
|
entity = cache[node_id]
|
||||||
|
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
|
node_type = node.type
|
||||||
|
type_node_id = generate_node_id(node_type)
|
||||||
|
if type_node_id not in cache:
|
||||||
|
type_node_name = generate_node_name(node_type)
|
||||||
|
|
||||||
|
type_node = EntityType(
|
||||||
|
id=type_node_id,
|
||||||
|
name=type_node_name,
|
||||||
|
type=type_node_name,
|
||||||
|
description=type_node_name,
|
||||||
|
ontology_valid=False,
|
||||||
|
)
|
||||||
|
cache[type_node_id] = type_node
|
||||||
|
else:
|
||||||
|
type_node = cache[type_node_id]
|
||||||
|
|
||||||
|
entity_types.append(type_node)
|
||||||
|
|
||||||
|
return entities + entity_types
|
||||||
55
cognee/tests/utils/extract_relationships.py
Normal file
55
cognee/tests/utils/extract_relationships.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
|
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||||
|
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
|
||||||
|
|
||||||
|
|
||||||
|
def extract_relationships(document_chunk: DocumentChunk, graph: KnowledgeGraph, cache: dict = {}):
|
||||||
|
relationships = []
|
||||||
|
|
||||||
|
for edge in graph.edges:
|
||||||
|
edge_id = f"{edge.source_node_id}_{edge.relationship_name}_{edge.target_node_id}"
|
||||||
|
|
||||||
|
if edge_id not in cache:
|
||||||
|
relationship = (
|
||||||
|
generate_edge_id(edge.source_node_id),
|
||||||
|
generate_edge_id(edge.target_node_id),
|
||||||
|
edge.relationship_name,
|
||||||
|
)
|
||||||
|
cache[edge_id] = relationship
|
||||||
|
else:
|
||||||
|
relationship = cache[edge_id]
|
||||||
|
|
||||||
|
relationships.append(relationship)
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
node_id = generate_node_id(node.id)
|
||||||
|
type_node_id = generate_node_id(node.type)
|
||||||
|
type_edge_id = f"{str(node_id)}_is_a_{str(type_node_id)}"
|
||||||
|
|
||||||
|
if type_edge_id not in cache:
|
||||||
|
relationship = (
|
||||||
|
node_id,
|
||||||
|
type_node_id,
|
||||||
|
"is_a",
|
||||||
|
)
|
||||||
|
cache[type_edge_id] = relationship
|
||||||
|
else:
|
||||||
|
relationship = cache[type_edge_id]
|
||||||
|
|
||||||
|
relationships.append(relationship)
|
||||||
|
|
||||||
|
chunk_edge_id = f"{str(document_chunk.id)}_contains_{str(node_id)}"
|
||||||
|
|
||||||
|
if chunk_edge_id not in cache:
|
||||||
|
relationship = (
|
||||||
|
document_chunk.id,
|
||||||
|
node_id,
|
||||||
|
"contains",
|
||||||
|
)
|
||||||
|
cache[chunk_edge_id] = relationship
|
||||||
|
else:
|
||||||
|
relationship = cache[chunk_edge_id]
|
||||||
|
|
||||||
|
relationships.append(relationship)
|
||||||
|
|
||||||
|
return relationships
|
||||||
12
cognee/tests/utils/extract_summary.py
Normal file
12
cognee/tests/utils/extract_summary.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
from uuid import uuid5
|
||||||
|
from cognee.modules.chunking.models import DocumentChunk
|
||||||
|
from cognee.shared.data_models import SummarizedContent
|
||||||
|
from cognee.tasks.summarization.models import TextSummary
|
||||||
|
|
||||||
|
|
||||||
|
def extract_summary(document_chunk: DocumentChunk, summary=SummarizedContent) -> TextSummary:
|
||||||
|
return TextSummary(
|
||||||
|
id=uuid5(document_chunk.id, "TextSummary"),
|
||||||
|
text=summary.summary,
|
||||||
|
made_from=document_chunk,
|
||||||
|
)
|
||||||
26
cognee/tests/utils/filter_overlapping_entities.py
Normal file
26
cognee/tests/utils/filter_overlapping_entities.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
def filter_overlapping_entities(*entity_groups):
|
||||||
|
entity_count = {}
|
||||||
|
overlapping_entities = []
|
||||||
|
|
||||||
|
for group in entity_groups:
|
||||||
|
for entity in group:
|
||||||
|
if not entity.id in entity_count:
|
||||||
|
entity_count[entity.id] = 1
|
||||||
|
else:
|
||||||
|
entity_count[entity.id] += 1
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
grouped_entities = []
|
||||||
|
for group in entity_groups:
|
||||||
|
grouped_entities.append([])
|
||||||
|
|
||||||
|
for entity in group:
|
||||||
|
if entity_count[entity.id] == 1:
|
||||||
|
grouped_entities[index].append(entity)
|
||||||
|
else:
|
||||||
|
if entity not in overlapping_entities:
|
||||||
|
overlapping_entities.append(entity)
|
||||||
|
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
return overlapping_entities, *grouped_entities
|
||||||
33
cognee/tests/utils/filter_overlapping_relationships.py
Normal file
33
cognee/tests/utils/filter_overlapping_relationships.py
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
from cognee.modules.engine.utils import generate_node_id
|
||||||
|
|
||||||
|
|
||||||
|
def filter_overlapping_relationships(*relationship_groups):
|
||||||
|
relationship_count = {}
|
||||||
|
overlapping_relationships = []
|
||||||
|
|
||||||
|
for group in relationship_groups:
|
||||||
|
for relationship in group:
|
||||||
|
relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}"
|
||||||
|
|
||||||
|
if not relationship_id in relationship_count:
|
||||||
|
relationship_count[relationship_id] = 1
|
||||||
|
else:
|
||||||
|
relationship_count[relationship_id] += 1
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
grouped_relationships = []
|
||||||
|
for group in relationship_groups:
|
||||||
|
grouped_relationships.append([])
|
||||||
|
|
||||||
|
for relationship in group:
|
||||||
|
relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}"
|
||||||
|
|
||||||
|
if relationship_count[relationship_id] == 1:
|
||||||
|
grouped_relationships[index].append(relationship)
|
||||||
|
else:
|
||||||
|
if relationship not in overlapping_relationships:
|
||||||
|
overlapping_relationships.append(relationship)
|
||||||
|
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
return overlapping_relationships, *grouped_relationships
|
||||||
9
cognee/tests/utils/get_contains_edge_text.py
Normal file
9
cognee/tests/utils/get_contains_edge_text.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
def get_contains_edge_text(entity_name: str, entity_description: str) -> str:
|
||||||
|
edge_text = "; ".join(
|
||||||
|
[
|
||||||
|
"relationship_name: contains",
|
||||||
|
f"entity_name: {entity_name}",
|
||||||
|
f"entity_description: {entity_description}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return edge_text
|
||||||
20
cognee/tests/utils/isolate_relationships.py
Normal file
20
cognee/tests/utils/isolate_relationships.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
def isolate_relationships(source_relationships, *other_relationships):
|
||||||
|
final_relationships = []
|
||||||
|
cache = {relationship[2]: 1 for relationship in source_relationships}
|
||||||
|
duplicated_relationships = {}
|
||||||
|
|
||||||
|
for relationships in other_relationships:
|
||||||
|
for relationship in relationships:
|
||||||
|
if relationship[2] not in cache:
|
||||||
|
cache[relationship[2]] = 0
|
||||||
|
|
||||||
|
cache[relationship[2]] += 1
|
||||||
|
|
||||||
|
if cache[relationship[2]] == 2:
|
||||||
|
duplicated_relationships[relationship[2]] = True
|
||||||
|
|
||||||
|
for relationship in source_relationships:
|
||||||
|
if relationship[2] not in duplicated_relationships:
|
||||||
|
final_relationships.append(relationship)
|
||||||
|
|
||||||
|
return final_relationships
|
||||||
Loading…
Add table
Reference in a new issue