fix: legacy delete backwards compatibility for neo4j
This commit is contained in:
parent
cb380e51e9
commit
bc4eb9f6ce
10 changed files with 250 additions and 206 deletions
|
|
@ -1,21 +1,20 @@
|
|||
from uuid import UUID
|
||||
from typing import List
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Edge, Node
|
||||
from cognee.modules.graph.models import Edge
|
||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session: AsyncSession):
|
||||
async def has_edges_in_legacy_ledger(edges: List[Edge], session: AsyncSession):
|
||||
if len(edges) == 0:
|
||||
return []
|
||||
|
||||
query = select(GraphRelationshipLedger).where(
|
||||
and_(
|
||||
GraphRelationshipLedger.user_id == user_id,
|
||||
GraphRelationshipLedger.node_label.is_(None),
|
||||
or_(
|
||||
*[
|
||||
GraphRelationshipLedger.creator_function.ilike(f"%{edge.relationship_name}")
|
||||
|
|
@ -30,20 +29,3 @@ async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session:
|
|||
legacy_edge_names = set([edge.creator_function.split(".")[1] for edge in legacy_edges])
|
||||
|
||||
return [edge.relationship_name in legacy_edge_names for edge in edges]
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def get_node_ids(edges: List[Edge], session: AsyncSession):
|
||||
node_slugs = []
|
||||
|
||||
for edge in edges:
|
||||
node_slugs.append(edge.source_node_id)
|
||||
node_slugs.append(edge.destination_node_id)
|
||||
|
||||
query = select(Node).where(Node.slug.in_(node_slugs))
|
||||
|
||||
nodes = (await session.scalars(query)).all()
|
||||
|
||||
node_ids = {node.slug: node.id for node in nodes}
|
||||
|
||||
return node_ids
|
||||
|
|
|
|||
|
|
@ -1,36 +1,65 @@
|
|||
from typing import List
|
||||
from uuid import UUID
|
||||
from sqlalchemy import and_, or_, select
|
||||
from typing import List, Tuple
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
|
||||
is_backend_access_control_enabled,
|
||||
)
|
||||
from cognee.modules.graph.models import Node
|
||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||
|
||||
|
||||
@with_async_session
|
||||
async def has_nodes_in_legacy_ledger(nodes: List[Node], user_id: UUID, session: AsyncSession):
|
||||
async def has_nodes_in_legacy_ledger(nodes: List[Node], session: AsyncSession):
|
||||
node_ids = [node.slug for node in nodes]
|
||||
|
||||
query = select(
|
||||
GraphRelationshipLedger.source_node_id,
|
||||
GraphRelationshipLedger.destination_node_id,
|
||||
).where(
|
||||
and_(
|
||||
GraphRelationshipLedger.user_id == user_id,
|
||||
or_(
|
||||
GraphRelationshipLedger.source_node_id.in_(node_ids),
|
||||
GraphRelationshipLedger.destination_node_id.in_(node_ids),
|
||||
),
|
||||
query = (
|
||||
select(
|
||||
GraphRelationshipLedger.node_label,
|
||||
GraphRelationshipLedger.source_node_id,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
GraphRelationshipLedger.node_label.is_not(None),
|
||||
GraphRelationshipLedger.deleted_at.is_(None),
|
||||
GraphRelationshipLedger.source_node_id.in_(node_ids),
|
||||
GraphRelationshipLedger.source_node_id
|
||||
== GraphRelationshipLedger.destination_node_id,
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
legacy_nodes = await session.execute(query)
|
||||
entries = legacy_nodes.all()
|
||||
legacy_nodes = (await session.execute(query)).all()
|
||||
|
||||
found_ids = set()
|
||||
for entry in entries:
|
||||
found_ids.add(entry.source_node_id)
|
||||
found_ids.add(entry.destination_node_id)
|
||||
if len(legacy_nodes) == 0:
|
||||
return [False for __ in nodes]
|
||||
|
||||
return [node_id in found_ids for node_id in node_ids]
|
||||
if is_backend_access_control_enabled():
|
||||
confirmed_nodes = await confirm_nodes_in_graph(legacy_nodes)
|
||||
return [node_id in confirmed_nodes for node_id in node_ids]
|
||||
else:
|
||||
found_ids = set()
|
||||
for __, node_id in legacy_nodes:
|
||||
found_ids.add(node_id)
|
||||
|
||||
return [node_id in found_ids for node_id in node_ids]
|
||||
|
||||
|
||||
async def confirm_nodes_in_graph(
|
||||
legacy_nodes: List[Tuple[str, UUID]],
|
||||
):
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
graph_nodes = await graph_engine.get_nodes([str(node[1]) for node in legacy_nodes])
|
||||
graph_nodes_by_id = {node["id"]: node for node in graph_nodes}
|
||||
|
||||
confirmed_nodes = set()
|
||||
for __, node_id in legacy_nodes:
|
||||
if str(node_id) in graph_nodes_by_id:
|
||||
confirmed_nodes.add(node_id)
|
||||
|
||||
return confirmed_nodes
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.modules.users.models.User import User
|
||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||
|
||||
|
||||
|
|
@ -12,23 +11,21 @@ from .GraphRelationshipLedger import GraphRelationshipLedger
|
|||
async def record_data_in_legacy_ledger(
|
||||
nodes: List[DataPoint],
|
||||
edges: List[Tuple[UUID, UUID, str, Dict]],
|
||||
user: User,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
relationships = [
|
||||
GraphRelationshipLedger(
|
||||
source_node_id=node.id,
|
||||
destination_node_id=node.id,
|
||||
creator_function="add_nodes",
|
||||
user_id=user.id,
|
||||
node_label=getattr(node, "name", getattr(node, "text", node.id)),
|
||||
creator_function="add_data_points.nodes",
|
||||
)
|
||||
for node in nodes
|
||||
] + [
|
||||
GraphRelationshipLedger(
|
||||
source_node_id=edge[0],
|
||||
destination_node_id=edge[1],
|
||||
creator_function=f"add_edges.{edge[2]}",
|
||||
user_id=user.id,
|
||||
creator_function=f"add_data_points.{edge[2]}",
|
||||
)
|
||||
for edge in edges
|
||||
]
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id:
|
|||
if len(affected_nodes) == 0:
|
||||
return
|
||||
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
|
||||
|
||||
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
|
||||
|
||||
non_legacy_nodes = [
|
||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||
|
|
@ -71,10 +71,10 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id:
|
|||
if len(affected_nodes) == 0:
|
||||
return
|
||||
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
|
||||
|
||||
affected_relationships = await get_global_data_related_edges(data_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
|
||||
|
||||
non_legacy_nodes = [
|
||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> Non
|
|||
if len(affected_nodes) == 0:
|
||||
return
|
||||
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
|
||||
|
||||
affected_relationships = await get_dataset_related_edges(dataset_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
|
||||
|
||||
non_legacy_nodes = [
|
||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||
|
|
@ -71,10 +71,10 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> Non
|
|||
if len(affected_nodes) == 0:
|
||||
return
|
||||
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
|
||||
|
||||
affected_relationships = await get_global_dataset_related_edges(dataset_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
|
||||
|
||||
non_legacy_nodes = [
|
||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from cognee.modules.engine.models import Entity, EntityType
|
|||
from cognee.modules.data.processing.document_types import TextDocument
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.modules.engine.utils import generate_node_id
|
||||
from cognee.modules.engine.utils.generate_node_name import generate_node_name
|
||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||
from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
|
||||
from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model
|
||||
|
|
@ -45,6 +46,36 @@ from cognee.tests.utils.filter_overlapping_relationships import filter_overlappi
|
|||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||
|
||||
|
||||
async def assert_relationships_vector_index_present(formatted_relationships, legacy_relationships):
|
||||
"""Helper to check both formatted (new) and unformatted (legacy) relationships."""
|
||||
if formatted_relationships:
|
||||
await assert_edges_vector_index_present(formatted_relationships, convert_to_new_format=True)
|
||||
if legacy_relationships:
|
||||
await assert_edges_vector_index_present(legacy_relationships, convert_to_new_format=False)
|
||||
|
||||
|
||||
def build_relationships(chunk, document, summary, graph):
|
||||
"""Build all relationships for a chunk including structural and extracted ones."""
|
||||
return [
|
||||
(chunk.id, document.id, "is_part_of", {"relationship_name": "is_part_of"}),
|
||||
(summary.id, chunk.id, "made_from", {"relationship_name": "made_from"}),
|
||||
] + extract_relationships(chunk, graph)
|
||||
|
||||
|
||||
def build_contains_relationships(chunk_id, entities, entity_names):
|
||||
"""Build contains relationships for specific entities."""
|
||||
return [
|
||||
(
|
||||
chunk_id,
|
||||
entity.id,
|
||||
get_contains_edge_text(entity.name, entity.description),
|
||||
{"relationship_name": get_contains_edge_text(entity.name, entity.description)},
|
||||
)
|
||||
for entity in entities
|
||||
if isinstance(entity, Entity) and entity.name in entity_names
|
||||
]
|
||||
|
||||
|
||||
def create_legacy_data_points():
|
||||
document = TextDocument(
|
||||
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
||||
|
|
@ -56,51 +87,43 @@ def create_legacy_data_points():
|
|||
document_chunk = DocumentChunk(
|
||||
id=uuid5(
|
||||
NAMESPACE_OID,
|
||||
"Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
"Apple announced their new vector embeddings visualization tool called Embedding Atlas.",
|
||||
),
|
||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
||||
text="Apple announced their new vector embeddings visualization tool called Embedding Atlas.",
|
||||
chunk_size=187,
|
||||
chunk_index=0,
|
||||
cut_type="paragraph_end",
|
||||
is_part_of=document,
|
||||
)
|
||||
|
||||
graph_database = EntityType(
|
||||
id=uuid5(NAMESPACE_OID, "graph_database"),
|
||||
name="graph database",
|
||||
description="graph database",
|
||||
company = EntityType(
|
||||
id=generate_node_id("Company"),
|
||||
name=generate_node_name("Company"),
|
||||
description=generate_node_name("Company"),
|
||||
)
|
||||
neptune_analytics_entity = Entity(
|
||||
id=generate_node_id("neptune analytics"),
|
||||
name="neptune analytics",
|
||||
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
||||
is_a=graph_database,
|
||||
apple = Entity(
|
||||
id=generate_node_id("Apple"),
|
||||
name=generate_node_name("Apple"),
|
||||
description="Apple is a company",
|
||||
is_a=company,
|
||||
)
|
||||
neptune_database_entity = Entity(
|
||||
id=generate_node_id("amazon neptune database"),
|
||||
name="amazon neptune database",
|
||||
description="A popular managed graph database that complements Neptune Analytics.",
|
||||
is_a=graph_database,
|
||||
product = EntityType(
|
||||
id=generate_node_id("Product"),
|
||||
name=generate_node_name("Product"),
|
||||
description=generate_node_name("Product"),
|
||||
)
|
||||
|
||||
storage = EntityType(
|
||||
id=generate_node_id("storage"),
|
||||
name="storage",
|
||||
description="storage",
|
||||
)
|
||||
storage_entity = Entity(
|
||||
id=generate_node_id("amazon s3"),
|
||||
name="amazon s3",
|
||||
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
||||
is_a=storage,
|
||||
embedding_atlas = Entity(
|
||||
id=generate_node_id("Embedding Atlas"),
|
||||
name=generate_node_name("Embedding Atlas"),
|
||||
description="Embedding Atlas",
|
||||
is_a=product,
|
||||
)
|
||||
|
||||
entities = [
|
||||
graph_database,
|
||||
neptune_analytics_entity,
|
||||
neptune_database_entity,
|
||||
storage,
|
||||
storage_entity,
|
||||
company,
|
||||
product,
|
||||
apple,
|
||||
embedding_atlas,
|
||||
]
|
||||
|
||||
document_chunk.contains = entities
|
||||
|
|
@ -143,7 +166,7 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
assert not await vector_engine.has_collection("TextDocument_text")
|
||||
|
||||
# Add legacy data to the system
|
||||
__, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(user)
|
||||
__, all_legacy_data_points, all_legacy_relationships = await create_mocked_legacy_data(user)
|
||||
|
||||
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||
if text_input == "test": # LLM connection test
|
||||
|
|
@ -262,132 +285,153 @@ async def main(mock_create_structured_output: AsyncMock):
|
|||
)
|
||||
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
||||
|
||||
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
|
||||
johns_entities, maries_entities
|
||||
)
|
||||
all_johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||
all_maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||
|
||||
johns_data = [
|
||||
expected_johns_data = [
|
||||
johns_document,
|
||||
johns_chunk,
|
||||
johns_summary,
|
||||
*johns_entities,
|
||||
*all_johns_entities,
|
||||
]
|
||||
maries_data = [
|
||||
expected_maries_data = [
|
||||
maries_document,
|
||||
maries_chunk,
|
||||
maries_summary,
|
||||
*maries_entities,
|
||||
*all_maries_entities,
|
||||
]
|
||||
|
||||
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
|
||||
expected_data_points = expected_johns_data + expected_maries_data + all_legacy_data_points
|
||||
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(expected_data_points)
|
||||
await assert_nodes_vector_index_present(expected_data_points)
|
||||
|
||||
johns_relationships = extract_relationships(
|
||||
all_johns_relationships = build_relationships(
|
||||
johns_chunk,
|
||||
johns_document,
|
||||
johns_summary,
|
||||
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
||||
)
|
||||
maries_relationships = extract_relationships(
|
||||
all_maries_relationships = build_relationships(
|
||||
maries_chunk,
|
||||
maries_document,
|
||||
maries_summary,
|
||||
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
||||
)
|
||||
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
|
||||
filter_overlapping_relationships(
|
||||
johns_relationships, maries_relationships, legacy_relationships
|
||||
)
|
||||
)
|
||||
|
||||
johns_relationships = [
|
||||
(johns_chunk.id, johns_document.id, "is_part_of"),
|
||||
(johns_summary.id, johns_chunk.id, "made_from"),
|
||||
*johns_relationships,
|
||||
]
|
||||
maries_relationships = [
|
||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
||||
*maries_relationships,
|
||||
]
|
||||
|
||||
expected_relationships = (
|
||||
johns_relationships
|
||||
+ maries_relationships
|
||||
+ overlapping_relationships
|
||||
+ legacy_relationships
|
||||
all_johns_relationships + all_maries_relationships + all_legacy_relationships
|
||||
)
|
||||
|
||||
await assert_graph_edges_present(expected_relationships)
|
||||
|
||||
await assert_edges_vector_index_present(expected_relationships)
|
||||
await assert_relationships_vector_index_present(
|
||||
all_johns_relationships + all_maries_relationships, all_legacy_relationships
|
||||
)
|
||||
|
||||
# Delete John's data
|
||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
|
||||
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
|
||||
expected_data_points = [
|
||||
maries_document,
|
||||
maries_chunk,
|
||||
maries_summary,
|
||||
*all_maries_entities,
|
||||
*all_legacy_data_points,
|
||||
]
|
||||
|
||||
await assert_graph_nodes_not_present(johns_data)
|
||||
await assert_nodes_vector_index_not_present(johns_data)
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(expected_data_points)
|
||||
await assert_nodes_vector_index_present(expected_data_points)
|
||||
|
||||
(__, strictly_johns_entities, __, __) = filter_overlapping_entities(
|
||||
all_johns_entities, all_maries_entities, all_legacy_data_points
|
||||
)
|
||||
|
||||
not_expected_data_points = [
|
||||
johns_document,
|
||||
johns_chunk,
|
||||
johns_summary,
|
||||
*strictly_johns_entities,
|
||||
]
|
||||
|
||||
await assert_graph_nodes_not_present(not_expected_data_points)
|
||||
await assert_nodes_vector_index_not_present(not_expected_data_points)
|
||||
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(
|
||||
maries_relationships + overlapping_relationships + legacy_relationships
|
||||
await assert_graph_edges_present(all_maries_relationships + all_legacy_relationships)
|
||||
await assert_relationships_vector_index_present(
|
||||
all_maries_relationships, all_legacy_relationships
|
||||
)
|
||||
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
|
||||
|
||||
await assert_graph_edges_not_present(johns_relationships)
|
||||
(__, strictly_johns_relationships, __, __) = filter_overlapping_relationships(
|
||||
all_johns_relationships,
|
||||
all_maries_relationships,
|
||||
all_legacy_relationships,
|
||||
)
|
||||
|
||||
johns_contains_relationships = [
|
||||
(
|
||||
johns_chunk.id,
|
||||
entity.id,
|
||||
get_contains_edge_text(entity.name, entity.description),
|
||||
{
|
||||
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||
},
|
||||
)
|
||||
for entity in johns_entities
|
||||
if isinstance(entity, Entity)
|
||||
]
|
||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
||||
await assert_edges_vector_index_not_present(johns_contains_relationships)
|
||||
await assert_graph_edges_not_present(strictly_johns_relationships)
|
||||
|
||||
# Check that John's unique contains relationships are not in vector index
|
||||
not_expected_relationships = build_contains_relationships(
|
||||
johns_chunk.id,
|
||||
all_johns_entities,
|
||||
[generate_node_name("John"), generate_node_name("Food for Hungry")],
|
||||
)
|
||||
await assert_edges_vector_index_not_present(not_expected_relationships)
|
||||
|
||||
# Delete Marie's data
|
||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||
|
||||
# Assert data points presence in the graph, vector collections and nodes table
|
||||
await assert_graph_nodes_present(legacy_data_points)
|
||||
await assert_nodes_vector_index_present(legacy_data_points)
|
||||
await assert_graph_nodes_present(all_legacy_data_points)
|
||||
await assert_nodes_vector_index_present(all_legacy_data_points)
|
||||
|
||||
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
|
||||
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
|
||||
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(legacy_relationships)
|
||||
await assert_edges_vector_index_present(legacy_relationships)
|
||||
|
||||
await assert_graph_edges_not_present(
|
||||
johns_relationships + maries_relationships + overlapping_relationships
|
||||
(__, strictly_johns_entities, strictly_maries_entities, __) = filter_overlapping_entities(
|
||||
all_johns_entities, all_maries_entities, all_legacy_data_points
|
||||
)
|
||||
|
||||
maries_contains_relationships = [
|
||||
(
|
||||
maries_chunk.id,
|
||||
entity.id,
|
||||
get_contains_edge_text(entity.name, entity.description),
|
||||
{
|
||||
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
||||
},
|
||||
)
|
||||
for entity in maries_entities
|
||||
if isinstance(entity, Entity)
|
||||
not_expected_data_points = [
|
||||
johns_document,
|
||||
johns_chunk,
|
||||
johns_summary,
|
||||
*strictly_johns_entities,
|
||||
maries_document,
|
||||
maries_chunk,
|
||||
maries_summary,
|
||||
*strictly_maries_entities,
|
||||
]
|
||||
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
||||
await assert_edges_vector_index_not_present(maries_contains_relationships)
|
||||
|
||||
await assert_graph_nodes_not_present(not_expected_data_points)
|
||||
await assert_nodes_vector_index_not_present(not_expected_data_points)
|
||||
|
||||
# Assert relationships presence in the graph, vector collections and nodes table
|
||||
await assert_graph_edges_present(all_legacy_relationships)
|
||||
await assert_relationships_vector_index_present([], all_legacy_relationships)
|
||||
|
||||
(__, strictly_johns_relationships, strictly_maries_relationships, __) = (
|
||||
filter_overlapping_relationships(
|
||||
all_maries_relationships,
|
||||
all_johns_relationships,
|
||||
all_legacy_relationships,
|
||||
)
|
||||
)
|
||||
|
||||
await assert_graph_edges_not_present(
|
||||
strictly_johns_relationships + strictly_maries_relationships
|
||||
)
|
||||
|
||||
# Check that John's and Marie's unique contains relationships are not in vector index
|
||||
not_expected_relationships = build_contains_relationships(
|
||||
johns_chunk.id,
|
||||
all_johns_entities,
|
||||
[generate_node_name("John"), generate_node_name("Food for Hungry")],
|
||||
) + build_contains_relationships(
|
||||
maries_chunk.id,
|
||||
all_maries_entities,
|
||||
[generate_node_name("Marie"), generate_node_name("MacOS")],
|
||||
)
|
||||
await assert_edges_vector_index_not_present(not_expected_relationships)
|
||||
|
||||
|
||||
async def create_mocked_legacy_data(user):
|
||||
|
|
@ -423,31 +467,11 @@ async def create_mocked_legacy_data(user):
|
|||
await graph_engine.add_nodes(graph_nodes)
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
|
||||
nodes_by_id = {node.id: node for node in graph_nodes}
|
||||
|
||||
def format_relationship_name(relationship):
|
||||
if relationship[2] == "contains":
|
||||
node = nodes_by_id[relationship[1]]
|
||||
return get_contains_edge_text(node.name, node.description)
|
||||
return relationship[2]
|
||||
|
||||
await index_data_points(graph_nodes)
|
||||
await index_graph_edges(
|
||||
[
|
||||
(
|
||||
edge[0],
|
||||
edge[1],
|
||||
format_relationship_name(edge),
|
||||
{
|
||||
**(edge[3] or {}),
|
||||
"relationship_name": format_relationship_name(edge),
|
||||
},
|
||||
)
|
||||
for edge in graph_edges
|
||||
] # type: ignore
|
||||
)
|
||||
# Legacy relationships should NOT be formatted - index them as-is
|
||||
await index_graph_edges(graph_edges)
|
||||
|
||||
await record_data_in_legacy_ledger(graph_nodes, graph_edges, user)
|
||||
await record_data_in_legacy_ledger(graph_nodes, graph_edges)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
|
|
|
|||
|
|
@ -435,7 +435,7 @@ async def create_mocked_legacy_data(user):
|
|||
] # type: ignore
|
||||
)
|
||||
|
||||
await record_data_in_legacy_ledger(graph_nodes, graph_edges, user)
|
||||
await record_data_in_legacy_ledger(graph_nodes, graph_edges)
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ def format_relationship(relationship: Tuple[UUID, UUID, str, Dict], node: Dict):
|
|||
return {str(generate_edge_id(relationship[2])): relationship[2]}
|
||||
|
||||
|
||||
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||
async def assert_edges_vector_index_present(
|
||||
relationships: List[Tuple[UUID, UUID, str, Dict]], convert_to_new_format: bool = True
|
||||
):
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
|
@ -33,7 +35,11 @@ async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID
|
|||
for relationship in relationships:
|
||||
query_edge_ids = {
|
||||
**query_edge_ids,
|
||||
**format_relationship(relationship, nodes_by_id[str(relationship[1])]),
|
||||
**(
|
||||
format_relationship(relationship, nodes_by_id[str(relationship[1])])
|
||||
if convert_to_new_format
|
||||
else {str(generate_edge_id(relationship[2])): relationship[2]}
|
||||
),
|
||||
}
|
||||
|
||||
vector_items = await vector_engine.retrieve(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
|
||||
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
|
|
@ -15,7 +14,11 @@ async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, s
|
|||
|
||||
for relationship in relationships:
|
||||
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
||||
relationship_name = relationship[2]
|
||||
assert relationship_id not in edge_ids, (
|
||||
f"Edge '{relationship_name}' still present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
|
||||
)
|
||||
|
||||
if relationship_id in edge_ids:
|
||||
relationship_name = relationship[2]
|
||||
source_node = nodes_by_id[str(relationship[0])]
|
||||
destination_node = nodes_by_id[str(relationship[1])]
|
||||
assert False, (
|
||||
f"Edge '{relationship_name}' still present between '{source_node['name'] if 'node' in source_node else source_node['id']}' and '{destination_node['name'] if 'node' in destination_node else destination_node['id']}' in graph database."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
|
||||
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]):
|
||||
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
|
||||
|
|
@ -16,6 +15,10 @@ async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]
|
|||
for relationship in relationships:
|
||||
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
||||
relationship_name = relationship[2]
|
||||
source_node = nodes_by_id.get(str(relationship[0]), {})
|
||||
target_node = nodes_by_id.get(str(relationship[1]), {})
|
||||
source_name = source_node.get("name") or source_node.get("text") or str(relationship[0])
|
||||
target_name = target_node.get("name") or target_node.get("text") or str(relationship[1])
|
||||
assert relationship_id in edge_ids, (
|
||||
f"Edge '{relationship_name}' not present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
|
||||
f"Edge '{relationship_name}' not present between '{source_name}' and '{target_name}' in graph database."
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue