feat: Add context evaluation to eval framework [COG-1366] (#586)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a class-based retrieval mechanism to enhance answer generation with improved context extraction and completion. - Added a new evaluation metric for contextual relevancy and an option to enable context evaluation during the evaluation process. - **Refactor** - Transitioned from a function-based answer resolver to a more modular retriever approach to improve extensibility. - **Tests** - Updated tests to align with the new answer generation and evaluation process. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Daniel Molnar <soobrosa@gmail.com> Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
parent
f033f733b5
commit
433264d4e4
7 changed files with 44 additions and 29 deletions
|
|
@ -1,21 +1,17 @@
|
|||
import cognee
|
||||
from typing import List, Dict, Callable, Awaitable
|
||||
from cognee.api.v1.search import SearchType
|
||||
from typing import List, Dict
|
||||
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||
GraphSummaryCompletionRetriever,
|
||||
)
|
||||
|
||||
question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = {
|
||||
"cognee_graph_completion": lambda query, system_prompt_path: cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text=query,
|
||||
system_prompt_path=system_prompt_path,
|
||||
),
|
||||
"cognee_completion": lambda query, system_prompt_path: cognee.search(
|
||||
query_type=SearchType.COMPLETION, query_text=query, system_prompt_path=system_prompt_path
|
||||
),
|
||||
"graph_summary_completion": lambda query, system_prompt_path: cognee.search(
|
||||
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||
query_text=query,
|
||||
system_prompt_path=system_prompt_path,
|
||||
),
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
|
||||
|
||||
retriever_options: Dict[str, BaseRetriever] = {
|
||||
"cognee_graph_completion": GraphCompletionRetriever,
|
||||
"cognee_completion": CompletionRetriever,
|
||||
"graph_summary_completion": GraphSummaryCompletionRetriever,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -23,20 +19,22 @@ class AnswerGeneratorExecutor:
|
|||
async def question_answering_non_parallel(
|
||||
self,
|
||||
questions: List[Dict[str, str]],
|
||||
answer_resolver: Callable[[str], Awaitable[List[str]]],
|
||||
retriever: BaseRetriever,
|
||||
) -> List[Dict[str, str]]:
|
||||
answers = []
|
||||
for instance in questions:
|
||||
query_text = instance["question"]
|
||||
correct_answer = instance["answer"]
|
||||
|
||||
search_results = await answer_resolver(query_text)
|
||||
retrieval_context = await retriever.get_context(query_text)
|
||||
search_results = await retriever.get_completion(query_text, retrieval_context)
|
||||
|
||||
answers.append(
|
||||
{
|
||||
"question": query_text,
|
||||
"answer": search_results[0],
|
||||
"golden_answer": correct_answer,
|
||||
"retrieval_context": retrieval_context,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
from typing import List
|
||||
from cognee.eval_framework.answer_generation.answer_generation_executor import (
|
||||
AnswerGeneratorExecutor,
|
||||
question_answering_engine_options,
|
||||
retriever_options,
|
||||
)
|
||||
from cognee.infrastructure.files.storage import LocalStorage
|
||||
from cognee.infrastructure.databases.relational.get_relational_engine import (
|
||||
|
|
@ -48,9 +48,7 @@ async def run_question_answering(
|
|||
answer_generator = AnswerGeneratorExecutor()
|
||||
answers = await answer_generator.question_answering_non_parallel(
|
||||
questions=questions,
|
||||
answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]](
|
||||
query, system_prompt
|
||||
),
|
||||
retriever=retriever_options[params["qa_engine"]](system_prompt_path=system_prompt),
|
||||
)
|
||||
with open(params["answers_path"], "w", encoding="utf-8") as f:
|
||||
json.dump(answers, f, ensure_ascii=False, indent=4)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class EvalConfig(BaseSettings):
|
|||
|
||||
# Evaluation params
|
||||
evaluating_answers: bool = True
|
||||
evaluating_contexts: bool = True
|
||||
evaluation_engine: str = "DeepEval" # Options: 'DeepEval' (uses deepeval_model), 'DirectLLM' (uses default llm from .env)
|
||||
evaluation_metrics: List[str] = [
|
||||
"correctness",
|
||||
|
|
@ -51,6 +52,7 @@ class EvalConfig(BaseSettings):
|
|||
"answering_questions": self.answering_questions,
|
||||
"qa_engine": self.qa_engine,
|
||||
"evaluating_answers": self.evaluating_answers,
|
||||
"evaluating_contexts": self.evaluating_contexts, # Controls whether context evaluation should be performed
|
||||
"evaluation_engine": self.evaluation_engine,
|
||||
"evaluation_metrics": self.evaluation_metrics,
|
||||
"calculate_metrics": self.calculate_metrics,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter
|
|||
from cognee.eval_framework.evaluation.metrics.exact_match import ExactMatchMetric
|
||||
from cognee.eval_framework.evaluation.metrics.f1 import F1ScoreMetric
|
||||
from typing import Any, Dict, List
|
||||
from deepeval.metrics import ContextualRelevancyMetric
|
||||
|
||||
|
||||
class DeepEvalAdapter(BaseEvalAdapter):
|
||||
|
|
@ -13,6 +14,7 @@ class DeepEvalAdapter(BaseEvalAdapter):
|
|||
"correctness": self.g_eval_correctness(),
|
||||
"EM": ExactMatchMetric(),
|
||||
"f1": F1ScoreMetric(),
|
||||
"contextual_relevancy": ContextualRelevancyMetric(),
|
||||
}
|
||||
|
||||
async def evaluate_answers(
|
||||
|
|
@ -29,6 +31,7 @@ class DeepEvalAdapter(BaseEvalAdapter):
|
|||
input=answer["question"],
|
||||
actual_output=answer["answer"],
|
||||
expected_output=answer["golden_answer"],
|
||||
retrieval_context=[answer["retrieval_context"]],
|
||||
)
|
||||
metric_results = {}
|
||||
for metric in evaluator_metrics:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,11 @@ from cognee.eval_framework.evaluation.evaluator_adapters import EvaluatorAdapter
|
|||
|
||||
|
||||
class EvaluationExecutor:
|
||||
def __init__(self, evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEval") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEval",
|
||||
evaluate_contexts: bool = False,
|
||||
) -> None:
|
||||
if isinstance(evaluator_engine, str):
|
||||
try:
|
||||
adapter_enum = EvaluatorAdapter(evaluator_engine)
|
||||
|
|
@ -14,7 +18,10 @@ class EvaluationExecutor:
|
|||
self.eval_adapter = evaluator_engine.adapter_class()
|
||||
else:
|
||||
self.eval_adapter = evaluator_engine
|
||||
self.evaluate_contexts = evaluate_contexts
|
||||
|
||||
async def execute(self, answers: List[Dict[str, str]], evaluator_metrics: Any) -> Any:
|
||||
if self.evaluate_contexts:
|
||||
evaluator_metrics.append("contextual_relevancy")
|
||||
metrics = await self.eval_adapter.evaluate_answers(answers, evaluator_metrics)
|
||||
return metrics
|
||||
|
|
|
|||
|
|
@ -42,7 +42,10 @@ async def execute_evaluation(params: dict) -> None:
|
|||
raise ValueError(f"Error decoding JSON from {params['answers_path']}: {e}")
|
||||
|
||||
logging.info(f"Loaded {len(answers)} answers from {params['answers_path']}")
|
||||
evaluator = EvaluationExecutor(evaluator_engine=params["evaluation_engine"])
|
||||
evaluator = EvaluationExecutor(
|
||||
evaluator_engine=params["evaluation_engine"],
|
||||
evaluate_contexts=params["evaluating_contexts"],
|
||||
)
|
||||
metrics = await evaluator.execute(
|
||||
answers=answers, evaluator_metrics=params["evaluation_metrics"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,14 +11,18 @@ async def test_answer_generation():
|
|||
limit = 1
|
||||
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
|
||||
|
||||
mock_answer_resolver = AsyncMock()
|
||||
mock_answer_resolver.side_effect = lambda query: ["mock_answer"]
|
||||
mock_retriever = AsyncMock()
|
||||
mock_retriever.get_context = AsyncMock(return_value="Mocked retrieval context")
|
||||
mock_retriever.get_completion = AsyncMock(return_value=["Mocked answer"])
|
||||
|
||||
answer_generator = AnswerGeneratorExecutor()
|
||||
answers = await answer_generator.question_answering_non_parallel(
|
||||
questions=qa_pairs, answer_resolver=mock_answer_resolver
|
||||
questions=qa_pairs,
|
||||
retriever=mock_retriever,
|
||||
)
|
||||
|
||||
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
|
||||
|
||||
assert len(answers) == len(qa_pairs)
|
||||
assert answers[0]["question"] == qa_pairs[0]["question"], (
|
||||
"AnswerGeneratorExecutor is passing the question incorrectly"
|
||||
|
|
@ -26,6 +30,6 @@ async def test_answer_generation():
|
|||
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
|
||||
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
|
||||
)
|
||||
assert answers[0]["answer"] == "mock_answer", (
|
||||
assert answers[0]["answer"] == "Mocked answer", (
|
||||
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue