fix: legacy delete backwards compatibility for neo4j

This commit is contained in:
Boris Arzentar 2025-11-23 22:33:10 +01:00
parent cb380e51e9
commit bc4eb9f6ce
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
10 changed files with 250 additions and 206 deletions

View file

@ -1,21 +1,20 @@
from uuid import UUID
from typing import List
from sqlalchemy import and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Edge, Node
from cognee.modules.graph.models import Edge
from .GraphRelationshipLedger import GraphRelationshipLedger
@with_async_session
async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session: AsyncSession):
async def has_edges_in_legacy_ledger(edges: List[Edge], session: AsyncSession):
if len(edges) == 0:
return []
query = select(GraphRelationshipLedger).where(
and_(
GraphRelationshipLedger.user_id == user_id,
GraphRelationshipLedger.node_label.is_(None),
or_(
*[
GraphRelationshipLedger.creator_function.ilike(f"%{edge.relationship_name}")
@ -30,20 +29,3 @@ async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session:
legacy_edge_names = set([edge.creator_function.split(".")[1] for edge in legacy_edges])
return [edge.relationship_name in legacy_edge_names for edge in edges]
@with_async_session
async def get_node_ids(edges: List[Edge], session: AsyncSession):
node_slugs = []
for edge in edges:
node_slugs.append(edge.source_node_id)
node_slugs.append(edge.destination_node_id)
query = select(Node).where(Node.slug.in_(node_slugs))
nodes = (await session.scalars(query)).all()
node_ids = {node.slug: node.id for node in nodes}
return node_ids

View file

@ -1,36 +1,65 @@
from typing import List
from uuid import UUID
from sqlalchemy import and_, or_, select
from typing import List, Tuple
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import with_async_session
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
from cognee.modules.graph.models import Node
from .GraphRelationshipLedger import GraphRelationshipLedger
@with_async_session
async def has_nodes_in_legacy_ledger(nodes: List[Node], user_id: UUID, session: AsyncSession):
async def has_nodes_in_legacy_ledger(nodes: List[Node], session: AsyncSession):
node_ids = [node.slug for node in nodes]
query = select(
GraphRelationshipLedger.source_node_id,
GraphRelationshipLedger.destination_node_id,
).where(
and_(
GraphRelationshipLedger.user_id == user_id,
or_(
GraphRelationshipLedger.source_node_id.in_(node_ids),
GraphRelationshipLedger.destination_node_id.in_(node_ids),
),
query = (
select(
GraphRelationshipLedger.node_label,
GraphRelationshipLedger.source_node_id,
)
.where(
and_(
GraphRelationshipLedger.node_label.is_not(None),
GraphRelationshipLedger.deleted_at.is_(None),
GraphRelationshipLedger.source_node_id.in_(node_ids),
GraphRelationshipLedger.source_node_id
== GraphRelationshipLedger.destination_node_id,
)
)
.distinct()
)
legacy_nodes = await session.execute(query)
entries = legacy_nodes.all()
legacy_nodes = (await session.execute(query)).all()
found_ids = set()
for entry in entries:
found_ids.add(entry.source_node_id)
found_ids.add(entry.destination_node_id)
if len(legacy_nodes) == 0:
return [False for __ in nodes]
return [node_id in found_ids for node_id in node_ids]
if is_backend_access_control_enabled():
confirmed_nodes = await confirm_nodes_in_graph(legacy_nodes)
return [node_id in confirmed_nodes for node_id in node_ids]
else:
found_ids = set()
for __, node_id in legacy_nodes:
found_ids.add(node_id)
return [node_id in found_ids for node_id in node_ids]
async def confirm_nodes_in_graph(
legacy_nodes: List[Tuple[str, UUID]],
):
graph_engine = await get_graph_engine()
graph_nodes = await graph_engine.get_nodes([str(node[1]) for node in legacy_nodes])
graph_nodes_by_id = {node["id"]: node for node in graph_nodes}
confirmed_nodes = set()
for __, node_id in legacy_nodes:
if str(node_id) in graph_nodes_by_id:
confirmed_nodes.add(node_id)
return confirmed_nodes

View file

@ -4,7 +4,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from cognee.infrastructure.databases.relational import with_async_session
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.users.models.User import User
from .GraphRelationshipLedger import GraphRelationshipLedger
@ -12,23 +11,21 @@ from .GraphRelationshipLedger import GraphRelationshipLedger
async def record_data_in_legacy_ledger(
nodes: List[DataPoint],
edges: List[Tuple[UUID, UUID, str, Dict]],
user: User,
session: AsyncSession,
) -> None:
relationships = [
GraphRelationshipLedger(
source_node_id=node.id,
destination_node_id=node.id,
creator_function="add_nodes",
user_id=user.id,
node_label=getattr(node, "name", getattr(node, "text", node.id)),
creator_function="add_data_points.nodes",
)
for node in nodes
] + [
GraphRelationshipLedger(
source_node_id=edge[0],
destination_node_id=edge[1],
creator_function=f"add_edges.{edge[2]}",
user_id=user.id,
creator_function=f"add_data_points.{edge[2]}",
)
for edge in edges
]

View file

@ -25,10 +25,10 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id:
if len(affected_nodes) == 0:
return
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
affected_relationships = await get_data_related_edges(dataset_id, data_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
non_legacy_nodes = [
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
@ -71,10 +71,10 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id:
if len(affected_nodes) == 0:
return
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
affected_relationships = await get_global_data_related_edges(data_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
non_legacy_nodes = [
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]

View file

@ -25,10 +25,10 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> Non
if len(affected_nodes) == 0:
return
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
affected_relationships = await get_dataset_related_edges(dataset_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
non_legacy_nodes = [
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
@ -71,10 +71,10 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> Non
if len(affected_nodes) == 0:
return
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
affected_relationships = await get_global_dataset_related_edges(dataset_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
non_legacy_nodes = [
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]

View file

@ -18,6 +18,7 @@ from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.data.processing.document_types import TextDocument
from cognee.modules.engine.operations.setup import setup
from cognee.modules.engine.utils import generate_node_id
from cognee.modules.engine.utils.generate_node_name import generate_node_name
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model
@ -45,6 +46,36 @@ from cognee.tests.utils.filter_overlapping_relationships import filter_overlappi
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
async def assert_relationships_vector_index_present(formatted_relationships, legacy_relationships):
"""Helper to check both formatted (new) and unformatted (legacy) relationships."""
if formatted_relationships:
await assert_edges_vector_index_present(formatted_relationships, convert_to_new_format=True)
if legacy_relationships:
await assert_edges_vector_index_present(legacy_relationships, convert_to_new_format=False)
def build_relationships(chunk, document, summary, graph):
"""Build all relationships for a chunk including structural and extracted ones."""
return [
(chunk.id, document.id, "is_part_of", {"relationship_name": "is_part_of"}),
(summary.id, chunk.id, "made_from", {"relationship_name": "made_from"}),
] + extract_relationships(chunk, graph)
def build_contains_relationships(chunk_id, entities, entity_names):
"""Build contains relationships for specific entities."""
return [
(
chunk_id,
entity.id,
get_contains_edge_text(entity.name, entity.description),
{"relationship_name": get_contains_edge_text(entity.name, entity.description)},
)
for entity in entities
if isinstance(entity, Entity) and entity.name in entity_names
]
def create_legacy_data_points():
document = TextDocument(
id=uuid5(NAMESPACE_OID, "text_test.txt"),
@ -56,51 +87,43 @@ def create_legacy_data_points():
document_chunk = DocumentChunk(
id=uuid5(
NAMESPACE_OID,
"Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
"Apple announced their new vector embeddings visualization tool called Embedding Atlas.",
),
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
text="Apple announced their new vector embeddings visualization tool called Embedding Atlas.",
chunk_size=187,
chunk_index=0,
cut_type="paragraph_end",
is_part_of=document,
)
graph_database = EntityType(
id=uuid5(NAMESPACE_OID, "graph_database"),
name="graph database",
description="graph database",
company = EntityType(
id=generate_node_id("Company"),
name=generate_node_name("Company"),
description=generate_node_name("Company"),
)
neptune_analytics_entity = Entity(
id=generate_node_id("neptune analytics"),
name="neptune analytics",
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
is_a=graph_database,
apple = Entity(
id=generate_node_id("Apple"),
name=generate_node_name("Apple"),
description="Apple is a company",
is_a=company,
)
neptune_database_entity = Entity(
id=generate_node_id("amazon neptune database"),
name="amazon neptune database",
description="A popular managed graph database that complements Neptune Analytics.",
is_a=graph_database,
product = EntityType(
id=generate_node_id("Product"),
name=generate_node_name("Product"),
description=generate_node_name("Product"),
)
storage = EntityType(
id=generate_node_id("storage"),
name="storage",
description="storage",
)
storage_entity = Entity(
id=generate_node_id("amazon s3"),
name="amazon s3",
description="A storage service provided by Amazon Web Services that allows storing graph data.",
is_a=storage,
embedding_atlas = Entity(
id=generate_node_id("Embedding Atlas"),
name=generate_node_name("Embedding Atlas"),
description="Embedding Atlas",
is_a=product,
)
entities = [
graph_database,
neptune_analytics_entity,
neptune_database_entity,
storage,
storage_entity,
company,
product,
apple,
embedding_atlas,
]
document_chunk.contains = entities
@ -143,7 +166,7 @@ async def main(mock_create_structured_output: AsyncMock):
assert not await vector_engine.has_collection("TextDocument_text")
# Add legacy data to the system
__, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(user)
__, all_legacy_data_points, all_legacy_relationships = await create_mocked_legacy_data(user)
def mock_llm_output(text_input: str, system_prompt: str, response_model):
if text_input == "test": # LLM connection test
@ -262,132 +285,153 @@ async def main(mock_create_structured_output: AsyncMock):
)
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
johns_entities, maries_entities
)
all_johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
all_maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
johns_data = [
expected_johns_data = [
johns_document,
johns_chunk,
johns_summary,
*johns_entities,
*all_johns_entities,
]
maries_data = [
expected_maries_data = [
maries_document,
maries_chunk,
maries_summary,
*maries_entities,
*all_maries_entities,
]
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
expected_data_points = expected_johns_data + expected_maries_data + all_legacy_data_points
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(expected_data_points)
await assert_nodes_vector_index_present(expected_data_points)
johns_relationships = extract_relationships(
all_johns_relationships = build_relationships(
johns_chunk,
johns_document,
johns_summary,
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
)
maries_relationships = extract_relationships(
all_maries_relationships = build_relationships(
maries_chunk,
maries_document,
maries_summary,
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
)
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
filter_overlapping_relationships(
johns_relationships, maries_relationships, legacy_relationships
)
)
johns_relationships = [
(johns_chunk.id, johns_document.id, "is_part_of"),
(johns_summary.id, johns_chunk.id, "made_from"),
*johns_relationships,
]
maries_relationships = [
(maries_chunk.id, maries_document.id, "is_part_of"),
(maries_summary.id, maries_chunk.id, "made_from"),
*maries_relationships,
]
expected_relationships = (
johns_relationships
+ maries_relationships
+ overlapping_relationships
+ legacy_relationships
all_johns_relationships + all_maries_relationships + all_legacy_relationships
)
await assert_graph_edges_present(expected_relationships)
await assert_edges_vector_index_present(expected_relationships)
await assert_relationships_vector_index_present(
all_johns_relationships + all_maries_relationships, all_legacy_relationships
)
# Delete John's data
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
expected_data_points = [
maries_document,
maries_chunk,
maries_summary,
*all_maries_entities,
*all_legacy_data_points,
]
await assert_graph_nodes_not_present(johns_data)
await assert_nodes_vector_index_not_present(johns_data)
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(expected_data_points)
await assert_nodes_vector_index_present(expected_data_points)
(__, strictly_johns_entities, __, __) = filter_overlapping_entities(
all_johns_entities, all_maries_entities, all_legacy_data_points
)
not_expected_data_points = [
johns_document,
johns_chunk,
johns_summary,
*strictly_johns_entities,
]
await assert_graph_nodes_not_present(not_expected_data_points)
await assert_nodes_vector_index_not_present(not_expected_data_points)
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(
maries_relationships + overlapping_relationships + legacy_relationships
await assert_graph_edges_present(all_maries_relationships + all_legacy_relationships)
await assert_relationships_vector_index_present(
all_maries_relationships, all_legacy_relationships
)
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
await assert_graph_edges_not_present(johns_relationships)
(__, strictly_johns_relationships, __, __) = filter_overlapping_relationships(
all_johns_relationships,
all_maries_relationships,
all_legacy_relationships,
)
johns_contains_relationships = [
(
johns_chunk.id,
entity.id,
get_contains_edge_text(entity.name, entity.description),
{
"relationship_name": get_contains_edge_text(entity.name, entity.description),
},
)
for entity in johns_entities
if isinstance(entity, Entity)
]
# We check only by relationship name and we need edges that are created by John's data and no other.
await assert_edges_vector_index_not_present(johns_contains_relationships)
await assert_graph_edges_not_present(strictly_johns_relationships)
# Check that John's unique contains relationships are not in vector index
not_expected_relationships = build_contains_relationships(
johns_chunk.id,
all_johns_entities,
[generate_node_name("John"), generate_node_name("Food for Hungry")],
)
await assert_edges_vector_index_not_present(not_expected_relationships)
# Delete Marie's data
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
# Assert data points presence in the graph, vector collections and nodes table
await assert_graph_nodes_present(legacy_data_points)
await assert_nodes_vector_index_present(legacy_data_points)
await assert_graph_nodes_present(all_legacy_data_points)
await assert_nodes_vector_index_present(all_legacy_data_points)
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(legacy_relationships)
await assert_edges_vector_index_present(legacy_relationships)
await assert_graph_edges_not_present(
johns_relationships + maries_relationships + overlapping_relationships
(__, strictly_johns_entities, strictly_maries_entities, __) = filter_overlapping_entities(
all_johns_entities, all_maries_entities, all_legacy_data_points
)
maries_contains_relationships = [
(
maries_chunk.id,
entity.id,
get_contains_edge_text(entity.name, entity.description),
{
"relationship_name": get_contains_edge_text(entity.name, entity.description),
},
)
for entity in maries_entities
if isinstance(entity, Entity)
not_expected_data_points = [
johns_document,
johns_chunk,
johns_summary,
*strictly_johns_entities,
maries_document,
maries_chunk,
maries_summary,
*strictly_maries_entities,
]
# We check only by relationship name and we need edges that are created by legacy data and no other.
await assert_edges_vector_index_not_present(maries_contains_relationships)
await assert_graph_nodes_not_present(not_expected_data_points)
await assert_nodes_vector_index_not_present(not_expected_data_points)
# Assert relationships presence in the graph, vector collections and nodes table
await assert_graph_edges_present(all_legacy_relationships)
await assert_relationships_vector_index_present([], all_legacy_relationships)
(__, strictly_johns_relationships, strictly_maries_relationships, __) = (
filter_overlapping_relationships(
all_maries_relationships,
all_johns_relationships,
all_legacy_relationships,
)
)
await assert_graph_edges_not_present(
strictly_johns_relationships + strictly_maries_relationships
)
# Check that John's and Marie's unique contains relationships are not in vector index
not_expected_relationships = build_contains_relationships(
johns_chunk.id,
all_johns_entities,
[generate_node_name("John"), generate_node_name("Food for Hungry")],
) + build_contains_relationships(
maries_chunk.id,
all_maries_entities,
[generate_node_name("Marie"), generate_node_name("MacOS")],
)
await assert_edges_vector_index_not_present(not_expected_relationships)
async def create_mocked_legacy_data(user):
@ -423,31 +467,11 @@ async def create_mocked_legacy_data(user):
await graph_engine.add_nodes(graph_nodes)
await graph_engine.add_edges(graph_edges)
nodes_by_id = {node.id: node for node in graph_nodes}
def format_relationship_name(relationship):
if relationship[2] == "contains":
node = nodes_by_id[relationship[1]]
return get_contains_edge_text(node.name, node.description)
return relationship[2]
await index_data_points(graph_nodes)
await index_graph_edges(
[
(
edge[0],
edge[1],
format_relationship_name(edge),
{
**(edge[3] or {}),
"relationship_name": format_relationship_name(edge),
},
)
for edge in graph_edges
] # type: ignore
)
# Legacy relationships should NOT be formatted - index them as-is
await index_graph_edges(graph_edges)
await record_data_in_legacy_ledger(graph_nodes, graph_edges, user)
await record_data_in_legacy_ledger(graph_nodes, graph_edges)
db_engine = get_relational_engine()

View file

@ -435,7 +435,7 @@ async def create_mocked_legacy_data(user):
] # type: ignore
)
await record_data_in_legacy_ledger(graph_nodes, graph_edges, user)
await record_data_in_legacy_ledger(graph_nodes, graph_edges)
db_engine = get_relational_engine()

View file

@ -21,7 +21,9 @@ def format_relationship(relationship: Tuple[UUID, UUID, str, Dict], node: Dict):
return {str(generate_edge_id(relationship[2])): relationship[2]}
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
async def assert_edges_vector_index_present(
relationships: List[Tuple[UUID, UUID, str, Dict]], convert_to_new_format: bool = True
):
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
@ -33,7 +35,11 @@ async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID
for relationship in relationships:
query_edge_ids = {
**query_edge_ids,
**format_relationship(relationship, nodes_by_id[str(relationship[1])]),
**(
format_relationship(relationship, nodes_by_id[str(relationship[1])])
if convert_to_new_format
else {str(generate_edge_id(relationship[2])): relationship[2]}
),
}
vector_items = await vector_engine.retrieve(

View file

@ -1,11 +1,10 @@
from uuid import UUID
from typing import List, Tuple
from typing import Dict, List, Tuple
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine.models.DataPoint import DataPoint
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str]]):
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
@ -15,7 +14,11 @@ async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, s
for relationship in relationships:
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
relationship_name = relationship[2]
assert relationship_id not in edge_ids, (
f"Edge '{relationship_name}' still present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
)
if relationship_id in edge_ids:
relationship_name = relationship[2]
source_node = nodes_by_id[str(relationship[0])]
destination_node = nodes_by_id[str(relationship[1])]
assert False, (
f"Edge '{relationship_name}' still present between '{source_node['name'] if 'node' in source_node else source_node['id']}' and '{destination_node['name'] if 'node' in destination_node else destination_node['id']}' in graph database."
)

View file

@ -1,11 +1,10 @@
from uuid import UUID
from typing import List, Tuple
from typing import Dict, List, Tuple
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine.models.DataPoint import DataPoint
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]):
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
@ -16,6 +15,10 @@ async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]
for relationship in relationships:
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
relationship_name = relationship[2]
source_node = nodes_by_id.get(str(relationship[0]), {})
target_node = nodes_by_id.get(str(relationship[1]), {})
source_name = source_node.get("name") or source_node.get("text") or str(relationship[0])
target_name = target_node.get("name") or target_node.get("text") or str(relationship[1])
assert relationship_id in edge_ids, (
f"Edge '{relationship_name}' not present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
f"Edge '{relationship_name}' not present between '{source_name}' and '{target_name}' in graph database."
)