From daf7d4ae2605db732292b7db7949c618a2c80b3c Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Thu, 13 Mar 2025 14:23:13 +0100 Subject: [PATCH] feat: COG-1526 instance filter in eval (#627) ## Description - Added _filter_instances to BaseBenchmarkAdapter supporting filtering by IDs, indices, or JSON files. - Updated HotpotQAAdapter and MusiqueQAAdapter to use the base class filtering. - Added instance_filter parameter to corpus builder pipeline. - Extracted _get_raw_corpus method in both adapters for better code organization ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin ## Summary by CodeRabbit - **New Features** - Corpus loading and building now support a flexible filtering option, allowing users to apply custom criteria to tailor the retrieved data. - **Refactor** - The extraction process has been reorganized to separately handle text content and associated metadata, enhancing clarity and overall workflow efficiency. --- .../base_benchmark_adapter.py | 51 ++++++++++- .../benchmark_adapters/dummy_adapter.py | 13 ++- .../benchmark_adapters/hotpot_qa_adapter.py | 67 ++++++++------ .../benchmark_adapters/musique_adapter.py | 87 ++++++++++--------- .../corpus_builder/corpus_builder_executor.py | 12 ++- .../corpus_builder/run_corpus_builder.py | 5 +- 6 files changed, 158 insertions(+), 77 deletions(-) diff --git a/cognee/eval_framework/benchmark_adapters/base_benchmark_adapter.py b/cognee/eval_framework/benchmark_adapters/base_benchmark_adapter.py index 9efa09e94..28a73126e 100644 --- a/cognee/eval_framework/benchmark_adapters/base_benchmark_adapter.py +++ b/cognee/eval_framework/benchmark_adapters/base_benchmark_adapter.py @@ -1,10 +1,55 @@ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Optional, Any, Union, Tuple +import os +import json +import logging + +logger = logging.getLogger(__name__) class BaseBenchmarkAdapter(ABC): + def _filter_instances( + self, + instances: List[dict[str, Any]], + instance_filter: Union[str, List[str], List[int]], + id_key: str = "id", + ) -> List[dict[str, Any]]: + """Filter instances by IDs or indices, or load filter from a JSON file.""" + if isinstance(instance_filter, str): + logger.info(f"Loading instance filter from file: {instance_filter}") + if not os.path.isfile(instance_filter): + raise FileNotFoundError(f"Filter file not found: {instance_filter}") + + with open(instance_filter, "r", encoding="utf-8") as f: + try: + instance_filter = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in filter file: {e}") + + if all(isinstance(fid, str) for fid in instance_filter): + logger.info(f"Filtering by {len(instance_filter)} string IDs using key '{id_key}'") + filtered = [inst for inst in instances if inst.get(id_key) in instance_filter] + if not filtered: + logger.warning(f"No instances found with the provided IDs using key '{id_key}'") + return filtered + + if all(isinstance(fid, int) for fid in instance_filter): + logger.info(f"Filtering by {len(instance_filter)} integer indices") + filtered = [instances[i] for i in instance_filter if 0 <= i < len(instances)] + if not filtered: + logger.warning("No instances found at the provided indices") + return filtered + + raise ValueError( + "instance_filter must be a list of string ids, integer indices, or a JSON file path." + ) + @abstractmethod def load_corpus( - self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False - ) -> List[str]: + self, + limit: Optional[int] = None, + seed: int = 42, + load_golden_context: bool = False, + instance_filter: Optional[Union[str, List[str], List[int]]] = None, + ) -> Tuple[List[str], List[dict[str, Any]]]: pass diff --git a/cognee/eval_framework/benchmark_adapters/dummy_adapter.py b/cognee/eval_framework/benchmark_adapters/dummy_adapter.py index 9bf945d06..479f60a9f 100644 --- a/cognee/eval_framework/benchmark_adapters/dummy_adapter.py +++ b/cognee/eval_framework/benchmark_adapters/dummy_adapter.py @@ -1,12 +1,16 @@ -from typing import Optional, Any +from typing import Optional, Any, List, Union, Tuple from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter class DummyAdapter(BaseBenchmarkAdapter): def load_corpus( - self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False - ) -> tuple[list[str], list[dict[str, Any]]]: + self, + limit: Optional[int] = None, + seed: int = 42, + load_golden_context: bool = False, + instance_filter: Optional[Union[str, List[str], List[int]]] = None, + ) -> Tuple[List[str], List[dict[str, Any]]]: corpus_list = [ "The cognee is an AI memory engine that supports different vector and graph databases", "Neo4j is a graph database supported by cognee", @@ -22,4 +26,7 @@ class DummyAdapter(BaseBenchmarkAdapter): question_answer_pairs = [qa_pair] + # Instance filtering is not applicable for the dummy adapter as it always returns the same data + # but we include the parameter for API consistency + return corpus_list, question_answer_pairs diff --git a/cognee/eval_framework/benchmark_adapters/hotpot_qa_adapter.py b/cognee/eval_framework/benchmark_adapters/hotpot_qa_adapter.py index d8e5a03c2..828af37ff 100644 --- a/cognee/eval_framework/benchmark_adapters/hotpot_qa_adapter.py +++ b/cognee/eval_framework/benchmark_adapters/hotpot_qa_adapter.py @@ -2,7 +2,7 @@ import requests import os import json import random -from typing import Optional, Any, List, Tuple +from typing import Optional, Any, List, Union, Tuple from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter @@ -37,17 +37,33 @@ class HotpotQAAdapter(BaseBenchmarkAdapter): return "\n".join(golden_contexts) - def _process_item( + def _get_raw_corpus(self) -> List[dict[str, Any]]: + """Loads the raw corpus data from file or URL and returns it as a list of dictionaries.""" + filename = self.dataset_info["filename"] + + if os.path.exists(filename): + with open(filename, "r", encoding="utf-8") as f: + raw_corpus = json.load(f) + else: + response = requests.get(self.dataset_info["url"]) + response.raise_for_status() + raw_corpus = response.json() + + with open(filename, "w", encoding="utf-8") as f: + json.dump(raw_corpus, f, ensure_ascii=False, indent=4) + + return raw_corpus + + def _get_corpus_entries(self, item: dict[str, Any]) -> List[str]: + """Extracts corpus entries from the context of an item.""" + return [" ".join(sentences) for title, sentences in item["context"]] + + def _get_question_answer_pair( self, item: dict[str, Any], - corpus_list: List[str], - question_answer_pairs: List[dict[str, Any]], load_golden_context: bool = False, - ) -> None: - """Processes a single item and adds it to the corpus and QA pairs.""" - for title, sentences in item["context"]: - corpus_list.append(" ".join(sentences)) - + ) -> dict[str, Any]: + """Extracts a question-answer pair from an item.""" qa_pair = { "question": item["question"], "answer": item["answer"].lower(), @@ -57,33 +73,30 @@ class HotpotQAAdapter(BaseBenchmarkAdapter): if load_golden_context: qa_pair["golden_context"] = self._get_golden_context(item) - question_answer_pairs.append(qa_pair) + return qa_pair def load_corpus( - self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False + self, + limit: Optional[int] = None, + seed: int = 42, + load_golden_context: bool = False, + instance_filter: Optional[Union[str, List[str], List[int]]] = None, ) -> Tuple[List[str], List[dict[str, Any]]]: - """Loads and processes the HotpotQA corpus, optionally with golden context.""" - filename = self.dataset_info["filename"] + """Loads and processes the HotpotQA corpus, optionally with filtering and golden context.""" + raw_corpus = self._get_raw_corpus() - if os.path.exists(filename): - with open(filename, "r", encoding="utf-8") as f: - corpus_json = json.load(f) - else: - response = requests.get(self.dataset_info["url"]) - response.raise_for_status() - corpus_json = response.json() + if instance_filter is not None: + raw_corpus = self._filter_instances(raw_corpus, instance_filter, id_key="_id") - with open(filename, "w", encoding="utf-8") as f: - json.dump(corpus_json, f, ensure_ascii=False, indent=4) - - if limit is not None and 0 < limit < len(corpus_json): + if limit is not None and 0 < limit < len(raw_corpus): random.seed(seed) - corpus_json = random.sample(corpus_json, limit) + raw_corpus = random.sample(raw_corpus, limit) corpus_list = [] question_answer_pairs = [] - for item in corpus_json: - self._process_item(item, corpus_list, question_answer_pairs, load_golden_context) + for item in raw_corpus: + corpus_list.extend(self._get_corpus_entries(item)) + question_answer_pairs.append(self._get_question_answer_pair(item, load_golden_context)) return corpus_list, question_answer_pairs diff --git a/cognee/eval_framework/benchmark_adapters/musique_adapter.py b/cognee/eval_framework/benchmark_adapters/musique_adapter.py index 3be44edf8..b4a45339d 100644 --- a/cognee/eval_framework/benchmark_adapters/musique_adapter.py +++ b/cognee/eval_framework/benchmark_adapters/musique_adapter.py @@ -1,7 +1,7 @@ import os import json import random -from typing import Optional, Any, List +from typing import Optional, Any, List, Union, Tuple import zipfile import gdown @@ -38,41 +38,8 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter): return "\n".join(golden_context) - def _process_item( - self, - item: dict[str, Any], - corpus_list: List[str], - question_answer_pairs: List[dict[str, Any]], - load_golden_context: bool = False, - ) -> None: - """Processes a single item and adds it to the corpus and QA pairs.""" - # Add paragraphs to corpus - paragraphs = item.get("paragraphs", []) - for paragraph in paragraphs: - corpus_list.append(paragraph["paragraph_text"]) - - # Create QA pair - qa_pair = { - "id": item.get("id", ""), - "question": item.get("question", ""), - "answer": item.get("answer", "").lower() - if isinstance(item.get("answer"), str) - else item.get("answer"), - } - - if load_golden_context: - qa_pair["golden_context"] = self._get_golden_context(item) - - question_answer_pairs.append(qa_pair) - - def load_corpus( - self, - limit: Optional[int] = None, - seed: int = 42, - load_golden_context: bool = False, - auto_download: bool = True, - ) -> tuple[list[str], list[dict[str, Any]]]: - """Loads and processes the Musique QA dataset.""" + def _get_raw_corpus(self, auto_download: bool = True) -> List[dict[str, Any]]: + """Loads the raw corpus data from file or downloads it if needed.""" target_filename = self.dataset_info["filename"] if not os.path.exists(target_filename): @@ -87,15 +54,55 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter): with open(target_filename, "r", encoding="utf-8") as f: data = [json.loads(line) for line in f] - if limit is not None and 0 < limit < len(data): + return data + + def _get_corpus_entries(self, item: dict[str, Any]) -> List[str]: + """Extracts corpus entries from the paragraphs of an item.""" + return [paragraph["paragraph_text"] for paragraph in item.get("paragraphs", [])] + + def _get_question_answer_pair( + self, + item: dict[str, Any], + load_golden_context: bool = False, + ) -> dict[str, Any]: + """Extracts a question-answer pair from an item.""" + qa_pair = { + "id": item.get("id", ""), + "question": item.get("question", ""), + "answer": item.get("answer", "").lower() + if isinstance(item.get("answer"), str) + else item.get("answer"), + } + + if load_golden_context: + qa_pair["golden_context"] = self._get_golden_context(item) + + return qa_pair + + def load_corpus( + self, + limit: Optional[int] = None, + seed: int = 42, + load_golden_context: bool = False, + auto_download: bool = True, + instance_filter: Optional[Union[str, List[str], List[int]]] = None, + ) -> Tuple[List[str], List[dict[str, Any]]]: + """Loads and processes the Musique QA dataset with optional filtering.""" + raw_corpus = self._get_raw_corpus(auto_download) + + if instance_filter is not None: + raw_corpus = self._filter_instances(raw_corpus, instance_filter, id_key="id") + + if limit is not None and 0 < limit < len(raw_corpus): random.seed(seed) - data = random.sample(data, limit) + raw_corpus = random.sample(raw_corpus, limit) corpus_list = [] question_answer_pairs = [] - for item in data: - self._process_item(item, corpus_list, question_answer_pairs, load_golden_context) + for item in raw_corpus: + corpus_list.extend(self._get_corpus_entries(item)) + question_answer_pairs.append(self._get_question_answer_pair(item, load_golden_context)) return corpus_list, question_answer_pairs diff --git a/cognee/eval_framework/corpus_builder/corpus_builder_executor.py b/cognee/eval_framework/corpus_builder/corpus_builder_executor.py index 2a6ff63ce..1f3ecc2cb 100644 --- a/cognee/eval_framework/corpus_builder/corpus_builder_executor.py +++ b/cognee/eval_framework/corpus_builder/corpus_builder_executor.py @@ -29,10 +29,13 @@ class CorpusBuilderExecutor: self.task_getter = task_getter def load_corpus( - self, limit: Optional[int] = None, load_golden_context: bool = False + self, + limit: Optional[int] = None, + load_golden_context: bool = False, + instance_filter: Optional[Union[str, List[str], List[int]]] = None, ) -> Tuple[List[Dict], List[str]]: self.raw_corpus, self.questions = self.adapter.load_corpus( - limit=limit, load_golden_context=load_golden_context + limit=limit, load_golden_context=load_golden_context, instance_filter=instance_filter ) return self.raw_corpus, self.questions @@ -42,8 +45,11 @@ class CorpusBuilderExecutor: chunk_size=1024, chunker=TextChunker, load_golden_context: bool = False, + instance_filter: Optional[Union[str, List[str], List[int]]] = None, ) -> List[str]: - self.load_corpus(limit=limit, load_golden_context=load_golden_context) + self.load_corpus( + limit=limit, load_golden_context=load_golden_context, instance_filter=instance_filter + ) await self.run_cognee(chunk_size=chunk_size, chunker=chunker) return self.questions diff --git a/cognee/eval_framework/corpus_builder/run_corpus_builder.py b/cognee/eval_framework/corpus_builder/run_corpus_builder.py index f443cfcac..c1af75981 100644 --- a/cognee/eval_framework/corpus_builder/run_corpus_builder.py +++ b/cognee/eval_framework/corpus_builder/run_corpus_builder.py @@ -33,7 +33,9 @@ async def create_and_insert_questions_table(questions_payload): await session.commit() -async def run_corpus_builder(params: dict, chunk_size=1024, chunker=TextChunker) -> List[dict]: +async def run_corpus_builder( + params: dict, chunk_size=1024, chunker=TextChunker, instance_filter=None +) -> List[dict]: if params.get("building_corpus_from_scratch"): logging.info("Corpus Builder started...") @@ -51,6 +53,7 @@ async def run_corpus_builder(params: dict, chunk_size=1024, chunker=TextChunker) chunker=chunker, chunk_size=chunk_size, load_golden_context=params.get("evaluating_contexts"), + instance_filter=instance_filter, ) with open(params["questions_path"], "w", encoding="utf-8") as f: json.dump(questions, f, ensure_ascii=False, indent=4)