fix: fixes hotpot and twowiki tests (that are using url to download dataset)
This commit is contained in:
parent
499c717f85
commit
fc89f71e7c
2 changed files with 58 additions and 4 deletions
|
|
@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\
|
|||
{"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]}
|
||||
"""
|
||||
|
||||
MOCK_HOTPOT_CORPUS = [
|
||||
{
|
||||
"_id": "1",
|
||||
"question": "Next to which country is Germany located?",
|
||||
"answer": "Netherlands",
|
||||
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
||||
"level": "easy",
|
||||
"type": "comparison",
|
||||
"context": [
|
||||
["Germany", ["Germany is in Europe."]],
|
||||
["Netherlands", ["The Netherlands borders Germany."]],
|
||||
],
|
||||
"supporting_facts": [["Netherlands", 0]],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
ADAPTER_CLASSES = [
|
||||
HotpotQAAdapter,
|
||||
|
|
@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass):
|
|||
adapter = AdapterClass()
|
||||
result = adapter.load_corpus()
|
||||
|
||||
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
||||
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||
adapter = AdapterClass()
|
||||
result = adapter.load_corpus()
|
||||
|
||||
else:
|
||||
adapter = AdapterClass()
|
||||
result = adapter.load_corpus()
|
||||
|
|
@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass):
|
|||
):
|
||||
adapter = AdapterClass()
|
||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||
elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter):
|
||||
with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||
adapter = AdapterClass()
|
||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||
else:
|
||||
adapter = AdapterClass()
|
||||
corpus_list, qa_pairs = adapter.load_corpus(limit=limit)
|
||||
|
|
|
|||
|
|
@ -2,15 +2,38 @@ import pytest
|
|||
from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter
|
||||
|
||||
benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"]
|
||||
|
||||
MOCK_HOTPOT_CORPUS = [
|
||||
{
|
||||
"_id": "1",
|
||||
"question": "Next to which country is Germany located?",
|
||||
"answer": "Netherlands",
|
||||
# HotpotQA uses "level"; TwoWikiMultiHop uses "type".
|
||||
"level": "easy",
|
||||
"type": "comparison",
|
||||
"context": [
|
||||
["Germany", ["Germany is in Europe."]],
|
||||
["Netherlands", ["The Netherlands borders Germany."]],
|
||||
],
|
||||
"supporting_facts": [["Netherlands", 0]],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("benchmark", benchmark_options)
|
||||
def test_corpus_builder_load_corpus(benchmark):
|
||||
limit = 2
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
||||
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||
else:
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
raw_corpus, questions = corpus_builder.load_corpus(limit=limit)
|
||||
|
||||
assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}"
|
||||
assert len(questions) <= 2, (
|
||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||
|
|
@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark):
|
|||
@patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock)
|
||||
async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark):
|
||||
limit = 2
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
questions = await corpus_builder.build_corpus(limit=limit)
|
||||
if benchmark in ("HotPotQA", "TwoWikiMultiHop"):
|
||||
with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS):
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
questions = await corpus_builder.build_corpus(limit=limit)
|
||||
else:
|
||||
corpus_builder = CorpusBuilderExecutor(benchmark, "Default")
|
||||
questions = await corpus_builder.build_corpus(limit=limit)
|
||||
|
||||
assert len(questions) <= 2, (
|
||||
f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue