fix: lint errors and improve delete tests

This commit is contained in:
Boris Arzentar 2025-11-19 16:06:28 +01:00
parent 43459eeeac
commit e6baaf83e6
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
5 changed files with 97 additions and 145 deletions

View file

@ -54,7 +54,7 @@ async def legacy_delete(data: Data, mode: str = "soft"):
async def delete_document_subgraph(document_id: UUID, mode: str = "soft"):
"""Delete a document and all its related nodes in the correct order."""
graph_db = await get_graph_engine()
subgraph = await graph_db.get_document_subgraph(document_id)
subgraph = await graph_db.get_document_subgraph(str(document_id))
if not subgraph:
raise DocumentSubgraphNotFoundError(f"Document not found with id: {document_id}")

View file

@ -19,6 +19,8 @@ from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.engine.operations.setup import setup
from cognee.modules.engine.utils import generate_node_id
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model
from cognee.modules.pipelines.models import DataItemStatus
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
@ -43,7 +45,7 @@ from cognee.tests.utils.filter_overlapping_relationships import filter_overlappi
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
def create_nodes_and_edges():
def create_legacy_data_points():
document = TextDocument(
id=uuid5(NAMESPACE_OID, "text_test.txt"),
name="text_test.txt",
@ -72,11 +74,13 @@ def create_nodes_and_edges():
id=generate_node_id("neptune analytics"),
name="neptune analytics",
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
is_a=graph_database,
)
neptune_database_entity = Entity(
id=generate_node_id("amazon neptune database"),
name="amazon neptune database",
description="A popular managed graph database that complements Neptune Analytics.",
is_a=graph_database,
)
storage = EntityType(
@ -88,11 +92,10 @@ def create_nodes_and_edges():
id=generate_node_id("amazon s3"),
name="amazon s3",
description="A storage service provided by Amazon Web Services that allows storing graph data.",
is_a=storage,
)
nodes_data = [
document,
document_chunk,
entities = [
graph_database,
neptune_analytics_entity,
neptune_database_entity,
@ -100,66 +103,14 @@ def create_nodes_and_edges():
storage_entity,
]
edges_data = [
(
document_chunk.id,
storage_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
storage_entity.id,
storage.id,
"is_a",
{
"relationship_name": "is_a",
},
),
(
document_chunk.id,
neptune_database_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
neptune_database_entity.id,
graph_database.id,
"is_a",
{
"relationship_name": "is_a",
},
),
(
document_chunk.id,
document.id,
"is_part_of",
{
"relationship_name": "is_part_of",
},
),
(
document_chunk.id,
neptune_analytics_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
neptune_analytics_entity.id,
graph_database.id,
"is_a",
{
"relationship_name": "is_a",
},
),
document_chunk.contains = entities
data_points = [
document,
document_chunk,
]
return nodes_data, edges_data
return data_points
@pytest.mark.asyncio
@ -441,13 +392,38 @@ async def main(mock_create_structured_output: AsyncMock):
async def create_mocked_legacy_data(user):
graph_engine = await get_graph_engine()
legacy_nodes, legacy_edges = create_nodes_and_edges()
legacy_document = legacy_nodes[0]
legacy_data_points = create_legacy_data_points()
legacy_document = legacy_data_points[0]
await graph_engine.add_nodes(legacy_nodes)
await graph_engine.add_edges(legacy_edges)
nodes = []
edges = []
nodes_by_id = {node.id: node for node in legacy_nodes}
added_nodes = {}
added_edges = {}
visited_properties = {}
results = await asyncio.gather(
*[
get_graph_from_model(
data_point,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
for data_point in legacy_data_points
]
)
for result_nodes, result_edges in results:
nodes.extend(result_nodes)
edges.extend(result_edges)
graph_nodes, graph_edges = deduplicate_nodes_and_edges(nodes, edges)
await graph_engine.add_nodes(graph_nodes)
await graph_engine.add_edges(graph_edges)
nodes_by_id = {node.id: node for node in graph_nodes}
def format_relationship_name(relationship):
if relationship[2] == "contains":
@ -455,7 +431,7 @@ async def create_mocked_legacy_data(user):
return get_contains_edge_text(node.name, node.description)
return relationship[2]
await index_data_points(legacy_nodes)
await index_data_points(graph_nodes)
await index_graph_edges(
[
(
@ -467,11 +443,11 @@ async def create_mocked_legacy_data(user):
"relationship_name": format_relationship_name(edge),
},
)
for edge in legacy_edges
for edge in graph_edges
] # type: ignore
)
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
await record_data_in_legacy_ledger(graph_nodes, graph_edges, user)
db_engine = get_relational_engine()
@ -499,7 +475,7 @@ async def create_mocked_legacy_data(user):
await session.commit()
return legacy_document, legacy_nodes, legacy_edges
return legacy_document, graph_nodes, graph_edges
if __name__ == "__main__":

View file

@ -20,6 +20,8 @@ from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.engine.operations.setup import setup
from cognee.modules.engine.utils import generate_node_id
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model
from cognee.modules.pipelines.models import DataItemStatus
from cognee.modules.users.methods import get_default_user
from cognee.shared.data_models import KnowledgeGraph, Node, Edge, SummarizedContent
@ -44,7 +46,7 @@ from cognee.tests.utils.filter_overlapping_relationships import filter_overlappi
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
def create_nodes_and_edges():
def create_legacy_data_points():
document = TextDocument(
id=uuid5(NAMESPACE_OID, "text_test.txt"),
name="text_test.txt",
@ -73,11 +75,13 @@ def create_nodes_and_edges():
id=generate_node_id("neptune analytics"),
name="neptune analytics",
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
is_a=graph_database,
)
neptune_database_entity = Entity(
id=generate_node_id("amazon neptune database"),
name="amazon neptune database",
description="A popular managed graph database that complements Neptune Analytics.",
is_a=graph_database,
)
storage = EntityType(
@ -89,11 +93,10 @@ def create_nodes_and_edges():
id=generate_node_id("amazon s3"),
name="amazon s3",
description="A storage service provided by Amazon Web Services that allows storing graph data.",
is_a=storage,
)
nodes_data = [
document,
document_chunk,
entities = [
graph_database,
neptune_analytics_entity,
neptune_database_entity,
@ -101,66 +104,14 @@ def create_nodes_and_edges():
storage_entity,
]
edges_data = [
(
document_chunk.id,
storage_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
storage_entity.id,
storage.id,
"is_a",
{
"relationship_name": "is_a",
},
),
(
document_chunk.id,
neptune_database_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
neptune_database_entity.id,
graph_database.id,
"is_a",
{
"relationship_name": "is_a",
},
),
(
document_chunk.id,
document.id,
"is_part_of",
{
"relationship_name": "is_part_of",
},
),
(
document_chunk.id,
neptune_analytics_entity.id,
"contains",
{
"relationship_name": "contains",
},
),
(
neptune_analytics_entity.id,
graph_database.id,
"is_a",
{
"relationship_name": "is_a",
},
),
document_chunk.contains = entities
data_points = [
document,
document_chunk,
]
return nodes_data, edges_data
return data_points
@pytest.mark.asyncio
@ -436,13 +387,38 @@ async def main(mock_create_structured_output: AsyncMock):
async def create_mocked_legacy_data(user):
graph_engine = await get_graph_engine()
legacy_nodes, legacy_edges = create_nodes_and_edges()
legacy_document = legacy_nodes[0]
legacy_data_points = create_legacy_data_points()
legacy_document = legacy_data_points[0]
await graph_engine.add_nodes(legacy_nodes)
await graph_engine.add_edges(legacy_edges)
nodes = []
edges = []
nodes_by_id = {node.id: node for node in legacy_nodes}
added_nodes = {}
added_edges = {}
visited_properties = {}
results = await asyncio.gather(
*[
get_graph_from_model(
data_point,
added_nodes=added_nodes,
added_edges=added_edges,
visited_properties=visited_properties,
)
for data_point in legacy_data_points
]
)
for result_nodes, result_edges in results:
nodes.extend(result_nodes)
edges.extend(result_edges)
graph_nodes, graph_edges = deduplicate_nodes_and_edges(nodes, edges)
await graph_engine.add_nodes(graph_nodes)
await graph_engine.add_edges(graph_edges)
nodes_by_id = {node.id: node for node in graph_nodes}
def format_relationship_name(relationship):
if relationship[2] == "contains":
@ -450,7 +426,7 @@ async def create_mocked_legacy_data(user):
return get_contains_edge_text(node.name, node.description)
return relationship[2]
await index_data_points(legacy_nodes)
await index_data_points(graph_nodes)
await index_graph_edges(
[
(
@ -462,11 +438,11 @@ async def create_mocked_legacy_data(user):
"relationship_name": format_relationship_name(edge),
},
)
for edge in legacy_edges
for edge in graph_edges
] # type: ignore
)
await record_data_in_legacy_ledger(legacy_nodes, legacy_edges, user)
await record_data_in_legacy_ledger(graph_nodes, graph_edges, user)
db_engine = get_relational_engine()
@ -494,7 +470,7 @@ async def create_mocked_legacy_data(user):
await session.commit()
return legacy_document, legacy_nodes, legacy_edges
return legacy_document, graph_nodes, graph_edges
if __name__ == "__main__":

View file

@ -4,7 +4,7 @@ def filter_overlapping_entities(*entity_groups):
for group in entity_groups:
for entity in group:
if not entity.id in entity_count:
if entity.id not in entity_count:
entity_count[entity.id] = 1
else:
entity_count[entity.id] += 1

View file

@ -9,7 +9,7 @@ def filter_overlapping_relationships(*relationship_groups):
for relationship in group:
relationship_id = f"{relationship[0]}_{relationship[2]}_{relationship[1]}"
if not relationship_id in relationship_count:
if relationship_id not in relationship_count:
relationship_count[relationship_id] = 1
else:
relationship_count[relationship_id] += 1