cognee/cognee/tests/test_relational_db_migration.py
Igor Ilic f9f18d1b0c
feat: Add columns as nodes in relational db migration (#826)
<!-- .github/pull_request_template.md -->

## Description
Add ability to map column values from relational databases to graph

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-05-15 07:31:31 -04:00

239 lines
8.5 KiB
Python

import json
import pathlib
import os
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
)
from cognee.tasks.ingestion import migrate_relational_database
from cognee.modules.search.types import SearchType
import cognee
def nodes_dict(nodes):
return {n_id: data for (n_id, data) in nodes}
def normalize_node_name(node_name: str) -> str:
if node_name and ":" in node_name:
prefix, suffix = node_name.split(":", 1)
prefix = prefix.capitalize()
return f"{prefix}:{suffix}"
return node_name
async def setup_test_db():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
migration_engine = get_migration_relational_engine()
return migration_engine
async def relational_db_migration():
migration_engine = await setup_test_db()
schema = await migration_engine.extract_schema()
graph_engine = await get_graph_engine()
await migrate_relational_database(graph_engine, schema=schema)
# 1. Search the graph
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
)
print("Search results:", search_results)
# 2. Assert that the search results contain "AC/DC"
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
migration_db_provider = migration_engine.engine.dialect.name
if migration_db_provider == "postgresql":
relationship_label = "reports_to"
else:
relationship_label = "ReportsTo"
# 3. Directly verify the 'reports to' hierarchy
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
distinct_node_names = set()
found_edges = set()
if graph_db_provider == "neo4j":
query_str = f"""
MATCH (n)-[r:{relationship_label}]->(m)
RETURN n, r, m
"""
rows = await graph_engine.query(query_str)
for row in rows:
n_data = row["n"]
m_data = row["m"]
source_name = normalize_node_name(n_data.get("name", ""))
target_name = normalize_node_name(m_data.get("name", ""))
found_edges.add((source_name, target_name))
distinct_node_names.update([source_name, target_name])
elif graph_db_provider == "kuzu":
query_str = f"""
MATCH (n:Node)-[r:EDGE]->(m:Node)
WHERE r.relationship_name = '{relationship_label}'
RETURN r, n, m
"""
rows = await graph_engine.query(query_str)
for row in rows:
n_data = row[1]
m_data = row[2]
source_name = normalize_node_name(n_data.get("name", ""))
target_name = normalize_node_name(m_data.get("name", ""))
if source_name and target_name:
found_edges.add((source_name, target_name))
distinct_node_names.update([source_name, target_name])
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
node_map = nodes_dict(nodes)
for src, tgt, key, edge_data in edges:
if key == relationship_label:
src_name = normalize_node_name(node_map[src].get("name"))
tgt_name = normalize_node_name(node_map[tgt].get("name"))
if src_name and tgt_name:
found_edges.add((src_name, tgt_name))
distinct_node_names.update([src_name, tgt_name])
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
assert len(distinct_node_names) == 12, (
f"Expected 12 distinct node references, found {len(distinct_node_names)}"
)
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
expected_edges = {
("Employee:5", "Employee:2"),
("Employee:2", "Employee:1"),
("Employee:4", "Employee:2"),
("Employee:6", "Employee:1"),
("Employee:8", "Employee:6"),
("Employee:7", "Employee:6"),
("Employee:3", "Employee:2"),
}
for e in expected_edges:
assert e in found_edges, f"Edge {e} not found in the actual '{relationship_label}' edges!"
# 4. Verify the total number of nodes and edges in the graph
if migration_db_provider == "sqlite":
if graph_db_provider == "neo4j":
query_str = """
MATCH (n)
WITH count(n) AS node_count
MATCH ()-[r]->()
RETURN node_count, count(r) AS edge_count
"""
rows = await graph_engine.query(query_str)
node_count = rows[0]["node_count"]
edge_count = rows[0]["edge_count"]
elif graph_db_provider == "kuzu":
query_nodes = "MATCH (n:Node) RETURN count(n) as c"
rows_n = await graph_engine.query(query_nodes)
node_count = rows_n[0][0]
query_edges = "MATCH (n:Node)-[r:EDGE]->(m:Node) RETURN count(r) as c"
rows_e = await graph_engine.query(query_edges)
edge_count = rows_e[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
node_count = len(nodes)
edge_count = len(edges)
# NOTE: Because of the different size of the postgres and sqlite databases,
# different number of nodes and edges are expected
assert node_count == 543, f"Expected 543 nodes, got {node_count}"
assert edge_count == 1317, f"Expected 1317 edges, got {edge_count}"
elif migration_db_provider == "postgresql":
if graph_db_provider == "neo4j":
query_str = """
MATCH (n)
WITH count(n) AS node_count
MATCH ()-[r]->()
RETURN node_count, count(r) AS edge_count
"""
rows = await graph_engine.query(query_str)
node_count = rows[0]["node_count"]
edge_count = rows[0]["edge_count"]
elif graph_db_provider == "kuzu":
query_nodes = "MATCH (n:Node) RETURN count(n) as c"
rows_n = await graph_engine.query(query_nodes)
node_count = rows_n[0][0]
query_edges = "MATCH (n:Node)-[r:EDGE]->(m:Node) RETURN count(r) as c"
rows_e = await graph_engine.query(query_edges)
edge_count = rows_e[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
node_count = len(nodes)
edge_count = len(edges)
# NOTE: Because of the different size of the postgres and sqlite databases,
# different number of nodes and edges are expected
assert node_count == 522, f"Expected 522 nodes, got {node_count}"
assert edge_count == 961, f"Expected 961 edges, got {edge_count}"
print(f"Node & edge count validated: node_count={node_count}, edge_count={edge_count}.")
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
async def test_migration_sqlite():
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
cognee.config.set_migration_db_config(
{
"migration_db_path": database_to_migrate_path,
"migration_db_name": "migration_database.sqlite",
"migration_db_provider": "sqlite",
}
)
await relational_db_migration()
async def test_migration_postgres():
# To run test manually you first need to run the Chinook_PostgreSql.sql script in the test_data directory
cognee.config.set_migration_db_config(
{
"migration_db_name": "test_migration_db",
"migration_db_host": "127.0.0.1",
"migration_db_port": "5432",
"migration_db_username": "cognee",
"migration_db_password": "cognee",
"migration_db_provider": "postgres",
}
)
await relational_db_migration()
async def main():
print("Starting SQLite database migration test...")
await test_migration_sqlite()
print("Starting PostgreSQL database migration test...")
await test_migration_postgres()
if __name__ == "__main__":
import asyncio
asyncio.run(main())