Calculate official hotpot EM and F1 scores (#292)
This commit is contained in:
parent
6d85165189
commit
4f2745504c
3 changed files with 161 additions and 13 deletions
|
|
@ -1,14 +1,72 @@
|
|||
from deepeval.metrics import GEval
|
||||
from deepeval.test_case import LLMTestCaseParams
|
||||
from deepeval.metrics import BaseMetric, GEval
|
||||
from deepeval.test_case import LLMTestCase, LLMTestCaseParams
|
||||
|
||||
from evals.official_hotpot_metrics import exact_match_score, f1_score
|
||||
|
||||
correctness_metric = GEval(
|
||||
name="Correctness",
|
||||
model="gpt-4o-mini",
|
||||
evaluation_params=[
|
||||
LLMTestCaseParams.ACTUAL_OUTPUT,
|
||||
LLMTestCaseParams.EXPECTED_OUTPUT
|
||||
],
|
||||
evaluation_steps=[
|
||||
"Determine whether the actual output is factually correct based on the expected output."
|
||||
]
|
||||
)
|
||||
name="Correctness",
|
||||
model="gpt-4o-mini",
|
||||
evaluation_params=[
|
||||
LLMTestCaseParams.ACTUAL_OUTPUT,
|
||||
LLMTestCaseParams.EXPECTED_OUTPUT
|
||||
],
|
||||
evaluation_steps=[
|
||||
"Determine whether the actual output is factually correct based on the expected output."
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class f1_score_metric(BaseMetric):
|
||||
|
||||
"""F1 score taken directly from the official hotpot benchmark
|
||||
implementation and wrapped into a deepeval metric."""
|
||||
|
||||
def __init__(self, threshold: float = 0.5):
|
||||
self.threshold = threshold
|
||||
|
||||
def measure(self, test_case: LLMTestCase):
|
||||
f1, precision, recall = f1_score(
|
||||
prediction=test_case.actual_output,
|
||||
ground_truth=test_case.expected_output,
|
||||
)
|
||||
self.score = f1
|
||||
self.success = self.score >= self.threshold
|
||||
return self.score
|
||||
|
||||
# Reusing regular measure as async F1 score is not implemented
|
||||
async def a_measure(self, test_case: LLMTestCase):
|
||||
return self.measure(test_case)
|
||||
|
||||
def is_successful(self):
|
||||
return self.success
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return "Official hotpot F1 score"
|
||||
|
||||
class em_score_metric(BaseMetric):
|
||||
|
||||
"""Exact Match score taken directly from the official hotpot benchmark
|
||||
implementation and wrapped into a deepeval metric."""
|
||||
|
||||
def __init__(self, threshold: float = 0.5):
|
||||
self.threshold = threshold
|
||||
|
||||
def measure(self, test_case: LLMTestCase):
|
||||
self.score = exact_match_score(
|
||||
prediction=test_case.actual_output,
|
||||
ground_truth=test_case.expected_output,
|
||||
)
|
||||
self.success = self.score >= self.threshold
|
||||
return self.score
|
||||
|
||||
# Reusing regular measure as async F1 score is not implemented
|
||||
async def a_measure(self, test_case: LLMTestCase):
|
||||
return self.measure(test_case)
|
||||
|
||||
def is_successful(self):
|
||||
return self.success
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return "Official hotpot EM score"
|
||||
|
|
@ -111,7 +111,9 @@ if __name__ == "__main__":
|
|||
|
||||
parser.add_argument("--with_cognee", action="store_true")
|
||||
parser.add_argument("--num_samples", type=int, default=500)
|
||||
parser.add_argument("--metric", type=str, default="correctness_metric")
|
||||
parser.add_argument("--metric", type=str, default="correctness_metric",
|
||||
help="Valid options are Deepeval metrics (e.g. AnswerRelevancyMetric) \
|
||||
and metrics defined in evals/deepeval_metrics.py, e.g. f1_score_metric")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -120,6 +122,8 @@ if __name__ == "__main__":
|
|||
metric = metric_cls()
|
||||
except AttributeError:
|
||||
metric = getattr(evals.deepeval_metrics, args.metric)
|
||||
if isinstance(metric, type):
|
||||
metric = metric()
|
||||
|
||||
if args.with_cognee:
|
||||
answer_provider = answer_with_cognee
|
||||
86
evals/official_hotpot_metrics.py
Normal file
86
evals/official_hotpot_metrics.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
These are the official evaluation metrics for HotpotQA taken from https://hotpotqa.github.io/
|
||||
"""
|
||||
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
import ujson as json
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
normalized_prediction = normalize_answer(prediction)
|
||||
normalized_ground_truth = normalize_answer(ground_truth)
|
||||
|
||||
ZERO_METRIC = (0, 0, 0)
|
||||
|
||||
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
|
||||
return ZERO_METRIC
|
||||
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
|
||||
return ZERO_METRIC
|
||||
|
||||
prediction_tokens = normalized_prediction.split()
|
||||
ground_truth_tokens = normalized_ground_truth.split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return ZERO_METRIC
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1, precision, recall
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return (normalize_answer(prediction) == normalize_answer(ground_truth))
|
||||
|
||||
def update_answer(metrics, prediction, gold):
|
||||
em = exact_match_score(prediction, gold)
|
||||
f1, prec, recall = f1_score(prediction, gold)
|
||||
metrics['em'] += float(em)
|
||||
metrics['f1'] += f1
|
||||
metrics['prec'] += prec
|
||||
metrics['recall'] += recall
|
||||
return em, prec, recall
|
||||
|
||||
def update_sp(metrics, prediction, gold):
|
||||
cur_sp_pred = set(map(tuple, prediction))
|
||||
gold_sp_pred = set(map(tuple, gold))
|
||||
tp, fp, fn = 0, 0, 0
|
||||
for e in cur_sp_pred:
|
||||
if e in gold_sp_pred:
|
||||
tp += 1
|
||||
else:
|
||||
fp += 1
|
||||
for e in gold_sp_pred:
|
||||
if e not in cur_sp_pred:
|
||||
fn += 1
|
||||
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
|
||||
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
|
||||
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
|
||||
em = 1.0 if fp + fn == 0 else 0.0
|
||||
metrics['sp_em'] += em
|
||||
metrics['sp_f1'] += f1
|
||||
metrics['sp_prec'] += prec
|
||||
metrics['sp_recall'] += recall
|
||||
return em, prec, recall
|
||||
Loading…
Add table
Reference in a new issue