feat: COG-1526 instance filter in eval (#627)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Added _filter_instances to BaseBenchmarkAdapter supporting filtering
by IDs, indices, or JSON files.
- Updated HotpotQAAdapter and MusiqueQAAdapter to use the base class
filtering.
- Added instance_filter parameter to corpus builder pipeline.
- Extracted _get_raw_corpus method in both adapters for better code
organization
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Corpus loading and building now support a flexible filtering option,
allowing users to apply custom criteria to tailor the retrieved data.

- **Refactor**
- The extraction process has been reorganized to separately handle text
content and associated metadata, enhancing clarity and overall workflow
efficiency.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
lxobr 2025-03-13 14:23:13 +01:00 committed by GitHub
parent 88ed411f03
commit daf7d4ae26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 158 additions and 77 deletions

View file

@ -1,10 +1,55 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional, Any, Union, Tuple
import os
import json
import logging
logger = logging.getLogger(__name__)
class BaseBenchmarkAdapter(ABC): class BaseBenchmarkAdapter(ABC):
def _filter_instances(
self,
instances: List[dict[str, Any]],
instance_filter: Union[str, List[str], List[int]],
id_key: str = "id",
) -> List[dict[str, Any]]:
"""Filter instances by IDs or indices, or load filter from a JSON file."""
if isinstance(instance_filter, str):
logger.info(f"Loading instance filter from file: {instance_filter}")
if not os.path.isfile(instance_filter):
raise FileNotFoundError(f"Filter file not found: {instance_filter}")
with open(instance_filter, "r", encoding="utf-8") as f:
try:
instance_filter = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in filter file: {e}")
if all(isinstance(fid, str) for fid in instance_filter):
logger.info(f"Filtering by {len(instance_filter)} string IDs using key '{id_key}'")
filtered = [inst for inst in instances if inst.get(id_key) in instance_filter]
if not filtered:
logger.warning(f"No instances found with the provided IDs using key '{id_key}'")
return filtered
if all(isinstance(fid, int) for fid in instance_filter):
logger.info(f"Filtering by {len(instance_filter)} integer indices")
filtered = [instances[i] for i in instance_filter if 0 <= i < len(instances)]
if not filtered:
logger.warning("No instances found at the provided indices")
return filtered
raise ValueError(
"instance_filter must be a list of string ids, integer indices, or a JSON file path."
)
@abstractmethod @abstractmethod
def load_corpus( def load_corpus(
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False self,
) -> List[str]: limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[str], List[dict[str, Any]]]:
pass pass

View file

@ -1,12 +1,16 @@
from typing import Optional, Any from typing import Optional, Any, List, Union, Tuple
from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter
class DummyAdapter(BaseBenchmarkAdapter): class DummyAdapter(BaseBenchmarkAdapter):
def load_corpus( def load_corpus(
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False self,
) -> tuple[list[str], list[dict[str, Any]]]: limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[str], List[dict[str, Any]]]:
corpus_list = [ corpus_list = [
"The cognee is an AI memory engine that supports different vector and graph databases", "The cognee is an AI memory engine that supports different vector and graph databases",
"Neo4j is a graph database supported by cognee", "Neo4j is a graph database supported by cognee",
@ -22,4 +26,7 @@ class DummyAdapter(BaseBenchmarkAdapter):
question_answer_pairs = [qa_pair] question_answer_pairs = [qa_pair]
# Instance filtering is not applicable for the dummy adapter as it always returns the same data
# but we include the parameter for API consistency
return corpus_list, question_answer_pairs return corpus_list, question_answer_pairs

View file

@ -2,7 +2,7 @@ import requests
import os import os
import json import json
import random import random
from typing import Optional, Any, List, Tuple from typing import Optional, Any, List, Union, Tuple
from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter from cognee.eval_framework.benchmark_adapters.base_benchmark_adapter import BaseBenchmarkAdapter
@ -37,17 +37,33 @@ class HotpotQAAdapter(BaseBenchmarkAdapter):
return "\n".join(golden_contexts) return "\n".join(golden_contexts)
def _process_item( def _get_raw_corpus(self) -> List[dict[str, Any]]:
"""Loads the raw corpus data from file or URL and returns it as a list of dictionaries."""
filename = self.dataset_info["filename"]
if os.path.exists(filename):
with open(filename, "r", encoding="utf-8") as f:
raw_corpus = json.load(f)
else:
response = requests.get(self.dataset_info["url"])
response.raise_for_status()
raw_corpus = response.json()
with open(filename, "w", encoding="utf-8") as f:
json.dump(raw_corpus, f, ensure_ascii=False, indent=4)
return raw_corpus
def _get_corpus_entries(self, item: dict[str, Any]) -> List[str]:
"""Extracts corpus entries from the context of an item."""
return [" ".join(sentences) for title, sentences in item["context"]]
def _get_question_answer_pair(
self, self,
item: dict[str, Any], item: dict[str, Any],
corpus_list: List[str],
question_answer_pairs: List[dict[str, Any]],
load_golden_context: bool = False, load_golden_context: bool = False,
) -> None: ) -> dict[str, Any]:
"""Processes a single item and adds it to the corpus and QA pairs.""" """Extracts a question-answer pair from an item."""
for title, sentences in item["context"]:
corpus_list.append(" ".join(sentences))
qa_pair = { qa_pair = {
"question": item["question"], "question": item["question"],
"answer": item["answer"].lower(), "answer": item["answer"].lower(),
@ -57,33 +73,30 @@ class HotpotQAAdapter(BaseBenchmarkAdapter):
if load_golden_context: if load_golden_context:
qa_pair["golden_context"] = self._get_golden_context(item) qa_pair["golden_context"] = self._get_golden_context(item)
question_answer_pairs.append(qa_pair) return qa_pair
def load_corpus( def load_corpus(
self, limit: Optional[int] = None, seed: int = 42, load_golden_context: bool = False self,
limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[str], List[dict[str, Any]]]: ) -> Tuple[List[str], List[dict[str, Any]]]:
"""Loads and processes the HotpotQA corpus, optionally with golden context.""" """Loads and processes the HotpotQA corpus, optionally with filtering and golden context."""
filename = self.dataset_info["filename"] raw_corpus = self._get_raw_corpus()
if os.path.exists(filename): if instance_filter is not None:
with open(filename, "r", encoding="utf-8") as f: raw_corpus = self._filter_instances(raw_corpus, instance_filter, id_key="_id")
corpus_json = json.load(f)
else:
response = requests.get(self.dataset_info["url"])
response.raise_for_status()
corpus_json = response.json()
with open(filename, "w", encoding="utf-8") as f: if limit is not None and 0 < limit < len(raw_corpus):
json.dump(corpus_json, f, ensure_ascii=False, indent=4)
if limit is not None and 0 < limit < len(corpus_json):
random.seed(seed) random.seed(seed)
corpus_json = random.sample(corpus_json, limit) raw_corpus = random.sample(raw_corpus, limit)
corpus_list = [] corpus_list = []
question_answer_pairs = [] question_answer_pairs = []
for item in corpus_json: for item in raw_corpus:
self._process_item(item, corpus_list, question_answer_pairs, load_golden_context) corpus_list.extend(self._get_corpus_entries(item))
question_answer_pairs.append(self._get_question_answer_pair(item, load_golden_context))
return corpus_list, question_answer_pairs return corpus_list, question_answer_pairs

View file

@ -1,7 +1,7 @@
import os import os
import json import json
import random import random
from typing import Optional, Any, List from typing import Optional, Any, List, Union, Tuple
import zipfile import zipfile
import gdown import gdown
@ -38,41 +38,8 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter):
return "\n".join(golden_context) return "\n".join(golden_context)
def _process_item( def _get_raw_corpus(self, auto_download: bool = True) -> List[dict[str, Any]]:
self, """Loads the raw corpus data from file or downloads it if needed."""
item: dict[str, Any],
corpus_list: List[str],
question_answer_pairs: List[dict[str, Any]],
load_golden_context: bool = False,
) -> None:
"""Processes a single item and adds it to the corpus and QA pairs."""
# Add paragraphs to corpus
paragraphs = item.get("paragraphs", [])
for paragraph in paragraphs:
corpus_list.append(paragraph["paragraph_text"])
# Create QA pair
qa_pair = {
"id": item.get("id", ""),
"question": item.get("question", ""),
"answer": item.get("answer", "").lower()
if isinstance(item.get("answer"), str)
else item.get("answer"),
}
if load_golden_context:
qa_pair["golden_context"] = self._get_golden_context(item)
question_answer_pairs.append(qa_pair)
def load_corpus(
self,
limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
auto_download: bool = True,
) -> tuple[list[str], list[dict[str, Any]]]:
"""Loads and processes the Musique QA dataset."""
target_filename = self.dataset_info["filename"] target_filename = self.dataset_info["filename"]
if not os.path.exists(target_filename): if not os.path.exists(target_filename):
@ -87,15 +54,55 @@ class MusiqueQAAdapter(BaseBenchmarkAdapter):
with open(target_filename, "r", encoding="utf-8") as f: with open(target_filename, "r", encoding="utf-8") as f:
data = [json.loads(line) for line in f] data = [json.loads(line) for line in f]
if limit is not None and 0 < limit < len(data): return data
def _get_corpus_entries(self, item: dict[str, Any]) -> List[str]:
"""Extracts corpus entries from the paragraphs of an item."""
return [paragraph["paragraph_text"] for paragraph in item.get("paragraphs", [])]
def _get_question_answer_pair(
self,
item: dict[str, Any],
load_golden_context: bool = False,
) -> dict[str, Any]:
"""Extracts a question-answer pair from an item."""
qa_pair = {
"id": item.get("id", ""),
"question": item.get("question", ""),
"answer": item.get("answer", "").lower()
if isinstance(item.get("answer"), str)
else item.get("answer"),
}
if load_golden_context:
qa_pair["golden_context"] = self._get_golden_context(item)
return qa_pair
def load_corpus(
self,
limit: Optional[int] = None,
seed: int = 42,
load_golden_context: bool = False,
auto_download: bool = True,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[str], List[dict[str, Any]]]:
"""Loads and processes the Musique QA dataset with optional filtering."""
raw_corpus = self._get_raw_corpus(auto_download)
if instance_filter is not None:
raw_corpus = self._filter_instances(raw_corpus, instance_filter, id_key="id")
if limit is not None and 0 < limit < len(raw_corpus):
random.seed(seed) random.seed(seed)
data = random.sample(data, limit) raw_corpus = random.sample(raw_corpus, limit)
corpus_list = [] corpus_list = []
question_answer_pairs = [] question_answer_pairs = []
for item in data: for item in raw_corpus:
self._process_item(item, corpus_list, question_answer_pairs, load_golden_context) corpus_list.extend(self._get_corpus_entries(item))
question_answer_pairs.append(self._get_question_answer_pair(item, load_golden_context))
return corpus_list, question_answer_pairs return corpus_list, question_answer_pairs

View file

@ -29,10 +29,13 @@ class CorpusBuilderExecutor:
self.task_getter = task_getter self.task_getter = task_getter
def load_corpus( def load_corpus(
self, limit: Optional[int] = None, load_golden_context: bool = False self,
limit: Optional[int] = None,
load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> Tuple[List[Dict], List[str]]: ) -> Tuple[List[Dict], List[str]]:
self.raw_corpus, self.questions = self.adapter.load_corpus( self.raw_corpus, self.questions = self.adapter.load_corpus(
limit=limit, load_golden_context=load_golden_context limit=limit, load_golden_context=load_golden_context, instance_filter=instance_filter
) )
return self.raw_corpus, self.questions return self.raw_corpus, self.questions
@ -42,8 +45,11 @@ class CorpusBuilderExecutor:
chunk_size=1024, chunk_size=1024,
chunker=TextChunker, chunker=TextChunker,
load_golden_context: bool = False, load_golden_context: bool = False,
instance_filter: Optional[Union[str, List[str], List[int]]] = None,
) -> List[str]: ) -> List[str]:
self.load_corpus(limit=limit, load_golden_context=load_golden_context) self.load_corpus(
limit=limit, load_golden_context=load_golden_context, instance_filter=instance_filter
)
await self.run_cognee(chunk_size=chunk_size, chunker=chunker) await self.run_cognee(chunk_size=chunk_size, chunker=chunker)
return self.questions return self.questions

View file

@ -33,7 +33,9 @@ async def create_and_insert_questions_table(questions_payload):
await session.commit() await session.commit()
async def run_corpus_builder(params: dict, chunk_size=1024, chunker=TextChunker) -> List[dict]: async def run_corpus_builder(
params: dict, chunk_size=1024, chunker=TextChunker, instance_filter=None
) -> List[dict]:
if params.get("building_corpus_from_scratch"): if params.get("building_corpus_from_scratch"):
logging.info("Corpus Builder started...") logging.info("Corpus Builder started...")
@ -51,6 +53,7 @@ async def run_corpus_builder(params: dict, chunk_size=1024, chunker=TextChunker)
chunker=chunker, chunker=chunker,
chunk_size=chunk_size, chunk_size=chunk_size,
load_golden_context=params.get("evaluating_contexts"), load_golden_context=params.get("evaluating_contexts"),
instance_filter=instance_filter,
) )
with open(params["questions_path"], "w", encoding="utf-8") as f: with open(params["questions_path"], "w", encoding="utf-8") as f:
json.dump(questions, f, ensure_ascii=False, indent=4) json.dump(questions, f, ensure_ascii=False, indent=4)