test: relational db migration (#695)
<!-- .github/pull_request_template.md --> ## Description test for database migration ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
parent
534e7fb22c
commit
0ca8ef2448
3 changed files with 16176 additions and 0 deletions
15970
cognee/tests/test_data/Chinook_PostgreSql.sql
Normal file
15970
cognee/tests/test_data/Chinook_PostgreSql.sql
Normal file
File diff suppressed because it is too large
Load diff
BIN
cognee/tests/test_data/migration_database.sqlite
Normal file
BIN
cognee/tests/test_data/migration_database.sqlite
Normal file
Binary file not shown.
206
cognee/tests/test_relational_db_migration.py
Normal file
206
cognee/tests/test_relational_db_migration.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
import json
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import os
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
get_migration_relational_engine,
|
||||
create_db_and_tables as create_relational_db_and_tables,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.pgvector import (
|
||||
create_db_and_tables as create_pgvector_db_and_tables,
|
||||
)
|
||||
from cognee.tasks.ingestion import migrate_relational_database
|
||||
from cognee.modules.search.types import SearchType
|
||||
import cognee
|
||||
|
||||
|
||||
def nodes_dict(nodes):
|
||||
return {n_id: data for (n_id, data) in nodes}
|
||||
|
||||
|
||||
def normalize_node_name(node_name: str) -> str:
|
||||
if node_name and ":" in node_name:
|
||||
prefix, suffix = node_name.split(":", 1)
|
||||
prefix = prefix.capitalize()
|
||||
return f"{prefix}:{suffix}"
|
||||
return node_name
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def setup_test_db():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
|
||||
relational_engine = get_migration_relational_engine()
|
||||
return relational_engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relational_db_migration(setup_test_db):
|
||||
relational_engine = setup_test_db
|
||||
schema = await relational_engine.extract_schema()
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
await migrate_relational_database(graph_engine, schema=schema)
|
||||
|
||||
# 1. Search the graph
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
|
||||
)
|
||||
print("Search results:", search_results)
|
||||
|
||||
# 2. Assert that the search results contain "AC/DC"
|
||||
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
|
||||
|
||||
relational_db_provider = os.getenv("MIGRATION_DB_PROVIDER", "sqlite").lower()
|
||||
if relational_db_provider == "postgres":
|
||||
relationship_label = "reports_to"
|
||||
else:
|
||||
relationship_label = "ReportsTo"
|
||||
|
||||
# 3. Directly verify the 'reports to' hierarchy
|
||||
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
|
||||
|
||||
distinct_node_names = set()
|
||||
found_edges = set()
|
||||
|
||||
if graph_db_provider == "neo4j":
|
||||
query_str = f"""
|
||||
MATCH (n)-[r:{relationship_label}]->(m)
|
||||
RETURN n, r, m
|
||||
"""
|
||||
rows = await graph_engine.query(query_str)
|
||||
for row in rows:
|
||||
n_data = row["n"]
|
||||
m_data = row["m"]
|
||||
|
||||
source_name = normalize_node_name(n_data.get("name", ""))
|
||||
target_name = normalize_node_name(m_data.get("name", ""))
|
||||
|
||||
found_edges.add((source_name, target_name))
|
||||
distinct_node_names.update([source_name, target_name])
|
||||
|
||||
elif graph_db_provider == "kuzu":
|
||||
query_str = f"""
|
||||
MATCH (n:Node)-[r:EDGE]->(m:Node)
|
||||
WHERE r.relationship_name = '{relationship_label}'
|
||||
RETURN r, n, m
|
||||
"""
|
||||
rows = await graph_engine.query(query_str)
|
||||
for row in rows:
|
||||
n_data = row[1]
|
||||
m_data = row[2]
|
||||
|
||||
source_props = {}
|
||||
if "properties" in n_data and n_data["properties"]:
|
||||
source_props = json.loads(n_data["properties"])
|
||||
target_props = {}
|
||||
if "properties" in m_data and m_data["properties"]:
|
||||
target_props = json.loads(m_data["properties"])
|
||||
|
||||
source_name = normalize_node_name(source_props.get("name", f"id:{n_data['id']}"))
|
||||
target_name = normalize_node_name(target_props.get("name", f"id:{m_data['id']}"))
|
||||
|
||||
found_edges.add((source_name, target_name))
|
||||
distinct_node_names.update([source_name, target_name])
|
||||
|
||||
elif graph_db_provider == "networkx":
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
node_map = nodes_dict(nodes)
|
||||
for src, tgt, key, edge_data in edges:
|
||||
if key == relationship_label:
|
||||
src_name = normalize_node_name(node_map[src].get("name"))
|
||||
tgt_name = normalize_node_name(node_map[tgt].get("name"))
|
||||
if src_name and tgt_name:
|
||||
found_edges.add((src_name, tgt_name))
|
||||
distinct_node_names.update([src_name, tgt_name])
|
||||
else:
|
||||
pytest.fail(f"Unsupported graph database provider: {graph_db_provider}")
|
||||
|
||||
assert len(distinct_node_names) == 8, (
|
||||
f"Expected 8 distinct node references, found {len(distinct_node_names)}"
|
||||
)
|
||||
assert len(found_edges) == 7, f"Expected 7 {relationship_label} edges, got {len(found_edges)}"
|
||||
|
||||
expected_edges = {
|
||||
("Employee:5", "Employee:2"),
|
||||
("Employee:2", "Employee:1"),
|
||||
("Employee:4", "Employee:2"),
|
||||
("Employee:6", "Employee:1"),
|
||||
("Employee:8", "Employee:6"),
|
||||
("Employee:7", "Employee:6"),
|
||||
("Employee:3", "Employee:2"),
|
||||
}
|
||||
for e in expected_edges:
|
||||
assert e in found_edges, f"Edge {e} not found in the actual '{relationship_label}' edges!"
|
||||
|
||||
# 4. Verify the total number of nodes and edges in the graph
|
||||
if relational_db_provider == "sqlite":
|
||||
if graph_db_provider == "neo4j":
|
||||
query_str = """
|
||||
MATCH (n)
|
||||
WITH count(n) AS node_count
|
||||
MATCH ()-[r]->()
|
||||
RETURN node_count, count(r) AS edge_count
|
||||
"""
|
||||
rows = await graph_engine.query(query_str)
|
||||
node_count = rows[0]["node_count"]
|
||||
edge_count = rows[0]["edge_count"]
|
||||
|
||||
elif graph_db_provider == "kuzu":
|
||||
query_nodes = "MATCH (n:Node) RETURN count(n) as c"
|
||||
rows_n = await graph_engine.query(query_nodes)
|
||||
node_count = rows_n[0]["c"]
|
||||
|
||||
query_edges = "MATCH (n:Node)-[r:EDGE]->(m:Node) RETURN count(r) as c"
|
||||
rows_e = await graph_engine.query(query_edges)
|
||||
edge_count = rows_e[0]["c"]
|
||||
|
||||
elif graph_db_provider == "networkx":
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
node_count = len(nodes)
|
||||
edge_count = len(edges)
|
||||
|
||||
# NOTE: Because of the different size of the postgres and sqlite databases,
|
||||
# different number of nodes and edges are expected
|
||||
assert node_count == 227, f"Expected 227 nodes, got {node_count}"
|
||||
assert edge_count == 580, f"Expected 580 edges, got {edge_count}"
|
||||
|
||||
elif relational_db_provider == "postgres":
|
||||
if graph_db_provider == "neo4j":
|
||||
query_str = """
|
||||
MATCH (n)
|
||||
WITH count(n) AS node_count
|
||||
MATCH ()-[r]->()
|
||||
RETURN node_count, count(r) AS edge_count
|
||||
"""
|
||||
rows = await graph_engine.query(query_str)
|
||||
node_count = rows[0]["node_count"]
|
||||
edge_count = rows[0]["edge_count"]
|
||||
|
||||
elif graph_db_provider == "kuzu":
|
||||
query_nodes = "MATCH (n:Node) RETURN count(n) as c"
|
||||
rows_n = await graph_engine.query(query_nodes)
|
||||
node_count = rows_n[0]["c"]
|
||||
|
||||
query_edges = "MATCH (n:Node)-[r:EDGE]->(m:Node) RETURN count(r) as c"
|
||||
rows_e = await graph_engine.query(query_edges)
|
||||
edge_count = rows_e[0]["c"]
|
||||
|
||||
elif graph_db_provider == "networkx":
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
node_count = len(nodes)
|
||||
edge_count = len(edges)
|
||||
|
||||
# NOTE: Because of the different size of the postgres and sqlite databases,
|
||||
# different number of nodes and edges are expected
|
||||
assert node_count == 115, f"Expected 115 nodes, got {node_count}"
|
||||
assert edge_count == 356, f"Expected 356 edges, got {edge_count}"
|
||||
|
||||
print(f"Node & edge count validated: node_count={node_count}, edge_count={edge_count}.")
|
||||
|
||||
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
|
||||
Loading…
Add table
Reference in a new issue