Edited test_relation_db_migration.py to include schema_only ingestion testcase
This commit is contained in:
parent
67f948a145
commit
656894370e
1 changed files with 75 additions and 0 deletions
|
|
@ -197,6 +197,79 @@ async def relational_db_migration():
|
||||||
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
|
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_schema_only_migration():
|
||||||
|
# 1. Setup test DB and extract schema
|
||||||
|
migration_engine = await setup_test_db()
|
||||||
|
schema = await migration_engine.extract_schema()
|
||||||
|
|
||||||
|
# 2. Setup graph engine
|
||||||
|
graph_engine = await get_graph_engine()
|
||||||
|
|
||||||
|
# 4. Migrate schema only
|
||||||
|
await migrate_relational_database(graph_engine, schema=schema, schema_only=True)
|
||||||
|
|
||||||
|
# 5. Verify number of tables through search
|
||||||
|
search_results = await cognee.search(
|
||||||
|
query_text="How many tables are there in this database",
|
||||||
|
query_type=cognee.SearchType.GRAPH_COMPLETION,
|
||||||
|
)
|
||||||
|
assert any("11" in r for r in search_results), (
|
||||||
|
"Number of tables in the database reported in search_results is either None or not equal to 11"
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
|
||||||
|
|
||||||
|
edge_counts = {
|
||||||
|
"is_part_of": 0,
|
||||||
|
"has_relationship": 0,
|
||||||
|
"foreign_key": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if graph_db_provider == "neo4j":
|
||||||
|
for rel_type in edge_counts.keys():
|
||||||
|
query_str = f"""
|
||||||
|
MATCH ()-[r:{rel_type}]->()
|
||||||
|
RETURN count(r) as c
|
||||||
|
"""
|
||||||
|
rows = await graph_engine.query(query_str)
|
||||||
|
edge_counts[rel_type] = rows[0]["c"]
|
||||||
|
|
||||||
|
elif graph_db_provider == "kuzu":
|
||||||
|
for rel_type in edge_counts.keys():
|
||||||
|
query_str = f"""
|
||||||
|
MATCH ()-[r:EDGE]->()
|
||||||
|
WHERE r.relationship_name = '{rel_type}'
|
||||||
|
RETURN count(r) as c
|
||||||
|
"""
|
||||||
|
rows = await graph_engine.query(query_str)
|
||||||
|
edge_counts[rel_type] = rows[0][0]
|
||||||
|
|
||||||
|
elif graph_db_provider == "networkx":
|
||||||
|
nodes, edges = await graph_engine.get_graph_data()
|
||||||
|
for _, _, key, _ in edges:
|
||||||
|
if key in edge_counts:
|
||||||
|
edge_counts[key] += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||||
|
|
||||||
|
# 7. Assert counts match expected values
|
||||||
|
expected_counts = {
|
||||||
|
"is_part_of": 11,
|
||||||
|
"has_relationship": 22,
|
||||||
|
"foreign_key": 11,
|
||||||
|
}
|
||||||
|
|
||||||
|
for rel_type, expected in expected_counts.items():
|
||||||
|
actual = edge_counts[rel_type]
|
||||||
|
assert actual == expected, (
|
||||||
|
f"Expected {expected} edges for relationship '{rel_type}', but found {actual}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Schema-only migration edge counts validated successfully!")
|
||||||
|
print(f"Edge counts: {edge_counts}")
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_sqlite():
|
async def test_migration_sqlite():
|
||||||
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
|
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
|
||||||
|
|
||||||
|
|
@ -209,6 +282,7 @@ async def test_migration_sqlite():
|
||||||
)
|
)
|
||||||
|
|
||||||
await relational_db_migration()
|
await relational_db_migration()
|
||||||
|
await test_schema_only_migration()
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_postgres():
|
async def test_migration_postgres():
|
||||||
|
|
@ -224,6 +298,7 @@ async def test_migration_postgres():
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await relational_db_migration()
|
await relational_db_migration()
|
||||||
|
await test_schema_only_migration()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue