adds context e2e test
This commit is contained in:
parent
714fa1f165
commit
e766a2d78c
1 changed files with 50 additions and 1 deletions
|
|
@ -62,11 +62,18 @@ async def setup_search_db_environment():
|
|||
|
||||
if has_collection:
|
||||
collection = await vector_engine.get_collection("Triplet_text")
|
||||
count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown"
|
||||
count = await collection.count_rows() if hasattr(collection, 'count_rows') else "unknown"
|
||||
logger.info(f"Triplet_text collection row count: {count}")
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
query = "Next to which country is Germany located?"
|
||||
gk_retriever_context = await GraphCompletionCotRetriever().get_context(query=query)
|
||||
gk_cot_retriever_context = await GraphCompletionCotRetriever().get_context(query=query)
|
||||
gk_ext_retriever_context = await GraphCompletionContextExtensionRetriever().get_context(query=query)
|
||||
gk_sum_retriever_context = await GraphSummaryCompletionRetriever().get_context(query=query)
|
||||
triplet_retriever_context = await TripletRetriever().get_context(query=query)
|
||||
|
||||
yield {
|
||||
"dataset_name": dataset_name,
|
||||
|
|
@ -74,6 +81,11 @@ async def setup_search_db_environment():
|
|||
"explanation_file_path_quantum": explanation_file_path_quantum,
|
||||
"graph_engine": graph_engine,
|
||||
"vector_engine": vector_engine,
|
||||
"gk_retriever_context": gk_retriever_context,
|
||||
"gk_cot_retriever_context": gk_cot_retriever_context,
|
||||
"gk_ext_retriever_context": gk_ext_retriever_context,
|
||||
"gk_sum_retriever_context": gk_sum_retriever_context,
|
||||
"triplet_retriever_context": triplet_retriever_context,
|
||||
}
|
||||
|
||||
logger.info("Fixture teardown: pruning data and system")
|
||||
|
|
@ -130,3 +142,40 @@ async def test_graph_vector_engine_consistency(setup_search_db_environment):
|
|||
assert len(edges) == len(collection), (
|
||||
f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retriever_contexts(setup_search_db_environment):
|
||||
"""Test that all retrievers return valid contexts with expected content."""
|
||||
|
||||
|
||||
context_gk = setup_search_db_environment["gk_retriever_context"]
|
||||
context_gk_cot = setup_search_db_environment["gk_cot_retriever_context"]
|
||||
context_gk_ext = setup_search_db_environment["gk_ext_retriever_context"]
|
||||
context_gk_sum = setup_search_db_environment["gk_sum_retriever_context"]
|
||||
context_triplet = setup_search_db_environment["triplet_retriever_context"]
|
||||
|
||||
# Test graph-based retrievers (should return lists)
|
||||
for name, context in [
|
||||
("GraphCompletionRetriever", context_gk),
|
||||
("GraphCompletionCotRetriever", context_gk_cot),
|
||||
("GraphCompletionContextExtensionRetriever", context_gk_ext),
|
||||
("GraphSummaryCompletionRetriever", context_gk_sum),
|
||||
]:
|
||||
assert isinstance(context, list), f"{name}: Context should be a list"
|
||||
assert len(context) > 0, f"{name}: Context should not be empty"
|
||||
|
||||
context_text = await resolve_edges_to_text(context)
|
||||
lower = context_text.lower()
|
||||
assert "germany" in lower or "netherlands" in lower, (
|
||||
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
|
||||
)
|
||||
|
||||
# Test triplet retriever (should return string)
|
||||
assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string"
|
||||
assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty"
|
||||
lower_triplet = context_triplet.lower()
|
||||
assert "germany" in lower_triplet or "netherlands" in lower_triplet, (
|
||||
f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}"
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue