feat: Add context evaluation to eval framework [COG-1366] (#586)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a class-based retrieval mechanism to enhance answer
generation with improved context extraction and completion.
- Added a new evaluation metric for contextual relevancy and an option
to enable context evaluation during the evaluation process.

- **Refactor**
- Transitioned from a function-based answer resolver to a more modular
retriever approach to improve extensibility.

- **Tests**
- Updated tests to align with the new answer generation and evaluation
process.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Daniel Molnar <soobrosa@gmail.com>
Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
alekszievr 2025-03-05 16:40:24 +01:00 committed by GitHub
parent f033f733b5
commit 433264d4e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 44 additions and 29 deletions

View file

@ -1,21 +1,17 @@
import cognee
from typing import List, Dict, Callable, Awaitable
from cognee.api.v1.search import SearchType
from typing import List, Dict
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = {
"cognee_graph_completion": lambda query, system_prompt_path: cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text=query,
system_prompt_path=system_prompt_path,
),
"cognee_completion": lambda query, system_prompt_path: cognee.search(
query_type=SearchType.COMPLETION, query_text=query, system_prompt_path=system_prompt_path
),
"graph_summary_completion": lambda query, system_prompt_path: cognee.search(
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
query_text=query,
system_prompt_path=system_prompt_path,
),
from cognee.modules.retrieval.base_retriever import BaseRetriever
retriever_options: Dict[str, BaseRetriever] = {
"cognee_graph_completion": GraphCompletionRetriever,
"cognee_completion": CompletionRetriever,
"graph_summary_completion": GraphSummaryCompletionRetriever,
}
@ -23,20 +19,22 @@ class AnswerGeneratorExecutor:
async def question_answering_non_parallel(
self,
questions: List[Dict[str, str]],
answer_resolver: Callable[[str], Awaitable[List[str]]],
retriever: BaseRetriever,
) -> List[Dict[str, str]]:
answers = []
for instance in questions:
query_text = instance["question"]
correct_answer = instance["answer"]
search_results = await answer_resolver(query_text)
retrieval_context = await retriever.get_context(query_text)
search_results = await retriever.get_completion(query_text, retrieval_context)
answers.append(
{
"question": query_text,
"answer": search_results[0],
"golden_answer": correct_answer,
"retrieval_context": retrieval_context,
}
)

View file

@ -3,7 +3,7 @@ import json
from typing import List
from cognee.eval_framework.answer_generation.answer_generation_executor import (
AnswerGeneratorExecutor,
question_answering_engine_options,
retriever_options,
)
from cognee.infrastructure.files.storage import LocalStorage
from cognee.infrastructure.databases.relational.get_relational_engine import (
@ -48,9 +48,7 @@ async def run_question_answering(
answer_generator = AnswerGeneratorExecutor()
answers = await answer_generator.question_answering_non_parallel(
questions=questions,
answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]](
query, system_prompt
),
retriever=retriever_options[params["qa_engine"]](system_prompt_path=system_prompt),
)
with open(params["answers_path"], "w", encoding="utf-8") as f:
json.dump(answers, f, ensure_ascii=False, indent=4)

View file

@ -18,6 +18,7 @@ class EvalConfig(BaseSettings):
# Evaluation params
evaluating_answers: bool = True
evaluating_contexts: bool = True
evaluation_engine: str = "DeepEval" # Options: 'DeepEval' (uses deepeval_model), 'DirectLLM' (uses default llm from .env)
evaluation_metrics: List[str] = [
"correctness",
@ -51,6 +52,7 @@ class EvalConfig(BaseSettings):
"answering_questions": self.answering_questions,
"qa_engine": self.qa_engine,
"evaluating_answers": self.evaluating_answers,
"evaluating_contexts": self.evaluating_contexts, # Controls whether context evaluation should be performed
"evaluation_engine": self.evaluation_engine,
"evaluation_metrics": self.evaluation_metrics,
"calculate_metrics": self.calculate_metrics,

View file

@ -5,6 +5,7 @@ from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter
from cognee.eval_framework.evaluation.metrics.exact_match import ExactMatchMetric
from cognee.eval_framework.evaluation.metrics.f1 import F1ScoreMetric
from typing import Any, Dict, List
from deepeval.metrics import ContextualRelevancyMetric
class DeepEvalAdapter(BaseEvalAdapter):
@ -13,6 +14,7 @@ class DeepEvalAdapter(BaseEvalAdapter):
"correctness": self.g_eval_correctness(),
"EM": ExactMatchMetric(),
"f1": F1ScoreMetric(),
"contextual_relevancy": ContextualRelevancyMetric(),
}
async def evaluate_answers(
@ -29,6 +31,7 @@ class DeepEvalAdapter(BaseEvalAdapter):
input=answer["question"],
actual_output=answer["answer"],
expected_output=answer["golden_answer"],
retrieval_context=[answer["retrieval_context"]],
)
metric_results = {}
for metric in evaluator_metrics:

View file

@ -3,7 +3,11 @@ from cognee.eval_framework.evaluation.evaluator_adapters import EvaluatorAdapter
class EvaluationExecutor:
def __init__(self, evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEval") -> None:
def __init__(
self,
evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEval",
evaluate_contexts: bool = False,
) -> None:
if isinstance(evaluator_engine, str):
try:
adapter_enum = EvaluatorAdapter(evaluator_engine)
@ -14,7 +18,10 @@ class EvaluationExecutor:
self.eval_adapter = evaluator_engine.adapter_class()
else:
self.eval_adapter = evaluator_engine
self.evaluate_contexts = evaluate_contexts
async def execute(self, answers: List[Dict[str, str]], evaluator_metrics: Any) -> Any:
if self.evaluate_contexts:
evaluator_metrics.append("contextual_relevancy")
metrics = await self.eval_adapter.evaluate_answers(answers, evaluator_metrics)
return metrics

View file

@ -42,7 +42,10 @@ async def execute_evaluation(params: dict) -> None:
raise ValueError(f"Error decoding JSON from {params['answers_path']}: {e}")
logging.info(f"Loaded {len(answers)} answers from {params['answers_path']}")
evaluator = EvaluationExecutor(evaluator_engine=params["evaluation_engine"])
evaluator = EvaluationExecutor(
evaluator_engine=params["evaluation_engine"],
evaluate_contexts=params["evaluating_contexts"],
)
metrics = await evaluator.execute(
answers=answers, evaluator_metrics=params["evaluation_metrics"]
)

View file

@ -11,14 +11,18 @@ async def test_answer_generation():
limit = 1
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
mock_answer_resolver = AsyncMock()
mock_answer_resolver.side_effect = lambda query: ["mock_answer"]
mock_retriever = AsyncMock()
mock_retriever.get_context = AsyncMock(return_value="Mocked retrieval context")
mock_retriever.get_completion = AsyncMock(return_value=["Mocked answer"])
answer_generator = AnswerGeneratorExecutor()
answers = await answer_generator.question_answering_non_parallel(
questions=qa_pairs, answer_resolver=mock_answer_resolver
questions=qa_pairs,
retriever=mock_retriever,
)
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
assert len(answers) == len(qa_pairs)
assert answers[0]["question"] == qa_pairs[0]["question"], (
"AnswerGeneratorExecutor is passing the question incorrectly"
@ -26,6 +30,6 @@ async def test_answer_generation():
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
)
assert answers[0]["answer"] == "mock_answer", (
assert answers[0]["answer"] == "Mocked answer", (
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
)