fix: add detailed tests for delete
This commit is contained in:
parent
77b3e731d8
commit
a89dad328e
20 changed files with 1011 additions and 266 deletions
|
|
@ -6,7 +6,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.modules.data.methods import create_dataset
|
||||
from cognee.modules.data.methods import create_authorized_dataset
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_config
|
||||
|
|
@ -66,7 +66,7 @@ async def get_or_create_dataset_database(
|
|||
async with db_engine.get_async_session() as session:
|
||||
# Create dataset if it doesn't exist
|
||||
if isinstance(dataset, str):
|
||||
dataset = await create_dataset(dataset, user, session)
|
||||
dataset = await create_authorized_dataset(dataset, user)
|
||||
|
||||
# Try to fetch an existing row first
|
||||
stmt = select(DatasetDatabase).where(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.context_global_variables import set_database_global_context_variable
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.methods import create_authorized_dataset
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.engine.utils import generate_node_id
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -53,11 +54,11 @@ async def main():
|
|||
works_for: List[Organization]
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
|
||||
companyA = ForProfit(name="Company A")
|
||||
companyB = NonProfit(name="Company B")
|
||||
companyA = ForProfit(id=generate_node_id("Company A"), name="Company A")
|
||||
companyB = NonProfit(id=generate_node_id("Company B"), name="Company B")
|
||||
|
||||
person1 = Person(name="John", works_for=[companyA, companyB])
|
||||
person2 = Person(name="Jane", works_for=[companyB])
|
||||
person1 = Person(id=generate_node_id("John"), name="John", works_for=[companyA, companyB])
|
||||
person2 = Person(id=generate_node_id("Jane"), name="Jane", works_for=[companyB])
|
||||
|
||||
user: User = await get_default_user() # type: ignore
|
||||
|
||||
|
|
@ -93,15 +94,59 @@ async def main():
|
|||
graph_engine = await get_graph_engine()
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
# Initial check
|
||||
assert len(nodes) == 4 and len(edges) == 3, (
|
||||
"Nodes and edges are not correctly added to the graph."
|
||||
)
|
||||
|
||||
nodes_by_id = {node[0]: node[1] for node in nodes}
|
||||
|
||||
assert str(generate_node_id("John")) in nodes_by_id, "John node not present in the graph."
|
||||
assert str(generate_node_id("Jane")) in nodes_by_id, "Jane node not present in the graph."
|
||||
assert str(generate_node_id("Company A")) in nodes_by_id, (
|
||||
"Company A node not present in the graph."
|
||||
)
|
||||
assert str(generate_node_id("Company B")) in nodes_by_id, (
|
||||
"Company B node not present in the graph."
|
||||
)
|
||||
|
||||
edges_by_ids = {f"{edge[0]}_{edge[2]}_{edge[1]}": edge[3] for edge in edges}
|
||||
|
||||
assert (
|
||||
f"{str(generate_node_id('John'))}_works_for_{str(generate_node_id('Company A'))}"
|
||||
in edges_by_ids
|
||||
), "Edge between John and Company A not present in the graph."
|
||||
assert (
|
||||
f"{str(generate_node_id('John'))}_works_for_{str(generate_node_id('Company B'))}"
|
||||
in edges_by_ids
|
||||
), "Edge between John and Company A not present in the graph."
|
||||
assert (
|
||||
f"{str(generate_node_id('Jane'))}_works_for_{str(generate_node_id('Company B'))}"
|
||||
in edges_by_ids
|
||||
), "Edge between John and Company A not present in the graph."
|
||||
|
||||
# Second data deletion
|
||||
await datasets.delete_data(dataset.id, data1.id, user)
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 2 and len(edges) == 1, "Nodes and edges are not deleted properly."
|
||||
|
||||
nodes_by_id = {node[0]: node[1] for node in nodes}
|
||||
|
||||
assert str(generate_node_id("Jane")) in nodes_by_id, "Jane node not present in the graph."
|
||||
assert str(generate_node_id("Company B")) in nodes_by_id, (
|
||||
"Company B node not present in the graph."
|
||||
)
|
||||
|
||||
edges_by_ids = {f"{edge[0]}_{edge[2]}_{edge[1]}": edge[3] for edge in edges}
|
||||
|
||||
assert (
|
||||
f"{str(generate_node_id('Jane'))}_works_for_{str(generate_node_id('Company B'))}"
|
||||
in edges_by_ids
|
||||
), "Edge between John and Company A not present in the graph."
|
||||
|
||||
# Second data deletion
|
||||
await datasets.delete_data(dataset.id, data2.id, user)
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
|
|
|||
|
|
@ -1,17 +1,40 @@
|
|||
import os
|
||||
import pathlib
|
||||
import pytest
|
||||
import pathlib
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.llm import LLMGateway
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types.TextDocument import TextDocument
|
||||
from cognee.modules.engine.models import Entity
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.tests.utils.assert_edges_vector_index_not_present import (
|
||||
assert_edges_vector_index_not_present,
|
||||
)
|
||||
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
|
||||
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
|
||||
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
|
||||
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
|
||||
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
|
||||
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
|
||||
assert_nodes_vector_index_not_present,
|
||||
)
|
||||
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
|
||||
from cognee.tests.utils.extract_entities import extract_entities
|
||||
from cognee.tests.utils.extract_relationships import extract_relationships
|
||||
from cognee.tests.utils.extract_summary import extract_summary
|
||||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -107,92 +130,159 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
|
||||
mock_create_structured_output.side_effect = mock_llm_output
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
await set_database_global_context_variables("main_dataset", user.id)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||
assert not await vector_engine.has_collection("Entity_name")
|
||||
assert not await vector_engine.has_collection("DocumentChunk_text")
|
||||
assert not await vector_engine.has_collection("TextSummary_text")
|
||||
assert not await vector_engine.has_collection("TextDocument_text")
|
||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||
|
||||
add_john_result = await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
add_john_result = await cognee.add(johns_text)
|
||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
add_marie_result = await cognee.add(
|
||||
"Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
)
|
||||
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
add_marie_result = await cognee.add(maries_text)
|
||||
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
dataset_id = list(cognify_result.keys())[0]
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||
assert len(initial_nodes) == 15 and len(initial_edges) == 19, (
|
||||
"Number of nodes and edges is not correct."
|
||||
johns_document = TextDocument(
|
||||
id=johns_data_id,
|
||||
name="John's Work",
|
||||
raw_data_location="johns_data_location",
|
||||
external_metadata="",
|
||||
)
|
||||
johns_chunk = DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
|
||||
text=johns_text,
|
||||
chunk_size=14,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=johns_document,
|
||||
)
|
||||
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
|
||||
|
||||
maries_document = TextDocument(
|
||||
id=maries_data_id,
|
||||
name="Maries's Work",
|
||||
raw_data_location="maries_data_location",
|
||||
external_metadata="",
|
||||
)
|
||||
maries_chunk = DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
|
||||
text=maries_text,
|
||||
chunk_size=14,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=maries_document,
|
||||
)
|
||||
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
||||
|
||||
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
|
||||
johns_entities, maries_entities
|
||||
)
|
||||
|
||||
initial_nodes_by_vector_collection = {}
|
||||
johns_data = [
|
||||
johns_document,
|
||||
johns_chunk,
|
||||
johns_summary,
|
||||
*johns_entities,
|
||||
]
|
||||
maries_data = [
|
||||
maries_document,
|
||||
maries_chunk,
|
||||
maries_summary,
|
||||
*maries_entities,
|
||||
]
|
||||
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||
if collection_name not in initial_nodes_by_vector_collection:
|
||||
initial_nodes_by_vector_collection[collection_name] = []
|
||||
initial_nodes_by_vector_collection[collection_name].append(node)
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(johns_data + maries_data + overlapping_entities)
|
||||
await assert_nodes_vector_index_present(johns_data + maries_data + overlapping_entities)
|
||||
|
||||
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||
johns_relationships = extract_relationships(
|
||||
johns_chunk,
|
||||
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
||||
)
|
||||
maries_relationships = extract_relationships(
|
||||
maries_chunk,
|
||||
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
||||
)
|
||||
(overlapping_relationships, johns_relationships, maries_relationships) = (
|
||||
filter_overlapping_relationships(johns_relationships, maries_relationships)
|
||||
)
|
||||
|
||||
user = await get_default_user()
|
||||
johns_relationships = [
|
||||
(johns_chunk.id, johns_document.id, "is_part_of"),
|
||||
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||
*johns_relationships,
|
||||
]
|
||||
johns_edge_text_relationships = [
|
||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
maries_relationships = [
|
||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||
*maries_relationships,
|
||||
]
|
||||
maries_edge_text_relationships = [
|
||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in maries_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
|
||||
expected_relationships = johns_relationships + maries_relationships + overlapping_relationships
|
||||
|
||||
await assert_graph_edges_present(expected_relationships)
|
||||
|
||||
await assert_edges_vector_index_present(
|
||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||
)
|
||||
|
||||
# Delete John's data from cognee
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 9 and len(edges) == 10, "Nodes and edges are not deleted."
|
||||
assert not any(
|
||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry"
|
||||
for node in nodes
|
||||
if "name" in node[1]
|
||||
), "Nodes are not deleted."
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(maries_data + overlapping_entities)
|
||||
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
|
||||
|
||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||
await assert_graph_nodes_not_present(johns_data)
|
||||
await assert_nodes_vector_index_not_present(johns_data)
|
||||
|
||||
after_delete_nodes_by_vector_collection = {}
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||
if collection_name not in after_delete_nodes_by_vector_collection:
|
||||
after_delete_nodes_by_vector_collection[collection_name] = []
|
||||
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
||||
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await assert_graph_edges_not_present(johns_relationships)
|
||||
|
||||
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
||||
|
||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
strictly_johns_relationships = isolate_relationships(johns_relationships, maries_relationships)
|
||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||
await assert_edges_vector_index_not_present(
|
||||
strictly_johns_relationships + johns_edge_text_relationships
|
||||
)
|
||||
|
||||
# Delete Marie's data from cognee
|
||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||
|
||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||
assert len(final_nodes) == 0 and len(final_edges) == 0, "Nodes and edges are not deleted."
|
||||
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
|
||||
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
|
||||
|
||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||
query_node_ids = [node[0] for node in initial_nodes]
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_not_present(
|
||||
johns_relationships + maries_relationships + overlapping_relationships
|
||||
)
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
|
||||
query_edge_ids = [edge[0] for edge in initial_edges]
|
||||
|
||||
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
await assert_edges_vector_index_not_present(maries_relationships)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import json
|
||||
import pytest
|
||||
import pathlib
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
@ -17,17 +17,34 @@ from cognee.modules.data.models import Data
|
|||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
|
||||
from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
|
||||
from cognee.modules.engine.utils import generate_node_id
|
||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||
from cognee.modules.pipelines.models import DataItemStatus
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||
from cognee.tasks.storage import index_data_points, index_graph_edges
|
||||
|
||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||
from cognee.tests.utils.assert_edges_vector_index_not_present import (
|
||||
assert_edges_vector_index_not_present,
|
||||
)
|
||||
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
|
||||
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
|
||||
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
|
||||
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
|
||||
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
|
||||
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
|
||||
assert_nodes_vector_index_not_present,
|
||||
)
|
||||
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
|
||||
from cognee.tests.utils.extract_entities import extract_entities
|
||||
from cognee.tests.utils.extract_relationships import extract_relationships
|
||||
from cognee.tests.utils.extract_summary import extract_summary
|
||||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||
|
||||
|
||||
async def get_nodes_and_edges():
|
||||
def create_nodes_and_edges():
|
||||
document = TextDocument(
|
||||
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
||||
name="text_test.txt",
|
||||
|
|
@ -73,15 +90,8 @@ async def get_nodes_and_edges():
|
|||
name="amazon s3",
|
||||
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||
)
|
||||
document_chunk.contains = [
|
||||
graph_database,
|
||||
neptune_analytics_entity,
|
||||
neptune_database_entity,
|
||||
storage,
|
||||
storage_entity,
|
||||
]
|
||||
|
||||
data_points = [
|
||||
nodes_data = [
|
||||
document,
|
||||
document_chunk,
|
||||
graph_database,
|
||||
|
|
@ -91,39 +101,71 @@ async def get_nodes_and_edges():
|
|||
storage_entity,
|
||||
]
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
edges_data = [
|
||||
(
|
||||
document_chunk.id,
|
||||
storage_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
storage_entity.id,
|
||||
storage.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
neptune_database_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
neptune_database_entity.id,
|
||||
graph_database.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
document.id,
|
||||
"is_part_of",
|
||||
{
|
||||
"relationship_name": "is_part_of",
|
||||
},
|
||||
),
|
||||
(
|
||||
document_chunk.id,
|
||||
neptune_analytics_entity.id,
|
||||
"contains",
|
||||
{
|
||||
"relationship_name": "contains",
|
||||
},
|
||||
),
|
||||
(
|
||||
neptune_analytics_entity.id,
|
||||
graph_database.id,
|
||||
"is_a",
|
||||
{
|
||||
"relationship_name": "is_a",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
get_graph_from_model(
|
||||
data_point,
|
||||
added_nodes=added_nodes,
|
||||
added_edges=added_edges,
|
||||
visited_properties=visited_properties,
|
||||
)
|
||||
for data_point in data_points
|
||||
]
|
||||
)
|
||||
|
||||
for result_nodes, result_edges in results:
|
||||
nodes.extend(result_nodes)
|
||||
edges.extend(result_edges)
|
||||
|
||||
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
|
||||
|
||||
return nodes, edges
|
||||
return nodes_data, edges_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||
async def main(mock_create_structured_output: AsyncMock):
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "False"
|
||||
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_1"
|
||||
)
|
||||
|
|
@ -139,6 +181,9 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
user = await get_default_user()
|
||||
await set_database_global_context_variables("main_dataset", user.id)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||
|
|
@ -147,9 +192,8 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
assert not await vector_engine.has_collection("TextSummary_text")
|
||||
assert not await vector_engine.has_collection("TextDocument_text")
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
old_nodes, old_edges = await add_mocked_legacy_data(user)
|
||||
# Add legacy data to the system
|
||||
__, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(user)
|
||||
|
||||
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||
if text_input == "test": # LLM connection test
|
||||
|
|
@ -225,109 +269,188 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
|
||||
mock_create_structured_output.side_effect = mock_llm_output
|
||||
|
||||
add_john_result = await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
add_john_result = await cognee.add(johns_text)
|
||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
add_marie_result = await cognee.add(
|
||||
"Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
)
|
||||
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
add_marie_result = await cognee.add(maries_text)
|
||||
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
dataset_id = list(cognify_result.keys())[0]
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||
assert len(initial_nodes) == 22 and len(initial_edges) == 25, (
|
||||
"Number of nodes and edges is not correct."
|
||||
johns_document = TextDocument(
|
||||
id=johns_data_id,
|
||||
name="John's Work",
|
||||
raw_data_location="johns_data_location",
|
||||
external_metadata="",
|
||||
)
|
||||
johns_chunk = DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
|
||||
text=johns_text,
|
||||
chunk_size=14,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=johns_document,
|
||||
)
|
||||
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
|
||||
|
||||
maries_document = TextDocument(
|
||||
id=maries_data_id,
|
||||
name="Maries's Work",
|
||||
raw_data_location="maries_data_location",
|
||||
external_metadata="",
|
||||
)
|
||||
maries_chunk = DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
|
||||
text=maries_text,
|
||||
chunk_size=14,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=maries_document,
|
||||
)
|
||||
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
||||
|
||||
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
|
||||
johns_entities, maries_entities
|
||||
)
|
||||
|
||||
initial_nodes_by_vector_collection = {}
|
||||
johns_data = [
|
||||
johns_document,
|
||||
johns_chunk,
|
||||
johns_summary,
|
||||
*johns_entities,
|
||||
]
|
||||
maries_data = [
|
||||
maries_document,
|
||||
maries_chunk,
|
||||
maries_summary,
|
||||
*maries_entities,
|
||||
]
|
||||
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
node_metadata = node_data["metadata"]
|
||||
node_metadata = json.loads(node_metadata) if type(node_metadata) is str else node_metadata
|
||||
collection_name = node_data["type"] + "_" + node_metadata["index_fields"][0]
|
||||
if collection_name not in initial_nodes_by_vector_collection:
|
||||
initial_nodes_by_vector_collection[collection_name] = []
|
||||
initial_nodes_by_vector_collection[collection_name].append(node)
|
||||
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
|
||||
|
||||
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(expected_data_points)
|
||||
await assert_nodes_vector_index_present(expected_data_points)
|
||||
|
||||
johns_relationships = extract_relationships(
|
||||
johns_chunk,
|
||||
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
||||
)
|
||||
maries_relationships = extract_relationships(
|
||||
maries_chunk,
|
||||
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
||||
)
|
||||
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
|
||||
filter_overlapping_relationships(
|
||||
johns_relationships, maries_relationships, legacy_relationships
|
||||
)
|
||||
)
|
||||
|
||||
johns_relationships = [
|
||||
(johns_chunk.id, johns_document.id, "is_part_of"),
|
||||
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||
*johns_relationships,
|
||||
]
|
||||
johns_edge_text_relationships = [
|
||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
maries_relationships = [
|
||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||
*maries_relationships,
|
||||
]
|
||||
maries_edge_text_relationships = [
|
||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in maries_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
|
||||
expected_relationships = (
|
||||
johns_relationships
|
||||
+ maries_relationships
|
||||
+ overlapping_relationships
|
||||
+ legacy_relationships
|
||||
)
|
||||
|
||||
await assert_graph_edges_present(expected_relationships)
|
||||
|
||||
await assert_edges_vector_index_present(
|
||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||
)
|
||||
|
||||
# Delete John's data
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 16 and len(edges) == 16, "Nodes and edges are not deleted."
|
||||
assert not any(
|
||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry"
|
||||
for node in nodes
|
||||
if "name" in node[1]
|
||||
), "Nodes are not deleted."
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
|
||||
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
|
||||
|
||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||
await assert_graph_nodes_not_present(johns_data)
|
||||
await assert_nodes_vector_index_not_present(johns_data)
|
||||
|
||||
after_delete_nodes_by_vector_collection = {}
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
node_metadata = node_data["metadata"]
|
||||
node_metadata = json.loads(node_metadata) if type(node_metadata) is str else node_metadata
|
||||
collection_name = node_data["type"] + "_" + node_metadata["index_fields"][0]
|
||||
if collection_name not in after_delete_nodes_by_vector_collection:
|
||||
after_delete_nodes_by_vector_collection[collection_name] = []
|
||||
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(
|
||||
maries_relationships + overlapping_relationships + legacy_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(
|
||||
maries_relationships + maries_edge_text_relationships + legacy_relationships
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await assert_graph_edges_not_present(johns_relationships)
|
||||
|
||||
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
||||
|
||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
strictly_johns_relationships = isolate_relationships(
|
||||
johns_relationships, maries_relationships, legacy_relationships
|
||||
)
|
||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||
await assert_edges_vector_index_not_present(
|
||||
strictly_johns_relationships + johns_edge_text_relationships
|
||||
)
|
||||
|
||||
# Delete Marie's data
|
||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||
|
||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||
assert len(final_nodes) == 7 and len(final_edges) == 6, "Nodes and edges are not deleted."
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(legacy_data_points)
|
||||
await assert_nodes_vector_index_present(legacy_data_points)
|
||||
|
||||
old_nodes_by_vector_collection = {}
|
||||
for node in old_nodes:
|
||||
node_metadata = node.metadata
|
||||
collection_name = node.type + "_" + node_metadata["index_fields"][0]
|
||||
if collection_name not in old_nodes_by_vector_collection:
|
||||
old_nodes_by_vector_collection[collection_name] = []
|
||||
old_nodes_by_vector_collection[collection_name].append(node)
|
||||
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
|
||||
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
|
||||
|
||||
for collection_name, old_nodes in old_nodes_by_vector_collection.items():
|
||||
query_node_ids = [str(node.id) for node in old_nodes]
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(legacy_relationships)
|
||||
await assert_edges_vector_index_present(legacy_relationships)
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == len(old_nodes), "Vector items are not deleted."
|
||||
await assert_graph_edges_not_present(
|
||||
johns_relationships + maries_relationships + overlapping_relationships
|
||||
)
|
||||
|
||||
query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges]))
|
||||
|
||||
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
||||
assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted."
|
||||
strictly_maries_relationships = isolate_relationships(
|
||||
maries_relationships, legacy_relationships
|
||||
)
|
||||
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
||||
if strictly_maries_relationships:
|
||||
await assert_edges_vector_index_not_present(strictly_maries_relationships)
|
||||
|
||||
|
||||
async def add_mocked_legacy_data(user):
|
||||
async def create_mocked_legacy_data(user):
|
||||
graph_engine = await get_graph_engine()
|
||||
old_nodes, old_edges = await get_nodes_and_edges()
|
||||
old_document = old_nodes[0]
|
||||
legacy_nodes, legacy_edges = create_nodes_and_edges()
|
||||
legacy_document = legacy_nodes[0]
|
||||
|
||||
await graph_engine.add_nodes(old_nodes)
|
||||
await graph_engine.add_edges(old_edges)
|
||||
await graph_engine.add_nodes(legacy_nodes)
|
||||
await graph_engine.add_edges(legacy_edges)
|
||||
|
||||
await index_data_points(old_nodes)
|
||||
await index_graph_edges(old_edges)
|
||||
await index_data_points(legacy_nodes)
|
||||
await index_graph_edges(legacy_edges)
|
||||
|
||||
await record_data_in_legacy_ledger(old_nodes, old_edges, user)
|
||||
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
|
|
@ -335,12 +458,12 @@ async def add_mocked_legacy_data(user):
|
|||
|
||||
async with db_engine.get_async_session() as session:
|
||||
old_data = Data(
|
||||
id=old_document.id,
|
||||
name=old_document.name,
|
||||
id=legacy_document.id,
|
||||
name=legacy_document.name,
|
||||
extension="txt",
|
||||
raw_data_location=old_document.raw_data_location,
|
||||
external_metadata=old_document.external_metadata,
|
||||
mime_type=old_document.mime_type,
|
||||
raw_data_location=legacy_document.raw_data_location,
|
||||
external_metadata=legacy_document.external_metadata,
|
||||
mime_type=legacy_document.mime_type,
|
||||
owner_id=user.id,
|
||||
pipeline_status={
|
||||
"cognify_pipeline": {
|
||||
|
|
@ -355,7 +478,7 @@ async def add_mocked_legacy_data(user):
|
|||
|
||||
await session.commit()
|
||||
|
||||
return old_nodes, old_edges
|
||||
return legacy_document, legacy_nodes, legacy_edges
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import os
|
||||
import pytest
|
||||
import pathlib
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import cognee
|
||||
from cognee.api.v1.datasets import datasets
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
|
@ -16,15 +17,34 @@ from cognee.modules.data.models import Data
|
|||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
|
||||
from cognee.modules.engine.utils import generate_node_id
|
||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||
from cognee.modules.pipelines.models import DataItemStatus
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
|
||||
from cognee.tasks.storage import index_data_points, index_graph_edges
|
||||
from cognee.tests.utils.assert_edges_vector_index_not_present import (
|
||||
assert_edges_vector_index_not_present,
|
||||
)
|
||||
from cognee.tests.utils.assert_edges_vector_index_present import assert_edges_vector_index_present
|
||||
from cognee.tests.utils.assert_graph_edges_not_present import assert_graph_edges_not_present
|
||||
from cognee.tests.utils.assert_graph_edges_present import assert_graph_edges_present
|
||||
from cognee.tests.utils.assert_graph_nodes_not_present import assert_graph_nodes_not_present
|
||||
from cognee.tests.utils.assert_graph_nodes_present import assert_graph_nodes_present
|
||||
from cognee.tests.utils.assert_nodes_vector_index_not_present import (
|
||||
assert_nodes_vector_index_not_present,
|
||||
)
|
||||
from cognee.tests.utils.assert_nodes_vector_index_present import assert_nodes_vector_index_present
|
||||
from cognee.tests.utils.extract_entities import extract_entities
|
||||
from cognee.tests.utils.extract_relationships import extract_relationships
|
||||
from cognee.tests.utils.extract_summary import extract_summary
|
||||
from cognee.tests.utils.filter_overlapping_entities import filter_overlapping_entities
|
||||
from cognee.tests.utils.filter_overlapping_relationships import filter_overlapping_relationships
|
||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
from cognee.tests.utils.isolate_relationships import isolate_relationships
|
||||
|
||||
|
||||
def get_nodes_and_edges():
|
||||
def create_nodes_and_edges():
|
||||
document = TextDocument(
|
||||
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
||||
name="text_test.txt",
|
||||
|
|
@ -146,8 +166,6 @@ def get_nodes_and_edges():
|
|||
@pytest.mark.asyncio
|
||||
@patch.object(LLMGateway, "acreate_structured_output", new_callable=AsyncMock)
|
||||
async def main(mock_create_structured_output: AsyncMock):
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "False"
|
||||
|
||||
data_directory_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, ".data_storage/test_delete_default_graph_with_legacy_graph_2"
|
||||
)
|
||||
|
|
@ -163,6 +181,9 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
await cognee.prune.prune_system(metadata=True)
|
||||
await setup()
|
||||
|
||||
user = await get_default_user()
|
||||
await set_database_global_context_variables("main_dataset", user.id)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
assert not await vector_engine.has_collection("EdgeType_relationship_name")
|
||||
|
|
@ -171,9 +192,10 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
assert not await vector_engine.has_collection("TextSummary_text")
|
||||
assert not await vector_engine.has_collection("TextDocument_text")
|
||||
|
||||
user = await get_default_user()
|
||||
|
||||
old_document, old_nodes, old_edges = await add_mocked_legacy_data(user)
|
||||
# Add legacy data to the system
|
||||
legacy_document, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(
|
||||
user
|
||||
)
|
||||
|
||||
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||
if text_input == "test": # LLM connection test
|
||||
|
|
@ -249,100 +271,186 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
|
||||
mock_create_structured_output.side_effect = mock_llm_output
|
||||
|
||||
add_john_result = await cognee.add(
|
||||
"John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
)
|
||||
johns_text = "John works for Apple. He is also affiliated with a non-profit organization called 'Food for Hungry'"
|
||||
add_john_result = await cognee.add(johns_text)
|
||||
johns_data_id = add_john_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
await cognee.add("Marie works for Apple as well. She is a software engineer on MacOS project.")
|
||||
maries_text = "Marie works for Apple as well. She is a software engineer on MacOS project."
|
||||
add_marie_result = await cognee.add(maries_text)
|
||||
maries_data_id = add_marie_result.data_ingestion_info[0]["data_id"]
|
||||
|
||||
cognify_result: dict = await cognee.cognify()
|
||||
dataset_id = list(cognify_result.keys())[0]
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
initial_nodes, initial_edges = await graph_engine.get_graph_data()
|
||||
assert len(initial_nodes) == 22 and len(initial_edges) == 26, (
|
||||
"Number of nodes and edges is not correct."
|
||||
johns_document = TextDocument(
|
||||
id=johns_data_id,
|
||||
name="John's Work",
|
||||
raw_data_location="johns_data_location",
|
||||
external_metadata="",
|
||||
)
|
||||
johns_chunk = DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, f"{str(johns_data_id)}-0"),
|
||||
text=johns_text,
|
||||
chunk_size=14,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=johns_document,
|
||||
)
|
||||
johns_summary = extract_summary(johns_chunk, mock_llm_output("John", "", SummarizedContent)) # type: ignore
|
||||
|
||||
maries_document = TextDocument(
|
||||
id=maries_data_id,
|
||||
name="Maries's Work",
|
||||
raw_data_location="maries_data_location",
|
||||
external_metadata="",
|
||||
)
|
||||
maries_chunk = DocumentChunk(
|
||||
id=uuid5(NAMESPACE_OID, f"{str(maries_data_id)}-0"),
|
||||
text=maries_text,
|
||||
chunk_size=14,
|
||||
chunk_index=0,
|
||||
cut_type="sentence_end",
|
||||
is_part_of=maries_document,
|
||||
)
|
||||
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
||||
|
||||
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
|
||||
johns_entities, maries_entities
|
||||
)
|
||||
|
||||
initial_nodes_by_vector_collection = {}
|
||||
johns_data = [
|
||||
johns_document,
|
||||
johns_chunk,
|
||||
johns_summary,
|
||||
*johns_entities,
|
||||
]
|
||||
maries_data = [
|
||||
maries_document,
|
||||
maries_chunk,
|
||||
maries_summary,
|
||||
*maries_entities,
|
||||
]
|
||||
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||
if collection_name not in initial_nodes_by_vector_collection:
|
||||
initial_nodes_by_vector_collection[collection_name] = []
|
||||
initial_nodes_by_vector_collection[collection_name].append(node)
|
||||
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
|
||||
|
||||
initial_node_ids = set([node[0] for node in initial_nodes])
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(expected_data_points)
|
||||
await assert_nodes_vector_index_present(expected_data_points)
|
||||
|
||||
johns_relationships = extract_relationships(
|
||||
johns_chunk,
|
||||
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
||||
)
|
||||
maries_relationships = extract_relationships(
|
||||
maries_chunk,
|
||||
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
||||
)
|
||||
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
|
||||
filter_overlapping_relationships(
|
||||
johns_relationships, maries_relationships, legacy_relationships
|
||||
)
|
||||
)
|
||||
|
||||
johns_relationships = [
|
||||
(johns_chunk.id, johns_document.id, "is_part_of"),
|
||||
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||
*johns_relationships,
|
||||
]
|
||||
johns_edge_text_relationships = [
|
||||
(johns_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
maries_relationships = [
|
||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||
*maries_relationships,
|
||||
]
|
||||
maries_edge_text_relationships = [
|
||||
(maries_chunk.id, entity.id, get_contains_edge_text(entity.name, entity.description))
|
||||
for entity in maries_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
|
||||
expected_relationships = (
|
||||
johns_relationships
|
||||
+ maries_relationships
|
||||
+ overlapping_relationships
|
||||
+ legacy_relationships
|
||||
)
|
||||
|
||||
await assert_graph_edges_present(expected_relationships)
|
||||
|
||||
await assert_edges_vector_index_present(
|
||||
expected_relationships + johns_edge_text_relationships + maries_edge_text_relationships
|
||||
)
|
||||
|
||||
# Delete John's data
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 16 and len(edges) == 17, "Nodes and edges are not deleted."
|
||||
assert not any(
|
||||
node[1]["name"] == "john" or node[1]["name"] == "food for hungry" for node in nodes
|
||||
), "Nodes are not deleted."
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
|
||||
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
|
||||
|
||||
after_first_delete_node_ids = set([node[0] for node in nodes])
|
||||
await assert_graph_nodes_not_present(johns_data)
|
||||
await assert_nodes_vector_index_not_present(johns_data)
|
||||
|
||||
after_delete_nodes_by_vector_collection = {}
|
||||
for node in initial_nodes:
|
||||
node_data = node[1]
|
||||
collection_name = node_data["type"] + "_" + node_data["metadata"]["index_fields"][0]
|
||||
if collection_name not in after_delete_nodes_by_vector_collection:
|
||||
after_delete_nodes_by_vector_collection[collection_name] = []
|
||||
after_delete_nodes_by_vector_collection[collection_name].append(node)
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(
|
||||
maries_relationships + overlapping_relationships + legacy_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(
|
||||
maries_relationships + maries_edge_text_relationships + legacy_relationships
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
await assert_graph_edges_not_present(johns_relationships)
|
||||
|
||||
removed_node_ids = initial_node_ids - after_first_delete_node_ids
|
||||
strictly_johns_relationships = isolate_relationships(
|
||||
johns_relationships, maries_relationships, legacy_relationships
|
||||
)
|
||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||
await assert_edges_vector_index_not_present(
|
||||
strictly_johns_relationships + johns_edge_text_relationships
|
||||
)
|
||||
|
||||
for collection_name, initial_nodes in initial_nodes_by_vector_collection.items():
|
||||
query_node_ids = [node[0] for node in initial_nodes if node[0] in removed_node_ids]
|
||||
# Delete legacy data
|
||||
await datasets.delete_data(dataset_id, legacy_document.id, user) # type: ignore
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(maries_data + overlapping_entities)
|
||||
await assert_nodes_vector_index_present(maries_data + overlapping_entities)
|
||||
|
||||
# Delete old document
|
||||
await datasets.delete_data(dataset_id, old_document.id, user) # type: ignore
|
||||
await assert_graph_nodes_not_present(johns_data + legacy_data_points)
|
||||
await assert_nodes_vector_index_not_present(johns_data + legacy_data_points)
|
||||
|
||||
final_nodes, final_edges = await graph_engine.get_graph_data()
|
||||
assert len(final_nodes) == 9 and len(final_edges) == 10, "Nodes and edges are not deleted."
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(maries_relationships + overlapping_relationships)
|
||||
await assert_edges_vector_index_present(maries_relationships + maries_edge_text_relationships)
|
||||
|
||||
old_nodes_by_vector_collection = {}
|
||||
for node in old_nodes:
|
||||
collection_name = node.type + "_" + node.metadata["index_fields"][0]
|
||||
if collection_name not in old_nodes_by_vector_collection:
|
||||
old_nodes_by_vector_collection[collection_name] = []
|
||||
old_nodes_by_vector_collection[collection_name].append(node)
|
||||
await assert_graph_edges_not_present(johns_relationships + legacy_relationships)
|
||||
|
||||
for collection_name, old_nodes in old_nodes_by_vector_collection.items():
|
||||
query_node_ids = [str(node.id) for node in old_nodes]
|
||||
|
||||
if query_node_ids:
|
||||
vector_items = await vector_engine.retrieve(collection_name, query_node_ids)
|
||||
assert len(vector_items) == 0, "Vector items are not deleted."
|
||||
|
||||
query_edge_ids = list(set([str(generate_edge_id(edge[2])) for edge in old_edges]))
|
||||
|
||||
vector_items = await vector_engine.retrieve("EdgeType_relationship_name", query_edge_ids)
|
||||
assert len(vector_items) == len(query_edge_ids), "Vector items are not deleted."
|
||||
strictly_legacy_relationships = isolate_relationships(
|
||||
legacy_relationships, maries_relationships
|
||||
)
|
||||
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
||||
if strictly_legacy_relationships:
|
||||
await assert_edges_vector_index_not_present(strictly_legacy_relationships)
|
||||
|
||||
|
||||
async def add_mocked_legacy_data(user):
|
||||
async def create_mocked_legacy_data(user):
|
||||
graph_engine = await get_graph_engine()
|
||||
old_nodes, old_edges = get_nodes_and_edges()
|
||||
old_document = old_nodes[0]
|
||||
legacy_nodes, legacy_edges = create_nodes_and_edges()
|
||||
legacy_document = legacy_nodes[0]
|
||||
|
||||
await graph_engine.add_nodes(old_nodes)
|
||||
await graph_engine.add_edges(old_edges)
|
||||
await graph_engine.add_nodes(legacy_nodes)
|
||||
await graph_engine.add_edges(legacy_edges)
|
||||
|
||||
await index_data_points(old_nodes)
|
||||
await index_graph_edges(old_edges)
|
||||
await index_data_points(legacy_nodes)
|
||||
await index_graph_edges(legacy_edges)
|
||||
|
||||
await record_data_in_legacy_ledger(old_nodes, old_edges, user)
|
||||
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
|
|
@ -350,12 +458,12 @@ async def add_mocked_legacy_data(user):
|
|||
|
||||
async with db_engine.get_async_session() as session:
|
||||
old_data = Data(
|
||||
id=old_document.id,
|
||||
name=old_document.name,
|
||||
id=legacy_document.id,
|
||||
name=legacy_document.name,
|
||||
extension="txt",
|
||||
raw_data_location=old_document.raw_data_location,
|
||||
external_metadata=old_document.external_metadata,
|
||||
mime_type=old_document.mime_type,
|
||||
raw_data_location=legacy_document.raw_data_location,
|
||||
external_metadata=legacy_document.external_metadata,
|
||||
mime_type=legacy_document.mime_type,
|
||||
owner_id=user.id,
|
||||
pipeline_status={
|
||||
"cognify_pipeline": {
|
||||
|
|
@ -370,7 +478,7 @@ async def add_mocked_legacy_data(user):
|
|||
|
||||
await session.commit()
|
||||
|
||||
return old_document, old_nodes, old_edges
|
||||
return legacy_document, legacy_nodes, legacy_edges
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
23
cognee/tests/utils/assert_edges_vector_index_not_present.py
Normal file
23
cognee/tests/utils/assert_edges_vector_index_not_present.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Tuple
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.engine.utils import generate_edge_id
|
||||
|
||||
|
||||
async def assert_edges_vector_index_not_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
query_edge_ids = {
|
||||
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
|
||||
}
|
||||
|
||||
vector_items = await vector_engine.retrieve(
|
||||
"EdgeType_relationship_name", list(query_edge_ids.keys())
|
||||
)
|
||||
|
||||
vector_items_by_id = {str(vector_item.id): vector_item for vector_item in vector_items}
|
||||
|
||||
for relationship_id, relationship_name in query_edge_ids.items():
|
||||
assert relationship_id not in vector_items_by_id, (
|
||||
f"Relationship '{relationship_name}' still present in the vector store."
|
||||
)
|
||||
28
cognee/tests/utils/assert_edges_vector_index_present.py
Normal file
28
cognee/tests/utils/assert_edges_vector_index_present.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Tuple
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.engine.utils import generate_edge_id
|
||||
|
||||
|
||||
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
query_edge_ids = {
|
||||
str(generate_edge_id(relationship[2])): relationship[2] for relationship in relationships
|
||||
}
|
||||
|
||||
vector_items = await vector_engine.retrieve(
|
||||
"EdgeType_relationship_name", list(query_edge_ids.keys())
|
||||
)
|
||||
|
||||
vector_items_by_id = {str(vector_item.id): vector_item for vector_item in vector_items}
|
||||
|
||||
for relationship_id, relationship_name in query_edge_ids.items():
|
||||
assert relationship_id in vector_items_by_id, (
|
||||
f"Relationship '{relationship_name}' not found in vector store."
|
||||
)
|
||||
|
||||
vector_relationship = vector_items_by_id[relationship_id]
|
||||
assert vector_relationship.payload["text"] == relationship_name, (
|
||||
f"Vectorized edge '{vector_relationship.payload['text']}' does not match the relationship text '{relationship_name}'."
|
||||
)
|
||||
21
cognee/tests/utils/assert_graph_edges_not_present.py
Normal file
21
cognee/tests/utils/assert_graph_edges_not_present.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Tuple
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
|
||||
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
nodes_by_id = {str(node[0]): node[1] for node in nodes}
|
||||
|
||||
edge_ids = set([f"{str(edge[0])}_{edge[2]}_{str(edge[1])}" for edge in edges])
|
||||
|
||||
for relationship in relationships:
|
||||
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
||||
relationship_name = relationship[2]
|
||||
assert relationship_id not in edge_ids, (
|
||||
f"Edge '{relationship_name}' still present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
|
||||
)
|
||||
21
cognee/tests/utils/assert_graph_edges_present.py
Normal file
21
cognee/tests/utils/assert_graph_edges_present.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Tuple
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
|
||||
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
nodes_by_id = {str(node[0]): node[1] for node in nodes}
|
||||
|
||||
edge_ids = set([f"{str(edge[0])}_{edge[2]}_{str(edge[1])}" for edge in edges])
|
||||
|
||||
for relationship in relationships:
|
||||
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
||||
relationship_name = relationship[2]
|
||||
assert relationship_id in edge_ids, (
|
||||
f"Edge '{relationship_name}' not present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
|
||||
)
|
||||
16
cognee/tests/utils/assert_graph_nodes_not_present.py
Normal file
16
cognee/tests/utils/assert_graph_nodes_not_present.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from typing import List
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
|
||||
async def assert_graph_nodes_not_present(data_points: List[DataPoint]):
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, __ = await graph_engine.get_graph_data()
|
||||
|
||||
node_ids = set(node[0] for node in nodes)
|
||||
|
||||
for data_point in data_points:
|
||||
node_name = getattr(data_point, "label", getattr(data_point, "name", data_point.id))
|
||||
assert str(data_point.id) not in node_ids, (
|
||||
f"Node '{node_name}' is present in graph database."
|
||||
)
|
||||
14
cognee/tests/utils/assert_graph_nodes_present.py
Normal file
14
cognee/tests/utils/assert_graph_nodes_present.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from typing import List
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
|
||||
async def assert_graph_nodes_present(data_points: List[DataPoint]):
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, __ = await graph_engine.get_graph_data()
|
||||
|
||||
node_ids = set(node[0] for node in nodes)
|
||||
|
||||
for data_point in data_points:
|
||||
node_name = getattr(data_point, "label", getattr(data_point, "name", data_point.id))
|
||||
assert str(data_point.id) in node_ids, f"Node '{node_name}' not found in graph database."
|
||||
28
cognee/tests/utils/assert_nodes_vector_index_not_present.py
Normal file
28
cognee/tests/utils/assert_nodes_vector_index_not_present.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from typing import List
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
|
||||
async def assert_nodes_vector_index_not_present(data_points: List[DataPoint]):
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
data_points_by_vector_collection = {}
|
||||
|
||||
for data_point in data_points:
|
||||
node_metadata = data_point.metadata or {}
|
||||
collection_name = data_point.type + "_" + node_metadata["index_fields"][0]
|
||||
|
||||
if collection_name not in data_points_by_vector_collection:
|
||||
data_points_by_vector_collection[collection_name] = []
|
||||
|
||||
data_points_by_vector_collection[collection_name].append(data_point)
|
||||
|
||||
for collection_name, collection_data_points in data_points_by_vector_collection.items():
|
||||
query_data_point_ids = set([str(data_point.id) for data_point in collection_data_points])
|
||||
|
||||
vector_items = await vector_engine.retrieve(collection_name, list(query_data_point_ids))
|
||||
|
||||
for vector_item in vector_items:
|
||||
assert str(vector_item.id) not in query_data_point_ids, (
|
||||
f"{vector_item.payload['text']} is still present in the vector store."
|
||||
)
|
||||
28
cognee/tests/utils/assert_nodes_vector_index_present.py
Normal file
28
cognee/tests/utils/assert_nodes_vector_index_present.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from typing import List
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
|
||||
async def assert_nodes_vector_index_present(data_points: List[DataPoint]):
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
data_points_by_vector_collection = {}
|
||||
|
||||
for data_point in data_points:
|
||||
node_metadata = data_point.metadata or {}
|
||||
collection_name = data_point.type + "_" + node_metadata["index_fields"][0]
|
||||
|
||||
if collection_name not in data_points_by_vector_collection:
|
||||
data_points_by_vector_collection[collection_name] = []
|
||||
|
||||
data_points_by_vector_collection[collection_name].append(data_point)
|
||||
|
||||
for collection_name, collection_data_points in data_points_by_vector_collection.items():
|
||||
query_data_point_ids = set([str(data_point.id) for data_point in collection_data_points])
|
||||
|
||||
vector_items = await vector_engine.retrieve(collection_name, list(query_data_point_ids))
|
||||
|
||||
for vector_item in vector_items:
|
||||
assert str(vector_item.id) in query_data_point_ids, (
|
||||
f"{vector_item.payload['text']} is not present in the vector store."
|
||||
)
|
||||
45
cognee/tests/utils/extract_entities.py
Normal file
45
cognee/tests/utils/extract_entities.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.engine.utils import generate_node_id, generate_node_name
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
|
||||
def extract_entities(graph: KnowledgeGraph, cache: dict = {}):
|
||||
entities = []
|
||||
entity_types = []
|
||||
|
||||
for node in graph.nodes:
|
||||
node_id = generate_node_id(node.id)
|
||||
|
||||
if node_id not in cache:
|
||||
entity = Entity(
|
||||
id=node_id,
|
||||
name=generate_node_name(node.id),
|
||||
type=node.type,
|
||||
description=node.description,
|
||||
ontology_valid=False,
|
||||
)
|
||||
cache[node_id] = entity
|
||||
else:
|
||||
entity = cache[node_id]
|
||||
|
||||
entities.append(entity)
|
||||
|
||||
node_type = node.type
|
||||
type_node_id = generate_node_id(node_type)
|
||||
if type_node_id not in cache:
|
||||
type_node_name = generate_node_name(node_type)
|
||||
|
||||
type_node = EntityType(
|
||||
id=type_node_id,
|
||||
name=type_node_name,
|
||||
type=type_node_name,
|
||||
description=type_node_name,
|
||||
ontology_valid=False,
|
||||
)
|
||||
cache[type_node_id] = type_node
|
||||
else:
|
||||
type_node = cache[type_node_id]
|
||||
|
||||
entity_types.append(type_node)
|
||||
|
||||
return entities + entity_types
|
||||
55
cognee/tests/utils/extract_relationships.py
Normal file
55
cognee/tests/utils/extract_relationships.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.engine.utils import generate_edge_id, generate_node_id
|
||||
|
||||
|
||||
def extract_relationships(document_chunk: DocumentChunk, graph: KnowledgeGraph, cache: dict = {}):
|
||||
relationships = []
|
||||
|
||||
for edge in graph.edges:
|
||||
edge_id = f"{edge.source_node_id}_{edge.relationship_name}_{edge.target_node_id}"
|
||||
|
||||
if edge_id not in cache:
|
||||
relationship = (
|
||||
generate_edge_id(edge.source_node_id),
|
||||
generate_edge_id(edge.target_node_id),
|
||||
edge.relationship_name,
|
||||
)
|
||||
cache[edge_id] = relationship
|
||||
else:
|
||||
relationship = cache[edge_id]
|
||||
|
||||
relationships.append(relationship)
|
||||
|
||||
for node in graph.nodes:
|
||||
node_id = generate_node_id(node.id)
|
||||
type_node_id = generate_node_id(node.type)
|
||||
type_edge_id = f"{str(node_id)}_is_a_{str(type_node_id)}"
|
||||
|
||||
if type_edge_id not in cache:
|
||||
relationship = (
|
||||
node_id,
|
||||
type_node_id,
|
||||
"is_a",
|
||||
)
|
||||
cache[type_edge_id] = relationship
|
||||
else:
|
||||
relationship = cache[type_edge_id]
|
||||
|
||||
relationships.append(relationship)
|
||||
|
||||
chunk_edge_id = f"{str(document_chunk.id)}_contains_{str(node_id)}"
|
||||
|
||||
if chunk_edge_id not in cache:
|
||||
relationship = (
|
||||
document_chunk.id,
|
||||
node_id,
|
||||
"contains",
|
||||
)
|
||||
cache[chunk_edge_id] = relationship
|
||||
else:
|
||||
relationship = cache[chunk_edge_id]
|
||||
|
||||
relationships.append(relationship)
|
||||
|
||||
return relationships
|
||||
12
cognee/tests/utils/extract_summary.py
Normal file
12
cognee/tests/utils/extract_summary.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from uuid import uuid5
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.shared.data_models import SummarizedContent
|
||||
from cognee.tasks.summarization.models import TextSummary
|
||||
|
||||
|
||||
def extract_summary(document_chunk: DocumentChunk, summary=SummarizedContent) -> TextSummary:
|
||||
return TextSummary(
|
||||
id=uuid5(document_chunk.id, "TextSummary"),
|
||||
text=summary.summary,
|
||||
made_from=document_chunk,
|
||||
)
|
||||
26
cognee/tests/utils/filter_overlapping_entities.py
Normal file
26
cognee/tests/utils/filter_overlapping_entities.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
def filter_overlapping_entities(*entity_groups):
|
||||
entity_count = {}
|
||||
overlapping_entities = []
|
||||
|
||||
for group in entity_groups:
|
||||
for entity in group:
|
||||
if not entity.id in entity_count:
|
||||
entity_count[entity.id] = 1
|
||||
else:
|
||||
entity_count[entity.id] += 1
|
||||
|
||||
index = 0
|
||||
grouped_entities = []
|
||||
for group in entity_groups:
|
||||
grouped_entities.append([])
|
||||
|
||||
for entity in group:
|
||||
if entity_count[entity.id] == 1:
|
||||
grouped_entities[index].append(entity)
|
||||
else:
|
||||
if entity not in overlapping_entities:
|
||||
overlapping_entities.append(entity)
|
||||
|
||||
index += 1
|
||||
|
||||
return overlapping_entities, *grouped_entities
|
||||
33
cognee/tests/utils/filter_overlapping_relationships.py
Normal file
33
cognee/tests/utils/filter_overlapping_relationships.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from cognee.modules.engine.utils import generate_node_id
|
||||
|
||||
|
||||
def filter_overlapping_relationships(*relationship_groups):
|
||||
relationship_count = {}
|
||||
overlapping_relationships = []
|
||||
|
||||
for group in relationship_groups:
|
||||
for relationship in group:
|
||||
relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}"
|
||||
|
||||
if not relationship_id in relationship_count:
|
||||
relationship_count[relationship_id] = 1
|
||||
else:
|
||||
relationship_count[relationship_id] += 1
|
||||
|
||||
index = 0
|
||||
grouped_relationships = []
|
||||
for group in relationship_groups:
|
||||
grouped_relationships.append([])
|
||||
|
||||
for relationship in group:
|
||||
relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}"
|
||||
|
||||
if relationship_count[relationship_id] == 1:
|
||||
grouped_relationships[index].append(relationship)
|
||||
else:
|
||||
if relationship not in overlapping_relationships:
|
||||
overlapping_relationships.append(relationship)
|
||||
|
||||
index += 1
|
||||
|
||||
return overlapping_relationships, *grouped_relationships
|
||||
9
cognee/tests/utils/get_contains_edge_text.py
Normal file
9
cognee/tests/utils/get_contains_edge_text.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
def get_contains_edge_text(entity_name: str, entity_description: str) -> str:
|
||||
edge_text = "; ".join(
|
||||
[
|
||||
"relationship_name: contains",
|
||||
f"entity_name: {entity_name}",
|
||||
f"entity_description: {entity_description}",
|
||||
]
|
||||
)
|
||||
return edge_text
|
||||
20
cognee/tests/utils/isolate_relationships.py
Normal file
20
cognee/tests/utils/isolate_relationships.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
def isolate_relationships(source_relationships, *other_relationships):
|
||||
final_relationships = []
|
||||
cache = {relationship[2]: 1 for relationship in source_relationships}
|
||||
duplicated_relationships = {}
|
||||
|
||||
for relationships in other_relationships:
|
||||
for relationship in relationships:
|
||||
if relationship[2] not in cache:
|
||||
cache[relationship[2]] = 0
|
||||
|
||||
cache[relationship[2]] += 1
|
||||
|
||||
if cache[relationship[2]] == 2:
|
||||
duplicated_relationships[relationship[2]] = True
|
||||
|
||||
for relationship in source_relationships:
|
||||
if relationship[2] not in duplicated_relationships:
|
||||
final_relationships.append(relationship)
|
||||
|
||||
return final_relationships
|
||||
Loading…
Add table
Reference in a new issue