adds context e2e test

This commit is contained in:
hajdul88 2025-12-11 18:36:10 +01:00
parent 714fa1f165
commit e766a2d78c

View file

@ -62,11 +62,18 @@ async def setup_search_db_environment():
if has_collection: if has_collection:
collection = await vector_engine.get_collection("Triplet_text") collection = await vector_engine.get_collection("Triplet_text")
count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown" count = await collection.count_rows() if hasattr(collection, 'count_rows') else "unknown"
logger.info(f"Triplet_text collection row count: {count}") logger.info(f"Triplet_text collection row count: {count}")
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
query = "Next to which country is Germany located?"
gk_retriever_context = await GraphCompletionCotRetriever().get_context(query=query)
gk_cot_retriever_context = await GraphCompletionCotRetriever().get_context(query=query)
gk_ext_retriever_context = await GraphCompletionContextExtensionRetriever().get_context(query=query)
gk_sum_retriever_context = await GraphSummaryCompletionRetriever().get_context(query=query)
triplet_retriever_context = await TripletRetriever().get_context(query=query)
yield { yield {
"dataset_name": dataset_name, "dataset_name": dataset_name,
@ -74,6 +81,11 @@ async def setup_search_db_environment():
"explanation_file_path_quantum": explanation_file_path_quantum, "explanation_file_path_quantum": explanation_file_path_quantum,
"graph_engine": graph_engine, "graph_engine": graph_engine,
"vector_engine": vector_engine, "vector_engine": vector_engine,
"gk_retriever_context": gk_retriever_context,
"gk_cot_retriever_context": gk_cot_retriever_context,
"gk_ext_retriever_context": gk_ext_retriever_context,
"gk_sum_retriever_context": gk_sum_retriever_context,
"triplet_retriever_context": triplet_retriever_context,
} }
logger.info("Fixture teardown: pruning data and system") logger.info("Fixture teardown: pruning data and system")
@ -130,3 +142,40 @@ async def test_graph_vector_engine_consistency(setup_search_db_environment):
assert len(edges) == len(collection), ( assert len(edges) == len(collection), (
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection" f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
) )
@pytest.mark.asyncio
async def test_retriever_contexts(setup_search_db_environment):
"""Test that all retrievers return valid contexts with expected content."""
context_gk = setup_search_db_environment["gk_retriever_context"]
context_gk_cot = setup_search_db_environment["gk_cot_retriever_context"]
context_gk_ext = setup_search_db_environment["gk_ext_retriever_context"]
context_gk_sum = setup_search_db_environment["gk_sum_retriever_context"]
context_triplet = setup_search_db_environment["triplet_retriever_context"]
# Test graph-based retrievers (should return lists)
for name, context in [
("GraphCompletionRetriever", context_gk),
("GraphCompletionCotRetriever", context_gk_cot),
("GraphCompletionContextExtensionRetriever", context_gk_ext),
("GraphSummaryCompletionRetriever", context_gk_sum),
]:
assert isinstance(context, list), f"{name}: Context should be a list"
assert len(context) > 0, f"{name}: Context should not be empty"
context_text = await resolve_edges_to_text(context)
lower = context_text.lower()
assert "germany" in lower or "netherlands" in lower, (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
)
# Test triplet retriever (should return string)
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
lower_triplet = context_triplet.lower()
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
)