From e766a2d78c3fdc35c6a9f6fcdde8b0acec0e6e0e Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:36:10 +0100 Subject: [PATCH] adds context e2e test --- cognee/tests/test_search_db.py | 51 +++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index d3b6ba0a2..2a79a16b2 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -62,11 +62,18 @@ async def setup_search_db_environment(): if has_collection: collection = await vector_engine.get_collection("Triplet_text") - count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown" + count = await collection.count_rows() if hasattr(collection, 'count_rows') else "unknown" logger.info(f"Triplet_text collection row count: {count}") graph_engine = await get_graph_engine() + vector_engine = get_vector_engine() + query = "Next to which country is Germany located?" + gk_retriever_context = await GraphCompletionCotRetriever().get_context(query=query) + gk_cot_retriever_context = await GraphCompletionCotRetriever().get_context(query=query) + gk_ext_retriever_context = await GraphCompletionContextExtensionRetriever().get_context(query=query) + gk_sum_retriever_context = await GraphSummaryCompletionRetriever().get_context(query=query) + triplet_retriever_context = await TripletRetriever().get_context(query=query) yield { "dataset_name": dataset_name, @@ -74,6 +81,11 @@ async def setup_search_db_environment(): "explanation_file_path_quantum": explanation_file_path_quantum, "graph_engine": graph_engine, "vector_engine": vector_engine, + "gk_retriever_context": gk_retriever_context, + "gk_cot_retriever_context": gk_cot_retriever_context, + "gk_ext_retriever_context": gk_ext_retriever_context, + "gk_sum_retriever_context": gk_sum_retriever_context, + "triplet_retriever_context": triplet_retriever_context, } logger.info("Fixture teardown: pruning data and system") @@ -130,3 +142,40 @@ async def test_graph_vector_engine_consistency(setup_search_db_environment): assert len(edges) == len(collection), ( f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection" ) + + +@pytest.mark.asyncio +async def test_retriever_contexts(setup_search_db_environment): + """Test that all retrievers return valid contexts with expected content.""" + + + context_gk = setup_search_db_environment["gk_retriever_context"] + context_gk_cot = setup_search_db_environment["gk_cot_retriever_context"] + context_gk_ext = setup_search_db_environment["gk_ext_retriever_context"] + context_gk_sum = setup_search_db_environment["gk_sum_retriever_context"] + context_triplet = setup_search_db_environment["triplet_retriever_context"] + + # Test graph-based retrievers (should return lists) + for name, context in [ + ("GraphCompletionRetriever", context_gk), + ("GraphCompletionCotRetriever", context_gk_cot), + ("GraphCompletionContextExtensionRetriever", context_gk_ext), + ("GraphSummaryCompletionRetriever", context_gk_sum), + ]: + assert isinstance(context, list), f"{name}: Context should be a list" + assert len(context) > 0, f"{name}: Context should not be empty" + + context_text = await resolve_edges_to_text(context) + lower = context_text.lower() + assert "germany" in lower or "netherlands" in lower, ( + f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}" + ) + + # Test triplet retriever (should return string) + assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string" + assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty" + lower_triplet = context_triplet.lower() + assert "germany" in lower_triplet or "netherlands" in lower_triplet, ( + f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}" + ) +