Edited test_relation_db_migration.py to include schema_only ingestion testcase

This commit is contained in:
Geoff-Robin 2025-09-20 11:05:39 +05:30 committed by Igor Ilic
parent 67f948a145
commit 656894370e

View file

@ -197,6 +197,79 @@ async def relational_db_migration():
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
async def test_schema_only_migration():
# 1. Setup test DB and extract schema
migration_engine = await setup_test_db()
schema = await migration_engine.extract_schema()
# 2. Setup graph engine
graph_engine = await get_graph_engine()
# 4. Migrate schema only
await migrate_relational_database(graph_engine, schema=schema, schema_only=True)
# 5. Verify number of tables through search
search_results = await cognee.search(
query_text="How many tables are there in this database",
query_type=cognee.SearchType.GRAPH_COMPLETION,
)
assert any("11" in r for r in search_results), (
"Number of tables in the database reported in search_results is either None or not equal to 11"
)
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
edge_counts = {
"is_part_of": 0,
"has_relationship": 0,
"foreign_key": 0,
}
if graph_db_provider == "neo4j":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:{rel_type}]->()
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0]["c"]
elif graph_db_provider == "kuzu":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:EDGE]->()
WHERE r.relationship_name = '{rel_type}'
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
for _, _, key, _ in edges:
if key in edge_counts:
edge_counts[key] += 1
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
# 7. Assert counts match expected values
expected_counts = {
"is_part_of": 11,
"has_relationship": 22,
"foreign_key": 11,
}
for rel_type, expected in expected_counts.items():
actual = edge_counts[rel_type]
assert actual == expected, (
f"Expected {expected} edges for relationship '{rel_type}', but found {actual}"
)
print("Schema-only migration edge counts validated successfully!")
print(f"Edge counts: {edge_counts}")
async def test_migration_sqlite():
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
@ -209,6 +282,7 @@ async def test_migration_sqlite():
)
await relational_db_migration()
await test_schema_only_migration()
async def test_migration_postgres():
@ -224,6 +298,7 @@ async def test_migration_postgres():
}
)
await relational_db_migration()
await test_schema_only_migration()
async def main():