cognee/cognee/tests/test_relational_db_migration.py
Igor Ilic a5bd504daa
Relational DB migration test search (#1752)
<!-- .github/pull_request_template.md -->

## Description
Add deterministic Cognee search test after rel DB migration. 
Test gathers all relevant relationships regarding Customers and their
Invoices from relational DB that was migrated and then tries to get the
same results with Cognee search.

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-11-12 21:32:22 +01:00

367 lines
13 KiB
Python

import pathlib
import os
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
create_db_and_tables as create_relational_db_and_tables,
)
from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
)
from cognee.tasks.ingestion import migrate_relational_database
from cognee.modules.search.types import SearchType
import cognee
def nodes_dict(nodes):
return {n_id: data for (n_id, data) in nodes}
def normalize_node_name(node_name: str) -> str:
if node_name and ":" in node_name:
prefix, suffix = node_name.split(":", 1)
prefix = prefix.capitalize()
return f"{prefix}:{suffix}"
return node_name
async def setup_test_db():
# Disable backend access control to migrate relational data
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
migration_engine = get_migration_relational_engine()
return migration_engine
async def relational_db_migration():
migration_engine = await setup_test_db()
schema = await migration_engine.extract_schema()
graph_engine = await get_graph_engine()
await migrate_relational_database(graph_engine, schema=schema)
# 1. Search the graph
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
)
print("Search results:", search_results)
# 2. Assert that the search results contain "AC/DC"
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
migration_db_provider = migration_engine.engine.dialect.name
if migration_db_provider == "postgresql":
relationship_label = "reports_to"
else:
relationship_label = "ReportsTo"
# 3. Directly verify the 'reports to' hierarchy
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
distinct_node_names = set()
found_edges = set()
if graph_db_provider == "neo4j":
query_str = f"""
MATCH (n)-[r:{relationship_label}]->(m)
RETURN n, r, m
"""
rows = await graph_engine.query(query_str)
for row in rows:
n_data = row["n"]
m_data = row["m"]
source_name = normalize_node_name(n_data.get("name", ""))
target_name = normalize_node_name(m_data.get("name", ""))
found_edges.add((source_name, target_name))
distinct_node_names.update([source_name, target_name])
elif graph_db_provider == "kuzu":
query_str = f"""
MATCH (n:Node)-[r:EDGE]->(m:Node)
WHERE r.relationship_name = '{relationship_label}'
RETURN r, n, m
"""
rows = await graph_engine.query(query_str)
for row in rows:
n_data = row[1]
m_data = row[2]
source_name = normalize_node_name(n_data.get("name", ""))
target_name = normalize_node_name(m_data.get("name", ""))
if source_name and target_name:
found_edges.add((source_name, target_name))
distinct_node_names.update([source_name, target_name])
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
node_map = nodes_dict(nodes)
for src, tgt, key, edge_data in edges:
if key == relationship_label:
src_name = normalize_node_name(node_map[src].get("name"))
tgt_name = normalize_node_name(node_map[tgt].get("name"))
if src_name and tgt_name:
found_edges.add((src_name, tgt_name))
distinct_node_names.update([src_name, tgt_name])
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
assert len(distinct_node_names) == 12, (
f"Expected 12 distinct node references, found {len(distinct_node_names)}"
)
assert len(found_edges) == 15, f"Expected 15 {relationship_label} edges, got {len(found_edges)}"
expected_edges = {
("Employee:5", "Employee:2"),
("Employee:2", "Employee:1"),
("Employee:4", "Employee:2"),
("Employee:6", "Employee:1"),
("Employee:8", "Employee:6"),
("Employee:7", "Employee:6"),
("Employee:3", "Employee:2"),
}
for e in expected_edges:
assert e in found_edges, f"Edge {e} not found in the actual '{relationship_label}' edges!"
# 4. Verify the total number of nodes and edges in the graph
if migration_db_provider == "sqlite":
if graph_db_provider == "neo4j":
query_str = """
MATCH (n)
WITH count(n) AS node_count
MATCH ()-[r]->()
RETURN node_count, count(r) AS edge_count
"""
rows = await graph_engine.query(query_str)
node_count = rows[0]["node_count"]
edge_count = rows[0]["edge_count"]
elif graph_db_provider == "kuzu":
query_nodes = "MATCH (n:Node) RETURN count(n) as c"
rows_n = await graph_engine.query(query_nodes)
node_count = rows_n[0][0]
query_edges = "MATCH (n:Node)-[r:EDGE]->(m:Node) RETURN count(r) as c"
rows_e = await graph_engine.query(query_edges)
edge_count = rows_e[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
node_count = len(nodes)
edge_count = len(edges)
# NOTE: Because of the different size of the postgres and sqlite databases,
# different number of nodes and edges are expected
assert node_count == 543, f"Expected 543 nodes, got {node_count}"
assert edge_count == 1317, f"Expected 1317 edges, got {edge_count}"
elif migration_db_provider == "postgresql":
if graph_db_provider == "neo4j":
query_str = """
MATCH (n)
WITH count(n) AS node_count
MATCH ()-[r]->()
RETURN node_count, count(r) AS edge_count
"""
rows = await graph_engine.query(query_str)
node_count = rows[0]["node_count"]
edge_count = rows[0]["edge_count"]
elif graph_db_provider == "kuzu":
query_nodes = "MATCH (n:Node) RETURN count(n) as c"
rows_n = await graph_engine.query(query_nodes)
node_count = rows_n[0][0]
query_edges = "MATCH (n:Node)-[r:EDGE]->(m:Node) RETURN count(r) as c"
rows_e = await graph_engine.query(query_edges)
edge_count = rows_e[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
node_count = len(nodes)
edge_count = len(edges)
# NOTE: Because of the different size of the postgres and sqlite databases,
# different number of nodes and edges are expected
assert node_count == 522, f"Expected 522 nodes, got {node_count}"
assert edge_count == 961, f"Expected 961 edges, got {edge_count}"
print(f"Node & edge count validated: node_count={node_count}, edge_count={edge_count}.")
print(f"All checks passed for {graph_db_provider} provider with '{relationship_label}' edges!")
async def test_schema_only_migration():
# 1. Setup test DB and extract schema
migration_engine = await setup_test_db()
schema = await migration_engine.extract_schema()
# 2. Setup graph engine
graph_engine = await get_graph_engine()
# 4. Migrate schema only
await migrate_relational_database(graph_engine, schema=schema, schema_only=True)
# 5. Verify number of tables through search
search_results = await cognee.search(
query_text="How many tables are there in this database",
query_type=cognee.SearchType.GRAPH_COMPLETION,
top_k=30,
)
assert any("11" in r for r in search_results), (
"Number of tables in the database reported in search_results is either None or not equal to 11"
)
graph_db_provider = os.getenv("GRAPH_DATABASE_PROVIDER", "networkx").lower()
edge_counts = {
"is_part_of": 0,
"has_relationship": 0,
"foreign_key": 0,
}
if graph_db_provider == "neo4j":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:{rel_type}]->()
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0]["c"]
elif graph_db_provider == "kuzu":
for rel_type in edge_counts.keys():
query_str = f"""
MATCH ()-[r:EDGE]->()
WHERE r.relationship_name = '{rel_type}'
RETURN count(r) as c
"""
rows = await graph_engine.query(query_str)
edge_counts[rel_type] = rows[0][0]
elif graph_db_provider == "networkx":
nodes, edges = await graph_engine.get_graph_data()
for _, _, key, _ in edges:
if key in edge_counts:
edge_counts[key] += 1
else:
raise ValueError(f"Unsupported graph database provider: {graph_db_provider}")
# 7. Assert counts match expected values
expected_counts = {
"is_part_of": 11,
"has_relationship": 22,
"foreign_key": 11,
}
for rel_type, expected in expected_counts.items():
actual = edge_counts[rel_type]
assert actual == expected, (
f"Expected {expected} edges for relationship '{rel_type}', but found {actual}"
)
print("Schema-only migration edge counts validated successfully!")
print(f"Edge counts: {edge_counts}")
async def test_search_result_quality():
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
)
# Get relational database with original data
migration_engine = get_migration_relational_engine()
from sqlalchemy import text
async with migration_engine.engine.connect() as conn:
result = await conn.execute(
text("""
SELECT
c.CustomerId,
c.FirstName,
c.LastName,
GROUP_CONCAT(i.InvoiceId, ',') AS invoice_ids
FROM Customer AS c
LEFT JOIN Invoice AS i ON c.CustomerId = i.CustomerId
GROUP BY c.CustomerId, c.FirstName, c.LastName
""")
)
for row in result:
# Get expected invoice IDs from relational DB for each Customer
customer_id = row.CustomerId
invoice_ids = row.invoice_ids.split(",") if row.invoice_ids else []
print(f"Relational DB Customer {customer_id}: {invoice_ids}")
# Use Cognee search to get invoice IDs for the same Customer but by providing Customer name
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=f"List me all the invoices of Customer:{row.FirstName} {row.LastName}.",
top_k=50,
system_prompt="Just return me the invoiceID as a number without any text. This is an example output: ['1', '2', '3']. Where 1, 2, 3 are invoiceIDs of an invoice",
)
print(f"Cognee search result: {search_results}")
import ast
lst = ast.literal_eval(search_results[0]) # converts string -> Python list
# Transfrom both lists to int for comparison, sorting and type consistency
lst = sorted([int(x) for x in lst])
invoice_ids = sorted([int(x) for x in invoice_ids])
assert lst == invoice_ids, (
f"Search results {lst} do not match expected invoice IDs {invoice_ids} for Customer:{customer_id}"
)
async def test_migration_sqlite():
database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/")
cognee.config.set_migration_db_config(
{
"migration_db_path": database_to_migrate_path,
"migration_db_name": "migration_database.sqlite",
"migration_db_provider": "sqlite",
}
)
await relational_db_migration()
await test_search_result_quality()
await test_schema_only_migration()
async def test_migration_postgres():
# To run test manually you first need to run the Chinook_PostgreSql.sql script in the test_data directory
cognee.config.set_migration_db_config(
{
"migration_db_name": "test_migration_db",
"migration_db_host": "127.0.0.1",
"migration_db_port": "5432",
"migration_db_username": "cognee",
"migration_db_password": "cognee",
"migration_db_provider": "postgres",
}
)
await relational_db_migration()
await test_schema_only_migration()
async def main():
print("Starting SQLite database migration test...")
await test_migration_sqlite()
print("Starting PostgreSQL database migration test...")
await test_migration_postgres()
if __name__ == "__main__":
import asyncio
asyncio.run(main())