fix: add detailed tests for delete

This commit is contained in:
Boris Arzentar 2025-11-17 22:04:30 +01:00
parent 77b3e731d8
commit a89dad328e
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
20 changed files with 1011 additions and 266 deletions

View file

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from cognee.base_config import get_base_config
from cognee.modules.data.methods import create_dataset
from cognee.modules.data.methods import create_authorized_dataset
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vectordb_config
from cognee.infrastructure.databases.graph.config import get_graph_config
@ -66,7 +66,7 @@ async def get_or_create_dataset_database(
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist
if isinstance(dataset, str):
dataset = await create_dataset(dataset, user, session)
dataset = await create_authorized_dataset(dataset, user)
# Try to fetch an existing row first
stmt = select(DatasetDatabase).where(

View file

@ -10,6 +10,7 @@ from cognee.context_global_variables import set_database_global_context_variable
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.methods import create_authorized_dataset
from cognee.modules.engine.operations.setup import setup
from cognee.modules.engine.utils import generate_node_id
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
@ -53,11 +54,11 @@ async def main():
works_for: List[Organization]
metadata: dict = {"index_fields": ["name"]}
companyA = ForProfit(name="Company A")
companyB = NonProfit(name="Company B")
companyA = ForProfit(id=generate_node_id("Company A"), name="Company A")
companyB = NonProfit(id=generate_node_id("Company B"), name="Company B")
person1 = Person(name="John", works_for=[companyA, companyB])
person2 = Person(name="Jane", works_for=[companyB])
person1 = Person(id=generate_node_id("John"), name="John", works_for=[companyA, companyB])
person2 = Person(id=generate_node_id("Jane"), name="Jane", works_for=[companyB])
user: User = await get_default_user() # type: ignore
@ -93,15 +94,59 @@ async def main():
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
# Initial check
assert len(nodes) == 4 and len(edges) == 3, (
"Nodes and edges are not correctly added to the graph."
)
nodes_by_id = {node[0]: node[1] for node in nodes}
assert str(generate_node_id("John")) in nodes_by_id, "John node not present in the graph."
assert str(generate_node_id("Jane")) in nodes_by_id, "Jane node not present in the graph."
assert str(generate_node_id("Company A")) in nodes_by_id, (
"Company A node not present in the graph."
)
assert str(generate_node_id("Company B")) in nodes_by_id, (
"Company B node not present in the graph."
)
edges_by_ids = {f"{edge[0]}_{edge[2]}_{edge[1]}": edge[3] for edge in edges}
assert (
f"{str(generate_node_id('John'))}_works_for_{str(generate_node_id('Company A'))}"
in edges_by_ids
), "Edge between John and Company A not present in the graph."
assert (
f"{str(generate_node_id('John'))}_works_for_{str(generate_node_id('Company B'))}"
in edges_by_ids
), "Edge between John and Company A not present in the graph."
assert (
f"{str(generate_node_id('Jane'))}_works_for_{str(generate_node_id('Company B'))}"
in edges_by_ids
), "Edge between John and Company A not present in the graph."
# Second data deletion
await datasets.delete_data(dataset.id, data1.id, user)
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly."
nodes_by_id = {node[0]: node[1] for node in nodes}
assert str(generate_node_id("Jane")) in nodes_by_id, "Jane node not present in the graph."
assert str(generate_node_id("Company B")) in nodes_by_id, (
"Company B node not present in the graph."
)
edges_by_ids = {f"{edge[0]}_{edge[2]}_{edge[1]}": edge[3] for edge in edges}
assert (
f"{str(generate_node_id('Jane'))}_works_for_{str(generate_node_id('Company B'))}"
in edges_by_ids
), "Edge between John and Company A not present in the graph."
# Second data deletion
await datasets.delete_data(dataset.id, data2.id, user)
nodes, edges = await graph_engine.get_graph_data()

View file

@ -1,17 +1,40 @@
import os
import pathlib
import pytest
import pathlib
from uuid import NAMESPACE_OID, uuid5
from unittest.mock import AsyncMock, patch
import cognee
from cognee.api.v1.datasets import datasets
from cognee.context_global_variables import set_database_global_context_variables
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm import LLMGateway
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
from cognee.modules.engine.models import Entity
from cognee.modules.engine.operations.setup import setup
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
from cognee.shared.logging_utils import get_logger
from cognee.tests.utils.assert_edges_vector_index_not_present import (
assert_edges_vector_index_not_present,
)
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
assert_nodes_vector_index_not_present,
)
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
from cognee.tests.utils.extract_entities import extract_entities
from cognee.tests.utils.extract_relationships import extract_relationships
from cognee.tests.utils.extract_summary import extract_summary
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
from cognee.tests.utils.isolate_relationships import isolate_relationships
logger = get_logger()
@ -107,92 +130,159 @@ async def main(mock_create_structured_output: AsyncMock):
mock_create_structured_output.side_effect = mock_llm_output
user = await get_default_user()
await set_database_global_context_variables("main_dataset", user.id)
vector_engine = get_vector_engine()
assert not await vector_engine.has_collection("EdgeType_relationship_name")
assert not await vector_engine.has_collection("Entity_name")
assert not await vector_engine.has_collection("DocumentChunk_text")
assert not await vector_engine.has_collection("TextSummary_text")
assert not await vector_engine.has_collection("TextDocument_text")
assert not await vector_engine.has_collection("EdgeType_relationship_name")
add_john_result = await cognee.add(
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
)
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
add_john_result = await cognee.add(johns_text)
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
add_marie_result = await cognee.add(
"Marie works for Apple as well. She is a software engineer on MacOS project."
)
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
add_marie_result = await cognee.add(maries_text)
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
cognify_result: dict = await cognee.cognify()
dataset_id = list(cognify_result.keys())[0]
graph_engine = await get_graph_engine()
initial_nodes, initial_edges = await graph_engine.get_graph_data()
assert len(initial_nodes) == 15 and len(initial_edges) == 19, (
"Number of nodes and edges is not correct."
johns_document = TextDocument(
id=johns_data_id,
name="John's Work",
raw_data_location="johns_data_location",
external_metadata="",
)
johns_chunk = DocumentChunk(
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
text=johns_text,
chunk_size=14,
chunk_index=0,
cut_type="sentence_end",
is_part_of=johns_document,
)
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
maries_document = TextDocument(
id=maries_data_id,
name="Maries's Work",
raw_data_location="maries_data_location",
external_metadata="",
)
maries_chunk = DocumentChunk(
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
text=maries_text,
chunk_size=14,
chunk_index=0,
cut_type="sentence_end",
is_part_of=maries_document,
)
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
johns_entities, maries_entities
)
initial_nodes_by_vector_collection = {}
johns_data = [
johns_document,
johns_chunk,
johns_summary,
*johns_entities,
]
maries_data = [
maries_document,
maries_chunk,
maries_summary,
*maries_entities,
]
for node in initial_nodes:
node_data = node[1]
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
if collection_name not in initial_nodes_by_vector_collection:
initial_nodes_by_vector_collection[collection_name] = []
initial_nodes_by_vector_collection[collection_name].append(node)
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(johns_data + maries_data + overlapping_entities)
await assert_nodes_vector_index_present(johns_data + maries_data + overlapping_entities)
initial_node_ids = set([node[0] for node in initial_nodes])
johns_relationships = extract_relationships(
johns_chunk,
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
)
maries_relationships = extract_relationships(
maries_chunk,
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
)
(overlapping_relationships, johns_relationships, maries_relationships) = (
filter_overlapping_relationships(johns_relationships, maries_relationships)
)
user = await get_default_user()
johns_relationships = [
(johns_chunk.id, johns_document.id, "is_part_of"),
(johns_summary.id, johns_chunk.id, "made_from"),
*johns_relationships,
]
johns_edge_text_relationships = [
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in johns_entities
if isinstance(entity, Entity)
]
maries_relationships = [
(maries_chunk.id, maries_document.id, "is_part_of"),
(maries_summary.id, maries_chunk.id, "made_from"),
*maries_relationships,
]
maries_edge_text_relationships = [
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in maries_entities
if isinstance(entity, Entity)
]
expected_relationships = johns_relationships + maries_relationships + overlapping_relationships
await assert_graph_edges_present(expected_relationships)
await assert_edges_vector_index_present(
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
)
# Delete John's data from cognee
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted."
assert not any(
node[1]["name"] == "john" or node[1]["name"] == "food for hungry"
for node in nodes
if "name" in node[1]
), "Nodes are not deleted."
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(maries_data + overlapping_entities)
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
after_first_delete_node_ids = set([node[0] for node in nodes])
await assert_graph_nodes_not_present(johns_data)
await assert_nodes_vector_index_not_present(johns_data)
after_delete_nodes_by_vector_collection = {}
for node in initial_nodes:
node_data = node[1]
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
if collection_name not in after_delete_nodes_by_vector_collection:
after_delete_nodes_by_vector_collection[collection_name] = []
after_delete_nodes_by_vector_collection[collection_name].append(node)
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
vector_engine = get_vector_engine()
await assert_graph_edges_not_present(johns_relationships)
removed_node_ids = initial_node_ids - after_first_delete_node_ids
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
strictly_johns_relationships = isolate_relationships(johns_relationships, maries_relationships)
# We check only by relationship name and we need edges that are created by John's data and no other.
await assert_edges_vector_index_not_present(
strictly_johns_relationships + johns_edge_text_relationships
)
# Delete Marie's data from cognee
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
final_nodes, final_edges = await graph_engine.get_graph_data()
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
query_node_ids = [node[0] for node in initial_nodes]
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_not_present(
johns_relationships + maries_relationships + overlapping_relationships
)
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
query_edge_ids = [edge[0] for edge in initial_edges]
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
await assert_edges_vector_index_not_present(maries_relationships)
if __name__ == "__main__":

View file

@ -1,12 +1,12 @@
import os
import json
import pytest
import pathlib
from uuid import NAMESPACE_OID, uuid5
import pytest
from unittest.mock import AsyncMock, patch
import cognee
from cognee.api.v1.datasets import datasets
from cognee.context_global_variables import set_database_global_context_variables
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
@ -17,17 +17,34 @@ from cognee.modules.data.models import Data
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.engine.operations.setup import setup
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
from cognee.modules.engine.utils import generate_node_id
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
from cognee.modules.pipelines.models import DataItemStatus
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
from cognee.tasks.storage import index_data_points, index_graph_edges
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
from cognee.tests.utils.assert_edges_vector_index_not_present import (
assert_edges_vector_index_not_present,
)
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
assert_nodes_vector_index_not_present,
)
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
from cognee.tests.utils.extract_entities import extract_entities
from cognee.tests.utils.extract_relationships import extract_relationships
from cognee.tests.utils.extract_summary import extract_summary
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
from cognee.tests.utils.isolate_relationships import isolate_relationships
async def get_nodes_and_edges():
def create_nodes_and_edges():
document = TextDocument(
id=uuid5(NAMESPACE_OID, "text_test.txt"),
name="text_test.txt",
@ -73,15 +90,8 @@ async def get_nodes_and_edges():
name="amazon s3",
description="A storage service provided by Amazon Web Services that allows storing graph data.",
)
document_chunk.contains = [
graph_database,
neptune_analytics_entity,
neptune_database_entity,
storage,
storage_entity,
]
data_points = [
nodes_data = [
document,
document_chunk,
graph_database,
@ -91,39 +101,71 @@ async def get_nodes_and_edges():
storage_entity,
]
nodes = []
edges = []
edges_data = [
(
document_chunk.id,
storage_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
storage_entity.id,
storage.id,
"is_a",
{
"relationship_name": "is_a",
},
),
(
document_chunk.id,
neptune_database_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
neptune_database_entity.id,
graph_database.id,
"is_a",
{
"relationship_name": "is_a",
},
),
(
document_chunk.id,
document.id,
"is_part_of",
{
"relationship_name": "is_part_of",
},
),
(
document_chunk.id,
neptune_analytics_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
neptune_analytics_entity.id,
graph_database.id,
"is_a",
{
"relationship_name": "is_a",
},
),
]
added_nodes = {}
added_edges = {}
visited_properties = {}
results = await asyncio.gather(
*[
get_graph_from_model(
data_point,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
for data_point in data_points
]
)
for result_nodes, result_edges in results:
nodes.extend(result_nodes)
edges.extend(result_edges)
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
return nodes, edges
return nodes_data, edges_data
@pytest.mark.asyncio
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
async def main(mock_create_structured_output: AsyncMock):
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "False"
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_1"
)
@ -139,6 +181,9 @@ async def main(mock_create_structured_output: AsyncMock):
await cognee.prune.prune_system(metadata=True)
await setup()
user = await get_default_user()
await set_database_global_context_variables("main_dataset", user.id)
vector_engine = get_vector_engine()
assert not await vector_engine.has_collection("EdgeType_relationship_name")
@ -147,9 +192,8 @@ async def main(mock_create_structured_output: AsyncMock):
assert not await vector_engine.has_collection("TextSummary_text")
assert not await vector_engine.has_collection("TextDocument_text")
user = await get_default_user()
old_nodes, old_edges = await add_mocked_legacy_data(user)
# Add legacy data to the system
__, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(user)
def mock_llm_output(text_input: str, system_prompt: str, response_model):
if text_input == "test": # LLM connection test
@ -225,109 +269,188 @@ async def main(mock_create_structured_output: AsyncMock):
mock_create_structured_output.side_effect = mock_llm_output
add_john_result = await cognee.add(
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
)
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
add_john_result = await cognee.add(johns_text)
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
add_marie_result = await cognee.add(
"Marie works for Apple as well. She is a software engineer on MacOS project."
)
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
add_marie_result = await cognee.add(maries_text)
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
cognify_result: dict = await cognee.cognify()
dataset_id = list(cognify_result.keys())[0]
graph_engine = await get_graph_engine()
initial_nodes, initial_edges = await graph_engine.get_graph_data()
assert len(initial_nodes) == 22 and len(initial_edges) == 25, (
"Number of nodes and edges is not correct."
johns_document = TextDocument(
id=johns_data_id,
name="John's Work",
raw_data_location="johns_data_location",
external_metadata="",
)
johns_chunk = DocumentChunk(
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
text=johns_text,
chunk_size=14,
chunk_index=0,
cut_type="sentence_end",
is_part_of=johns_document,
)
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
maries_document = TextDocument(
id=maries_data_id,
name="Maries's Work",
raw_data_location="maries_data_location",
external_metadata="",
)
maries_chunk = DocumentChunk(
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
text=maries_text,
chunk_size=14,
chunk_index=0,
cut_type="sentence_end",
is_part_of=maries_document,
)
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
johns_entities, maries_entities
)
initial_nodes_by_vector_collection = {}
johns_data = [
johns_document,
johns_chunk,
johns_summary,
*johns_entities,
]
maries_data = [
maries_document,
maries_chunk,
maries_summary,
*maries_entities,
]
for node in initial_nodes:
node_data = node[1]
node_metadata = node_data["metadata"]
node_metadata = json.loads(node_metadata) if type(node_metadata) is str else node_metadata
collection_name = node_data["type"] + "_" + node_metadata["index_fields"][0]
if collection_name not in initial_nodes_by_vector_collection:
initial_nodes_by_vector_collection[collection_name] = []
initial_nodes_by_vector_collection[collection_name].append(node)
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
initial_node_ids = set([node[0] for node in initial_nodes])
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(expected_data_points)
await assert_nodes_vector_index_present(expected_data_points)
johns_relationships = extract_relationships(
johns_chunk,
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
)
maries_relationships = extract_relationships(
maries_chunk,
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
)
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
filter_overlapping_relationships(
johns_relationships, maries_relationships, legacy_relationships
)
)
johns_relationships = [
(johns_chunk.id, johns_document.id, "is_part_of"),
(johns_summary.id, johns_chunk.id, "made_from"),
*johns_relationships,
]
johns_edge_text_relationships = [
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in johns_entities
if isinstance(entity, Entity)
]
maries_relationships = [
(maries_chunk.id, maries_document.id, "is_part_of"),
(maries_summary.id, maries_chunk.id, "made_from"),
*maries_relationships,
]
maries_edge_text_relationships = [
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in maries_entities
if isinstance(entity, Entity)
]
expected_relationships = (
johns_relationships
+ maries_relationships
+ overlapping_relationships
+ legacy_relationships
)
await assert_graph_edges_present(expected_relationships)
await assert_edges_vector_index_present(
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
)
# Delete John's data
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 16 and len(edges) == 16, "Nodes and edges are not deleted."
assert not any(
node[1]["name"] == "john" or node[1]["name"] == "food for hungry"
for node in nodes
if "name" in node[1]
), "Nodes are not deleted."
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
after_first_delete_node_ids = set([node[0] for node in nodes])
await assert_graph_nodes_not_present(johns_data)
await assert_nodes_vector_index_not_present(johns_data)
after_delete_nodes_by_vector_collection = {}
for node in initial_nodes:
node_data = node[1]
node_metadata = node_data["metadata"]
node_metadata = json.loads(node_metadata) if type(node_metadata) is str else node_metadata
collection_name = node_data["type"] + "_" + node_metadata["index_fields"][0]
if collection_name not in after_delete_nodes_by_vector_collection:
after_delete_nodes_by_vector_collection[collection_name] = []
after_delete_nodes_by_vector_collection[collection_name].append(node)
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(
maries_relationships + overlapping_relationships + legacy_relationships
)
await assert_edges_vector_index_present(
maries_relationships + maries_edge_text_relationships + legacy_relationships
)
vector_engine = get_vector_engine()
await assert_graph_edges_not_present(johns_relationships)
removed_node_ids = initial_node_ids - after_first_delete_node_ids
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
strictly_johns_relationships = isolate_relationships(
johns_relationships, maries_relationships, legacy_relationships
)
# We check only by relationship name and we need edges that are created by John's data and no other.
await assert_edges_vector_index_not_present(
strictly_johns_relationships + johns_edge_text_relationships
)
# Delete Marie's data
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
final_nodes, final_edges = await graph_engine.get_graph_data()
assert len(final_nodes) == 7 and len(final_edges) == 6, "Nodes and edges are not deleted."
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(legacy_data_points)
await assert_nodes_vector_index_present(legacy_data_points)
old_nodes_by_vector_collection = {}
for node in old_nodes:
node_metadata = node.metadata
collection_name = node.type + "_" + node_metadata["index_fields"][0]
if collection_name not in old_nodes_by_vector_collection:
old_nodes_by_vector_collection[collection_name] = []
old_nodes_by_vector_collection[collection_name].append(node)
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
for collection_name, old_nodes in old_nodes_by_vector_collection.items():
query_node_ids = [str(node.id) for node in old_nodes]
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(legacy_relationships)
await assert_edges_vector_index_present(legacy_relationships)
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == len(old_nodes), "Vector items are not deleted."
await assert_graph_edges_not_present(
johns_relationships + maries_relationships + overlapping_relationships
)
query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges]))
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted."
strictly_maries_relationships = isolate_relationships(
maries_relationships, legacy_relationships
)
# We check only by relationship name and we need edges that are created by legacy data and no other.
if strictly_maries_relationships:
await assert_edges_vector_index_not_present(strictly_maries_relationships)
async def add_mocked_legacy_data(user):
async def create_mocked_legacy_data(user):
graph_engine = await get_graph_engine()
old_nodes, old_edges = await get_nodes_and_edges()
old_document = old_nodes[0]
legacy_nodes, legacy_edges = create_nodes_and_edges()
legacy_document = legacy_nodes[0]
await graph_engine.add_nodes(old_nodes)
await graph_engine.add_edges(old_edges)
await graph_engine.add_nodes(legacy_nodes)
await graph_engine.add_edges(legacy_edges)
await index_data_points(old_nodes)
await index_graph_edges(old_edges)
await index_data_points(legacy_nodes)
await index_graph_edges(legacy_edges)
await record_data_in_legacy_ledger(old_nodes, old_edges, user)
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
db_engine = get_relational_engine()
@ -335,12 +458,12 @@ async def add_mocked_legacy_data(user):
async with db_engine.get_async_session() as session:
old_data = Data(
id=old_document.id,
name=old_document.name,
id=legacy_document.id,
name=legacy_document.name,
extension="txt",
raw_data_location=old_document.raw_data_location,
external_metadata=old_document.external_metadata,
mime_type=old_document.mime_type,
raw_data_location=legacy_document.raw_data_location,
external_metadata=legacy_document.external_metadata,
mime_type=legacy_document.mime_type,
owner_id=user.id,
pipeline_status={
"cognify_pipeline": {
@ -355,7 +478,7 @@ async def add_mocked_legacy_data(user):
await session.commit()
return old_nodes, old_edges
return legacy_document, legacy_nodes, legacy_edges
if __name__ == "__main__":

View file

@ -1,11 +1,12 @@
import os
import pytest
import pathlib
from uuid import NAMESPACE_OID, uuid5
import pytest
from unittest.mock import AsyncMock, patch
import cognee
from cognee.api.v1.datasets import datasets
from cognee.context_global_variables import set_database_global_context_variables
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
@ -16,15 +17,34 @@ from cognee.modules.data.models import Data
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.engine.operations.setup import setup
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
from cognee.modules.engine.utils import generate_node_id
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
from cognee.modules.pipelines.models import DataItemStatus
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
from cognee.tasks.storage import index_data_points, index_graph_edges
from cognee.tests.utils.assert_edges_vector_index_not_present import (
assert_edges_vector_index_not_present,
)
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
assert_nodes_vector_index_not_present,
)
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
from cognee.tests.utils.extract_entities import extract_entities
from cognee.tests.utils.extract_relationships import extract_relationships
from cognee.tests.utils.extract_summary import extract_summary
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
from cognee.tests.utils.isolate_relationships import isolate_relationships
def get_nodes_and_edges():
def create_nodes_and_edges():
document = TextDocument(
id=uuid5(NAMESPACE_OID, "text_test.txt"),
name="text_test.txt",
@ -146,8 +166,6 @@ def get_nodes_and_edges():
@pytest.mark.asyncio
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
async def main(mock_create_structured_output: AsyncMock):
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "False"
data_directory_path = os.path.join(
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_2"
)
@ -163,6 +181,9 @@ async def main(mock_create_structured_output: AsyncMock):
await cognee.prune.prune_system(metadata=True)
await setup()
user = await get_default_user()
await set_database_global_context_variables("main_dataset", user.id)
vector_engine = get_vector_engine()
assert not await vector_engine.has_collection("EdgeType_relationship_name")
@ -171,9 +192,10 @@ async def main(mock_create_structured_output: AsyncMock):
assert not await vector_engine.has_collection("TextSummary_text")
assert not await vector_engine.has_collection("TextDocument_text")
user = await get_default_user()
old_document, old_nodes, old_edges = await add_mocked_legacy_data(user)
# Add legacy data to the system
legacy_document, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(
user
)
def mock_llm_output(text_input: str, system_prompt: str, response_model):
if text_input == "test": # LLM connection test
@ -249,100 +271,186 @@ async def main(mock_create_structured_output: AsyncMock):
mock_create_structured_output.side_effect = mock_llm_output
add_john_result = await cognee.add(
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
)
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
add_john_result = await cognee.add(johns_text)
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
add_marie_result = await cognee.add(maries_text)
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
cognify_result: dict = await cognee.cognify()
dataset_id = list(cognify_result.keys())[0]
graph_engine = await get_graph_engine()
initial_nodes, initial_edges = await graph_engine.get_graph_data()
assert len(initial_nodes) == 22 and len(initial_edges) == 26, (
"Number of nodes and edges is not correct."
johns_document = TextDocument(
id=johns_data_id,
name="John's Work",
raw_data_location="johns_data_location",
external_metadata="",
)
johns_chunk = DocumentChunk(
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
text=johns_text,
chunk_size=14,
chunk_index=0,
cut_type="sentence_end",
is_part_of=johns_document,
)
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
maries_document = TextDocument(
id=maries_data_id,
name="Maries's Work",
raw_data_location="maries_data_location",
external_metadata="",
)
maries_chunk = DocumentChunk(
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
text=maries_text,
chunk_size=14,
chunk_index=0,
cut_type="sentence_end",
is_part_of=maries_document,
)
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
johns_entities, maries_entities
)
initial_nodes_by_vector_collection = {}
johns_data = [
johns_document,
johns_chunk,
johns_summary,
*johns_entities,
]
maries_data = [
maries_document,
maries_chunk,
maries_summary,
*maries_entities,
]
for node in initial_nodes:
node_data = node[1]
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
if collection_name not in initial_nodes_by_vector_collection:
initial_nodes_by_vector_collection[collection_name] = []
initial_nodes_by_vector_collection[collection_name].append(node)
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
initial_node_ids = set([node[0] for node in initial_nodes])
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(expected_data_points)
await assert_nodes_vector_index_present(expected_data_points)
johns_relationships = extract_relationships(
johns_chunk,
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
)
maries_relationships = extract_relationships(
maries_chunk,
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
)
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
filter_overlapping_relationships(
johns_relationships, maries_relationships, legacy_relationships
)
)
johns_relationships = [
(johns_chunk.id, johns_document.id, "is_part_of"),
(johns_summary.id, johns_chunk.id, "made_from"),
*johns_relationships,
]
johns_edge_text_relationships = [
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in johns_entities
if isinstance(entity, Entity)
]
maries_relationships = [
(maries_chunk.id, maries_document.id, "is_part_of"),
(maries_summary.id, maries_chunk.id, "made_from"),
*maries_relationships,
]
maries_edge_text_relationships = [
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
for entity in maries_entities
if isinstance(entity, Entity)
]
expected_relationships = (
johns_relationships
+ maries_relationships
+ overlapping_relationships
+ legacy_relationships
)
await assert_graph_edges_present(expected_relationships)
await assert_edges_vector_index_present(
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
)
# Delete John's data
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 16 and len(edges) == 17, "Nodes and edges are not deleted."
assert not any(
node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes
), "Nodes are not deleted."
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
after_first_delete_node_ids = set([node[0] for node in nodes])
await assert_graph_nodes_not_present(johns_data)
await assert_nodes_vector_index_not_present(johns_data)
after_delete_nodes_by_vector_collection = {}
for node in initial_nodes:
node_data = node[1]
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
if collection_name not in after_delete_nodes_by_vector_collection:
after_delete_nodes_by_vector_collection[collection_name] = []
after_delete_nodes_by_vector_collection[collection_name].append(node)
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(
maries_relationships + overlapping_relationships + legacy_relationships
)
await assert_edges_vector_index_present(
maries_relationships + maries_edge_text_relationships + legacy_relationships
)
vector_engine = get_vector_engine()
await assert_graph_edges_not_present(johns_relationships)
removed_node_ids = initial_node_ids - after_first_delete_node_ids
strictly_johns_relationships = isolate_relationships(
johns_relationships, maries_relationships, legacy_relationships
)
# We check only by relationship name and we need edges that are created by John's data and no other.
await assert_edges_vector_index_not_present(
strictly_johns_relationships + johns_edge_text_relationships
)
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
# Delete legacy data
await datasets.delete_data(dataset_id, legacy_document.id, user) # type: ignore
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(maries_data + overlapping_entities)
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
# Delete old document
await datasets.delete_data(dataset_id, old_document.id, user) # type: ignore
await assert_graph_nodes_not_present(johns_data + legacy_data_points)
await assert_nodes_vector_index_not_present(johns_data + legacy_data_points)
final_nodes, final_edges = await graph_engine.get_graph_data()
assert len(final_nodes) == 9 and len(final_edges) == 10, "Nodes and edges are not deleted."
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
old_nodes_by_vector_collection = {}
for node in old_nodes:
collection_name = node.type + "_" + node.metadata["index_fields"][0]
if collection_name not in old_nodes_by_vector_collection:
old_nodes_by_vector_collection[collection_name] = []
old_nodes_by_vector_collection[collection_name].append(node)
await assert_graph_edges_not_present(johns_relationships + legacy_relationships)
for collection_name, old_nodes in old_nodes_by_vector_collection.items():
query_node_ids = [str(node.id) for node in old_nodes]
if query_node_ids:
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
assert len(vector_items) == 0, "Vector items are not deleted."
query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges]))
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted."
strictly_legacy_relationships = isolate_relationships(
legacy_relationships, maries_relationships
)
# We check only by relationship name and we need edges that are created by legacy data and no other.
if strictly_legacy_relationships:
await assert_edges_vector_index_not_present(strictly_legacy_relationships)
async def add_mocked_legacy_data(user):
async def create_mocked_legacy_data(user):
graph_engine = await get_graph_engine()
old_nodes, old_edges = get_nodes_and_edges()
old_document = old_nodes[0]
legacy_nodes, legacy_edges = create_nodes_and_edges()
legacy_document = legacy_nodes[0]
await graph_engine.add_nodes(old_nodes)
await graph_engine.add_edges(old_edges)
await graph_engine.add_nodes(legacy_nodes)
await graph_engine.add_edges(legacy_edges)
await index_data_points(old_nodes)
await index_graph_edges(old_edges)
await index_data_points(legacy_nodes)
await index_graph_edges(legacy_edges)
await record_data_in_legacy_ledger(old_nodes, old_edges, user)
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
db_engine = get_relational_engine()
@ -350,12 +458,12 @@ async def add_mocked_legacy_data(user):
async with db_engine.get_async_session() as session:
old_data = Data(
id=old_document.id,
name=old_document.name,
id=legacy_document.id,
name=legacy_document.name,
extension="txt",
raw_data_location=old_document.raw_data_location,
external_metadata=old_document.external_metadata,
mime_type=old_document.mime_type,
raw_data_location=legacy_document.raw_data_location,
external_metadata=legacy_document.external_metadata,
mime_type=legacy_document.mime_type,
owner_id=user.id,
pipeline_status={
"cognify_pipeline": {
@ -370,7 +478,7 @@ async def add_mocked_legacy_data(user):
await session.commit()
return old_document, old_nodes, old_edges
return legacy_document, legacy_nodes, legacy_edges
if __name__ == "__main__":

View file

@ -0,0 +1,23 @@
from uuid import UUID
from typing import List, Tuple
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.engine.utils import generate_edge_id
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str]]):
vector_engine = get_vector_engine()
query_edge_ids = {
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
}
vector_items = await vector_engine.retrieve(
"EdgeType_relationship_name", list(query_edge_ids.keys())
)
vector_items_by_id = {str(vector_item.id): vector_item for vector_item in vector_items}
for relationship_id, relationship_name in query_edge_ids.items():
assert relationship_id not in vector_items_by_id, (
f"Relationship '{relationship_name}' still present in the vector store."
)

View file

@ -0,0 +1,28 @@
from uuid import UUID
from typing import List, Tuple
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.engine.utils import generate_edge_id
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str]]):
vector_engine = get_vector_engine()
query_edge_ids = {
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
}
vector_items = await vector_engine.retrieve(
"EdgeType_relationship_name", list(query_edge_ids.keys())
)
vector_items_by_id = {str(vector_item.id): vector_item for vector_item in vector_items}
for relationship_id, relationship_name in query_edge_ids.items():
assert relationship_id in vector_items_by_id, (
f"Relationship '{relationship_name}' not found in vector store."
)
vector_relationship = vector_items_by_id[relationship_id]
assert vector_relationship.payload["text"] == relationship_name, (
f"Vectorized edge '{vector_relationship.payload['text']}' does not match the relationship text '{relationship_name}'."
)

View file

@ -0,0 +1,21 @@
from uuid import UUID
from typing import List, Tuple
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine.models.DataPoint import DataPoint
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str]]):
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
nodes_by_id = {str(node[0]): node[1] for node in nodes}
edge_ids = set([f"{str(edge[0])}_{edge[2]}_{str(edge[1])}" for edge in edges])
for relationship in relationships:
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
relationship_name = relationship[2]
assert relationship_id not in edge_ids, (
f"Edge '{relationship_name}' still present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
)

View file

@ -0,0 +1,21 @@
from uuid import UUID
from typing import List, Tuple
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine.models.DataPoint import DataPoint
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]):
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
nodes_by_id = {str(node[0]): node[1] for node in nodes}
edge_ids = set([f"{str(edge[0])}_{edge[2]}_{str(edge[1])}" for edge in edges])
for relationship in relationships:
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
relationship_name = relationship[2]
assert relationship_id in edge_ids, (
f"Edge '{relationship_name}' not present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
)

View file

@ -0,0 +1,16 @@
from typing import List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine.models.DataPoint import DataPoint
async def assert_graph_nodes_not_present(data_points: List[DataPoint]):
graph_engine = await get_graph_engine()
nodes, __ = await graph_engine.get_graph_data()
node_ids = set(node[0] for node in nodes)
for data_point in data_points:
node_name = getattr(data_point, "label", getattr(data_point, "name", data_point.id))
assert str(data_point.id) not in node_ids, (
f"Node '{node_name}' is present in graph database."
)

View file

@ -0,0 +1,14 @@
from typing import List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine.models.DataPoint import DataPoint
async def assert_graph_nodes_present(data_points: List[DataPoint]):
graph_engine = await get_graph_engine()
nodes, __ = await graph_engine.get_graph_data()
node_ids = set(node[0] for node in nodes)
for data_point in data_points:
node_name = getattr(data_point, "label", getattr(data_point, "name", data_point.id))
assert str(data_point.id) in node_ids, f"Node '{node_name}' not found in graph database."

View file

@ -0,0 +1,28 @@
from typing import List
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine.models.DataPoint import DataPoint
async def assert_nodes_vector_index_not_present(data_points: List[DataPoint]):
vector_engine = get_vector_engine()
data_points_by_vector_collection = {}
for data_point in data_points:
node_metadata = data_point.metadata or {}
collection_name = data_point.type + "_" + node_metadata["index_fields"][0]
if collection_name not in data_points_by_vector_collection:
data_points_by_vector_collection[collection_name] = []
data_points_by_vector_collection[collection_name].append(data_point)
for collection_name, collection_data_points in data_points_by_vector_collection.items():
query_data_point_ids = set([str(data_point.id) for data_point in collection_data_points])
vector_items = await vector_engine.retrieve(collection_name, list(query_data_point_ids))
for vector_item in vector_items:
assert str(vector_item.id) not in query_data_point_ids, (
f"{vector_item.payload['text']} is still present in the vector store."
)

View file

@ -0,0 +1,28 @@
from typing import List
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine.models.DataPoint import DataPoint
async def assert_nodes_vector_index_present(data_points: List[DataPoint]):
vector_engine = get_vector_engine()
data_points_by_vector_collection = {}
for data_point in data_points:
node_metadata = data_point.metadata or {}
collection_name = data_point.type + "_" + node_metadata["index_fields"][0]
if collection_name not in data_points_by_vector_collection:
data_points_by_vector_collection[collection_name] = []
data_points_by_vector_collection[collection_name].append(data_point)
for collection_name, collection_data_points in data_points_by_vector_collection.items():
query_data_point_ids = set([str(data_point.id) for data_point in collection_data_points])
vector_items = await vector_engine.retrieve(collection_name, list(query_data_point_ids))
for vector_item in vector_items:
assert str(vector_item.id) in query_data_point_ids, (
f"{vector_item.payload['text']} is not present in the vector store."
)

View file

@ -0,0 +1,45 @@
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import generate_node_id, generate_node_name
from cognee.shared.data_models import KnowledgeGraph
def extract_entities(graph: KnowledgeGraph, cache: dict = {}):
entities = []
entity_types = []
for node in graph.nodes:
node_id = generate_node_id(node.id)
if node_id not in cache:
entity = Entity(
id=node_id,
name=generate_node_name(node.id),
type=node.type,
description=node.description,
ontology_valid=False,
)
cache[node_id] = entity
else:
entity = cache[node_id]
entities.append(entity)
node_type = node.type
type_node_id = generate_node_id(node_type)
if type_node_id not in cache:
type_node_name = generate_node_name(node_type)
type_node = EntityType(
id=type_node_id,
name=type_node_name,
type=type_node_name,
description=type_node_name,
ontology_valid=False,
)
cache[type_node_id] = type_node
else:
type_node = cache[type_node_id]
entity_types.append(type_node)
return entities + entity_types

View file

@ -0,0 +1,55 @@
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
def extract_relationships(document_chunk: DocumentChunk, graph: KnowledgeGraph, cache: dict = {}):
relationships = []
for edge in graph.edges:
edge_id = f"{edge.source_node_id}_{edge.relationship_name}_{edge.target_node_id}"
if edge_id not in cache:
relationship = (
generate_edge_id(edge.source_node_id),
generate_edge_id(edge.target_node_id),
edge.relationship_name,
)
cache[edge_id] = relationship
else:
relationship = cache[edge_id]
relationships.append(relationship)
for node in graph.nodes:
node_id = generate_node_id(node.id)
type_node_id = generate_node_id(node.type)
type_edge_id = f"{str(node_id)}_is_a_{str(type_node_id)}"
if type_edge_id not in cache:
relationship = (
node_id,
type_node_id,
"is_a",
)
cache[type_edge_id] = relationship
else:
relationship = cache[type_edge_id]
relationships.append(relationship)
chunk_edge_id = f"{str(document_chunk.id)}_contains_{str(node_id)}"
if chunk_edge_id not in cache:
relationship = (
document_chunk.id,
node_id,
"contains",
)
cache[chunk_edge_id] = relationship
else:
relationship = cache[chunk_edge_id]
relationships.append(relationship)
return relationships

View file

@ -0,0 +1,12 @@
from uuid import uuid5
from cognee.modules.chunking.models import DocumentChunk
from cognee.shared.data_models import SummarizedContent
from cognee.tasks.summarization.models import TextSummary
def extract_summary(document_chunk: DocumentChunk, summary=SummarizedContent) -> TextSummary:
return TextSummary(
id=uuid5(document_chunk.id, "TextSummary"),
text=summary.summary,
made_from=document_chunk,
)

View file

@ -0,0 +1,26 @@
def filter_overlapping_entities(*entity_groups):
entity_count = {}
overlapping_entities = []
for group in entity_groups:
for entity in group:
if not entity.id in entity_count:
entity_count[entity.id] = 1
else:
entity_count[entity.id] += 1
index = 0
grouped_entities = []
for group in entity_groups:
grouped_entities.append([])
for entity in group:
if entity_count[entity.id] == 1:
grouped_entities[index].append(entity)
else:
if entity not in overlapping_entities:
overlapping_entities.append(entity)
index += 1
return overlapping_entities, *grouped_entities

View file

@ -0,0 +1,33 @@
from cognee.modules.engine.utils import generate_node_id
def filter_overlapping_relationships(*relationship_groups):
relationship_count = {}
overlapping_relationships = []
for group in relationship_groups:
for relationship in group:
relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}"
if not relationship_id in relationship_count:
relationship_count[relationship_id] = 1
else:
relationship_count[relationship_id] += 1
index = 0
grouped_relationships = []
for group in relationship_groups:
grouped_relationships.append([])
for relationship in group:
relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}"
if relationship_count[relationship_id] == 1:
grouped_relationships[index].append(relationship)
else:
if relationship not in overlapping_relationships:
overlapping_relationships.append(relationship)
index += 1
return overlapping_relationships, *grouped_relationships

View file

@ -0,0 +1,9 @@
def get_contains_edge_text(entity_name: str, entity_description: str) -> str:
edge_text = "; ".join(
[
"relationship_name: contains",
f"entity_name: {entity_name}",
f"entity_description: {entity_description}",
]
)
return edge_text

View file

@ -0,0 +1,20 @@
def isolate_relationships(source_relationships, *other_relationships):
final_relationships = []
cache = {relationship[2]: 1 for relationship in source_relationships}
duplicated_relationships = {}
for relationships in other_relationships:
for relationship in relationships:
if relationship[2] not in cache:
cache[relationship[2]] = 0
cache[relationship[2]] += 1
if cache[relationship[2]] == 2:
duplicated_relationships[relationship[2]] = True
for relationship in source_relationships:
if relationship[2] not in duplicated_relationships:
final_relationships.append(relationship)
return final_relationships