diff --git a/evals/llm_as_a_judge.py b/evals/llm_as_a_judge.py index 239c7aea9..8dd1518a7 100644 --- a/evals/llm_as_a_judge.py +++ b/evals/llm_as_a_judge.py @@ -4,13 +4,14 @@ import json import statistics from pathlib import Path +import deepeval.metrics import wget from deepeval.dataset import EvaluationDataset -from deepeval.metrics import GEval -from deepeval.test_case import LLMTestCase, LLMTestCaseParams +from deepeval.test_case import LLMTestCase from tqdm import tqdm import cognee +import evals.deepeval_metrics from cognee.api.v1.search import SearchType from cognee.base_config import get_base_config from cognee.infrastructure.llm.get_llm_client import get_llm_client @@ -34,7 +35,6 @@ async def answer_without_cognee(instance): return answer_prediction async def answer_with_cognee(instance): - await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) for (title, sentences) in instance["context"]: @@ -60,20 +60,23 @@ async def answer_with_cognee(instance): ) return answer_prediction -correctness_metric = GEval( - name="Correctness", - model="gpt-4o-mini", - evaluation_params=[ - LLMTestCaseParams.ACTUAL_OUTPUT, - LLMTestCaseParams.EXPECTED_OUTPUT - ], - evaluation_steps=[ - "Determine whether the actual output is factually correct based on the expected output." - ] - ) +async def eval_answers(instances, answers, eval_metric): + test_cases = [] + for i in range(len(answers)): + instance = instances[i] + answer = answers[i] + test_case = LLMTestCase( + input=instance["question"], + actual_output=answer, + expected_output=instance["answer"] + ) + test_cases.append(test_case) + evalset = EvaluationDataset(test_cases) + evalresults = evalset.evaluate([eval_metric]) + return evalresults -async def eval_correctness(with_cognee=True, num_samples=None): +async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric): base_config = get_base_config() data_root_dir = base_config.data_root_directory filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json") @@ -82,29 +85,32 @@ async def eval_correctness(with_cognee=True, num_samples=None): wget.download(url, out=data_root_dir) with open(filepath, "r") as file: dataset = json.load(file) - test_cases = [] if not num_samples: num_samples = len(dataset) - for instance in tqdm(dataset[:num_samples], desc="Evaluating correctness"): - if with_cognee: - answer = await answer_with_cognee(instance) - else: - answer = await answer_without_cognee(instance) - test_case = LLMTestCase( - input=instance["question"], - actual_output=answer, - expected_output=instance["answer"] - ) - test_cases.append(test_case) - evalset = EvaluationDataset(test_cases) - evalresults = evalset.evaluate([correctness_metric]) - avg_correctness = statistics.mean([result.metrics_data[0].score for result in evalresults.test_results]) - return avg_correctness + instances = dataset[:num_samples] + answers = [] + for instance in tqdm(instances, desc="Getting answers"): + answer = await answer_provider(instance) + answers.append(answer) + evalresults = await eval_answers(instances, answers, eval_metric) + avg_score = statistics.mean([result.metrics_data[0].score for result in evalresults.test_results]) + return avg_score if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--with_cognee", action="store_true") parser.add_argument("--num_samples", type=int, default=500) + parser.add_argument("--metric", type=str, default="correctness_metric") args = parser.parse_args() - avg_correctness = asyncio.run(eval_correctness(args.with_cognee, args.num_samples)) - print(f"Average correctness: {avg_correctness}") \ No newline at end of file + + try: + metric_cls = getattr(deepeval.metrics, args.metric) + metric = metric_cls() + except AttributeError: + metric = getattr(evals.deepeval_metrics, args.metric) + if args.with_cognee: + answer_provider = answer_with_cognee + else: + answer_provider = answer_without_cognee + avg_score = asyncio.run(eval_on_hotpotQA(answer_provider, args.num_samples, metric)) + print(f"Average {args.metric}: {avg_score}") \ No newline at end of file