Eval function takes eval_metric as input. Works with deepeval metrics like AnswerRelevancyMetric

This commit is contained in:
Rita Aleksziev 2024-11-27 16:14:05 +01:00
parent f47b185a9e
commit 4aa634d5e1

View file

@ -4,13 +4,14 @@ import json
import statistics import statistics
from pathlib import Path from pathlib import Path
import deepeval.metrics
import wget import wget
from deepeval.dataset import EvaluationDataset from deepeval.dataset import EvaluationDataset
from deepeval.metrics import GEval from deepeval.test_case import LLMTestCase
from deepeval.test_case import LLMTestCase, LLMTestCaseParams
from tqdm import tqdm from tqdm import tqdm
import cognee import cognee
import evals.deepeval_metrics
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
@ -34,7 +35,6 @@ async def answer_without_cognee(instance):
return answer_prediction return answer_prediction
async def answer_with_cognee(instance): async def answer_with_cognee(instance):
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
for (title, sentences) in instance["context"]: for (title, sentences) in instance["context"]:
@ -60,20 +60,23 @@ async def answer_with_cognee(instance):
) )
return answer_prediction return answer_prediction
correctness_metric = GEval(
name="Correctness",
model="gpt-4o-mini",
evaluation_params=[
LLMTestCaseParams.ACTUAL_OUTPUT,
LLMTestCaseParams.EXPECTED_OUTPUT
],
evaluation_steps=[
"Determine whether the actual output is factually correct based on the expected output."
]
)
async def eval_answers(instances, answers, eval_metric):
test_cases = []
for i in range(len(answers)):
instance = instances[i]
answer = answers[i]
test_case = LLMTestCase(
input=instance["question"],
actual_output=answer,
expected_output=instance["answer"]
)
test_cases.append(test_case)
evalset = EvaluationDataset(test_cases)
evalresults = evalset.evaluate([eval_metric])
return evalresults
async def eval_correctness(with_cognee=True, num_samples=None): async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
base_config = get_base_config() base_config = get_base_config()
data_root_dir = base_config.data_root_directory data_root_dir = base_config.data_root_directory
filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json") filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json")
@ -82,29 +85,32 @@ async def eval_correctness(with_cognee=True, num_samples=None):
wget.download(url, out=data_root_dir) wget.download(url, out=data_root_dir)
with open(filepath, "r") as file: with open(filepath, "r") as file:
dataset = json.load(file) dataset = json.load(file)
test_cases = []
if not num_samples: if not num_samples:
num_samples = len(dataset) num_samples = len(dataset)
for instance in tqdm(dataset[:num_samples], desc="Evaluating correctness"): instances = dataset[:num_samples]
if with_cognee: answers = []
answer = await answer_with_cognee(instance) for instance in tqdm(instances, desc="Getting answers"):
else: answer = await answer_provider(instance)
answer = await answer_without_cognee(instance) answers.append(answer)
test_case = LLMTestCase( evalresults = await eval_answers(instances, answers, eval_metric)
input=instance["question"], avg_score = statistics.mean([result.metrics_data[0].score for result in evalresults.test_results])
actual_output=answer, return avg_score
expected_output=instance["answer"]
)
test_cases.append(test_case)
evalset = EvaluationDataset(test_cases)
evalresults = evalset.evaluate([correctness_metric])
avg_correctness = statistics.mean([result.metrics_data[0].score for result in evalresults.test_results])
return avg_correctness
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--with_cognee", action="store_true") parser.add_argument("--with_cognee", action="store_true")
parser.add_argument("--num_samples", type=int, default=500) parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--metric", type=str, default="correctness_metric")
args = parser.parse_args() args = parser.parse_args()
avg_correctness = asyncio.run(eval_correctness(args.with_cognee, args.num_samples))
print(f"Average correctness: {avg_correctness}") try:
metric_cls = getattr(deepeval.metrics, args.metric)
metric = metric_cls()
except AttributeError:
metric = getattr(evals.deepeval_metrics, args.metric)
if args.with_cognee:
answer_provider = answer_with_cognee
else:
answer_provider = answer_without_cognee
avg_score = asyncio.run(eval_on_hotpotQA(answer_provider, args.num_samples, metric))
print(f"Average {args.metric}: {avg_score}")