diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 68b46dbf5..cb360f1c2 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -197,6 +197,79 @@ async def relational_db_migration(): print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!") +async def test_schema_only_migration(): + # 1. Setup test DB and extract schema + migration_engine = await setup_test_db() + schema = await migration_engine.extract_schema() + + # 2. Setup graph engine + graph_engine = await get_graph_engine() + + # 4. Migrate schema only + await migrate_relational_database(graph_engine, schema=schema, schema_only=True) + + # 5. Verify number of tables through search + search_results = await cognee.search( + query_text="How many tables are there in this database", + query_type=cognee.SearchType.GRAPH_COMPLETION, + ) + assert any("11" in r for r in search_results), ( + "Number of tables in the database reported in search_results is either None or not equal to 11" + ) + + graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower() + + edge_counts = { + "is_part_of": 0, + "has_relationship": 0, + "foreign_key": 0, + } + + if graph_db_provider == "neo4j": + for rel_type in edge_counts.keys(): + query_str = f""" + MATCH ()-[r:{rel_type}]->() + RETURN count(r) as c + """ + rows = await graph_engine.query(query_str) + edge_counts[rel_type] = rows[0]["c"] + + elif graph_db_provider == "kuzu": + for rel_type in edge_counts.keys(): + query_str = f""" + MATCH ()-[r:EDGE]->() + WHERE r.relationship_name = '{rel_type}' + RETURN count(r) as c + """ + rows = await graph_engine.query(query_str) + edge_counts[rel_type] = rows[0][0] + + elif graph_db_provider == "networkx": + nodes, edges = await graph_engine.get_graph_data() + for _, _, key, _ in edges: + if key in edge_counts: + edge_counts[key] += 1 + + else: + raise ValueError(f"Unsupported graph database provider: {graph_db_provider}") + + # 7. Assert counts match expected values + expected_counts = { + "is_part_of": 11, + "has_relationship": 22, + "foreign_key": 11, + } + + for rel_type, expected in expected_counts.items(): + actual = edge_counts[rel_type] + assert actual == expected, ( + f"Expected {expected} edges for relationship '{rel_type}', but found {actual}" + ) + + print("Schema-only migration edge counts validated successfully!") + print(f"Edge counts: {edge_counts}") + + async def test_migration_sqlite(): database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/") @@ -209,6 +282,7 @@ async def test_migration_sqlite(): ) await relational_db_migration() + await test_schema_only_migration() async def test_migration_postgres(): @@ -224,6 +298,7 @@ async def test_migration_postgres(): } ) await relational_db_migration() + await test_schema_only_migration() async def main():