From e677723952a3bffcf875a44548ab31823565a63e Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 12 Dec 2025 10:46:25 +0100 Subject: [PATCH] chore: adds tests for missing retrievers --- cognee/tests/test_search_db.py | 59 ++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 22395f6db..e279be3bf 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -16,6 +16,10 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever from cognee.modules.retrieval.triplet_retriever import TripletRetriever from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType @@ -168,6 +172,10 @@ async def test_search_db(): context_gk_cot = await GraphCompletionCotRetriever().get_context(query=query) context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(query=query) context_gk_sum = await GraphSummaryCompletionRetriever().get_context(query=query) + context_chunks = await ChunksRetriever(top_k=5).get_context(query=query) + context_summaries = await SummariesRetriever(top_k=5).get_context(query=query) + context_rag = await CompletionRetriever(top_k=3).get_context(query=query) + context_temporal = await TemporalRetriever(top_k=5).get_context(query=query) context_triplet = await TripletRetriever().get_context(query=query) for name, context in [ @@ -192,6 +200,26 @@ async def test_search_db(): f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}" ) + assert isinstance(context_chunks, list), "ChunksRetriever: Context should be a list" + assert context_chunks, "ChunksRetriever: Context should not be empty" + chunks_text = "\n".join(str(item.get("text", "")) for item in context_chunks).lower() + assert "germany" in chunks_text or "netherlands" in chunks_text, ( + "ChunksRetriever: Context did not contain 'germany' or 'netherlands'; " + f"got: {context_chunks!r}" + ) + + assert isinstance(context_summaries, list), "SummariesRetriever: Context should be a list" + assert context_summaries, "SummariesRetriever: Context should not be empty" + assert any(str(item.get("text", "")).strip() for item in context_summaries), ( + "SummariesRetriever: Expected at least one non-empty 'text' field in summary payloads" + ) + + assert isinstance(context_rag, str), "CompletionRetriever: Context should be a string" + assert context_rag.strip(), "CompletionRetriever: Context should not be empty" + + assert isinstance(context_temporal, str), "TemporalRetriever: Context should be a string" + assert context_temporal.strip(), "TemporalRetriever: Context should not be empty" + # --- Retriever triplets + vector distance validation --- triplets_gk = await GraphCompletionRetriever().get_triplets(query=query) triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query) @@ -255,6 +283,26 @@ async def test_search_db(): query_text="Next to which country is Germany located?", save_interaction=True, ) + completion_chunks = await cognee.search( + query_type=SearchType.CHUNKS, + query_text="Germany", + save_interaction=False, + ) + completion_summaries = await cognee.search( + query_type=SearchType.SUMMARIES, + query_text="Germany", + save_interaction=False, + ) + completion_rag = await cognee.search( + query_type=SearchType.RAG_COMPLETION, + query_text="Next to which country is Germany located?", + save_interaction=False, + ) + completion_temporal = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="Next to which country is Germany located?", + save_interaction=False, + ) await cognee.search( query_type=SearchType.FEEDBACK, @@ -270,6 +318,8 @@ async def test_search_db(): ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), ("GRAPH_SUMMARY_COMPLETION", completion_sum), ("TRIPLET_COMPLETION", completion_triplet), + ("RAG_COMPLETION", completion_rag), + ("TEMPORAL", completion_temporal), ]: assert isinstance(search_results, list), f"{name}: should return a list" assert len(search_results) == 1, ( @@ -286,6 +336,15 @@ async def test_search_db(): f"{name}: expected 'netherlands' in result, got: {text!r}" ) + for name, search_results in [ + ("CHUNKS", completion_chunks), + ("SUMMARIES", completion_summaries), + ]: + assert isinstance(search_results, list), f"{name}: should return a list" + assert search_results, f"{name}: should not be empty" + assert isinstance(search_results[0], dict), f"{name}: expected dict payloads" + assert str(search_results[0].get("text", "")).strip(), f"{name}: missing non-empty 'text'" + graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data()