feat: Add context evaluation to eval framework [COG-1366] (#586)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a class-based retrieval mechanism to enhance answer generation with improved context extraction and completion. - Added a new evaluation metric for contextual relevancy and an option to enable context evaluation during the evaluation process. - **Refactor** - Transitioned from a function-based answer resolver to a more modular retriever approach to improve extensibility. - **Tests** - Updated tests to align with the new answer generation and evaluation process. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: lxobr <122801072+lxobr@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Daniel Molnar <soobrosa@gmail.com> Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
parent
f033f733b5
commit
433264d4e4
7 changed files with 44 additions and 29 deletions
|
|
@ -1,21 +1,17 @@
|
||||||
import cognee
|
from typing import List, Dict
|
||||||
from typing import List, Dict, Callable, Awaitable
|
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
|
||||||
from cognee.api.v1.search import SearchType
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
from cognee.modules.retrieval.graph_summary_completion_retriever import (
|
||||||
|
GraphSummaryCompletionRetriever,
|
||||||
|
)
|
||||||
|
|
||||||
question_answering_engine_options: Dict[str, Callable[[str, str], Awaitable[List[str]]]] = {
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||||
"cognee_graph_completion": lambda query, system_prompt_path: cognee.search(
|
|
||||||
query_type=SearchType.GRAPH_COMPLETION,
|
|
||||||
query_text=query,
|
retriever_options: Dict[str, BaseRetriever] = {
|
||||||
system_prompt_path=system_prompt_path,
|
"cognee_graph_completion": GraphCompletionRetriever,
|
||||||
),
|
"cognee_completion": CompletionRetriever,
|
||||||
"cognee_completion": lambda query, system_prompt_path: cognee.search(
|
"graph_summary_completion": GraphSummaryCompletionRetriever,
|
||||||
query_type=SearchType.COMPLETION, query_text=query, system_prompt_path=system_prompt_path
|
|
||||||
),
|
|
||||||
"graph_summary_completion": lambda query, system_prompt_path: cognee.search(
|
|
||||||
query_type=SearchType.GRAPH_SUMMARY_COMPLETION,
|
|
||||||
query_text=query,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -23,20 +19,22 @@ class AnswerGeneratorExecutor:
|
||||||
async def question_answering_non_parallel(
|
async def question_answering_non_parallel(
|
||||||
self,
|
self,
|
||||||
questions: List[Dict[str, str]],
|
questions: List[Dict[str, str]],
|
||||||
answer_resolver: Callable[[str], Awaitable[List[str]]],
|
retriever: BaseRetriever,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
answers = []
|
answers = []
|
||||||
for instance in questions:
|
for instance in questions:
|
||||||
query_text = instance["question"]
|
query_text = instance["question"]
|
||||||
correct_answer = instance["answer"]
|
correct_answer = instance["answer"]
|
||||||
|
|
||||||
search_results = await answer_resolver(query_text)
|
retrieval_context = await retriever.get_context(query_text)
|
||||||
|
search_results = await retriever.get_completion(query_text, retrieval_context)
|
||||||
|
|
||||||
answers.append(
|
answers.append(
|
||||||
{
|
{
|
||||||
"question": query_text,
|
"question": query_text,
|
||||||
"answer": search_results[0],
|
"answer": search_results[0],
|
||||||
"golden_answer": correct_answer,
|
"golden_answer": correct_answer,
|
||||||
|
"retrieval_context": retrieval_context,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import json
|
||||||
from typing import List
|
from typing import List
|
||||||
from cognee.eval_framework.answer_generation.answer_generation_executor import (
|
from cognee.eval_framework.answer_generation.answer_generation_executor import (
|
||||||
AnswerGeneratorExecutor,
|
AnswerGeneratorExecutor,
|
||||||
question_answering_engine_options,
|
retriever_options,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
from cognee.infrastructure.databases.relational.get_relational_engine import (
|
from cognee.infrastructure.databases.relational.get_relational_engine import (
|
||||||
|
|
@ -48,9 +48,7 @@ async def run_question_answering(
|
||||||
answer_generator = AnswerGeneratorExecutor()
|
answer_generator = AnswerGeneratorExecutor()
|
||||||
answers = await answer_generator.question_answering_non_parallel(
|
answers = await answer_generator.question_answering_non_parallel(
|
||||||
questions=questions,
|
questions=questions,
|
||||||
answer_resolver=lambda query: question_answering_engine_options[params["qa_engine"]](
|
retriever=retriever_options[params["qa_engine"]](system_prompt_path=system_prompt),
|
||||||
query, system_prompt
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
with open(params["answers_path"], "w", encoding="utf-8") as f:
|
with open(params["answers_path"], "w", encoding="utf-8") as f:
|
||||||
json.dump(answers, f, ensure_ascii=False, indent=4)
|
json.dump(answers, f, ensure_ascii=False, indent=4)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ class EvalConfig(BaseSettings):
|
||||||
|
|
||||||
# Evaluation params
|
# Evaluation params
|
||||||
evaluating_answers: bool = True
|
evaluating_answers: bool = True
|
||||||
|
evaluating_contexts: bool = True
|
||||||
evaluation_engine: str = "DeepEval" # Options: 'DeepEval' (uses deepeval_model), 'DirectLLM' (uses default llm from .env)
|
evaluation_engine: str = "DeepEval" # Options: 'DeepEval' (uses deepeval_model), 'DirectLLM' (uses default llm from .env)
|
||||||
evaluation_metrics: List[str] = [
|
evaluation_metrics: List[str] = [
|
||||||
"correctness",
|
"correctness",
|
||||||
|
|
@ -51,6 +52,7 @@ class EvalConfig(BaseSettings):
|
||||||
"answering_questions": self.answering_questions,
|
"answering_questions": self.answering_questions,
|
||||||
"qa_engine": self.qa_engine,
|
"qa_engine": self.qa_engine,
|
||||||
"evaluating_answers": self.evaluating_answers,
|
"evaluating_answers": self.evaluating_answers,
|
||||||
|
"evaluating_contexts": self.evaluating_contexts, # Controls whether context evaluation should be performed
|
||||||
"evaluation_engine": self.evaluation_engine,
|
"evaluation_engine": self.evaluation_engine,
|
||||||
"evaluation_metrics": self.evaluation_metrics,
|
"evaluation_metrics": self.evaluation_metrics,
|
||||||
"calculate_metrics": self.calculate_metrics,
|
"calculate_metrics": self.calculate_metrics,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from cognee.eval_framework.evaluation.base_eval_adapter import BaseEvalAdapter
|
||||||
from cognee.eval_framework.evaluation.metrics.exact_match import ExactMatchMetric
|
from cognee.eval_framework.evaluation.metrics.exact_match import ExactMatchMetric
|
||||||
from cognee.eval_framework.evaluation.metrics.f1 import F1ScoreMetric
|
from cognee.eval_framework.evaluation.metrics.f1 import F1ScoreMetric
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
from deepeval.metrics import ContextualRelevancyMetric
|
||||||
|
|
||||||
|
|
||||||
class DeepEvalAdapter(BaseEvalAdapter):
|
class DeepEvalAdapter(BaseEvalAdapter):
|
||||||
|
|
@ -13,6 +14,7 @@ class DeepEvalAdapter(BaseEvalAdapter):
|
||||||
"correctness": self.g_eval_correctness(),
|
"correctness": self.g_eval_correctness(),
|
||||||
"EM": ExactMatchMetric(),
|
"EM": ExactMatchMetric(),
|
||||||
"f1": F1ScoreMetric(),
|
"f1": F1ScoreMetric(),
|
||||||
|
"contextual_relevancy": ContextualRelevancyMetric(),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def evaluate_answers(
|
async def evaluate_answers(
|
||||||
|
|
@ -29,6 +31,7 @@ class DeepEvalAdapter(BaseEvalAdapter):
|
||||||
input=answer["question"],
|
input=answer["question"],
|
||||||
actual_output=answer["answer"],
|
actual_output=answer["answer"],
|
||||||
expected_output=answer["golden_answer"],
|
expected_output=answer["golden_answer"],
|
||||||
|
retrieval_context=[answer["retrieval_context"]],
|
||||||
)
|
)
|
||||||
metric_results = {}
|
metric_results = {}
|
||||||
for metric in evaluator_metrics:
|
for metric in evaluator_metrics:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,11 @@ from cognee.eval_framework.evaluation.evaluator_adapters import EvaluatorAdapter
|
||||||
|
|
||||||
|
|
||||||
class EvaluationExecutor:
|
class EvaluationExecutor:
|
||||||
def __init__(self, evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEval") -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
evaluator_engine: Union[str, EvaluatorAdapter, Any] = "DeepEval",
|
||||||
|
evaluate_contexts: bool = False,
|
||||||
|
) -> None:
|
||||||
if isinstance(evaluator_engine, str):
|
if isinstance(evaluator_engine, str):
|
||||||
try:
|
try:
|
||||||
adapter_enum = EvaluatorAdapter(evaluator_engine)
|
adapter_enum = EvaluatorAdapter(evaluator_engine)
|
||||||
|
|
@ -14,7 +18,10 @@ class EvaluationExecutor:
|
||||||
self.eval_adapter = evaluator_engine.adapter_class()
|
self.eval_adapter = evaluator_engine.adapter_class()
|
||||||
else:
|
else:
|
||||||
self.eval_adapter = evaluator_engine
|
self.eval_adapter = evaluator_engine
|
||||||
|
self.evaluate_contexts = evaluate_contexts
|
||||||
|
|
||||||
async def execute(self, answers: List[Dict[str, str]], evaluator_metrics: Any) -> Any:
|
async def execute(self, answers: List[Dict[str, str]], evaluator_metrics: Any) -> Any:
|
||||||
|
if self.evaluate_contexts:
|
||||||
|
evaluator_metrics.append("contextual_relevancy")
|
||||||
metrics = await self.eval_adapter.evaluate_answers(answers, evaluator_metrics)
|
metrics = await self.eval_adapter.evaluate_answers(answers, evaluator_metrics)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,10 @@ async def execute_evaluation(params: dict) -> None:
|
||||||
raise ValueError(f"Error decoding JSON from {params['answers_path']}: {e}")
|
raise ValueError(f"Error decoding JSON from {params['answers_path']}: {e}")
|
||||||
|
|
||||||
logging.info(f"Loaded {len(answers)} answers from {params['answers_path']}")
|
logging.info(f"Loaded {len(answers)} answers from {params['answers_path']}")
|
||||||
evaluator = EvaluationExecutor(evaluator_engine=params["evaluation_engine"])
|
evaluator = EvaluationExecutor(
|
||||||
|
evaluator_engine=params["evaluation_engine"],
|
||||||
|
evaluate_contexts=params["evaluating_contexts"],
|
||||||
|
)
|
||||||
metrics = await evaluator.execute(
|
metrics = await evaluator.execute(
|
||||||
answers=answers, evaluator_metrics=params["evaluation_metrics"]
|
answers=answers, evaluator_metrics=params["evaluation_metrics"]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -11,14 +11,18 @@ async def test_answer_generation():
|
||||||
limit = 1
|
limit = 1
|
||||||
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
|
corpus_list, qa_pairs = DummyAdapter().load_corpus(limit=limit)
|
||||||
|
|
||||||
mock_answer_resolver = AsyncMock()
|
mock_retriever = AsyncMock()
|
||||||
mock_answer_resolver.side_effect = lambda query: ["mock_answer"]
|
mock_retriever.get_context = AsyncMock(return_value="Mocked retrieval context")
|
||||||
|
mock_retriever.get_completion = AsyncMock(return_value=["Mocked answer"])
|
||||||
|
|
||||||
answer_generator = AnswerGeneratorExecutor()
|
answer_generator = AnswerGeneratorExecutor()
|
||||||
answers = await answer_generator.question_answering_non_parallel(
|
answers = await answer_generator.question_answering_non_parallel(
|
||||||
questions=qa_pairs, answer_resolver=mock_answer_resolver
|
questions=qa_pairs,
|
||||||
|
retriever=mock_retriever,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mock_retriever.get_context.assert_any_await(qa_pairs[0]["question"])
|
||||||
|
|
||||||
assert len(answers) == len(qa_pairs)
|
assert len(answers) == len(qa_pairs)
|
||||||
assert answers[0]["question"] == qa_pairs[0]["question"], (
|
assert answers[0]["question"] == qa_pairs[0]["question"], (
|
||||||
"AnswerGeneratorExecutor is passing the question incorrectly"
|
"AnswerGeneratorExecutor is passing the question incorrectly"
|
||||||
|
|
@ -26,6 +30,6 @@ async def test_answer_generation():
|
||||||
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
|
assert answers[0]["golden_answer"] == qa_pairs[0]["answer"], (
|
||||||
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
|
"AnswerGeneratorExecutor is passing the golden answer incorrectly"
|
||||||
)
|
)
|
||||||
assert answers[0]["answer"] == "mock_answer", (
|
assert answers[0]["answer"] == "Mocked answer", (
|
||||||
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
|
"AnswerGeneratorExecutor is passing the generated answer incorrectly"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue