Feat/cog 946 abstract eval dataset (#418)

* QA eval dataset as argument, with hotpot and 2wikimultihop as options. Json schema validation for datasets.

* Load dataset file by filename, outsource utilities

* Use requests.get instead of wget
This commit is contained in:
alekszievr 2025-01-14 11:33:55 +01:00 committed by GitHub
parent 12031e6c43
commit a4ad1702ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 93 additions and 26 deletions

View file

@ -1,11 +1,7 @@
import argparse import argparse
import asyncio import asyncio
import json
import statistics import statistics
from pathlib import Path
import deepeval.metrics import deepeval.metrics
import wget
from deepeval.dataset import EvaluationDataset from deepeval.dataset import EvaluationDataset
from deepeval.test_case import LLMTestCase from deepeval.test_case import LLMTestCase
from tqdm import tqdm from tqdm import tqdm
@ -13,9 +9,9 @@ from tqdm import tqdm
import cognee import cognee
import evals.deepeval_metrics import evals.deepeval_metrics
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
from evals.qa_dataset_utils import load_qa_dataset
async def answer_without_cognee(instance): async def answer_without_cognee(instance):
@ -40,12 +36,8 @@ async def answer_with_cognee(instance):
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)
for title, sentences in instance["context"]: for title, sentences in instance["context"]:
await cognee.add("\n".join(sentences), dataset_name="HotPotQA") await cognee.add("\n".join(sentences), dataset_name="QA")
await cognee.cognify("QA")
for n in range(1, 4):
print(n)
await cognee.cognify("HotPotQA")
search_results = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"]) search_results = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
search_results_second = await cognee.search( search_results_second = await cognee.search(
@ -85,20 +77,10 @@ async def eval_answers(instances, answers, eval_metric):
return eval_results return eval_results
async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric): async def eval_on_QA_dataset(
base_config = get_base_config() dataset_name_or_filename: str, answer_provider, num_samples, eval_metric
data_root_dir = base_config.data_root_directory ):
dataset = load_qa_dataset(dataset_name_or_filename)
if not Path(data_root_dir).exists():
Path(data_root_dir).mkdir()
filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json")
if not filepath.exists():
url = "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json"
wget.download(url, out=data_root_dir)
with open(filepath, "r") as file:
dataset = json.load(file)
instances = dataset if not num_samples else dataset[:num_samples] instances = dataset if not num_samples else dataset[:num_samples]
answers = [] answers = []
@ -117,6 +99,7 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, help="Which dataset to evaluate on")
parser.add_argument("--with_cognee", action="store_true") parser.add_argument("--with_cognee", action="store_true")
parser.add_argument("--num_samples", type=int, default=500) parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument( parser.add_argument(
@ -142,5 +125,7 @@ if __name__ == "__main__":
else: else:
answer_provider = answer_without_cognee answer_provider = answer_without_cognee
avg_score = asyncio.run(eval_on_hotpotQA(answer_provider, args.num_samples, metric)) avg_score = asyncio.run(
eval_on_QA_dataset(args.dataset, answer_provider, args.num_samples, metric)
)
print(f"Average {args.metric}: {avg_score}") print(f"Average {args.metric}: {avg_score}")

82
evals/qa_dataset_utils.py Normal file
View file

@ -0,0 +1,82 @@
from cognee.root_dir import get_absolute_path
import json
import requests
from jsonschema import ValidationError, validate
from pathlib import Path
qa_datasets = {
"hotpotqa": {
"filename": "hotpot_dev_fullwiki_v1.json",
"URL": "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json",
},
"2wikimultihop": {
"filename": "data/dev.json",
"URL": "https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1",
},
}
qa_json_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"answer": {"type": "string"},
"question": {"type": "string"},
"context": {"type": "array"},
},
"required": ["answer", "question", "context"],
"additionalProperties": True,
},
}
def download_qa_dataset(dataset_name: str, filepath: Path):
if dataset_name not in qa_datasets:
raise ValueError(f"{dataset_name} is not a supported dataset.")
url = qa_datasets[dataset_name]["URL"]
if dataset_name == "2wikimultihop":
raise Exception(
"Please download 2wikimultihop dataset (data.zip) manually from \
https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1 \
and unzip it."
)
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(filepath, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Dataset {dataset_name} downloaded and saved to {filepath}")
else:
print(f"Failed to download {dataset_name}. Status code: {response.status_code}")
def load_qa_dataset(dataset_name_or_filename: str):
if dataset_name_or_filename in qa_datasets:
dataset_name = dataset_name_or_filename
filename = qa_datasets[dataset_name]["filename"]
data_root_dir = get_absolute_path("../.data")
if not Path(data_root_dir).exists():
Path(data_root_dir).mkdir()
filepath = data_root_dir / Path(filename)
if not filepath.exists():
download_qa_dataset(dataset_name, filepath)
else:
filename = dataset_name_or_filename
filepath = Path(filename)
with open(filepath, "r") as file:
dataset = json.load(file)
try:
validate(instance=dataset, schema=qa_json_schema)
except ValidationError as e:
print("File is not a valid QA dataset:", e.message)
return dataset