Edited test_relation_db_migration.py to include schema_only ingestion testcase
This commit is contained in:
parent
67f948a145
commit
656894370e
1 changed files with 75 additions and 0 deletions
|
|
@ -197,6 +197,79 @@ async def relational_db_migration():
|
|||
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
|
||||
|
||||
|
||||
async def test_schema_only_migration():
|
||||
# 1. Setup test DB and extract schema
|
||||
migration_engine = await setup_test_db()
|
||||
schema = await migration_engine.extract_schema()
|
||||
|
||||
# 2. Setup graph engine
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
# 4. Migrate schema only
|
||||
await migrate_relational_database(graph_engine, schema=schema, schema_only=True)
|
||||
|
||||
# 5. Verify number of tables through search
|
||||
search_results = await cognee.search(
|
||||
query_text="How many tables are there in this database",
|
||||
query_type=cognee.SearchType.GRAPH_COMPLETION,
|
||||
)
|
||||
assert any("11" in r for r in search_results), (
|
||||
"Number of tables in the database reported in search_results is either None or not equal to 11"
|
||||
)
|
||||
|
||||
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
|
||||
|
||||
edge_counts = {
|
||||
"is_part_of": 0,
|
||||
"has_relationship": 0,
|
||||
"foreign_key": 0,
|
||||
}
|
||||
|
||||
if graph_db_provider == "neo4j":
|
||||
for rel_type in edge_counts.keys():
|
||||
query_str = f"""
|
||||
MATCH ()-[r:{rel_type}]->()
|
||||
RETURN count(r) as c
|
||||
"""
|
||||
rows = await graph_engine.query(query_str)
|
||||
edge_counts[rel_type] = rows[0]["c"]
|
||||
|
||||
elif graph_db_provider == "kuzu":
|
||||
for rel_type in edge_counts.keys():
|
||||
query_str = f"""
|
||||
MATCH ()-[r:EDGE]->()
|
||||
WHERE r.relationship_name = '{rel_type}'
|
||||
RETURN count(r) as c
|
||||
"""
|
||||
rows = await graph_engine.query(query_str)
|
||||
edge_counts[rel_type] = rows[0][0]
|
||||
|
||||
elif graph_db_provider == "networkx":
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
for _, _, key, _ in edges:
|
||||
if key in edge_counts:
|
||||
edge_counts[key] += 1
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
|
||||
|
||||
# 7. Assert counts match expected values
|
||||
expected_counts = {
|
||||
"is_part_of": 11,
|
||||
"has_relationship": 22,
|
||||
"foreign_key": 11,
|
||||
}
|
||||
|
||||
for rel_type, expected in expected_counts.items():
|
||||
actual = edge_counts[rel_type]
|
||||
assert actual == expected, (
|
||||
f"Expected {expected} edges for relationship '{rel_type}', but found {actual}"
|
||||
)
|
||||
|
||||
print("Schema-only migration edge counts validated successfully!")
|
||||
print(f"Edge counts: {edge_counts}")
|
||||
|
||||
|
||||
async def test_migration_sqlite():
|
||||
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
|
||||
|
||||
|
|
@ -209,6 +282,7 @@ async def test_migration_sqlite():
|
|||
)
|
||||
|
||||
await relational_db_migration()
|
||||
await test_schema_only_migration()
|
||||
|
||||
|
||||
async def test_migration_postgres():
|
||||
|
|
@ -224,6 +298,7 @@ async def test_migration_postgres():
|
|||
}
|
||||
)
|
||||
await relational_db_migration()
|
||||
await test_schema_only_migration()
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue