diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 4557e9e2f..ae06e7c5d 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -1,6 +1,5 @@ import pathlib import os -from typing import List from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import ( get_migration_relational_engine, @@ -10,7 +9,7 @@ from cognee.infrastructure.databases.vector.pgvector import ( create_db_and_tables as create_pgvector_db_and_tables, ) from cognee.tasks.ingestion import migrate_relational_database -from cognee.modules.search.types import SearchResult, SearchType +from cognee.modules.search.types import SearchType import cognee @@ -274,6 +273,55 @@ async def test_schema_only_migration(): print(f"Edge counts: {edge_counts}") +async def test_search_result_quality(): + from cognee.infrastructure.databases.relational import ( + get_migration_relational_engine, + ) + + # Get relational database with original data + migration_engine = get_migration_relational_engine() + from sqlalchemy import text + + async with migration_engine.engine.connect() as conn: + result = await conn.execute( + text(""" + SELECT + c.CustomerId, + c.FirstName, + c.LastName, + GROUP_CONCAT(i.InvoiceId, ',') AS invoice_ids + FROM Customer AS c + LEFT JOIN Invoice AS i ON c.CustomerId = i.CustomerId + GROUP BY c.CustomerId, c.FirstName, c.LastName + """) + ) + + for row in result: + # Get expected invoice IDs from relational DB for each Customer + customer_id = row.CustomerId + invoice_ids = row.invoice_ids.split(",") if row.invoice_ids else [] + print(f"Relational DB Customer {customer_id}: {invoice_ids}") + + # Use Cognee search to get invoice IDs for the same Customer but by providing Customer name + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=f"List me all the invoices of Customer:{row.FirstName} {row.LastName}.", + top_k=50, + system_prompt="Just return me the invoiceID as a number without any text. This is an example output: ['1', '2', '3']. Where 1, 2, 3 are invoiceIDs of an invoice", + ) + print(f"Cognee search result: {search_results}") + + import ast + + lst = ast.literal_eval(search_results[0]) # converts string -> Python list + # Transfrom both lists to int for comparison, sorting and type consistency + lst = sorted([int(x) for x in lst]) + invoice_ids = sorted([int(x) for x in invoice_ids]) + assert lst == invoice_ids, ( + f"Search results {lst} do not match expected invoice IDs {invoice_ids} for Customer:{customer_id}" + ) + + async def test_migration_sqlite(): database_to_migrate_path = os.path.join(pathlib.Path(__file__).parent, "test_data/") @@ -286,6 +334,7 @@ async def test_migration_sqlite(): ) await relational_db_migration() + await test_search_result_quality() await test_schema_only_migration()