chore: adds tests for missing retrievers

This commit is contained in:
hajdul88 2025-12-12 10:46:25 +01:00
parent baa158b690
commit e677723952

View file

@ -16,6 +16,10 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
@ -168,6 +172,10 @@ async def test_search_db():
context_gk_cot = await GraphCompletionCotRetriever().get_context(query=query)
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(query=query)
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(query=query)
context_chunks = await ChunksRetriever(top_k=5).get_context(query=query)
context_summaries = await SummariesRetriever(top_k=5).get_context(query=query)
context_rag = await CompletionRetriever(top_k=3).get_context(query=query)
context_temporal = await TemporalRetriever(top_k=5).get_context(query=query)
context_triplet = await TripletRetriever().get_context(query=query)
for name, context in [
@ -192,6 +200,26 @@ async def test_search_db():
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
)
assert isinstance(context_chunks, list), "ChunksRetriever: Context should be a list"
assert context_chunks, "ChunksRetriever: Context should not be empty"
chunks_text = "\n".join(str(item.get("text", "")) for item in context_chunks).lower()
assert "germany" in chunks_text or "netherlands" in chunks_text, (
"ChunksRetriever: Context did not contain 'germany' or 'netherlands'; "
f"got: {context_chunks!r}"
)
assert isinstance(context_summaries, list), "SummariesRetriever: Context should be a list"
assert context_summaries, "SummariesRetriever: Context should not be empty"
assert any(str(item.get("text", "")).strip() for item in context_summaries), (
"SummariesRetriever: Expected at least one non-empty 'text' field in summary payloads"
)
assert isinstance(context_rag, str), "CompletionRetriever: Context should be a string"
assert context_rag.strip(), "CompletionRetriever: Context should not be empty"
assert isinstance(context_temporal, str), "TemporalRetriever: Context should be a string"
assert context_temporal.strip(), "TemporalRetriever: Context should not be empty"
# --- Retriever triplets + vector distance validation ---
triplets_gk = await GraphCompletionRetriever().get_triplets(query=query)
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query)
@ -255,6 +283,26 @@ async def test_search_db():
query_text="Next to which country is Germany located?",
save_interaction=True,
)
completion_chunks = await cognee.search(
query_type=SearchType.CHUNKS,
query_text="Germany",
save_interaction=False,
)
completion_summaries = await cognee.search(
query_type=SearchType.SUMMARIES,
query_text="Germany",
save_interaction=False,
)
completion_rag = await cognee.search(
query_type=SearchType.RAG_COMPLETION,
query_text="Next to which country is Germany located?",
save_interaction=False,
)
completion_temporal = await cognee.search(
query_type=SearchType.TEMPORAL,
query_text="Next to which country is Germany located?",
save_interaction=False,
)
await cognee.search(
query_type=SearchType.FEEDBACK,
@ -270,6 +318,8 @@ async def test_search_db():
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
("GRAPH_SUMMARY_COMPLETION", completion_sum),
("TRIPLET_COMPLETION", completion_triplet),
("RAG_COMPLETION", completion_rag),
("TEMPORAL", completion_temporal),
]:
assert isinstance(search_results, list), f"{name}: should return a list"
assert len(search_results) == 1, (
@ -286,6 +336,15 @@ async def test_search_db():
f"{name}: expected 'netherlands' in result, got: {text!r}"
)
for name, search_results in [
("CHUNKS", completion_chunks),
("SUMMARIES", completion_summaries),
]:
assert isinstance(search_results, list), f"{name}: should return a list"
assert search_results, f"{name}: should not be empty"
assert isinstance(search_results[0], dict), f"{name}: expected dict payloads"
assert str(search_results[0].get("text", "")).strip(), f"{name}: missing non-empty 'text'"
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()