<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a class-based retrieval mechanism to enhance answer generation with improved context extraction and completion. - Added a new evaluation metric for contextual relevancy and an option to enable context evaluation during the evaluation process. - **Refactor** - Transitioned from a function-based answer resolver to a more modular retriever approach to improve extensibility. - **Tests** - Updated tests to align with the new answer generation and evaluation process. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Daniel Molnar <soobrosa@gmail.com> Co-authored-by: Boris <boris@topoteretes.com>
35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
import pytest
|
|
from cognee.eval_framework.answer_generation.answer_generation_executor import (
|
|
AnswerGeneratorExecutor,
|
|
)
|
|
from cognee.eval_framework.benchmark_adapters.dummy_adapter import DummyAdapter
|
|
from unittest.mock import AsyncMock
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_answer_generation():
|
|
limit = 1
|
|
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
|
|
|
|
mock_retriever = AsyncMock()
|
|
mock_retriever.get_context = AsyncMock(return_value="Mocked retrieval context")
|
|
mock_retriever.get_completion = AsyncMock(return_value=["Mocked answer"])
|
|
|
|
answer_generator = AnswerGeneratorExecutor()
|
|
answers = await answer_generator.question_answering_non_parallel(
|
|
questions=qa_pairs,
|
|
retriever=mock_retriever,
|
|
)
|
|
|
|
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
|
|
|
|
assert len(answers) == len(qa_pairs)
|
|
assert answers[0]["question"] == qa_pairs[0]["question"], (
|
|
"AnswerGeneratorExecutor is passing the question incorrectly"
|
|
)
|
|
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
|
|
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
|
|
)
|
|
assert answers[0]["answer"] == "Mocked answer", (
|
|
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
|
|
)
|