Feat/cog 946 abstract eval dataset (#418)

* QA eval dataset as argument, with hotpot and 2wikimultihop as options. Json schema validation for datasets.

* Load dataset file by filename, outsource utilities

* Use requests.get instead of wget
This commit is contained in:
alekszievr 2025-01-14 11:33:55 +01:00 committed by GitHub
parent 12031e6c43
commit a4ad1702ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 93 additions and 26 deletions

View file

@ -1,11 +1,7 @@
import argparse
import asyncio
import json
import statistics
from pathlib import Path
import deepeval.metrics
import wget
from deepeval.dataset import EvaluationDataset
from deepeval.test_case import LLMTestCase
from tqdm import tqdm
@ -13,9 +9,9 @@ from tqdm import tqdm
import cognee
import evals.deepeval_metrics
from cognee.api.v1.search import SearchType
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from evals.qa_dataset_utils import load_qa_dataset
async def answer_without_cognee(instance):
@ -40,12 +36,8 @@ async def answer_with_cognee(instance):
await cognee.prune.prune_system(metadata=True)
for title, sentences in instance["context"]:
await cognee.add("\n".join(sentences), dataset_name="HotPotQA")
for n in range(1, 4):
print(n)
await cognee.cognify("HotPotQA")
await cognee.add("\n".join(sentences), dataset_name="QA")
await cognee.cognify("QA")
search_results = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
search_results_second = await cognee.search(
@ -85,20 +77,10 @@ async def eval_answers(instances, answers, eval_metric):
return eval_results
async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
base_config = get_base_config()
data_root_dir = base_config.data_root_directory
if not Path(data_root_dir).exists():
Path(data_root_dir).mkdir()
filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json")
if not filepath.exists():
url = "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json"
wget.download(url, out=data_root_dir)
with open(filepath, "r") as file:
dataset = json.load(file)
async def eval_on_QA_dataset(
dataset_name_or_filename: str, answer_provider, num_samples, eval_metric
):
dataset = load_qa_dataset(dataset_name_or_filename)
instances = dataset if not num_samples else dataset[:num_samples]
answers = []
@ -117,6 +99,7 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, help="Which dataset to evaluate on")
parser.add_argument("--with_cognee", action="store_true")
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument(
@ -142,5 +125,7 @@ if __name__ == "__main__":
else:
answer_provider = answer_without_cognee
avg_score = asyncio.run(eval_on_hotpotQA(answer_provider, args.num_samples, metric))
avg_score = asyncio.run(
eval_on_QA_dataset(args.dataset, answer_provider, args.num_samples, metric)
)
print(f"Average {args.metric}: {avg_score}")

82
evals/qa_dataset_utils.py Normal file
View file

@ -0,0 +1,82 @@
from cognee.root_dir import get_absolute_path
import json
import requests
from jsonschema import ValidationError, validate
from pathlib import Path
qa_datasets = {
"hotpotqa": {
"filename": "hotpot_dev_fullwiki_v1.json",
"URL": "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json",
},
"2wikimultihop": {
"filename": "data/dev.json",
"URL": "https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1",
},
}
qa_json_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"answer": {"type": "string"},
"question": {"type": "string"},
"context": {"type": "array"},
},
"required": ["answer", "question", "context"],
"additionalProperties": True,
},
}
def download_qa_dataset(dataset_name: str, filepath: Path):
if dataset_name not in qa_datasets:
raise ValueError(f"{dataset_name} is not a supported dataset.")
url = qa_datasets[dataset_name]["URL"]
if dataset_name == "2wikimultihop":
raise Exception(
"Please download 2wikimultihop dataset (data.zip) manually from \
https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1 \
and unzip it."
)
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(filepath, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Dataset {dataset_name} downloaded and saved to {filepath}")
else:
print(f"Failed to download {dataset_name}. Status code: {response.status_code}")
def load_qa_dataset(dataset_name_or_filename: str):
if dataset_name_or_filename in qa_datasets:
dataset_name = dataset_name_or_filename
filename = qa_datasets[dataset_name]["filename"]
data_root_dir = get_absolute_path("../.data")
if not Path(data_root_dir).exists():
Path(data_root_dir).mkdir()
filepath = data_root_dir / Path(filename)
if not filepath.exists():
download_qa_dataset(dataset_name, filepath)
else:
filename = dataset_name_or_filename
filepath = Path(filename)
with open(filepath, "r") as file:
dataset = json.load(file)
try:
validate(instance=dataset, schema=qa_json_schema)
except ValidationError as e:
print("File is not a valid QA dataset:", e.message)
return dataset