diff --git a/cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py b/cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py index 7857a8e25..ec2f094d5 100644 --- a/cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py +++ b/cognee/modules/graph/legacy/has_edges_in_legacy_ledger.py @@ -1,21 +1,20 @@ -from uuid import UUID from typing import List from sqlalchemy import and_, or_, select from sqlalchemy.ext.asyncio import AsyncSession from cognee.infrastructure.databases.relational import with_async_session -from cognee.modules.graph.models import Edge, Node +from cognee.modules.graph.models import Edge from .GraphRelationshipLedger import GraphRelationshipLedger @with_async_session -async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session: AsyncSession): +async def has_edges_in_legacy_ledger(edges: List[Edge], session: AsyncSession): if len(edges) == 0: return [] query = select(GraphRelationshipLedger).where( and_( - GraphRelationshipLedger.user_id == user_id, + GraphRelationshipLedger.node_label.is_(None), or_( *[ GraphRelationshipLedger.creator_function.ilike(f"%{edge.relationship_name}") @@ -30,20 +29,3 @@ async def has_edges_in_legacy_ledger(edges: List[Edge], user_id: UUID, session: legacy_edge_names = set([edge.creator_function.split(".")[1] for edge in legacy_edges]) return [edge.relationship_name in legacy_edge_names for edge in edges] - - -@with_async_session -async def get_node_ids(edges: List[Edge], session: AsyncSession): - node_slugs = [] - - for edge in edges: - node_slugs.append(edge.source_node_id) - node_slugs.append(edge.destination_node_id) - - query = select(Node).where(Node.slug.in_(node_slugs)) - - nodes = (await session.scalars(query)).all() - - node_ids = {node.slug: node.id for node in nodes} - - return node_ids diff --git a/cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py b/cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py index a84b57d75..d18e582b5 100644 --- a/cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py +++ b/cognee/modules/graph/legacy/has_nodes_in_legacy_ledger.py @@ -1,36 +1,65 @@ -from typing import List from uuid import UUID -from sqlalchemy import and_, or_, select +from typing import List, Tuple +from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession +from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import with_async_session +from cognee.infrastructure.environment.config.is_backend_access_control_enabled import ( + is_backend_access_control_enabled, +) from cognee.modules.graph.models import Node from .GraphRelationshipLedger import GraphRelationshipLedger @with_async_session -async def has_nodes_in_legacy_ledger(nodes: List[Node], user_id: UUID, session: AsyncSession): +async def has_nodes_in_legacy_ledger(nodes: List[Node], session: AsyncSession): node_ids = [node.slug for node in nodes] - query = select( - GraphRelationshipLedger.source_node_id, - GraphRelationshipLedger.destination_node_id, - ).where( - and_( - GraphRelationshipLedger.user_id == user_id, - or_( - GraphRelationshipLedger.source_node_id.in_(node_ids), - GraphRelationshipLedger.destination_node_id.in_(node_ids), - ), + query = ( + select( + GraphRelationshipLedger.node_label, + GraphRelationshipLedger.source_node_id, ) + .where( + and_( + GraphRelationshipLedger.node_label.is_not(None), + GraphRelationshipLedger.deleted_at.is_(None), + GraphRelationshipLedger.source_node_id.in_(node_ids), + GraphRelationshipLedger.source_node_id + == GraphRelationshipLedger.destination_node_id, + ) + ) + .distinct() ) - legacy_nodes = await session.execute(query) - entries = legacy_nodes.all() + legacy_nodes = (await session.execute(query)).all() - found_ids = set() - for entry in entries: - found_ids.add(entry.source_node_id) - found_ids.add(entry.destination_node_id) + if len(legacy_nodes) == 0: + return [False for __ in nodes] - return [node_id in found_ids for node_id in node_ids] + if is_backend_access_control_enabled(): + confirmed_nodes = await confirm_nodes_in_graph(legacy_nodes) + return [node_id in confirmed_nodes for node_id in node_ids] + else: + found_ids = set() + for __, node_id in legacy_nodes: + found_ids.add(node_id) + + return [node_id in found_ids for node_id in node_ids] + + +async def confirm_nodes_in_graph( + legacy_nodes: List[Tuple[str, UUID]], +): + graph_engine = await get_graph_engine() + + graph_nodes = await graph_engine.get_nodes([str(node[1]) for node in legacy_nodes]) + graph_nodes_by_id = {node["id"]: node for node in graph_nodes} + + confirmed_nodes = set() + for __, node_id in legacy_nodes: + if str(node_id) in graph_nodes_by_id: + confirmed_nodes.add(node_id) + + return confirmed_nodes diff --git a/cognee/modules/graph/legacy/record_data_in_legacy_ledger.py b/cognee/modules/graph/legacy/record_data_in_legacy_ledger.py index d3efbb99b..87c42290d 100644 --- a/cognee/modules/graph/legacy/record_data_in_legacy_ledger.py +++ b/cognee/modules/graph/legacy/record_data_in_legacy_ledger.py @@ -4,7 +4,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from cognee.infrastructure.databases.relational import with_async_session from cognee.infrastructure.engine.models.DataPoint import DataPoint -from cognee.modules.users.models.User import User from .GraphRelationshipLedger import GraphRelationshipLedger @@ -12,23 +11,21 @@ from .GraphRelationshipLedger import GraphRelationshipLedger async def record_data_in_legacy_ledger( nodes: List[DataPoint], edges: List[Tuple[UUID, UUID, str, Dict]], - user: User, session: AsyncSession, ) -> None: relationships = [ GraphRelationshipLedger( source_node_id=node.id, destination_node_id=node.id, - creator_function="add_nodes", - user_id=user.id, + node_label=getattr(node, "name", getattr(node, "text", node.id)), + creator_function="add_data_points.nodes", ) for node in nodes ] + [ GraphRelationshipLedger( source_node_id=edge[0], destination_node_id=edge[1], - creator_function=f"add_edges.{edge[2]}", - user_id=user.id, + creator_function=f"add_data_points.{edge[2]}", ) for edge in edges ] diff --git a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py index b2501e43b..274deb658 100644 --- a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py +++ b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py @@ -25,10 +25,10 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id: if len(affected_nodes) == 0: return - is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes) affected_relationships = await get_data_related_edges(dataset_id, data_id) - is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships) non_legacy_nodes = [ node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] @@ -71,10 +71,10 @@ async def delete_data_nodes_and_edges(dataset_id: UUID, data_id: UUID, user_id: if len(affected_nodes) == 0: return - is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes) affected_relationships = await get_global_data_related_edges(data_id) - is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships) non_legacy_nodes = [ node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] diff --git a/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py b/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py index 0ab427da8..0ccca6ab7 100644 --- a/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py +++ b/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py @@ -25,10 +25,10 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> Non if len(affected_nodes) == 0: return - is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes) affected_relationships = await get_dataset_related_edges(dataset_id) - is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships) non_legacy_nodes = [ node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] @@ -71,10 +71,10 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> Non if len(affected_nodes) == 0: return - is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes) affected_relationships = await get_global_dataset_related_edges(dataset_id) - is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships) non_legacy_nodes = [ node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] diff --git a/cognee/tests/test_delete_default_graph_with_legacy_data_1.py b/cognee/tests/test_delete_default_graph_with_legacy_data_1.py index 9fae87874..e1e923c24 100644 --- a/cognee/tests/test_delete_default_graph_with_legacy_data_1.py +++ b/cognee/tests/test_delete_default_graph_with_legacy_data_1.py @@ -18,6 +18,7 @@ from cognee.modules.engine.models import Entity, EntityType from cognee.modules.data.processing.document_types import TextDocument from cognee.modules.engine.operations.setup import setup from cognee.modules.engine.utils import generate_node_id +from cognee.modules.engine.utils.generate_node_name import generate_node_name from cognee.modules.graph.legacy.record_data_in_legacy_ledger import record_data_in_legacy_ledger from cognee.modules.graph.utils.deduplicate_nodes_and_edges import deduplicate_nodes_and_edges from cognee.modules.graph.utils.get_graph_from_model import get_graph_from_model @@ -45,6 +46,36 @@ from cognee.tests.utils.filter_overlapping_relationships import filter_overlappi from cognee.tests.utils.get_contains_edge_text import get_contains_edge_text +async def assert_relationships_vector_index_present(formatted_relationships, legacy_relationships): + """Helper to check both formatted (new) and unformatted (legacy) relationships.""" + if formatted_relationships: + await assert_edges_vector_index_present(formatted_relationships, convert_to_new_format=True) + if legacy_relationships: + await assert_edges_vector_index_present(legacy_relationships, convert_to_new_format=False) + + +def build_relationships(chunk, document, summary, graph): + """Build all relationships for a chunk including structural and extracted ones.""" + return [ + (chunk.id, document.id, "is_part_of", {"relationship_name": "is_part_of"}), + (summary.id, chunk.id, "made_from", {"relationship_name": "made_from"}), + ] + extract_relationships(chunk, graph) + + +def build_contains_relationships(chunk_id, entities, entity_names): + """Build contains relationships for specific entities.""" + return [ + ( + chunk_id, + entity.id, + get_contains_edge_text(entity.name, entity.description), + {"relationship_name": get_contains_edge_text(entity.name, entity.description)}, + ) + for entity in entities + if isinstance(entity, Entity) and entity.name in entity_names + ] + + def create_legacy_data_points(): document = TextDocument( id=uuid5(NAMESPACE_OID, "text_test.txt"), @@ -56,51 +87,43 @@ def create_legacy_data_points(): document_chunk = DocumentChunk( id=uuid5( NAMESPACE_OID, - "Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", + "Apple announced their new vector embeddings visualization tool called Embedding Atlas.", ), - text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ", + text="Apple announced their new vector embeddings visualization tool called Embedding Atlas.", chunk_size=187, chunk_index=0, cut_type="paragraph_end", is_part_of=document, ) - graph_database = EntityType( - id=uuid5(NAMESPACE_OID, "graph_database"), - name="graph database", - description="graph database", + company = EntityType( + id=generate_node_id("Company"), + name=generate_node_name("Company"), + description=generate_node_name("Company"), ) - neptune_analytics_entity = Entity( - id=generate_node_id("neptune analytics"), - name="neptune analytics", - description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.", - is_a=graph_database, + apple = Entity( + id=generate_node_id("Apple"), + name=generate_node_name("Apple"), + description="Apple is a company", + is_a=company, ) - neptune_database_entity = Entity( - id=generate_node_id("amazon neptune database"), - name="amazon neptune database", - description="A popular managed graph database that complements Neptune Analytics.", - is_a=graph_database, + product = EntityType( + id=generate_node_id("Product"), + name=generate_node_name("Product"), + description=generate_node_name("Product"), ) - - storage = EntityType( - id=generate_node_id("storage"), - name="storage", - description="storage", - ) - storage_entity = Entity( - id=generate_node_id("amazon s3"), - name="amazon s3", - description="A storage service provided by Amazon Web Services that allows storing graph data.", - is_a=storage, + embedding_atlas = Entity( + id=generate_node_id("Embedding Atlas"), + name=generate_node_name("Embedding Atlas"), + description="Embedding Atlas", + is_a=product, ) entities = [ - graph_database, - neptune_analytics_entity, - neptune_database_entity, - storage, - storage_entity, + company, + product, + apple, + embedding_atlas, ] document_chunk.contains = entities @@ -143,7 +166,7 @@ async def main(mock_create_structured_output: AsyncMock): assert not await vector_engine.has_collection("TextDocument_text") # Add legacy data to the system - __, legacy_data_points, legacy_relationships = await create_mocked_legacy_data(user) + __, all_legacy_data_points, all_legacy_relationships = await create_mocked_legacy_data(user) def mock_llm_output(text_input: str, system_prompt: str, response_model): if text_input == "test": # LLM connection test @@ -262,132 +285,153 @@ async def main(mock_create_structured_output: AsyncMock): ) maries_summary = extract_summary(maries_chunk, mock_llm_output("Marie", "", SummarizedContent)) # type: ignore - johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore - maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore - (overlapping_entities, johns_entities, maries_entities) = filter_overlapping_entities( - johns_entities, maries_entities - ) + all_johns_entities = extract_entities(mock_llm_output("John", "", KnowledgeGraph)) # type: ignore + all_maries_entities = extract_entities(mock_llm_output("Marie", "", KnowledgeGraph)) # type: ignore - johns_data = [ + expected_johns_data = [ johns_document, johns_chunk, johns_summary, - *johns_entities, + *all_johns_entities, ] - maries_data = [ + expected_maries_data = [ maries_document, maries_chunk, maries_summary, - *maries_entities, + *all_maries_entities, ] - expected_data_points = johns_data + maries_data + overlapping_entities + legacy_data_points + expected_data_points = expected_johns_data + expected_maries_data + all_legacy_data_points # Assert data points presence in the graph, vector collections and nodes table await assert_graph_nodes_present(expected_data_points) await assert_nodes_vector_index_present(expected_data_points) - johns_relationships = extract_relationships( + all_johns_relationships = build_relationships( johns_chunk, + johns_document, + johns_summary, mock_llm_output("John", "", KnowledgeGraph), # type: ignore ) - maries_relationships = extract_relationships( + all_maries_relationships = build_relationships( maries_chunk, + maries_document, + maries_summary, mock_llm_output("Marie", "", KnowledgeGraph), # type: ignore ) - (overlapping_relationships, johns_relationships, maries_relationships, legacy_relationships) = ( - filter_overlapping_relationships( - johns_relationships, maries_relationships, legacy_relationships - ) - ) - - johns_relationships = [ - (johns_chunk.id, johns_document.id, "is_part_of"), - (johns_summary.id, johns_chunk.id, "made_from"), - *johns_relationships, - ] - maries_relationships = [ - (maries_chunk.id, maries_document.id, "is_part_of"), - (maries_summary.id, maries_chunk.id, "made_from"), - *maries_relationships, - ] expected_relationships = ( - johns_relationships - + maries_relationships - + overlapping_relationships - + legacy_relationships + all_johns_relationships + all_maries_relationships + all_legacy_relationships ) await assert_graph_edges_present(expected_relationships) - - await assert_edges_vector_index_present(expected_relationships) + await assert_relationships_vector_index_present( + all_johns_relationships + all_maries_relationships, all_legacy_relationships + ) # Delete John's data await datasets.delete_data(dataset_id, johns_data_id, user) # type: ignore - # Assert data points presence in the graph, vector collections and nodes table - await assert_graph_nodes_present(maries_data + overlapping_entities + legacy_data_points) - await assert_nodes_vector_index_present(maries_data + overlapping_entities + legacy_data_points) + expected_data_points = [ + maries_document, + maries_chunk, + maries_summary, + *all_maries_entities, + *all_legacy_data_points, + ] - await assert_graph_nodes_not_present(johns_data) - await assert_nodes_vector_index_not_present(johns_data) + # Assert data points presence in the graph, vector collections and nodes table + await assert_graph_nodes_present(expected_data_points) + await assert_nodes_vector_index_present(expected_data_points) + + (__, strictly_johns_entities, __, __) = filter_overlapping_entities( + all_johns_entities, all_maries_entities, all_legacy_data_points + ) + + not_expected_data_points = [ + johns_document, + johns_chunk, + johns_summary, + *strictly_johns_entities, + ] + + await assert_graph_nodes_not_present(not_expected_data_points) + await assert_nodes_vector_index_not_present(not_expected_data_points) # Assert relationships presence in the graph, vector collections and nodes table - await assert_graph_edges_present( - maries_relationships + overlapping_relationships + legacy_relationships + await assert_graph_edges_present(all_maries_relationships + all_legacy_relationships) + await assert_relationships_vector_index_present( + all_maries_relationships, all_legacy_relationships ) - await assert_edges_vector_index_present(maries_relationships + legacy_relationships) - await assert_graph_edges_not_present(johns_relationships) + (__, strictly_johns_relationships, __, __) = filter_overlapping_relationships( + all_johns_relationships, + all_maries_relationships, + all_legacy_relationships, + ) - johns_contains_relationships = [ - ( - johns_chunk.id, - entity.id, - get_contains_edge_text(entity.name, entity.description), - { - "relationship_name": get_contains_edge_text(entity.name, entity.description), - }, - ) - for entity in johns_entities - if isinstance(entity, Entity) - ] - # We check only by relationship name and we need edges that are created by John's data and no other. - await assert_edges_vector_index_not_present(johns_contains_relationships) + await assert_graph_edges_not_present(strictly_johns_relationships) + + # Check that John's unique contains relationships are not in vector index + not_expected_relationships = build_contains_relationships( + johns_chunk.id, + all_johns_entities, + [generate_node_name("John"), generate_node_name("Food for Hungry")], + ) + await assert_edges_vector_index_not_present(not_expected_relationships) # Delete Marie's data await datasets.delete_data(dataset_id, maries_data_id, user) # type: ignore # Assert data points presence in the graph, vector collections and nodes table - await assert_graph_nodes_present(legacy_data_points) - await assert_nodes_vector_index_present(legacy_data_points) + await assert_graph_nodes_present(all_legacy_data_points) + await assert_nodes_vector_index_present(all_legacy_data_points) - await assert_graph_nodes_not_present(johns_data + maries_data + overlapping_entities) - await assert_nodes_vector_index_not_present(johns_data + maries_data + overlapping_entities) - - # Assert relationships presence in the graph, vector collections and nodes table - await assert_graph_edges_present(legacy_relationships) - await assert_edges_vector_index_present(legacy_relationships) - - await assert_graph_edges_not_present( - johns_relationships + maries_relationships + overlapping_relationships + (__, strictly_johns_entities, strictly_maries_entities, __) = filter_overlapping_entities( + all_johns_entities, all_maries_entities, all_legacy_data_points ) - maries_contains_relationships = [ - ( - maries_chunk.id, - entity.id, - get_contains_edge_text(entity.name, entity.description), - { - "relationship_name": get_contains_edge_text(entity.name, entity.description), - }, - ) - for entity in maries_entities - if isinstance(entity, Entity) + not_expected_data_points = [ + johns_document, + johns_chunk, + johns_summary, + *strictly_johns_entities, + maries_document, + maries_chunk, + maries_summary, + *strictly_maries_entities, ] - # We check only by relationship name and we need edges that are created by legacy data and no other. - await assert_edges_vector_index_not_present(maries_contains_relationships) + + await assert_graph_nodes_not_present(not_expected_data_points) + await assert_nodes_vector_index_not_present(not_expected_data_points) + + # Assert relationships presence in the graph, vector collections and nodes table + await assert_graph_edges_present(all_legacy_relationships) + await assert_relationships_vector_index_present([], all_legacy_relationships) + + (__, strictly_johns_relationships, strictly_maries_relationships, __) = ( + filter_overlapping_relationships( + all_maries_relationships, + all_johns_relationships, + all_legacy_relationships, + ) + ) + + await assert_graph_edges_not_present( + strictly_johns_relationships + strictly_maries_relationships + ) + + # Check that John's and Marie's unique contains relationships are not in vector index + not_expected_relationships = build_contains_relationships( + johns_chunk.id, + all_johns_entities, + [generate_node_name("John"), generate_node_name("Food for Hungry")], + ) + build_contains_relationships( + maries_chunk.id, + all_maries_entities, + [generate_node_name("Marie"), generate_node_name("MacOS")], + ) + await assert_edges_vector_index_not_present(not_expected_relationships) async def create_mocked_legacy_data(user): @@ -423,31 +467,11 @@ async def create_mocked_legacy_data(user): await graph_engine.add_nodes(graph_nodes) await graph_engine.add_edges(graph_edges) - nodes_by_id = {node.id: node for node in graph_nodes} - - def format_relationship_name(relationship): - if relationship[2] == "contains": - node = nodes_by_id[relationship[1]] - return get_contains_edge_text(node.name, node.description) - return relationship[2] - await index_data_points(graph_nodes) - await index_graph_edges( - [ - ( - edge[0], - edge[1], - format_relationship_name(edge), - { - **(edge[3] or {}), - "relationship_name": format_relationship_name(edge), - }, - ) - for edge in graph_edges - ] # type: ignore - ) + # Legacy relationships should NOT be formatted - index them as-is + await index_graph_edges(graph_edges) - await record_data_in_legacy_ledger(graph_nodes, graph_edges, user) + await record_data_in_legacy_ledger(graph_nodes, graph_edges) db_engine = get_relational_engine() diff --git a/cognee/tests/test_delete_default_graph_with_legacy_data_2.py b/cognee/tests/test_delete_default_graph_with_legacy_data_2.py index 10b1d7596..281cb64d2 100644 --- a/cognee/tests/test_delete_default_graph_with_legacy_data_2.py +++ b/cognee/tests/test_delete_default_graph_with_legacy_data_2.py @@ -435,7 +435,7 @@ async def create_mocked_legacy_data(user): ] # type: ignore ) - await record_data_in_legacy_ledger(graph_nodes, graph_edges, user) + await record_data_in_legacy_ledger(graph_nodes, graph_edges) db_engine = get_relational_engine() diff --git a/cognee/tests/utils/assert_edges_vector_index_present.py b/cognee/tests/utils/assert_edges_vector_index_present.py index 31e5ea0d0..3744b7d63 100644 --- a/cognee/tests/utils/assert_edges_vector_index_present.py +++ b/cognee/tests/utils/assert_edges_vector_index_present.py @@ -21,7 +21,9 @@ def format_relationship(relationship: Tuple[UUID, UUID, str, Dict], node: Dict): return {str(generate_edge_id(relationship[2])): relationship[2]} -async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID, str, Dict]]): +async def assert_edges_vector_index_present( + relationships: List[Tuple[UUID, UUID, str, Dict]], convert_to_new_format: bool = True +): vector_engine = get_vector_engine() graph_engine = await get_graph_engine() @@ -33,7 +35,11 @@ async def assert_edges_vector_index_present(relationships: List[Tuple[UUID, UUID for relationship in relationships: query_edge_ids = { **query_edge_ids, - **format_relationship(relationship, nodes_by_id[str(relationship[1])]), + **( + format_relationship(relationship, nodes_by_id[str(relationship[1])]) + if convert_to_new_format + else {str(generate_edge_id(relationship[2])): relationship[2]} + ), } vector_items = await vector_engine.retrieve( diff --git a/cognee/tests/utils/assert_graph_edges_not_present.py b/cognee/tests/utils/assert_graph_edges_not_present.py index 03e22ae62..bdb26b59b 100644 --- a/cognee/tests/utils/assert_graph_edges_not_present.py +++ b/cognee/tests/utils/assert_graph_edges_not_present.py @@ -1,11 +1,10 @@ from uuid import UUID -from typing import List, Tuple +from typing import Dict, List, Tuple from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.engine.models.DataPoint import DataPoint -async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str]]): +async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, str, Dict]]): graph_engine = await get_graph_engine() nodes, edges = await graph_engine.get_graph_data() @@ -15,7 +14,11 @@ async def assert_graph_edges_not_present(relationships: List[Tuple[UUID, UUID, s for relationship in relationships: relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}" - relationship_name = relationship[2] - assert relationship_id not in edge_ids, ( - f"Edge '{relationship_name}' still present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database." - ) + + if relationship_id in edge_ids: + relationship_name = relationship[2] + source_node = nodes_by_id[str(relationship[0])] + destination_node = nodes_by_id[str(relationship[1])] + assert False, ( + f"Edge '{relationship_name}' still present between '{source_node['name'] if 'node' in source_node else source_node['id']}' and '{destination_node['name'] if 'node' in destination_node else destination_node['id']}' in graph database." + ) diff --git a/cognee/tests/utils/assert_graph_edges_present.py b/cognee/tests/utils/assert_graph_edges_present.py index ae14c75cf..ea4af5ca2 100644 --- a/cognee/tests/utils/assert_graph_edges_present.py +++ b/cognee/tests/utils/assert_graph_edges_present.py @@ -1,11 +1,10 @@ from uuid import UUID -from typing import List, Tuple +from typing import Dict, List, Tuple from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.engine.models.DataPoint import DataPoint -async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]]): +async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str, Dict]]): graph_engine = await get_graph_engine() nodes, edges = await graph_engine.get_graph_data() @@ -16,6 +15,10 @@ async def assert_graph_edges_present(relationships: List[Tuple[UUID, UUID, str]] for relationship in relationships: relationship_id = f"{str(relationship[0])}_{relationship[2]}_{str(relationship[1])}" relationship_name = relationship[2] + source_node = nodes_by_id.get(str(relationship[0]), {}) + target_node = nodes_by_id.get(str(relationship[1]), {}) + source_name = source_node.get("name") or source_node.get("text") or str(relationship[0]) + target_name = target_node.get("name") or target_node.get("text") or str(relationship[1]) assert relationship_id in edge_ids, ( - f"Edge '{relationship_name}' not present between '{nodes_by_id[str(relationship[0])]['name']}' and '{nodes_by_id[str(relationship[1])]['name']}' in graph database." + f"Edge '{relationship_name}' not present between '{source_name}' and '{target_name}' in graph database." )