cognee/cognee/tests/unit/eval_framework/answer_generation_test.py
alekszievr 433264d4e4
feat: Add context evaluation to eval framework [COG-1366] (#586)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a class-based retrieval mechanism to enhance answer
generation with improved context extraction and completion.
- Added a new evaluation metric for contextual relevancy and an option
to enable context evaluation during the evaluation process.

- **Refactor**
- Transitioned from a function-based answer resolver to a more modular
retriever approach to improve extensibility.

- **Tests**
- Updated tests to align with the new answer generation and evaluation
process.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Daniel Molnar <soobrosa@gmail.com>
Co-authored-by: Boris <boris@topoteretes.com>
2025-03-05 16:40:24 +01:00

35 lines
1.3 KiB
Python

import pytest
from cognee.eval_framework.answer_generation.answer_generation_executor import (
AnswerGeneratorExecutor,
)
from cognee.eval_framework.benchmark_adapters.dummy_adapter import DummyAdapter
from unittest.mock import AsyncMock
@pytest.mark.asyncio
async def test_answer_generation():
limit = 1
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
mock_retriever = AsyncMock()
mock_retriever.get_context = AsyncMock(return_value="Mocked retrieval context")
mock_retriever.get_completion = AsyncMock(return_value=["Mocked answer"])
answer_generator = AnswerGeneratorExecutor()
answers = await answer_generator.question_answering_non_parallel(
questions=qa_pairs,
retriever=mock_retriever,
)
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
assert len(answers) == len(qa_pairs)
assert answers[0]["question"] == qa_pairs[0]["question"], (
"AnswerGeneratorExecutor is passing the question incorrectly"
)
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
)
assert answers[0]["answer"] == "Mocked answer", (
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
)