From 94bc0ef47f7aa8b07c2c3f40d0c9779432dbb7c8 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 11 Sep 2025 23:16:25 +0200 Subject: [PATCH] tests: Resolve failing search tests --- cognee/tests/test_relational_db_migration.py | 8 +++----- cognee/tests/test_search_db.py | 17 ++++++++++------- ...pletion_retriever_context_extension_test.py | 18 ++++++++++++------ .../graph_completion_retriever_cot_test.py | 18 ++++++++++++------ .../retrieval/insights_retriever_test.py | 6 ++---- 5 files changed, 39 insertions(+), 28 deletions(-) diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 49508144f..68b46dbf5 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -45,15 +45,13 @@ async def relational_db_migration(): await migrate_relational_database(graph_engine, schema=schema) # 1. Search the graph - search_results: List[SearchResult] = await cognee.search( + search_results = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC" - ) # type: ignore + ) print("Search results:", search_results) # 2. Assert that the search results contain "AC/DC" - assert any("AC/DC" in r.search_result for r in search_results), ( - "AC/DC not found in search results!" - ) + assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!" migration_db_provider = migration_engine.engine.dialect.name if migration_db_provider == "postgresql": diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 62b07f31a..cb4636470 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -144,13 +144,16 @@ async def main(): ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), ("GRAPH_SUMMARY_COMPLETION", completion_sum), ]: - for search_result in search_results: - completion = search_result.search_result - assert isinstance(completion, str), f"{name}: should return a string" - assert completion.strip(), f"{name}: string should not be empty" - assert "netherlands" in completion.lower(), ( - f"{name}: expected 'netherlands' in result, got: {completion!r}" - ) + assert isinstance(search_results, list), f"{name}: should return a list" + assert len(search_results) == 1, ( + f"{name}: expected single-element list, got {len(search_results)}" + ) + text = search_results[0] + assert isinstance(text, str), f"{name}: element should be a string" + assert text.strip(), f"{name}: string should not be empty" + assert "netherlands" in text.lower(), ( + f"{name}: expected 'netherlands' in result, got: {text!r}" + ) graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 02e3f73e2..74def2ae7 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -59,8 +59,10 @@ class TestGraphCompletionWithContextExtensionRetriever: answer = await retriever.get_completion("Who works at Canva?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) @pytest.mark.asyncio async def test_graph_completion_extension_context_complex(self): @@ -140,8 +142,10 @@ class TestGraphCompletionWithContextExtensionRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) @pytest.mark.asyncio async def test_get_graph_completion_extension_context_on_empty_graph(self): @@ -171,5 +175,7 @@ class TestGraphCompletionWithContextExtensionRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 54fa12f41..9a789a1bd 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -55,8 +55,10 @@ class TestGraphCompletionCoTRetriever: answer = await retriever.get_completion("Who works at Canva?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) @pytest.mark.asyncio async def test_graph_completion_cot_context_complex(self): @@ -133,8 +135,10 @@ class TestGraphCompletionCoTRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) @pytest.mark.asyncio async def test_get_graph_completion_cot_context_on_empty_graph(self): @@ -164,5 +168,7 @@ class TestGraphCompletionCoTRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) diff --git a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py index a3d9da63a..21dbc98dd 100644 --- a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py @@ -82,7 +82,7 @@ class TestInsightsRetriever: context = await retriever.get_context("Mike") - assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski" + assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski" @pytest.mark.asyncio async def test_insights_context_complex(self): @@ -222,9 +222,7 @@ class TestInsightsRetriever: context = await retriever.get_context("Christina") - assert context[0].node1.attributes["name"] == "Christina Mayer", ( - "Failed to get Christina Mayer" - ) + assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer" @pytest.mark.asyncio async def test_insights_context_on_empty_graph(self):