feat: COG-1526 instance filter in eval (#627)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Added _filter_instances to BaseBenchmarkAdapter supporting filtering
by IDs, indices, or JSON files.
- Updated HotpotQAAdapter and MusiqueQAAdapter to use the base class
filtering.
- Added instance_filter parameter to corpus builder pipeline.
- Extracted _get_raw_corpus method in both adapters for better code
organization
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Corpus loading and building now support a flexible filtering option,
allowing users to apply custom criteria to tailor the retrieved data.

- **Refactor**
- The extraction process has been reorganized to separately handle text
content and associated metadata, enhancing clarity and overall workflow
efficiency.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
lxobr 2025-03-13 14:23:13 +01:00 committed by GitHub
parent 88ed411f03
commit daf7d4ae26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 158 additions and 77 deletions

View file

@ -1,10 +1,55 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Any, Union, Tuple
import os
import json
import logging
logger = logging.getLogger(__name__)
class BaseBenchmarkAdapter(ABC):
def _filter_instances(
self,
instances: List[dict[str, Any]],
instance_filter: Union[str, List[str], List[int]],
id_key: str = "id",
) -> List[dict[str, Any]]:
"""Filter instances by IDs or indices, or load filter from a JSON file."""
if isinstance(instance_filter, str):
logger.info(f"Loading instance filter from file: {instance_filter}")
if not os.path.isfile(instance_filter):
raise FileNotFoundError(f"Filter file not found: {instance_filter}")
with open(instance_filter, "r", encoding="utf-8") as f:
try:
instance_filter = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in filter file: {e}")
if all(isinstance(fid, str) for fid in instance_filter):
logger.info(f"Filtering by {len(instance_filter)} string IDs using key '{id_key}'")
filtered = [inst for inst in instances if inst.get(id_key) in instance_filter]
if not filtered:
logger.warning(f"No instances found with the provided IDs using key '{id_key}'")
return filtered
if all(isinstance(fid, int) for fid in instance_filter):
logger.info(f"Filtering by {len(instance_filter)} integer indices")
filtered = [instances[i] for i in instance_filter if 0 <= i < len(instances)]
if not filtered:
logger.warning("No instances found at the provided indices")
return filtered
raise ValueError(
"instance_filter must be a list of string ids, integer indices, or a JSON file path."
)
@abstractmethod
def load_corpus(
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False
) -> List[str]:
self,
limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[str], List[dict[str, Any]]]:
pass

View file

@ -1,12 +1,16 @@
from typing import Optional, Any
from typing import Optional, Any, List, Union, Tuple
from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter
class DummyAdapter(BaseBenchmarkAdapter):
def load_corpus(
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False
) -> tuple[list[str], list[dict[str, Any]]]:
self,
limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[str], List[dict[str, Any]]]:
corpus_list = [
"The cognee is an AI memory engine that supports different vector and graph databases",
"Neo4j is a graph database supported by cognee",
@ -22,4 +26,7 @@ class DummyAdapter(BaseBenchmarkAdapter):
question_answer_pairs = [qa_pair]
# Instance filtering is not applicable for the dummy adapter as it always returns the same data
# but we include the parameter for API consistency
return corpus_list, question_answer_pairs

View file

@ -2,7 +2,7 @@ import requests
import os
import json
import random
from typing import Optional, Any, List, Tuple
from typing import Optional, Any, List, Union, Tuple
from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter
@ -37,17 +37,33 @@ class HotpotQAAdapter(BaseBenchmarkAdapter):
return "\n".join(golden_contexts)
def _process_item(
def _get_raw_corpus(self) -> List[dict[str, Any]]:
"""Loads the raw corpus data from file or URL and returns it as a list of dictionaries."""
filename = self.dataset_info["filename"]
if os.path.exists(filename):
with open(filename, "r", encoding="utf-8") as f:
raw_corpus = json.load(f)
else:
response = requests.get(self.dataset_info["url"])
response.raise_for_status()
raw_corpus = response.json()
with open(filename, "w", encoding="utf-8") as f:
json.dump(raw_corpus, f, ensure_ascii=False, indent=4)
return raw_corpus
def _get_corpus_entries(self, item: dict[str, Any]) -> List[str]:
"""Extracts corpus entries from the context of an item."""
return [" ".join(sentences) for title, sentences in item["context"]]
def _get_question_answer_pair(
self,
item: dict[str, Any],
corpus_list: List[str],
question_answer_pairs: List[dict[str, Any]],
load_golden_context: bool = False,
) -> None:
"""Processes a single item and adds it to the corpus and QA pairs."""
for title, sentences in item["context"]:
corpus_list.append(" ".join(sentences))
) -> dict[str, Any]:
"""Extracts a question-answer pair from an item."""
qa_pair = {
"question": item["question"],
"answer": item["answer"].lower(),
@ -57,33 +73,30 @@ class HotpotQAAdapter(BaseBenchmarkAdapter):
if load_golden_context:
qa_pair["golden_context"] = self._get_golden_context(item)
question_answer_pairs.append(qa_pair)
return qa_pair
def load_corpus(
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False
self,
limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[str], List[dict[str, Any]]]:
"""Loads and processes the HotpotQA corpus, optionally with golden context."""
filename = self.dataset_info["filename"]
"""Loads and processes the HotpotQA corpus, optionally with filtering and golden context."""
raw_corpus = self._get_raw_corpus()
if os.path.exists(filename):
with open(filename, "r", encoding="utf-8") as f:
corpus_json = json.load(f)
else:
response = requests.get(self.dataset_info["url"])
response.raise_for_status()
corpus_json = response.json()
if instance_filter is not None:
raw_corpus = self._filter_instances(raw_corpus, instance_filter, id_key="_id")
with open(filename, "w", encoding="utf-8") as f:
json.dump(corpus_json, f, ensure_ascii=False, indent=4)
if limit is not None and 0 < limit < len(corpus_json):
if limit is not None and 0 < limit < len(raw_corpus):
random.seed(seed)
corpus_json = random.sample(corpus_json, limit)
raw_corpus = random.sample(raw_corpus, limit)
corpus_list = []
question_answer_pairs = []
for item in corpus_json:
self._process_item(item, corpus_list, question_answer_pairs, load_golden_context)
for item in raw_corpus:
corpus_list.extend(self._get_corpus_entries(item))
question_answer_pairs.append(self._get_question_answer_pair(item, load_golden_context))
return corpus_list, question_answer_pairs

View file

@ -1,7 +1,7 @@
import os
import json
import random
from typing import Optional, Any, List
from typing import Optional, Any, List, Union, Tuple
import zipfile
import gdown
@ -38,41 +38,8 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter):
return "\n".join(golden_context)
def _process_item(
self,
item: dict[str, Any],
corpus_list: List[str],
question_answer_pairs: List[dict[str, Any]],
load_golden_context: bool = False,
) -> None:
"""Processes a single item and adds it to the corpus and QA pairs."""
# Add paragraphs to corpus
paragraphs = item.get("paragraphs", [])
for paragraph in paragraphs:
corpus_list.append(paragraph["paragraph_text"])
# Create QA pair
qa_pair = {
"id": item.get("id", ""),
"question": item.get("question", ""),
"answer": item.get("answer", "").lower()
if isinstance(item.get("answer"), str)
else item.get("answer"),
}
if load_golden_context:
qa_pair["golden_context"] = self._get_golden_context(item)
question_answer_pairs.append(qa_pair)
def load_corpus(
self,
limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
auto_download: bool = True,
) -> tuple[list[str], list[dict[str, Any]]]:
"""Loads and processes the Musique QA dataset."""
def _get_raw_corpus(self, auto_download: bool = True) -> List[dict[str, Any]]:
"""Loads the raw corpus data from file or downloads it if needed."""
target_filename = self.dataset_info["filename"]
if not os.path.exists(target_filename):
@ -87,15 +54,55 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter):
with open(target_filename, "r", encoding="utf-8") as f:
data = [json.loads(line) for line in f]
if limit is not None and 0 < limit < len(data):
return data
def _get_corpus_entries(self, item: dict[str, Any]) -> List[str]:
"""Extracts corpus entries from the paragraphs of an item."""
return [paragraph["paragraph_text"] for paragraph in item.get("paragraphs", [])]
def _get_question_answer_pair(
self,
item: dict[str, Any],
load_golden_context: bool = False,
) -> dict[str, Any]:
"""Extracts a question-answer pair from an item."""
qa_pair = {
"id": item.get("id", ""),
"question": item.get("question", ""),
"answer": item.get("answer", "").lower()
if isinstance(item.get("answer"), str)
else item.get("answer"),
}
if load_golden_context:
qa_pair["golden_context"] = self._get_golden_context(item)
return qa_pair
def load_corpus(
self,
limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
auto_download: bool = True,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[str], List[dict[str, Any]]]:
"""Loads and processes the Musique QA dataset with optional filtering."""
raw_corpus = self._get_raw_corpus(auto_download)
if instance_filter is not None:
raw_corpus = self._filter_instances(raw_corpus, instance_filter, id_key="id")
if limit is not None and 0 < limit < len(raw_corpus):
random.seed(seed)
data = random.sample(data, limit)
raw_corpus = random.sample(raw_corpus, limit)
corpus_list = []
question_answer_pairs = []
for item in data:
self._process_item(item, corpus_list, question_answer_pairs, load_golden_context)
for item in raw_corpus:
corpus_list.extend(self._get_corpus_entries(item))
question_answer_pairs.append(self._get_question_answer_pair(item, load_golden_context))
return corpus_list, question_answer_pairs

View file

@ -29,10 +29,13 @@ class CorpusBuilderExecutor:
self.task_getter = task_getter
def load_corpus(
self, limit: Optional[int] = None, load_golden_context: bool = False
self,
limit: Optional[int] = None,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[Dict], List[str]]:
self.raw_corpus, self.questions = self.adapter.load_corpus(
limit=limit, load_golden_context=load_golden_context
limit=limit, load_golden_context=load_golden_context, instance_filter=instance_filter
)
return self.raw_corpus, self.questions
@ -42,8 +45,11 @@ class CorpusBuilderExecutor:
chunk_size=1024,
chunker=TextChunker,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> List[str]:
self.load_corpus(limit=limit, load_golden_context=load_golden_context)
self.load_corpus(
limit=limit, load_golden_context=load_golden_context, instance_filter=instance_filter
)
await self.run_cognee(chunk_size=chunk_size, chunker=chunker)
return self.questions

View file

@ -33,7 +33,9 @@ async def create_and_insert_questions_table(questions_payload):
await session.commit()
async def run_corpus_builder(params: dict, chunk_size=1024, chunker=TextChunker) -> List[dict]:
async def run_corpus_builder(
params: dict, chunk_size=1024, chunker=TextChunker, instance_filter=None
) -> List[dict]:
if params.get("building_corpus_from_scratch"):
logging.info("Corpus Builder started...")
@ -51,6 +53,7 @@ async def run_corpus_builder(params: dict, chunk_size=1024, chunker=TextChunker)
chunker=chunker,
chunk_size=chunk_size,
load_golden_context=params.get("evaluating_contexts"),
instance_filter=instance_filter,
)
with open(params["questions_path"], "w", encoding="utf-8") as f:
json.dump(questions, f, ensure_ascii=False, indent=4)