fix: fixes hotpot and twowiki tests (that are using url to download dataset)

This commit is contained in:
hajdul88 2025-12-12 15:57:49 +01:00
parent 499c717f85
commit fc89f71e7c
2 changed files with 58 additions and 4 deletions

View file

@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]} {"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
""" """
MOCK_HOTPOT_CORPUS = [
{
"_id": "1",
"question": "Next to which country is Germany located?",
"answer": "Netherlands",
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
"level": "easy",
"type": "comparison",
"context": [
["Germany", ["Germany is in Europe."]],
["Netherlands", ["The Netherlands borders Germany."]],
],
"supporting_facts": [["Netherlands", 0]],
}
]
ADAPTER_CLASSES = [ ADAPTER_CLASSES = [
HotpotQAAdapter, HotpotQAAdapter,
@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
adapter = AdapterClass() adapter = AdapterClass()
result = adapter.load_corpus() result = adapter.load_corpus()
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
adapter = AdapterClass()
result = adapter.load_corpus()
else: else:
adapter = AdapterClass() adapter = AdapterClass()
result = adapter.load_corpus() result = adapter.load_corpus()
@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
): ):
adapter = AdapterClass() adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit) corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
else: else:
adapter = AdapterClass() adapter = AdapterClass()
corpus_list, qa_pairs = adapter.load_corpus(limit=limit) corpus_list, qa_pairs = adapter.load_corpus(limit=limit)

View file

@ -2,15 +2,38 @@ import pytest
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"] benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
MOCK_HOTPOT_CORPUS = [
{
"_id": "1",
"question": "Next to which country is Germany located?",
"answer": "Netherlands",
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
"level": "easy",
"type": "comparison",
"context": [
["Germany", ["Germany is in Europe."]],
["Netherlands", ["The Netherlands borders Germany."]],
],
"supporting_facts": [["Netherlands", 0]],
}
]
@pytest.mark.parametrize("benchmark", benchmark_options) @pytest.mark.parametrize("benchmark", benchmark_options)
def test_corpus_builder_load_corpus(benchmark): def test_corpus_builder_load_corpus(benchmark):
limit = 2 limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default") if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
raw_corpus, questions = corpus_builder.load_corpus(limit=limit) with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
else:
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}" assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
assert len(questions) <= 2, ( assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock) @patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark): async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
limit = 2 limit = 2
corpus_builder = CorpusBuilderExecutor(benchmark, "Default") if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
questions = await corpus_builder.build_corpus(limit=limit) with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
else:
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
questions = await corpus_builder.build_corpus(limit=limit)
assert len(questions) <= 2, ( assert len(questions) <= 2, (
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
) )