feat: COG-1526 instance filter in eval (#627)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Added _filter_instances to BaseBenchmarkAdapter supporting filtering by IDs, indices, or JSON files. - Updated HotpotQAAdapter and MusiqueQAAdapter to use the base class filtering. - Added instance_filter parameter to corpus builder pipeline. - Extracted _get_raw_corpus method in both adapters for better code organization ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Corpus loading and building now support a flexible filtering option, allowing users to apply custom criteria to tailor the retrieved data. - **Refactor** - The extraction process has been reorganized to separately handle text content and associated metadata, enhancing clarity and overall workflow efficiency. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
88ed411f03
commit
daf7d4ae26
6 changed files with 158 additions and 77 deletions
|
|
@ -1,10 +1,55 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Any, Union, Tuple
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseBenchmarkAdapter(ABC):
|
||||
def _filter_instances(
|
||||
self,
|
||||
instances: List[dict[str, Any]],
|
||||
instance_filter: Union[str, List[str], List[int]],
|
||||
id_key: str = "id",
|
||||
) -> List[dict[str, Any]]:
|
||||
"""Filter instances by IDs or indices, or load filter from a JSON file."""
|
||||
if isinstance(instance_filter, str):
|
||||
logger.info(f"Loading instance filter from file: {instance_filter}")
|
||||
if not os.path.isfile(instance_filter):
|
||||
raise FileNotFoundError(f"Filter file not found: {instance_filter}")
|
||||
|
||||
with open(instance_filter, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
instance_filter = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in filter file: {e}")
|
||||
|
||||
if all(isinstance(fid, str) for fid in instance_filter):
|
||||
logger.info(f"Filtering by {len(instance_filter)} string IDs using key '{id_key}'")
|
||||
filtered = [inst for inst in instances if inst.get(id_key) in instance_filter]
|
||||
if not filtered:
|
||||
logger.warning(f"No instances found with the provided IDs using key '{id_key}'")
|
||||
return filtered
|
||||
|
||||
if all(isinstance(fid, int) for fid in instance_filter):
|
||||
logger.info(f"Filtering by {len(instance_filter)} integer indices")
|
||||
filtered = [instances[i] for i in instance_filter if 0 <= i < len(instances)]
|
||||
if not filtered:
|
||||
logger.warning("No instances found at the provided indices")
|
||||
return filtered
|
||||
|
||||
raise ValueError(
|
||||
"instance_filter must be a list of string ids, integer indices, or a JSON file path."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def load_corpus(
|
||||
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False
|
||||
) -> List[str]:
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
seed: int = 42,
|
||||
load_golden_context: bool = False,
|
||||
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
|
||||
) -> Tuple[List[str], List[dict[str, Any]]]:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from typing import Optional, Any
|
||||
from typing import Optional, Any, List, Union, Tuple
|
||||
|
||||
from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter
|
||||
|
||||
|
||||
class DummyAdapter(BaseBenchmarkAdapter):
|
||||
def load_corpus(
|
||||
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False
|
||||
) -> tuple[list[str], list[dict[str, Any]]]:
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
seed: int = 42,
|
||||
load_golden_context: bool = False,
|
||||
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
|
||||
) -> Tuple[List[str], List[dict[str, Any]]]:
|
||||
corpus_list = [
|
||||
"The cognee is an AI memory engine that supports different vector and graph databases",
|
||||
"Neo4j is a graph database supported by cognee",
|
||||
|
|
@ -22,4 +26,7 @@ class DummyAdapter(BaseBenchmarkAdapter):
|
|||
|
||||
question_answer_pairs = [qa_pair]
|
||||
|
||||
# Instance filtering is not applicable for the dummy adapter as it always returns the same data
|
||||
# but we include the parameter for API consistency
|
||||
|
||||
return corpus_list, question_answer_pairs
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import requests
|
|||
import os
|
||||
import json
|
||||
import random
|
||||
from typing import Optional, Any, List, Tuple
|
||||
from typing import Optional, Any, List, Union, Tuple
|
||||
from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter
|
||||
|
||||
|
||||
|
|
@ -37,17 +37,33 @@ class HotpotQAAdapter(BaseBenchmarkAdapter):
|
|||
|
||||
return "\n".join(golden_contexts)
|
||||
|
||||
def _process_item(
|
||||
def _get_raw_corpus(self) -> List[dict[str, Any]]:
|
||||
"""Loads the raw corpus data from file or URL and returns it as a list of dictionaries."""
|
||||
filename = self.dataset_info["filename"]
|
||||
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
raw_corpus = json.load(f)
|
||||
else:
|
||||
response = requests.get(self.dataset_info["url"])
|
||||
response.raise_for_status()
|
||||
raw_corpus = response.json()
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump(raw_corpus, f, ensure_ascii=False, indent=4)
|
||||
|
||||
return raw_corpus
|
||||
|
||||
def _get_corpus_entries(self, item: dict[str, Any]) -> List[str]:
|
||||
"""Extracts corpus entries from the context of an item."""
|
||||
return [" ".join(sentences) for title, sentences in item["context"]]
|
||||
|
||||
def _get_question_answer_pair(
|
||||
self,
|
||||
item: dict[str, Any],
|
||||
corpus_list: List[str],
|
||||
question_answer_pairs: List[dict[str, Any]],
|
||||
load_golden_context: bool = False,
|
||||
) -> None:
|
||||
"""Processes a single item and adds it to the corpus and QA pairs."""
|
||||
for title, sentences in item["context"]:
|
||||
corpus_list.append(" ".join(sentences))
|
||||
|
||||
) -> dict[str, Any]:
|
||||
"""Extracts a question-answer pair from an item."""
|
||||
qa_pair = {
|
||||
"question": item["question"],
|
||||
"answer": item["answer"].lower(),
|
||||
|
|
@ -57,33 +73,30 @@ class HotpotQAAdapter(BaseBenchmarkAdapter):
|
|||
if load_golden_context:
|
||||
qa_pair["golden_context"] = self._get_golden_context(item)
|
||||
|
||||
question_answer_pairs.append(qa_pair)
|
||||
return qa_pair
|
||||
|
||||
def load_corpus(
|
||||
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
seed: int = 42,
|
||||
load_golden_context: bool = False,
|
||||
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
|
||||
) -> Tuple[List[str], List[dict[str, Any]]]:
|
||||
"""Loads and processes the HotpotQA corpus, optionally with golden context."""
|
||||
filename = self.dataset_info["filename"]
|
||||
"""Loads and processes the HotpotQA corpus, optionally with filtering and golden context."""
|
||||
raw_corpus = self._get_raw_corpus()
|
||||
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
corpus_json = json.load(f)
|
||||
else:
|
||||
response = requests.get(self.dataset_info["url"])
|
||||
response.raise_for_status()
|
||||
corpus_json = response.json()
|
||||
if instance_filter is not None:
|
||||
raw_corpus = self._filter_instances(raw_corpus, instance_filter, id_key="_id")
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump(corpus_json, f, ensure_ascii=False, indent=4)
|
||||
|
||||
if limit is not None and 0 < limit < len(corpus_json):
|
||||
if limit is not None and 0 < limit < len(raw_corpus):
|
||||
random.seed(seed)
|
||||
corpus_json = random.sample(corpus_json, limit)
|
||||
raw_corpus = random.sample(raw_corpus, limit)
|
||||
|
||||
corpus_list = []
|
||||
question_answer_pairs = []
|
||||
|
||||
for item in corpus_json:
|
||||
self._process_item(item, corpus_list, question_answer_pairs, load_golden_context)
|
||||
for item in raw_corpus:
|
||||
corpus_list.extend(self._get_corpus_entries(item))
|
||||
question_answer_pairs.append(self._get_question_answer_pair(item, load_golden_context))
|
||||
|
||||
return corpus_list, question_answer_pairs
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import json
|
||||
import random
|
||||
from typing import Optional, Any, List
|
||||
from typing import Optional, Any, List, Union, Tuple
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
|
@ -38,41 +38,8 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter):
|
|||
|
||||
return "\n".join(golden_context)
|
||||
|
||||
def _process_item(
|
||||
self,
|
||||
item: dict[str, Any],
|
||||
corpus_list: List[str],
|
||||
question_answer_pairs: List[dict[str, Any]],
|
||||
load_golden_context: bool = False,
|
||||
) -> None:
|
||||
"""Processes a single item and adds it to the corpus and QA pairs."""
|
||||
# Add paragraphs to corpus
|
||||
paragraphs = item.get("paragraphs", [])
|
||||
for paragraph in paragraphs:
|
||||
corpus_list.append(paragraph["paragraph_text"])
|
||||
|
||||
# Create QA pair
|
||||
qa_pair = {
|
||||
"id": item.get("id", ""),
|
||||
"question": item.get("question", ""),
|
||||
"answer": item.get("answer", "").lower()
|
||||
if isinstance(item.get("answer"), str)
|
||||
else item.get("answer"),
|
||||
}
|
||||
|
||||
if load_golden_context:
|
||||
qa_pair["golden_context"] = self._get_golden_context(item)
|
||||
|
||||
question_answer_pairs.append(qa_pair)
|
||||
|
||||
def load_corpus(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
seed: int = 42,
|
||||
load_golden_context: bool = False,
|
||||
auto_download: bool = True,
|
||||
) -> tuple[list[str], list[dict[str, Any]]]:
|
||||
"""Loads and processes the Musique QA dataset."""
|
||||
def _get_raw_corpus(self, auto_download: bool = True) -> List[dict[str, Any]]:
|
||||
"""Loads the raw corpus data from file or downloads it if needed."""
|
||||
target_filename = self.dataset_info["filename"]
|
||||
|
||||
if not os.path.exists(target_filename):
|
||||
|
|
@ -87,15 +54,55 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter):
|
|||
with open(target_filename, "r", encoding="utf-8") as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
|
||||
if limit is not None and 0 < limit < len(data):
|
||||
return data
|
||||
|
||||
def _get_corpus_entries(self, item: dict[str, Any]) -> List[str]:
|
||||
"""Extracts corpus entries from the paragraphs of an item."""
|
||||
return [paragraph["paragraph_text"] for paragraph in item.get("paragraphs", [])]
|
||||
|
||||
def _get_question_answer_pair(
|
||||
self,
|
||||
item: dict[str, Any],
|
||||
load_golden_context: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Extracts a question-answer pair from an item."""
|
||||
qa_pair = {
|
||||
"id": item.get("id", ""),
|
||||
"question": item.get("question", ""),
|
||||
"answer": item.get("answer", "").lower()
|
||||
if isinstance(item.get("answer"), str)
|
||||
else item.get("answer"),
|
||||
}
|
||||
|
||||
if load_golden_context:
|
||||
qa_pair["golden_context"] = self._get_golden_context(item)
|
||||
|
||||
return qa_pair
|
||||
|
||||
def load_corpus(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
seed: int = 42,
|
||||
load_golden_context: bool = False,
|
||||
auto_download: bool = True,
|
||||
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
|
||||
) -> Tuple[List[str], List[dict[str, Any]]]:
|
||||
"""Loads and processes the Musique QA dataset with optional filtering."""
|
||||
raw_corpus = self._get_raw_corpus(auto_download)
|
||||
|
||||
if instance_filter is not None:
|
||||
raw_corpus = self._filter_instances(raw_corpus, instance_filter, id_key="id")
|
||||
|
||||
if limit is not None and 0 < limit < len(raw_corpus):
|
||||
random.seed(seed)
|
||||
data = random.sample(data, limit)
|
||||
raw_corpus = random.sample(raw_corpus, limit)
|
||||
|
||||
corpus_list = []
|
||||
question_answer_pairs = []
|
||||
|
||||
for item in data:
|
||||
self._process_item(item, corpus_list, question_answer_pairs, load_golden_context)
|
||||
for item in raw_corpus:
|
||||
corpus_list.extend(self._get_corpus_entries(item))
|
||||
question_answer_pairs.append(self._get_question_answer_pair(item, load_golden_context))
|
||||
|
||||
return corpus_list, question_answer_pairs
|
||||
|
||||
|
|
|
|||
|
|
@ -29,10 +29,13 @@ class CorpusBuilderExecutor:
|
|||
self.task_getter = task_getter
|
||||
|
||||
def load_corpus(
|
||||
self, limit: Optional[int] = None, load_golden_context: bool = False
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
load_golden_context: bool = False,
|
||||
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
|
||||
) -> Tuple[List[Dict], List[str]]:
|
||||
self.raw_corpus, self.questions = self.adapter.load_corpus(
|
||||
limit=limit, load_golden_context=load_golden_context
|
||||
limit=limit, load_golden_context=load_golden_context, instance_filter=instance_filter
|
||||
)
|
||||
return self.raw_corpus, self.questions
|
||||
|
||||
|
|
@ -42,8 +45,11 @@ class CorpusBuilderExecutor:
|
|||
chunk_size=1024,
|
||||
chunker=TextChunker,
|
||||
load_golden_context: bool = False,
|
||||
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
|
||||
) -> List[str]:
|
||||
self.load_corpus(limit=limit, load_golden_context=load_golden_context)
|
||||
self.load_corpus(
|
||||
limit=limit, load_golden_context=load_golden_context, instance_filter=instance_filter
|
||||
)
|
||||
await self.run_cognee(chunk_size=chunk_size, chunker=chunker)
|
||||
return self.questions
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ async def create_and_insert_questions_table(questions_payload):
|
|||
await session.commit()
|
||||
|
||||
|
||||
async def run_corpus_builder(params: dict, chunk_size=1024, chunker=TextChunker) -> List[dict]:
|
||||
async def run_corpus_builder(
|
||||
params: dict, chunk_size=1024, chunker=TextChunker, instance_filter=None
|
||||
) -> List[dict]:
|
||||
if params.get("building_corpus_from_scratch"):
|
||||
logging.info("Corpus Builder started...")
|
||||
|
||||
|
|
@ -51,6 +53,7 @@ async def run_corpus_builder(params: dict, chunk_size=1024, chunker=TextChunker)
|
|||
chunker=chunker,
|
||||
chunk_size=chunk_size,
|
||||
load_golden_context=params.get("evaluating_contexts"),
|
||||
instance_filter=instance_filter,
|
||||
)
|
||||
with open(params["questions_path"], "w", encoding="utf-8") as f:
|
||||
json.dump(questions, f, ensure_ascii=False, indent=4)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue