Merge branch 'dev' into COG-975
This commit is contained in:
commit
aef7822758
11 changed files with 536 additions and 126 deletions
9
cognee/infrastructure/llm/prompts/llm_judge_prompts.py
Normal file
9
cognee/infrastructure/llm/prompts/llm_judge_prompts.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# LLM-as-a-judge metrics as described here: https://arxiv.org/abs/2404.16130
|
||||
|
||||
llm_judge_prompts = {
|
||||
"correctness": "Determine whether the actual output is factually correct based on the expected output.",
|
||||
"comprehensiveness": "Determine how much detail the answer provides to cover all the aspects and details of the question.",
|
||||
"diversity": "Determine how varied and rich the answer is in providing different perspectives and insights on the question.",
|
||||
"empowerment": "Determine how well the answer helps the reader understand and make informed judgements about the topic.",
|
||||
"directness": "Determine how specifically and clearly the answer addresses the question.",
|
||||
}
|
||||
|
|
@ -2,14 +2,57 @@ from deepeval.metrics import BaseMetric, GEval
|
|||
from deepeval.test_case import LLMTestCase, LLMTestCaseParams
|
||||
|
||||
from evals.official_hotpot_metrics import exact_match_score, f1_score
|
||||
from cognee.infrastructure.llm.prompts.llm_judge_prompts import llm_judge_prompts
|
||||
|
||||
correctness_metric = GEval(
|
||||
name="Correctness",
|
||||
model="gpt-4o-mini",
|
||||
evaluation_params=[LLMTestCaseParams.ACTUAL_OUTPUT, LLMTestCaseParams.EXPECTED_OUTPUT],
|
||||
evaluation_steps=[
|
||||
"Determine whether the actual output is factually correct based on the expected output."
|
||||
evaluation_steps=[llm_judge_prompts["correctness"]],
|
||||
)
|
||||
|
||||
comprehensiveness_metric = GEval(
|
||||
name="Comprehensiveness",
|
||||
model="gpt-4o-mini",
|
||||
evaluation_params=[
|
||||
LLMTestCaseParams.INPUT,
|
||||
LLMTestCaseParams.ACTUAL_OUTPUT,
|
||||
LLMTestCaseParams.EXPECTED_OUTPUT,
|
||||
],
|
||||
evaluation_steps=[llm_judge_prompts["comprehensiveness"]],
|
||||
)
|
||||
|
||||
diversity_metric = GEval(
|
||||
name="Diversity",
|
||||
model="gpt-4o-mini",
|
||||
evaluation_params=[
|
||||
LLMTestCaseParams.INPUT,
|
||||
LLMTestCaseParams.ACTUAL_OUTPUT,
|
||||
LLMTestCaseParams.EXPECTED_OUTPUT,
|
||||
],
|
||||
evaluation_steps=[llm_judge_prompts["diversity"]],
|
||||
)
|
||||
|
||||
empowerment_metric = GEval(
|
||||
name="Empowerment",
|
||||
model="gpt-4o-mini",
|
||||
evaluation_params=[
|
||||
LLMTestCaseParams.INPUT,
|
||||
LLMTestCaseParams.ACTUAL_OUTPUT,
|
||||
LLMTestCaseParams.EXPECTED_OUTPUT,
|
||||
],
|
||||
evaluation_steps=[llm_judge_prompts["empowerment"]],
|
||||
)
|
||||
|
||||
directness_metric = GEval(
|
||||
name="Directness",
|
||||
model="gpt-4o-mini",
|
||||
evaluation_params=[
|
||||
LLMTestCaseParams.INPUT,
|
||||
LLMTestCaseParams.ACTUAL_OUTPUT,
|
||||
LLMTestCaseParams.EXPECTED_OUTPUT,
|
||||
],
|
||||
evaluation_steps=[llm_judge_prompts["directness"]],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,61 +1,25 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
|
||||
import deepeval.metrics
|
||||
import wget
|
||||
from deepeval.dataset import EvaluationDataset
|
||||
from deepeval.test_case import LLMTestCase
|
||||
from tqdm import tqdm
|
||||
|
||||
import cognee
|
||||
import evals.deepeval_metrics
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.base_config import get_base_config
|
||||
import logging
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
from evals.qa_dataset_utils import load_qa_dataset
|
||||
from evals.qa_metrics_utils import get_metric
|
||||
from evals.qa_context_provider_utils import qa_context_providers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def answer_without_cognee(instance):
|
||||
args = {
|
||||
"question": instance["question"],
|
||||
"context": instance["context"],
|
||||
}
|
||||
user_prompt = render_prompt("context_for_question.txt", args)
|
||||
system_prompt = read_query_prompt("answer_hotpot_question.txt")
|
||||
|
||||
llm_client = get_llm_client()
|
||||
answer_prediction = await llm_client.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
return answer_prediction
|
||||
|
||||
|
||||
async def answer_with_cognee(instance):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
for title, sentences in instance["context"]:
|
||||
await cognee.add("\n".join(sentences), dataset_name="HotPotQA")
|
||||
|
||||
for n in range(1, 4):
|
||||
print(n)
|
||||
|
||||
await cognee.cognify("HotPotQA")
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
|
||||
search_results_second = await cognee.search(
|
||||
SearchType.SUMMARIES, query_text=instance["question"]
|
||||
)
|
||||
search_results = search_results + search_results_second
|
||||
async def answer_qa_instance(instance, context_provider):
|
||||
context = await context_provider(instance)
|
||||
|
||||
args = {
|
||||
"question": instance["question"],
|
||||
"context": search_results,
|
||||
"context": context,
|
||||
}
|
||||
user_prompt = render_prompt("context_for_question.txt", args)
|
||||
system_prompt = read_query_prompt("answer_hotpot_using_cognee_search.txt")
|
||||
|
|
@ -70,7 +34,7 @@ async def answer_with_cognee(instance):
|
|||
return answer_prediction
|
||||
|
||||
|
||||
async def eval_answers(instances, answers, eval_metric):
|
||||
async def deepeval_answers(instances, answers, eval_metric):
|
||||
test_cases = []
|
||||
|
||||
for instance, answer in zip(instances, answers):
|
||||
|
|
@ -85,28 +49,13 @@ async def eval_answers(instances, answers, eval_metric):
|
|||
return eval_results
|
||||
|
||||
|
||||
async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
|
||||
base_config = get_base_config()
|
||||
data_root_dir = base_config.data_root_directory
|
||||
|
||||
if not Path(data_root_dir).exists():
|
||||
Path(data_root_dir).mkdir()
|
||||
|
||||
filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json")
|
||||
if not filepath.exists():
|
||||
url = "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json"
|
||||
wget.download(url, out=data_root_dir)
|
||||
|
||||
with open(filepath, "r") as file:
|
||||
dataset = json.load(file)
|
||||
|
||||
instances = dataset if not num_samples else dataset[:num_samples]
|
||||
async def deepeval_on_instances(instances, context_provider, eval_metric):
|
||||
answers = []
|
||||
for instance in tqdm(instances, desc="Getting answers"):
|
||||
answer = await answer_provider(instance)
|
||||
answer = await answer_qa_instance(instance, context_provider)
|
||||
answers.append(answer)
|
||||
|
||||
eval_results = await eval_answers(instances, answers, eval_metric)
|
||||
eval_results = await deepeval_answers(instances, answers, eval_metric)
|
||||
avg_score = statistics.mean(
|
||||
[result.metrics_data[0].score for result in eval_results.test_results]
|
||||
)
|
||||
|
|
@ -114,33 +63,37 @@ async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
|
|||
return avg_score
|
||||
|
||||
|
||||
async def eval_on_QA_dataset(
|
||||
dataset_name_or_filename: str, context_provider_name, num_samples, eval_metric_name
|
||||
):
|
||||
dataset = load_qa_dataset(dataset_name_or_filename)
|
||||
context_provider = qa_context_providers[context_provider_name]
|
||||
eval_metric = get_metric(eval_metric_name)
|
||||
instances = dataset if not num_samples else dataset[:num_samples]
|
||||
|
||||
if eval_metric_name.startswith("promptfoo"):
|
||||
return await eval_metric.measure(instances, context_provider)
|
||||
else:
|
||||
return await deepeval_on_instances(instances, context_provider, eval_metric)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--with_cognee", action="store_true")
|
||||
parser.add_argument("--num_samples", type=int, default=500)
|
||||
parser.add_argument("--dataset", type=str, required=True, help="Which dataset to evaluate on")
|
||||
parser.add_argument(
|
||||
"--metric",
|
||||
"--rag_option",
|
||||
type=str,
|
||||
default="correctness_metric",
|
||||
help="Valid options are Deepeval metrics (e.g. AnswerRelevancyMetric) \
|
||||
and metrics defined in evals/deepeval_metrics.py, e.g. f1_score_metric",
|
||||
choices=qa_context_providers.keys(),
|
||||
required=True,
|
||||
help="RAG option to use for providing context",
|
||||
)
|
||||
parser.add_argument("--num_samples", type=int, default=500)
|
||||
parser.add_argument("--metric_name", type=str, default="Correctness")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
metric_cls = getattr(deepeval.metrics, args.metric)
|
||||
metric = metric_cls()
|
||||
except AttributeError:
|
||||
metric = getattr(evals.deepeval_metrics, args.metric)
|
||||
if isinstance(metric, type):
|
||||
metric = metric()
|
||||
|
||||
if args.with_cognee:
|
||||
answer_provider = answer_with_cognee
|
||||
else:
|
||||
answer_provider = answer_without_cognee
|
||||
|
||||
avg_score = asyncio.run(eval_on_hotpotQA(answer_provider, args.num_samples, metric))
|
||||
print(f"Average {args.metric}: {avg_score}")
|
||||
avg_score = asyncio.run(
|
||||
eval_on_QA_dataset(args.dataset, args.rag_option, args.num_samples, args.metric_name)
|
||||
)
|
||||
logger.info(f"Average {args.metric_name}: {avg_score}")
|
||||
|
|
|
|||
7
evals/promptfoo_config_template.yaml
Normal file
7
evals/promptfoo_config_template.yaml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# yaml-language-server: $schema=https://promptfoo.dev/config-schema.json
|
||||
|
||||
# Learn more about building a configuration: https://promptfoo.dev/docs/configuration/guide
|
||||
|
||||
description: "My eval"
|
||||
providers:
|
||||
- id: openai:gpt-4o-mini
|
||||
53
evals/promptfoo_metrics.py
Normal file
53
evals/promptfoo_metrics.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from evals.promptfoo_wrapper import PromptfooWrapper
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
import shutil
|
||||
|
||||
|
||||
class PromptfooMetric:
|
||||
def __init__(self, judge_prompt):
|
||||
promptfoo_path = shutil.which("promptfoo")
|
||||
self.wrapper = PromptfooWrapper(promptfoo_path=promptfoo_path)
|
||||
self.judge_prompt = judge_prompt
|
||||
|
||||
async def measure(self, instances, context_provider):
|
||||
with open(os.path.join(os.getcwd(), "evals/promptfoo_config_template.yaml"), "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
config["defaultTest"] = [{"assert": {"type": "llm_rubric", "value": self.judge_prompt}}]
|
||||
|
||||
# Fill config file with test cases
|
||||
tests = []
|
||||
for instance in instances:
|
||||
context = await context_provider(instance)
|
||||
test = {
|
||||
"vars": {
|
||||
"name": instance["question"][:15],
|
||||
"question": instance["question"],
|
||||
"context": context,
|
||||
}
|
||||
}
|
||||
tests.append(test)
|
||||
config["tests"] = tests
|
||||
|
||||
# Write the updated YAML back, preserving formatting and structure
|
||||
updated_yaml_file_path = os.path.join(os.getcwd(), "config_with_context.yaml")
|
||||
with open(updated_yaml_file_path, "w") as file:
|
||||
yaml.dump(config, file)
|
||||
|
||||
self.wrapper.run_eval(
|
||||
prompt_file=os.path.join(os.getcwd(), "evals/promptfooprompt.json"),
|
||||
config_file=os.path.join(os.getcwd(), "config_with_context.yaml"),
|
||||
out_format="json",
|
||||
)
|
||||
|
||||
file_path = os.path.join(os.getcwd(), "benchmark_results.json")
|
||||
|
||||
# Read and parse the JSON file
|
||||
with open(file_path, "r") as file:
|
||||
results = json.load(file)
|
||||
|
||||
self.score = results["results"]["prompts"][0]["metrics"]["score"]
|
||||
|
||||
return self.score
|
||||
157
evals/promptfoo_wrapper.py
Normal file
157
evals/promptfoo_wrapper.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
import subprocess
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Dict, Generator
|
||||
import shutil
|
||||
import platform
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class PromptfooWrapper:
|
||||
"""
|
||||
A Python wrapper class around the promptfoo CLI tool, allowing you to:
|
||||
- Evaluate prompts against different language models.
|
||||
- Compare responses from multiple models.
|
||||
- Pass configuration and prompt files.
|
||||
- Retrieve the outputs in a structured format, including binary output if needed.
|
||||
|
||||
This class assumes you have the promptfoo CLI installed and accessible in your environment.
|
||||
For more details on promptfoo, see: https://github.com/promptfoo/promptfoo
|
||||
"""
|
||||
|
||||
def __init__(self, promptfoo_path: str = ""):
|
||||
"""
|
||||
Initialize the wrapper with the path to the promptfoo executable.
|
||||
|
||||
:param promptfoo_path: Path to the promptfoo binary (default: 'promptfoo')
|
||||
"""
|
||||
self.promptfoo_path = promptfoo_path
|
||||
logger.debug(f"Initialized PromptfooWrapper with binary at: {self.promptfoo_path}")
|
||||
|
||||
def _validate_path(self, file_path: Optional[str]) -> None:
|
||||
"""
|
||||
Validate that a file path is accessible if provided.
|
||||
Raise FileNotFoundError if it does not exist.
|
||||
"""
|
||||
if file_path and not os.path.isfile(file_path):
|
||||
logger.error(f"File not found: {file_path}")
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
def _get_node_bin_dir(self) -> str:
|
||||
"""
|
||||
Determine the Node.js binary directory dynamically for macOS and Linux.
|
||||
"""
|
||||
node_executable = shutil.which("node")
|
||||
if not node_executable:
|
||||
logger.error("Node.js is not installed or not found in the system PATH.")
|
||||
raise EnvironmentError("Node.js is not installed or not in PATH.")
|
||||
|
||||
# Determine the Node.js binary directory
|
||||
node_bin_dir = os.path.dirname(node_executable)
|
||||
|
||||
# Special handling for macOS, where Homebrew installs Node in /usr/local or /opt/homebrew
|
||||
if platform.system() == "Darwin": # macOS
|
||||
logger.debug("Running on macOS")
|
||||
brew_prefix = os.popen("brew --prefix node").read().strip()
|
||||
if brew_prefix and os.path.exists(brew_prefix):
|
||||
node_bin_dir = os.path.join(brew_prefix, "bin")
|
||||
logger.debug(f"Detected Node.js binary directory using Homebrew: {node_bin_dir}")
|
||||
|
||||
# For Linux, Node.js installed via package managers should work out of the box
|
||||
logger.debug(f"Detected Node.js binary directory: {node_bin_dir}")
|
||||
return node_bin_dir
|
||||
|
||||
def _run_command(
|
||||
self,
|
||||
cmd: List[str],
|
||||
filename,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""
|
||||
Run a given command using subprocess and parse the output.
|
||||
"""
|
||||
logger.debug(f"Running command: {' '.join(cmd)}")
|
||||
|
||||
# Make a copy of the current environment
|
||||
env = os.environ.copy()
|
||||
|
||||
try:
|
||||
node_bin_dir = self._get_node_bin_dir()
|
||||
print(node_bin_dir)
|
||||
env["PATH"] = f"{node_bin_dir}:{env['PATH']}"
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.error(f"Failed to set Node.js binary directory: {e}")
|
||||
raise
|
||||
|
||||
# Add node's bin directory to the PATH
|
||||
# node_bin_dir = "/Users/vasilije/Library/Application Support/JetBrains/PyCharm2024.2/node/versions/20.15.0/bin"
|
||||
# # env["PATH"] = f"{node_bin_dir}:{env['PATH']}"
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False, env=env)
|
||||
|
||||
print(result.stderr)
|
||||
with open(filename, "r", encoding="utf-8") as file:
|
||||
read_data = json.load(file)
|
||||
print(f"{filename} created and written.")
|
||||
|
||||
# Log raw stdout for debugging
|
||||
logger.debug(f"Raw command output:\n{result.stdout}")
|
||||
|
||||
# Use the parse_promptfoo_output function to yield parsed results
|
||||
return read_data
|
||||
|
||||
def run_eval(
|
||||
self,
|
||||
prompt_file: Optional[str] = None,
|
||||
config_file: Optional[str] = None,
|
||||
eval_file: Optional[str] = None,
|
||||
out_format: str = "json",
|
||||
extra_args: Optional[List[str]] = None,
|
||||
binary_output: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Run the `promptfoo eval` command with the provided parameters and return parsed results.
|
||||
|
||||
:param prompt_file: Path to a file containing one or more prompts.
|
||||
:param config_file: Path to a config file specifying models, scoring methods, etc.
|
||||
:param eval_file: Path to an eval file with test data.
|
||||
:param out_format: Output format, e.g., 'json', 'yaml', or 'table'.
|
||||
:param extra_args: Additional command-line arguments for fine-tuning evaluation.
|
||||
:param binary_output: If True, interpret output as binary data instead of text.
|
||||
:return: List of parsed results (each result is a dictionary).
|
||||
"""
|
||||
self._validate_path(prompt_file)
|
||||
self._validate_path(config_file)
|
||||
self._validate_path(eval_file)
|
||||
|
||||
filename = "benchmark_results"
|
||||
|
||||
filename = os.path.join(os.getcwd(), f"{filename}.json")
|
||||
# Create an empty JSON file
|
||||
with open(filename, "w") as file:
|
||||
json.dump({}, file)
|
||||
|
||||
cmd = [self.promptfoo_path, "eval"]
|
||||
if prompt_file:
|
||||
cmd.extend(["--prompts", prompt_file])
|
||||
if config_file:
|
||||
cmd.extend(["--config", config_file])
|
||||
if eval_file:
|
||||
cmd.extend(["--eval", eval_file])
|
||||
cmd.extend(["--output", filename])
|
||||
if extra_args:
|
||||
cmd.extend(extra_args)
|
||||
|
||||
# Log the constructed command for debugging
|
||||
logger.debug(f"Constructed command: {' '.join(cmd)}")
|
||||
|
||||
# Collect results from the generator
|
||||
results = self._run_command(cmd, filename=filename)
|
||||
logger.debug(f"Parsed results: {json.dumps(results, indent=4)}")
|
||||
return results
|
||||
10
evals/promptfooprompt.json
Normal file
10
evals/promptfooprompt.json
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Answer the question using the provided context. Be as brief as possible."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The question is: `{{ question }}` \n And here is the context: `{{ context }}`"
|
||||
}
|
||||
]
|
||||
59
evals/qa_context_provider_utils.py
Normal file
59
evals/qa_context_provider_utils.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.tasks.completion.graph_query_completion import retrieved_edges_to_string
|
||||
|
||||
|
||||
async def get_raw_context(instance: dict) -> str:
|
||||
return instance["context"]
|
||||
|
||||
|
||||
async def cognify_instance(instance: dict):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
for title, sentences in instance["context"]:
|
||||
await cognee.add("\n".join(sentences), dataset_name="QA")
|
||||
await cognee.cognify("QA")
|
||||
|
||||
|
||||
async def get_context_with_cognee(instance: dict) -> str:
|
||||
await cognify_instance(instance)
|
||||
|
||||
insights = await cognee.search(SearchType.INSIGHTS, query_text=instance["question"])
|
||||
summaries = await cognee.search(SearchType.SUMMARIES, query_text=instance["question"])
|
||||
search_results = insights + summaries
|
||||
|
||||
search_results_str = "\n".join([context_item["text"] for context_item in search_results])
|
||||
|
||||
return search_results_str
|
||||
|
||||
|
||||
async def get_context_with_simple_rag(instance: dict) -> str:
|
||||
await cognify_instance(instance)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
found_chunks = await vector_engine.search("document_chunk_text", instance["question"], limit=5)
|
||||
|
||||
search_results_str = "\n".join([context_item.payload["text"] for context_item in found_chunks])
|
||||
|
||||
return search_results_str
|
||||
|
||||
|
||||
async def get_context_with_brute_force_triplet_search(instance: dict) -> str:
|
||||
await cognify_instance(instance)
|
||||
|
||||
found_triplets = await brute_force_triplet_search(instance["question"], top_k=5)
|
||||
|
||||
search_results_str = retrieved_edges_to_string(found_triplets)
|
||||
|
||||
return search_results_str
|
||||
|
||||
|
||||
qa_context_providers = {
|
||||
"no_rag": get_raw_context,
|
||||
"cognee": get_context_with_cognee,
|
||||
"simple_rag": get_context_with_simple_rag,
|
||||
"brute_force": get_context_with_brute_force_triplet_search,
|
||||
}
|
||||
82
evals/qa_dataset_utils.py
Normal file
82
evals/qa_dataset_utils.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
from cognee.root_dir import get_absolute_path
|
||||
import json
|
||||
import requests
|
||||
from jsonschema import ValidationError, validate
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
qa_datasets = {
|
||||
"hotpotqa": {
|
||||
"filename": "hotpot_dev_fullwiki_v1.json",
|
||||
"URL": "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json",
|
||||
},
|
||||
"2wikimultihop": {
|
||||
"filename": "data/dev.json",
|
||||
"URL": "https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1",
|
||||
},
|
||||
}
|
||||
|
||||
qa_json_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": {"type": "string"},
|
||||
"question": {"type": "string"},
|
||||
"context": {"type": "array"},
|
||||
},
|
||||
"required": ["answer", "question", "context"],
|
||||
"additionalProperties": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def download_qa_dataset(dataset_name: str, filepath: Path):
|
||||
if dataset_name not in qa_datasets:
|
||||
raise ValueError(f"{dataset_name} is not a supported dataset.")
|
||||
|
||||
url = qa_datasets[dataset_name]["URL"]
|
||||
|
||||
if dataset_name == "2wikimultihop":
|
||||
raise Exception(
|
||||
"Please download 2wikimultihop dataset (data.zip) manually from \
|
||||
https://www.dropbox.com/scl/fi/heid2pkiswhfaqr5g0piw/data.zip?rlkey=ira57daau8lxfj022xvk1irju&e=1 \
|
||||
and unzip it."
|
||||
)
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open(filepath, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
file.write(chunk)
|
||||
print(f"Dataset {dataset_name} downloaded and saved to {filepath}")
|
||||
else:
|
||||
print(f"Failed to download {dataset_name}. Status code: {response.status_code}")
|
||||
|
||||
|
||||
def load_qa_dataset(dataset_name_or_filename: str) -> list[dict]:
|
||||
if dataset_name_or_filename in qa_datasets:
|
||||
dataset_name = dataset_name_or_filename
|
||||
filename = qa_datasets[dataset_name]["filename"]
|
||||
|
||||
data_root_dir = get_absolute_path("../.data")
|
||||
if not Path(data_root_dir).exists():
|
||||
Path(data_root_dir).mkdir()
|
||||
|
||||
filepath = data_root_dir / Path(filename)
|
||||
if not filepath.exists():
|
||||
download_qa_dataset(dataset_name, filepath)
|
||||
else:
|
||||
filename = dataset_name_or_filename
|
||||
filepath = Path(filename)
|
||||
|
||||
with open(filepath, "r") as file:
|
||||
dataset = json.load(file)
|
||||
|
||||
try:
|
||||
validate(instance=dataset, schema=qa_json_schema)
|
||||
except ValidationError as e:
|
||||
raise ValidationError(f"Invalid QA dataset: {e.message}")
|
||||
|
||||
return dataset
|
||||
51
evals/qa_metrics_utils.py
Normal file
51
evals/qa_metrics_utils.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from evals.deepeval_metrics import (
|
||||
correctness_metric,
|
||||
comprehensiveness_metric,
|
||||
diversity_metric,
|
||||
empowerment_metric,
|
||||
directness_metric,
|
||||
f1_score_metric,
|
||||
em_score_metric,
|
||||
)
|
||||
from evals.promptfoo_metrics import PromptfooMetric
|
||||
from deepeval.metrics import AnswerRelevancyMetric
|
||||
import deepeval.metrics
|
||||
from cognee.infrastructure.llm.prompts.llm_judge_prompts import llm_judge_prompts
|
||||
|
||||
native_deepeval_metrics = {"AnswerRelevancy": AnswerRelevancyMetric}
|
||||
|
||||
custom_deepeval_metrics = {
|
||||
"Correctness": correctness_metric,
|
||||
"Comprehensiveness": comprehensiveness_metric,
|
||||
"Diversity": diversity_metric,
|
||||
"Empowerment": empowerment_metric,
|
||||
"Directness": directness_metric,
|
||||
"F1": f1_score_metric,
|
||||
"EM": em_score_metric,
|
||||
}
|
||||
|
||||
promptfoo_metrics = {
|
||||
"promptfoo.correctness": PromptfooMetric(llm_judge_prompts["correctness"]),
|
||||
"promptfoo.comprehensiveness": PromptfooMetric(llm_judge_prompts["comprehensiveness"]),
|
||||
"promptfoo.diversity": PromptfooMetric(llm_judge_prompts["diversity"]),
|
||||
"promptfoo.empowerment": PromptfooMetric(llm_judge_prompts["empowerment"]),
|
||||
"promptfoo.directness": PromptfooMetric(llm_judge_prompts["directness"]),
|
||||
}
|
||||
|
||||
qa_metrics = native_deepeval_metrics | custom_deepeval_metrics | promptfoo_metrics
|
||||
|
||||
|
||||
def get_metric(metric_name: str):
|
||||
if metric_name in qa_metrics:
|
||||
metric = qa_metrics[metric_name]
|
||||
else:
|
||||
try:
|
||||
metric_cls = getattr(deepeval.metrics, metric_name)
|
||||
metric = metric_cls()
|
||||
except AttributeError:
|
||||
raise Exception(f"Metric {metric_name} not supported")
|
||||
|
||||
if isinstance(metric, type):
|
||||
metric = metric()
|
||||
|
||||
return metric
|
||||
|
|
@ -1,5 +1,10 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"source": "[](https://colab.research.google.com/drive/1EpokQ8Y_5jIJ7HdixZms81Oqgh2sp7-E?usp=sharing)"
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -45,16 +50,14 @@
|
|||
"### 1. Setting Up the Environment\n",
|
||||
"\n",
|
||||
"Start by importing the required libraries and defining the environment:"
|
||||
],
|
||||
"id": "d0d7a82d729bbef6"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": "!pip install llama-index-graph-rag-cognee==0.1.1",
|
||||
"id": "598b52e384086512"
|
||||
"source": "!pip install llama-index-graph-rag-cognee==0.1.2"
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -69,8 +72,7 @@
|
|||
"\n",
|
||||
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"OPENAI_API_KEY\"] = \"\""
|
||||
],
|
||||
"id": "892a1b1198ec662f"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -81,8 +83,7 @@
|
|||
"### 2. Preparing the Dataset\n",
|
||||
"\n",
|
||||
"We’ll use a brief profile of an individual as our sample dataset:"
|
||||
],
|
||||
"id": "a1f16f5ca5249ebb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -98,8 +99,7 @@
|
|||
" text=\"David Thompson, Creative Graphic Designer with over 8 years of experience in visual design and branding.\"\n",
|
||||
" ),\n",
|
||||
" ]"
|
||||
],
|
||||
"id": "198022c34636a3a0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -108,8 +108,7 @@
|
|||
"### 3. Initializing CogneeGraphRAG\n",
|
||||
"\n",
|
||||
"Instantiate the Cognee framework with configurations for LLM, graph, and database providers:"
|
||||
],
|
||||
"id": "781ae78e52ff49a"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -126,8 +125,7 @@
|
|||
" relational_db_provider=\"sqlite\",\n",
|
||||
" relational_db_name=\"cognee_db\",\n",
|
||||
")"
|
||||
],
|
||||
"id": "17e466821ab88d50"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -136,16 +134,14 @@
|
|||
"### 4. Adding Data to Cognee\n",
|
||||
"\n",
|
||||
"Load the dataset into the cognee framework:"
|
||||
],
|
||||
"id": "2a55d5be9de0ce81"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": "await cogneeRAG.add(documents, \"test\")",
|
||||
"id": "238b716429aba541"
|
||||
"source": "await cogneeRAG.add(documents, \"test\")"
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -156,16 +152,14 @@
|
|||
"### 5. Processing Data into a Knowledge Graph\n",
|
||||
"\n",
|
||||
"Transform the data into a structured knowledge graph:"
|
||||
],
|
||||
"id": "23e5316aa7e5dbc7"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": "await cogneeRAG.process_data(\"test\")",
|
||||
"id": "c3b3063d428b07a2"
|
||||
"source": "await cogneeRAG.process_data(\"test\")"
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -176,8 +170,7 @@
|
|||
"### 6. Performing Searches\n",
|
||||
"\n",
|
||||
"### Answer prompt based on knowledge graph approach:"
|
||||
],
|
||||
"id": "e32327de54e98dc8"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -190,14 +183,12 @@
|
|||
"print(\"\\n\\nAnswer based on knowledge graph:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
],
|
||||
"id": "fddbf5916d1e50e5"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"source": "### Answer prompt based on RAG approach:",
|
||||
"id": "9246aed7f69ceb7e"
|
||||
"source": "### Answer prompt based on RAG approach:"
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -210,14 +201,12 @@
|
|||
"print(\"\\n\\nAnswer based on RAG:\\n\")\n",
|
||||
"for result in search_results:\n",
|
||||
" print(f\"{result}\\n\")"
|
||||
],
|
||||
"id": "fe77c7a7c57fe4e4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data.",
|
||||
"id": "89cc99628392eb99"
|
||||
"source": "In conclusion, the results demonstrate a significant advantage of the knowledge graph-based approach (Graphrag) over the RAG approach. Graphrag successfully identified all the mentioned individuals across multiple documents, showcasing its ability to aggregate and infer information from a global context. In contrast, the RAG approach was limited to identifying individuals within a single document due to its chunking-based processing constraints. This highlights Graphrag's superior capability in comprehensively resolving queries that span across a broader corpus of interconnected data."
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -226,8 +215,7 @@
|
|||
"### 7. Finding Related Nodes\n",
|
||||
"\n",
|
||||
"Explore relationships in the knowledge graph:"
|
||||
],
|
||||
"id": "44c9b67c09763610"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -240,8 +228,7 @@
|
|||
"print(\"\\n\\nRelated nodes are:\\n\")\n",
|
||||
"for node in related_nodes:\n",
|
||||
" print(f\"{node}\\n\")"
|
||||
],
|
||||
"id": "efbc1511586f46fe"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
|
|
@ -274,9 +261,8 @@
|
|||
"\n",
|
||||
"Try running it yourself\n",
|
||||
"\n",
|
||||
"Join cognee community"
|
||||
],
|
||||
"id": "d0f82c2c6eb7793"
|
||||
"[join the cognee community](https://discord.gg/tV7pr5XSj7)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {},
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue