93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
"""
|
|
These are the official evaluation metrics for HotpotQA taken from https://hotpotqa.github.io/
|
|
"""
|
|
|
|
import re
|
|
import string
|
|
import sys
|
|
from collections import Counter
|
|
|
|
import ujson as json
|
|
|
|
|
|
def normalize_answer(s):
|
|
def remove_articles(text):
|
|
return re.sub(r"\b(a|an|the)\b", " ", text)
|
|
|
|
def white_space_fix(text):
|
|
return " ".join(text.split())
|
|
|
|
def remove_punc(text):
|
|
exclude = set(string.punctuation)
|
|
return "".join(ch for ch in text if ch not in exclude)
|
|
|
|
def lower(text):
|
|
return text.lower()
|
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
|
|
|
|
|
def f1_score(prediction, ground_truth):
|
|
normalized_prediction = normalize_answer(prediction)
|
|
normalized_ground_truth = normalize_answer(ground_truth)
|
|
|
|
ZERO_METRIC = (0, 0, 0)
|
|
|
|
if (
|
|
normalized_prediction in ["yes", "no", "noanswer"]
|
|
and normalized_prediction != normalized_ground_truth
|
|
):
|
|
return ZERO_METRIC
|
|
if (
|
|
normalized_ground_truth in ["yes", "no", "noanswer"]
|
|
and normalized_prediction != normalized_ground_truth
|
|
):
|
|
return ZERO_METRIC
|
|
|
|
prediction_tokens = normalized_prediction.split()
|
|
ground_truth_tokens = normalized_ground_truth.split()
|
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
|
num_same = sum(common.values())
|
|
if num_same == 0:
|
|
return ZERO_METRIC
|
|
precision = 1.0 * num_same / len(prediction_tokens)
|
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
|
f1 = (2 * precision * recall) / (precision + recall)
|
|
return f1, precision, recall
|
|
|
|
|
|
def exact_match_score(prediction, ground_truth):
|
|
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
|
|
|
|
|
def update_answer(metrics, prediction, gold):
|
|
em = exact_match_score(prediction, gold)
|
|
f1, prec, recall = f1_score(prediction, gold)
|
|
metrics["em"] += float(em)
|
|
metrics["f1"] += f1
|
|
metrics["prec"] += prec
|
|
metrics["recall"] += recall
|
|
return em, prec, recall
|
|
|
|
|
|
def update_sp(metrics, prediction, gold):
|
|
cur_sp_pred = set(map(tuple, prediction))
|
|
gold_sp_pred = set(map(tuple, gold))
|
|
tp, fp, fn = 0, 0, 0
|
|
for e in cur_sp_pred:
|
|
if e in gold_sp_pred:
|
|
tp += 1
|
|
else:
|
|
fp += 1
|
|
for e in gold_sp_pred:
|
|
if e not in cur_sp_pred:
|
|
fn += 1
|
|
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
|
|
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
|
|
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
|
|
em = 1.0 if fp + fn == 0 else 0.0
|
|
metrics["sp_em"] += em
|
|
metrics["sp_f1"] += f1
|
|
metrics["sp_prec"] += prec
|
|
metrics["sp_recall"] += recall
|
|
return em, prec, recall
|