fix: legacy delete backwards compatibility for neo4j
This commit is contained in:
parent
cb380e51e9
commit
bc4eb9f6ce
10 changed files with 250 additions and 206 deletions
|
|
@ -1,21 +1,20 @@
|
||||||
from uuid import UUID
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from sqlalchemy import and_, or_, select
|
from sqlalchemy import and_, or_, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import with_async_session
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
from cognee.modules.graph.models import Edge, Node
|
from cognee.modules.graph.models import Edge
|
||||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||||
|
|
||||||
|
|
||||||
@with_async_session
|
@with_async_session
|
||||||
async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session: AsyncSession):
|
async def has_edges_in_legacy_ledger(edges: List[Edge], session: AsyncSession):
|
||||||
if len(edges) == 0:
|
if len(edges) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query = select(GraphRelationshipLedger).where(
|
query = select(GraphRelationshipLedger).where(
|
||||||
and_(
|
and_(
|
||||||
GraphRelationshipLedger.user_id == user_id,
|
GraphRelationshipLedger.node_label.is_(None),
|
||||||
or_(
|
or_(
|
||||||
*[
|
*[
|
||||||
GraphRelationshipLedger.creator_function.ilike(f"%{edge.relationship_name}")
|
GraphRelationshipLedger.creator_function.ilike(f"%{edge.relationship_name}")
|
||||||
|
|
@ -30,20 +29,3 @@ async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session:
|
||||||
legacy_edge_names = set([edge.creator_function.split(".")[1] for edge in legacy_edges])
|
legacy_edge_names = set([edge.creator_function.split(".")[1] for edge in legacy_edges])
|
||||||
|
|
||||||
return [edge.relationship_name in legacy_edge_names for edge in edges]
|
return [edge.relationship_name in legacy_edge_names for edge in edges]
|
||||||
|
|
||||||
|
|
||||||
@with_async_session
|
|
||||||
async def get_node_ids(edges: List[Edge], session: AsyncSession):
|
|
||||||
node_slugs = []
|
|
||||||
|
|
||||||
for edge in edges:
|
|
||||||
node_slugs.append(edge.source_node_id)
|
|
||||||
node_slugs.append(edge.destination_node_id)
|
|
||||||
|
|
||||||
query = select(Node).where(Node.slug.in_(node_slugs))
|
|
||||||
|
|
||||||
nodes = (await session.scalars(query)).all()
|
|
||||||
|
|
||||||
node_ids = {node.slug: node.id for node in nodes}
|
|
||||||
|
|
||||||
return node_ids
|
|
||||||
|
|
|
||||||
|
|
@ -1,36 +1,65 @@
|
||||||
from typing import List
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy import and_, or_, select
|
from typing import List, Tuple
|
||||||
|
from sqlalchemy import and_, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.databases.relational import with_async_session
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
|
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
|
||||||
|
is_backend_access_control_enabled,
|
||||||
|
)
|
||||||
from cognee.modules.graph.models import Node
|
from cognee.modules.graph.models import Node
|
||||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||||
|
|
||||||
|
|
||||||
@with_async_session
|
@with_async_session
|
||||||
async def has_nodes_in_legacy_ledger(nodes: List[Node], user_id: UUID, session: AsyncSession):
|
async def has_nodes_in_legacy_ledger(nodes: List[Node], session: AsyncSession):
|
||||||
node_ids = [node.slug for node in nodes]
|
node_ids = [node.slug for node in nodes]
|
||||||
|
|
||||||
query = select(
|
query = (
|
||||||
GraphRelationshipLedger.source_node_id,
|
select(
|
||||||
GraphRelationshipLedger.destination_node_id,
|
GraphRelationshipLedger.node_label,
|
||||||
).where(
|
GraphRelationshipLedger.source_node_id,
|
||||||
and_(
|
|
||||||
GraphRelationshipLedger.user_id == user_id,
|
|
||||||
or_(
|
|
||||||
GraphRelationshipLedger.source_node_id.in_(node_ids),
|
|
||||||
GraphRelationshipLedger.destination_node_id.in_(node_ids),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
GraphRelationshipLedger.node_label.is_not(None),
|
||||||
|
GraphRelationshipLedger.deleted_at.is_(None),
|
||||||
|
GraphRelationshipLedger.source_node_id.in_(node_ids),
|
||||||
|
GraphRelationshipLedger.source_node_id
|
||||||
|
== GraphRelationshipLedger.destination_node_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
)
|
)
|
||||||
|
|
||||||
legacy_nodes = await session.execute(query)
|
legacy_nodes = (await session.execute(query)).all()
|
||||||
entries = legacy_nodes.all()
|
|
||||||
|
|
||||||
found_ids = set()
|
if len(legacy_nodes) == 0:
|
||||||
for entry in entries:
|
return [False for __ in nodes]
|
||||||
found_ids.add(entry.source_node_id)
|
|
||||||
found_ids.add(entry.destination_node_id)
|
|
||||||
|
|
||||||
return [node_id in found_ids for node_id in node_ids]
|
if is_backend_access_control_enabled():
|
||||||
|
confirmed_nodes = await confirm_nodes_in_graph(legacy_nodes)
|
||||||
|
return [node_id in confirmed_nodes for node_id in node_ids]
|
||||||
|
else:
|
||||||
|
found_ids = set()
|
||||||
|
for __, node_id in legacy_nodes:
|
||||||
|
found_ids.add(node_id)
|
||||||
|
|
||||||
|
return [node_id in found_ids for node_id in node_ids]
|
||||||
|
|
||||||
|
|
||||||
|
async def confirm_nodes_in_graph(
|
||||||
|
legacy_nodes: List[Tuple[str, UUID]],
|
||||||
|
):
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
graph_nodes = await graph_engine.get_nodes([str(node[1]) for node in legacy_nodes])
|
||||||
|
graph_nodes_by_id = {node["id"]: node for node in graph_nodes}
|
||||||
|
|
||||||
|
confirmed_nodes = set()
|
||||||
|
for __, node_id in legacy_nodes:
|
||||||
|
if str(node_id) in graph_nodes_by_id:
|
||||||
|
confirmed_nodes.add(node_id)
|
||||||
|
|
||||||
|
return confirmed_nodes
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import with_async_session
|
from cognee.infrastructure.databases.relational import with_async_session
|
||||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||||
from cognee.modules.users.models.User import User
|
|
||||||
from .GraphRelationshipLedger import GraphRelationshipLedger
|
from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -12,23 +11,21 @@ from .GraphRelationshipLedger import GraphRelationshipLedger
|
||||||
async def record_data_in_legacy_ledger(
|
async def record_data_in_legacy_ledger(
|
||||||
nodes: List[DataPoint],
|
nodes: List[DataPoint],
|
||||||
edges: List[Tuple[UUID, UUID, str, Dict]],
|
edges: List[Tuple[UUID, UUID, str, Dict]],
|
||||||
user: User,
|
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
) -> None:
|
) -> None:
|
||||||
relationships = [
|
relationships = [
|
||||||
GraphRelationshipLedger(
|
GraphRelationshipLedger(
|
||||||
source_node_id=node.id,
|
source_node_id=node.id,
|
||||||
destination_node_id=node.id,
|
destination_node_id=node.id,
|
||||||
creator_function="add_nodes",
|
node_label=getattr(node, "name", getattr(node, "text", node.id)),
|
||||||
user_id=user.id,
|
creator_function="add_data_points.nodes",
|
||||||
)
|
)
|
||||||
for node in nodes
|
for node in nodes
|
||||||
] + [
|
] + [
|
||||||
GraphRelationshipLedger(
|
GraphRelationshipLedger(
|
||||||
source_node_id=edge[0],
|
source_node_id=edge[0],
|
||||||
destination_node_id=edge[1],
|
destination_node_id=edge[1],
|
||||||
creator_function=f"add_edges.{edge[2]}",
|
creator_function=f"add_data_points.{edge[2]}",
|
||||||
user_id=user.id,
|
|
||||||
)
|
)
|
||||||
for edge in edges
|
for edge in edges
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,10 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id:
|
||||||
if len(affected_nodes) == 0:
|
if len(affected_nodes) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
|
||||||
|
|
||||||
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
affected_relationships = await get_data_related_edges(dataset_id, data_id)
|
||||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
|
||||||
|
|
||||||
non_legacy_nodes = [
|
non_legacy_nodes = [
|
||||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||||
|
|
@ -71,10 +71,10 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id:
|
||||||
if len(affected_nodes) == 0:
|
if len(affected_nodes) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
|
||||||
|
|
||||||
affected_relationships = await get_global_data_related_edges(data_id)
|
affected_relationships = await get_global_data_related_edges(data_id)
|
||||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
|
||||||
|
|
||||||
non_legacy_nodes = [
|
non_legacy_nodes = [
|
||||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,10 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> Non
|
||||||
if len(affected_nodes) == 0:
|
if len(affected_nodes) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
|
||||||
|
|
||||||
affected_relationships = await get_dataset_related_edges(dataset_id)
|
affected_relationships = await get_dataset_related_edges(dataset_id)
|
||||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
|
||||||
|
|
||||||
non_legacy_nodes = [
|
non_legacy_nodes = [
|
||||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||||
|
|
@ -71,10 +71,10 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> Non
|
||||||
if len(affected_nodes) == 0:
|
if len(affected_nodes) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes)
|
||||||
|
|
||||||
affected_relationships = await get_global_dataset_related_edges(dataset_id)
|
affected_relationships = await get_global_dataset_related_edges(dataset_id)
|
||||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships)
|
||||||
|
|
||||||
non_legacy_nodes = [
|
non_legacy_nodes = [
|
||||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from cognee.modules.engine.models import Entity, EntityType
|
||||||
from cognee.modules.data.processing.document_types import TextDocument
|
from cognee.modules.data.processing.document_types import TextDocument
|
||||||
from cognee.modules.engine.operations.setup import setup
|
from cognee.modules.engine.operations.setup import setup
|
||||||
from cognee.modules.engine.utils import generate_node_id
|
from cognee.modules.engine.utils import generate_node_id
|
||||||
|
from cognee.modules.engine.utils.generate_node_name import generate_node_name
|
||||||
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger
|
||||||
from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
|
from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges
|
||||||
from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model
|
from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model
|
||||||
|
|
@ -45,6 +46,36 @@ from cognee.tests.utils.filter_overlapping_relationships import filter_overlappi
|
||||||
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_relationships_vector_index_present(formatted_relationships, legacy_relationships):
|
||||||
|
"""Helper to check both formatted (new) and unformatted (legacy) relationships."""
|
||||||
|
if formatted_relationships:
|
||||||
|
await assert_edges_vector_index_present(formatted_relationships, convert_to_new_format=True)
|
||||||
|
if legacy_relationships:
|
||||||
|
await assert_edges_vector_index_present(legacy_relationships, convert_to_new_format=False)
|
||||||
|
|
||||||
|
|
||||||
|
def build_relationships(chunk, document, summary, graph):
|
||||||
|
"""Build all relationships for a chunk including structural and extracted ones."""
|
||||||
|
return [
|
||||||
|
(chunk.id, document.id, "is_part_of", {"relationship_name": "is_part_of"}),
|
||||||
|
(summary.id, chunk.id, "made_from", {"relationship_name": "made_from"}),
|
||||||
|
] + extract_relationships(chunk, graph)
|
||||||
|
|
||||||
|
|
||||||
|
def build_contains_relationships(chunk_id, entities, entity_names):
|
||||||
|
"""Build contains relationships for specific entities."""
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
chunk_id,
|
||||||
|
entity.id,
|
||||||
|
get_contains_edge_text(entity.name, entity.description),
|
||||||
|
{"relationship_name": get_contains_edge_text(entity.name, entity.description)},
|
||||||
|
)
|
||||||
|
for entity in entities
|
||||||
|
if isinstance(entity, Entity) and entity.name in entity_names
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_legacy_data_points():
|
def create_legacy_data_points():
|
||||||
document = TextDocument(
|
document = TextDocument(
|
||||||
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
id=uuid5(NAMESPACE_OID, "text_test.txt"),
|
||||||
|
|
@ -56,51 +87,43 @@ def create_legacy_data_points():
|
||||||
document_chunk = DocumentChunk(
|
document_chunk = DocumentChunk(
|
||||||
id=uuid5(
|
id=uuid5(
|
||||||
NAMESPACE_OID,
|
NAMESPACE_OID,
|
||||||
"Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
"Apple announced their new vector embeddings visualization tool called Embedding Atlas.",
|
||||||
),
|
),
|
||||||
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
text="Apple announced their new vector embeddings visualization tool called Embedding Atlas.",
|
||||||
chunk_size=187,
|
chunk_size=187,
|
||||||
chunk_index=0,
|
chunk_index=0,
|
||||||
cut_type="paragraph_end",
|
cut_type="paragraph_end",
|
||||||
is_part_of=document,
|
is_part_of=document,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_database = EntityType(
|
company = EntityType(
|
||||||
id=uuid5(NAMESPACE_OID, "graph_database"),
|
id=generate_node_id("Company"),
|
||||||
name="graph database",
|
name=generate_node_name("Company"),
|
||||||
description="graph database",
|
description=generate_node_name("Company"),
|
||||||
)
|
)
|
||||||
neptune_analytics_entity = Entity(
|
apple = Entity(
|
||||||
id=generate_node_id("neptune analytics"),
|
id=generate_node_id("Apple"),
|
||||||
name="neptune analytics",
|
name=generate_node_name("Apple"),
|
||||||
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
description="Apple is a company",
|
||||||
is_a=graph_database,
|
is_a=company,
|
||||||
)
|
)
|
||||||
neptune_database_entity = Entity(
|
product = EntityType(
|
||||||
id=generate_node_id("amazon neptune database"),
|
id=generate_node_id("Product"),
|
||||||
name="amazon neptune database",
|
name=generate_node_name("Product"),
|
||||||
description="A popular managed graph database that complements Neptune Analytics.",
|
description=generate_node_name("Product"),
|
||||||
is_a=graph_database,
|
|
||||||
)
|
)
|
||||||
|
embedding_atlas = Entity(
|
||||||
storage = EntityType(
|
id=generate_node_id("Embedding Atlas"),
|
||||||
id=generate_node_id("storage"),
|
name=generate_node_name("Embedding Atlas"),
|
||||||
name="storage",
|
description="Embedding Atlas",
|
||||||
description="storage",
|
is_a=product,
|
||||||
)
|
|
||||||
storage_entity = Entity(
|
|
||||||
id=generate_node_id("amazon s3"),
|
|
||||||
name="amazon s3",
|
|
||||||
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
|
||||||
is_a=storage,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
entities = [
|
entities = [
|
||||||
graph_database,
|
company,
|
||||||
neptune_analytics_entity,
|
product,
|
||||||
neptune_database_entity,
|
apple,
|
||||||
storage,
|
embedding_atlas,
|
||||||
storage_entity,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
document_chunk.contains = entities
|
document_chunk.contains = entities
|
||||||
|
|
@ -143,7 +166,7 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
assert not await vector_engine.has_collection("TextDocument_text")
|
assert not await vector_engine.has_collection("TextDocument_text")
|
||||||
|
|
||||||
# Add legacy data to the system
|
# Add legacy data to the system
|
||||||
__, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(user)
|
__, all_legacy_data_points, all_legacy_relationships = await create_mocked_legacy_data(user)
|
||||||
|
|
||||||
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
def mock_llm_output(text_input: str, system_prompt: str, response_model):
|
||||||
if text_input == "test": # LLM connection test
|
if text_input == "test": # LLM connection test
|
||||||
|
|
@ -262,132 +285,153 @@ async def main(mock_create_structured_output: AsyncMock):
|
||||||
)
|
)
|
||||||
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore
|
||||||
|
|
||||||
johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
all_johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore
|
||||||
maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
all_maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore
|
||||||
(overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities(
|
|
||||||
johns_entities, maries_entities
|
|
||||||
)
|
|
||||||
|
|
||||||
johns_data = [
|
expected_johns_data = [
|
||||||
johns_document,
|
johns_document,
|
||||||
johns_chunk,
|
johns_chunk,
|
||||||
johns_summary,
|
johns_summary,
|
||||||
*johns_entities,
|
*all_johns_entities,
|
||||||
]
|
]
|
||||||
maries_data = [
|
expected_maries_data = [
|
||||||
maries_document,
|
maries_document,
|
||||||
maries_chunk,
|
maries_chunk,
|
||||||
maries_summary,
|
maries_summary,
|
||||||
*maries_entities,
|
*all_maries_entities,
|
||||||
]
|
]
|
||||||
|
|
||||||
expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points
|
expected_data_points = expected_johns_data + expected_maries_data + all_legacy_data_points
|
||||||
|
|
||||||
# Assert data points presence in the graph, vector collections and nodes table
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
await assert_graph_nodes_present(expected_data_points)
|
await assert_graph_nodes_present(expected_data_points)
|
||||||
await assert_nodes_vector_index_present(expected_data_points)
|
await assert_nodes_vector_index_present(expected_data_points)
|
||||||
|
|
||||||
johns_relationships = extract_relationships(
|
all_johns_relationships = build_relationships(
|
||||||
johns_chunk,
|
johns_chunk,
|
||||||
|
johns_document,
|
||||||
|
johns_summary,
|
||||||
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
mock_llm_output("John", "", KnowledgeGraph), # type: ignore
|
||||||
)
|
)
|
||||||
maries_relationships = extract_relationships(
|
all_maries_relationships = build_relationships(
|
||||||
maries_chunk,
|
maries_chunk,
|
||||||
|
maries_document,
|
||||||
|
maries_summary,
|
||||||
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore
|
||||||
)
|
)
|
||||||
(overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = (
|
|
||||||
filter_overlapping_relationships(
|
|
||||||
johns_relationships, maries_relationships, legacy_relationships
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
johns_relationships = [
|
|
||||||
(johns_chunk.id, johns_document.id, "is_part_of"),
|
|
||||||
(johns_summary.id, johns_chunk.id, "made_from"),
|
|
||||||
*johns_relationships,
|
|
||||||
]
|
|
||||||
maries_relationships = [
|
|
||||||
(maries_chunk.id, maries_document.id, "is_part_of"),
|
|
||||||
(maries_summary.id, maries_chunk.id, "made_from"),
|
|
||||||
*maries_relationships,
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_relationships = (
|
expected_relationships = (
|
||||||
johns_relationships
|
all_johns_relationships + all_maries_relationships + all_legacy_relationships
|
||||||
+ maries_relationships
|
|
||||||
+ overlapping_relationships
|
|
||||||
+ legacy_relationships
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await assert_graph_edges_present(expected_relationships)
|
await assert_graph_edges_present(expected_relationships)
|
||||||
|
await assert_relationships_vector_index_present(
|
||||||
await assert_edges_vector_index_present(expected_relationships)
|
all_johns_relationships + all_maries_relationships, all_legacy_relationships
|
||||||
|
)
|
||||||
|
|
||||||
# Delete John's data
|
# Delete John's data
|
||||||
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore
|
||||||
|
|
||||||
# Assert data points presence in the graph, vector collections and nodes table
|
expected_data_points = [
|
||||||
await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points)
|
maries_document,
|
||||||
await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points)
|
maries_chunk,
|
||||||
|
maries_summary,
|
||||||
|
*all_maries_entities,
|
||||||
|
*all_legacy_data_points,
|
||||||
|
]
|
||||||
|
|
||||||
await assert_graph_nodes_not_present(johns_data)
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
await assert_nodes_vector_index_not_present(johns_data)
|
await assert_graph_nodes_present(expected_data_points)
|
||||||
|
await assert_nodes_vector_index_present(expected_data_points)
|
||||||
|
|
||||||
|
(__, strictly_johns_entities, __, __) = filter_overlapping_entities(
|
||||||
|
all_johns_entities, all_maries_entities, all_legacy_data_points
|
||||||
|
)
|
||||||
|
|
||||||
|
not_expected_data_points = [
|
||||||
|
johns_document,
|
||||||
|
johns_chunk,
|
||||||
|
johns_summary,
|
||||||
|
*strictly_johns_entities,
|
||||||
|
]
|
||||||
|
|
||||||
|
await assert_graph_nodes_not_present(not_expected_data_points)
|
||||||
|
await assert_nodes_vector_index_not_present(not_expected_data_points)
|
||||||
|
|
||||||
# Assert relationships presence in the graph, vector collections and nodes table
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
await assert_graph_edges_present(
|
await assert_graph_edges_present(all_maries_relationships + all_legacy_relationships)
|
||||||
maries_relationships + overlapping_relationships + legacy_relationships
|
await assert_relationships_vector_index_present(
|
||||||
|
all_maries_relationships, all_legacy_relationships
|
||||||
)
|
)
|
||||||
await assert_edges_vector_index_present(maries_relationships + legacy_relationships)
|
|
||||||
|
|
||||||
await assert_graph_edges_not_present(johns_relationships)
|
(__, strictly_johns_relationships, __, __) = filter_overlapping_relationships(
|
||||||
|
all_johns_relationships,
|
||||||
|
all_maries_relationships,
|
||||||
|
all_legacy_relationships,
|
||||||
|
)
|
||||||
|
|
||||||
johns_contains_relationships = [
|
await assert_graph_edges_not_present(strictly_johns_relationships)
|
||||||
(
|
|
||||||
johns_chunk.id,
|
# Check that John's unique contains relationships are not in vector index
|
||||||
entity.id,
|
not_expected_relationships = build_contains_relationships(
|
||||||
get_contains_edge_text(entity.name, entity.description),
|
johns_chunk.id,
|
||||||
{
|
all_johns_entities,
|
||||||
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
[generate_node_name("John"), generate_node_name("Food for Hungry")],
|
||||||
},
|
)
|
||||||
)
|
await assert_edges_vector_index_not_present(not_expected_relationships)
|
||||||
for entity in johns_entities
|
|
||||||
if isinstance(entity, Entity)
|
|
||||||
]
|
|
||||||
# We check only by relationship name and we need edges that are created by John's data and no other.
|
|
||||||
await assert_edges_vector_index_not_present(johns_contains_relationships)
|
|
||||||
|
|
||||||
# Delete Marie's data
|
# Delete Marie's data
|
||||||
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore
|
||||||
|
|
||||||
# Assert data points presence in the graph, vector collections and nodes table
|
# Assert data points presence in the graph, vector collections and nodes table
|
||||||
await assert_graph_nodes_present(legacy_data_points)
|
await assert_graph_nodes_present(all_legacy_data_points)
|
||||||
await assert_nodes_vector_index_present(legacy_data_points)
|
await assert_nodes_vector_index_present(all_legacy_data_points)
|
||||||
|
|
||||||
await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities)
|
(__, strictly_johns_entities, strictly_maries_entities, __) = filter_overlapping_entities(
|
||||||
await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities)
|
all_johns_entities, all_maries_entities, all_legacy_data_points
|
||||||
|
|
||||||
# Assert relationships presence in the graph, vector collections and nodes table
|
|
||||||
await assert_graph_edges_present(legacy_relationships)
|
|
||||||
await assert_edges_vector_index_present(legacy_relationships)
|
|
||||||
|
|
||||||
await assert_graph_edges_not_present(
|
|
||||||
johns_relationships + maries_relationships + overlapping_relationships
|
|
||||||
)
|
)
|
||||||
|
|
||||||
maries_contains_relationships = [
|
not_expected_data_points = [
|
||||||
(
|
johns_document,
|
||||||
maries_chunk.id,
|
johns_chunk,
|
||||||
entity.id,
|
johns_summary,
|
||||||
get_contains_edge_text(entity.name, entity.description),
|
*strictly_johns_entities,
|
||||||
{
|
maries_document,
|
||||||
"relationship_name": get_contains_edge_text(entity.name, entity.description),
|
maries_chunk,
|
||||||
},
|
maries_summary,
|
||||||
)
|
*strictly_maries_entities,
|
||||||
for entity in maries_entities
|
|
||||||
if isinstance(entity, Entity)
|
|
||||||
]
|
]
|
||||||
# We check only by relationship name and we need edges that are created by legacy data and no other.
|
|
||||||
await assert_edges_vector_index_not_present(maries_contains_relationships)
|
await assert_graph_nodes_not_present(not_expected_data_points)
|
||||||
|
await assert_nodes_vector_index_not_present(not_expected_data_points)
|
||||||
|
|
||||||
|
# Assert relationships presence in the graph, vector collections and nodes table
|
||||||
|
await assert_graph_edges_present(all_legacy_relationships)
|
||||||
|
await assert_relationships_vector_index_present([], all_legacy_relationships)
|
||||||
|
|
||||||
|
(__, strictly_johns_relationships, strictly_maries_relationships, __) = (
|
||||||
|
filter_overlapping_relationships(
|
||||||
|
all_maries_relationships,
|
||||||
|
all_johns_relationships,
|
||||||
|
all_legacy_relationships,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await assert_graph_edges_not_present(
|
||||||
|
strictly_johns_relationships + strictly_maries_relationships
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that John's and Marie's unique contains relationships are not in vector index
|
||||||
|
not_expected_relationships = build_contains_relationships(
|
||||||
|
johns_chunk.id,
|
||||||
|
all_johns_entities,
|
||||||
|
[generate_node_name("John"), generate_node_name("Food for Hungry")],
|
||||||
|
) + build_contains_relationships(
|
||||||
|
maries_chunk.id,
|
||||||
|
all_maries_entities,
|
||||||
|
[generate_node_name("Marie"), generate_node_name("MacOS")],
|
||||||
|
)
|
||||||
|
await assert_edges_vector_index_not_present(not_expected_relationships)
|
||||||
|
|
||||||
|
|
||||||
async def create_mocked_legacy_data(user):
|
async def create_mocked_legacy_data(user):
|
||||||
|
|
@ -423,31 +467,11 @@ async def create_mocked_legacy_data(user):
|
||||||
await graph_engine.add_nodes(graph_nodes)
|
await graph_engine.add_nodes(graph_nodes)
|
||||||
await graph_engine.add_edges(graph_edges)
|
await graph_engine.add_edges(graph_edges)
|
||||||
|
|
||||||
nodes_by_id = {node.id: node for node in graph_nodes}
|
|
||||||
|
|
||||||
def format_relationship_name(relationship):
|
|
||||||
if relationship[2] == "contains":
|
|
||||||
node = nodes_by_id[relationship[1]]
|
|
||||||
return get_contains_edge_text(node.name, node.description)
|
|
||||||
return relationship[2]
|
|
||||||
|
|
||||||
await index_data_points(graph_nodes)
|
await index_data_points(graph_nodes)
|
||||||
await index_graph_edges(
|
# Legacy relationships should NOT be formatted - index them as-is
|
||||||
[
|
await index_graph_edges(graph_edges)
|
||||||
(
|
|
||||||
edge[0],
|
|
||||||
edge[1],
|
|
||||||
format_relationship_name(edge),
|
|
||||||
{
|
|
||||||
**(edge[3] or {}),
|
|
||||||
"relationship_name": format_relationship_name(edge),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for edge in graph_edges
|
|
||||||
] # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
await record_data_in_legacy_ledger(graph_nodes, graph_edges, user)
|
await record_data_in_legacy_ledger(graph_nodes, graph_edges)
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -435,7 +435,7 @@ async def create_mocked_legacy_data(user):
|
||||||
] # type: ignore
|
] # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
await record_data_in_legacy_ledger(graph_nodes, graph_edges, user)
|
await record_data_in_legacy_ledger(graph_nodes, graph_edges)
|
||||||
|
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,9 @@ def format_relationship(relationship: Tuple[UUID, UUID, str, Dict], node: Dict):
|
||||||
return {str(generate_edge_id(relationship[2])): relationship[2]}
|
return {str(generate_edge_id(relationship[2])): relationship[2]}
|
||||||
|
|
||||||
|
|
||||||
async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
async def assert_edges_vector_index_present(
|
||||||
|
relationships: List[Tuple[UUID, UUID, str, Dict]], convert_to_new_format: bool = True
|
||||||
|
):
|
||||||
vector_engine = get_vector_engine()
|
vector_engine = get_vector_engine()
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
|
|
@ -33,7 +35,11 @@ async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID
|
||||||
for relationship in relationships:
|
for relationship in relationships:
|
||||||
query_edge_ids = {
|
query_edge_ids = {
|
||||||
**query_edge_ids,
|
**query_edge_ids,
|
||||||
**format_relationship(relationship, nodes_by_id[str(relationship[1])]),
|
**(
|
||||||
|
format_relationship(relationship, nodes_by_id[str(relationship[1])])
|
||||||
|
if convert_to_new_format
|
||||||
|
else {str(generate_edge_id(relationship[2])): relationship[2]}
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
vector_items = await vector_engine.retrieve(
|
vector_items = await vector_engine.retrieve(
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
|
||||||
|
|
||||||
|
|
||||||
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str]]):
|
async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
|
@ -15,7 +14,11 @@ async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, s
|
||||||
|
|
||||||
for relationship in relationships:
|
for relationship in relationships:
|
||||||
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
||||||
relationship_name = relationship[2]
|
|
||||||
assert relationship_id not in edge_ids, (
|
if relationship_id in edge_ids:
|
||||||
f"Edge '{relationship_name}' still present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
|
relationship_name = relationship[2]
|
||||||
)
|
source_node = nodes_by_id[str(relationship[0])]
|
||||||
|
destination_node = nodes_by_id[str(relationship[1])]
|
||||||
|
assert False, (
|
||||||
|
f"Edge '{relationship_name}' still present between '{source_node['name'] if 'node' in source_node else source_node['id']}' and '{destination_node['name'] if 'node' in destination_node else destination_node['id']}' in graph database."
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
|
||||||
|
|
||||||
|
|
||||||
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]):
|
async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str, Dict]]):
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
|
||||||
|
|
@ -16,6 +15,10 @@ async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]
|
||||||
for relationship in relationships:
|
for relationship in relationships:
|
||||||
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}"
|
||||||
relationship_name = relationship[2]
|
relationship_name = relationship[2]
|
||||||
|
source_node = nodes_by_id.get(str(relationship[0]), {})
|
||||||
|
target_node = nodes_by_id.get(str(relationship[1]), {})
|
||||||
|
source_name = source_node.get("name") or source_node.get("text") or str(relationship[0])
|
||||||
|
target_name = target_node.get("name") or target_node.get("text") or str(relationship[1])
|
||||||
assert relationship_id in edge_ids, (
|
assert relationship_id in edge_ids, (
|
||||||
f"Edge '{relationship_name}' not present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database."
|
f"Edge '{relationship_name}' not present between '{source_name}' and '{target_name}' in graph database."
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue