229 lines
5.9 KiB
Text
229 lines
5.9 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Evaluation on the hotpotQA dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from evals.eval_on_hotpot import eval_on_hotpotQA\n",
|
|
"from evals.eval_on_hotpot import answer_with_cognee\n",
|
|
"from evals.eval_on_hotpot import answer_without_cognee\n",
|
|
"from evals.eval_on_hotpot import eval_answers\n",
|
|
"from cognee.base_config import get_base_config\n",
|
|
"from pathlib import Path\n",
|
|
"from tqdm import tqdm\n",
|
|
"import wget\n",
|
|
"import json\n",
|
|
"import statistics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Getting the answers for the first num_samples questions of the dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"answer_provider = answer_with_cognee # For native LLM answers use answer_without_cognee\n",
|
|
"num_samples = 10 # With cognee, it takes ~1m10s per sample\n",
|
|
"\n",
|
|
"base_config = get_base_config()\n",
|
|
"data_root_dir = base_config.data_root_directory\n",
|
|
"\n",
|
|
"if not Path(data_root_dir).exists():\n",
|
|
" Path(data_root_dir).mkdir()\n",
|
|
"\n",
|
|
"filepath = data_root_dir / Path(\"hotpot_dev_fullwiki_v1.json\")\n",
|
|
"if not filepath.exists():\n",
|
|
" url = \"http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json\"\n",
|
|
" wget.download(url, out=data_root_dir)\n",
|
|
"\n",
|
|
"with open(filepath, \"r\") as file:\n",
|
|
" dataset = json.load(file)\n",
|
|
"\n",
|
|
"instances = dataset if not num_samples else dataset[:num_samples]\n",
|
|
"answers = []\n",
|
|
"for instance in tqdm(instances, desc=\"Getting answers\"):\n",
|
|
" answer = await answer_provider(instance)\n",
|
|
" answers.append(answer)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Calculating the official HotpotQA benchmark metrics: F1 score and EM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from evals.deepeval_metrics import f1_score_metric\n",
|
|
"from evals.deepeval_metrics import em_score_metric"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"f1_metric = f1_score_metric()\n",
|
|
"eval_results = await eval_answers(instances, answers, f1_metric)\n",
|
|
"avg_f1_score = statistics.mean(\n",
|
|
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
|
")\n",
|
|
"print(\"F1 score: \", avg_f1_score)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"em_metric = em_score_metric()\n",
|
|
"eval_results = await eval_answers(instances, answers, em_metric)\n",
|
|
"avg_em_score = statistics.mean(\n",
|
|
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
|
")\n",
|
|
"print(\"EM score: \", avg_em_score)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Calculating a custom metric called Correctness\n",
|
|
"##### Correctness is judged by an LLM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from evals.deepeval_metrics import correctness_metric"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"eval_results = await eval_answers(\n",
|
|
" instances, answers, correctness_metric\n",
|
|
") # note that instantiation is not needed for correctness metric as it is already an instance\n",
|
|
"avg_correctness_score = statistics.mean(\n",
|
|
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
|
")\n",
|
|
"print(\"Correctness score: \", avg_correctness_score)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Using a metric from Deepeval"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from deepeval.metrics import AnswerRelevancyMetric"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"relevancy_metric = AnswerRelevancyMetric()\n",
|
|
"eval_results = await eval_answers(\n",
|
|
" instances, answers, relevancy_metric\n",
|
|
") # note that instantiation is not needed for correctness metric as it is already an instance\n",
|
|
"avg_relevancy_score = statistics.mean(\n",
|
|
" [result.metrics_data[0].score for result in eval_results.test_results]\n",
|
|
")\n",
|
|
"print(\"Relevancy score: \", avg_relevancy_score)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Answering and eval in one step"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"answer_provider = answer_without_cognee\n",
|
|
"f1_metric = f1_score_metric()\n",
|
|
"f1_score = await eval_on_hotpotQA(\n",
|
|
" answer_provider, num_samples=10, eval_metric=f1_metric\n",
|
|
") # takes ~1m10s per sample\n",
|
|
"print(\"F1 score: \", f1_score)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "myenv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.20"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|