chore: adds tests for missing retrievers
This commit is contained in:
parent
baa158b690
commit
e677723952
1 changed files with 59 additions and 0 deletions
|
|
@ -16,6 +16,10 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
|
|||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
|
||||
from cognee.modules.retrieval.triplet_retriever import TripletRetriever
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
|
@ -168,6 +172,10 @@ async def test_search_db():
|
|||
context_gk_cot = await GraphCompletionCotRetriever().get_context(query=query)
|
||||
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(query=query)
|
||||
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(query=query)
|
||||
context_chunks = await ChunksRetriever(top_k=5).get_context(query=query)
|
||||
context_summaries = await SummariesRetriever(top_k=5).get_context(query=query)
|
||||
context_rag = await CompletionRetriever(top_k=3).get_context(query=query)
|
||||
context_temporal = await TemporalRetriever(top_k=5).get_context(query=query)
|
||||
context_triplet = await TripletRetriever().get_context(query=query)
|
||||
|
||||
for name, context in [
|
||||
|
|
@ -192,6 +200,26 @@ async def test_search_db():
|
|||
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
||||
)
|
||||
|
||||
assert isinstance(context_chunks, list), "ChunksRetriever: Context should be a list"
|
||||
assert context_chunks, "ChunksRetriever: Context should not be empty"
|
||||
chunks_text = "\n".join(str(item.get("text", "")) for item in context_chunks).lower()
|
||||
assert "germany" in chunks_text or "netherlands" in chunks_text, (
|
||||
"ChunksRetriever: Context did not contain 'germany' or 'netherlands'; "
|
||||
f"got: {context_chunks!r}"
|
||||
)
|
||||
|
||||
assert isinstance(context_summaries, list), "SummariesRetriever: Context should be a list"
|
||||
assert context_summaries, "SummariesRetriever: Context should not be empty"
|
||||
assert any(str(item.get("text", "")).strip() for item in context_summaries), (
|
||||
"SummariesRetriever: Expected at least one non-empty 'text' field in summary payloads"
|
||||
)
|
||||
|
||||
assert isinstance(context_rag, str), "CompletionRetriever: Context should be a string"
|
||||
assert context_rag.strip(), "CompletionRetriever: Context should not be empty"
|
||||
|
||||
assert isinstance(context_temporal, str), "TemporalRetriever: Context should be a string"
|
||||
assert context_temporal.strip(), "TemporalRetriever: Context should not be empty"
|
||||
|
||||
# --- Retriever triplets + vector distance validation ---
|
||||
triplets_gk = await GraphCompletionRetriever().get_triplets(query=query)
|
||||
triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets(query=query)
|
||||
|
|
@ -255,6 +283,26 @@ async def test_search_db():
|
|||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=True,
|
||||
)
|
||||
completion_chunks = await cognee.search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="Germany",
|
||||
save_interaction=False,
|
||||
)
|
||||
completion_summaries = await cognee.search(
|
||||
query_type=SearchType.SUMMARIES,
|
||||
query_text="Germany",
|
||||
save_interaction=False,
|
||||
)
|
||||
completion_rag = await cognee.search(
|
||||
query_type=SearchType.RAG_COMPLETION,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=False,
|
||||
)
|
||||
completion_temporal = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL,
|
||||
query_text="Next to which country is Germany located?",
|
||||
save_interaction=False,
|
||||
)
|
||||
|
||||
await cognee.search(
|
||||
query_type=SearchType.FEEDBACK,
|
||||
|
|
@ -270,6 +318,8 @@ async def test_search_db():
|
|||
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
||||
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
||||
("TRIPLET_COMPLETION", completion_triplet),
|
||||
("RAG_COMPLETION", completion_rag),
|
||||
("TEMPORAL", completion_temporal),
|
||||
]:
|
||||
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||
assert len(search_results) == 1, (
|
||||
|
|
@ -286,6 +336,15 @@ async def test_search_db():
|
|||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||
)
|
||||
|
||||
for name, search_results in [
|
||||
("CHUNKS", completion_chunks),
|
||||
("SUMMARIES", completion_summaries),
|
||||
]:
|
||||
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||
assert search_results, f"{name}: should not be empty"
|
||||
assert isinstance(search_results[0], dict), f"{name}: expected dict payloads"
|
||||
assert str(search_results[0].get("text", "")).strip(), f"{name}: missing non-empty 'text'"
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = await graph_engine.get_graph_data()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue